# Georeferencing sample notebook

This notebook demonstrates the loading and rendering of figures and corresponding .geojson files from the biodiversity georeferencing dataset

## Imports, constants, and environment

In [None]:
import os
import json
import random
import warnings

import pandas as pd
import numpy as np
from PIL import Image

import geopandas as gpd
import contextily as cx
import xyzservices
from matplotlib import pyplot as plt
from pyproj import Transformer

warnings.filterwarnings('ignore')

# The base folder for the dataset; modify this if you are not running this notebook
# in the same folder where you checked out the repo.
dataset_folder = '.'

# The folder where the PDF files live
pdf_folder = os.path.join(dataset_folder,'pdfs')

# The folder where geojson files live
annotation_folder = os.path.join(dataset_folder,'annotations')

# The folder where images live
image_folder = os.path.join(dataset_folder,'images')

# The .csv file that indexes all the papers
index_file = os.path.join(dataset_folder,'index.csv')

assert os.path.isdir(pdf_folder)
assert os.path.isdir(annotation_folder)
assert os.path.isdir(image_folder)
assert os.path.isfile(index_file)

## Validate dataset

I.e., confirm that all files that should be present are actually present.

### Find valid rows in the index file

By convention, these are records with a non-NaN value in the "difficulty" column.

In [None]:
dataset_df = pd.read_csv(index_file)
dataset_records = dataset_df.to_dict('records')
print('Loaded {} rows from the .csv index file'.format(len(dataset_records)))

### Enumerate data and annotation files

In [None]:
pdf_files_relative = os.listdir(pdf_folder)
pdf_files_relative = [fn for fn in pdf_files_relative if fn.endswith('.pdf')]

annotation_files_relative = os.listdir(annotation_folder)
geojson_files_relative = [fn for fn in annotation_files_relative if fn.endswith('.geojson')]

image_extensions = ('.png','.jpg','.jpeg')
image_files_relative = os.listdir(image_folder)
image_files_relative = [fn for fn in image_files_relative if os.path.splitext(fn)[1].lower() in image_extensions]

print('Enumerated {} PDF files, {} .geojson files, and {} image files'.format(
    len(pdf_files_relative), len(geojson_files_relative),len(image_files_relative)))

### Validate data values

In [None]:
for i_record,r in enumerate(dataset_records):

    record_id = r['record_id']
    assert isinstance(record_id,str) and len(record_id) > 0
    
    # Verify that the PDF file is present
    pdf_file_relative = r['pdf_filename']
    assert pdf_file_relative in pdf_files_relative
    
    # Verify that the .geojson file is present
    geojson_file_relative = r['geojson_filename']
    assert geojson_file_relative in geojson_files_relative

    # Verify that the image file is present 
    image_file_relative = r['image_filename']
    assert image_file_relative in image_files_relative

    assert r['difficulty'] >= 0 and r['difficulty'] <= 1.0

## Rendering functions

In [None]:
# We'll enumerate all the tile sets available via contextily so it's
# easier to try new basemaps, though we're only going to use one
# in this notebook.
tile_set_name_to_tile_set_object = {}

providers = cx.providers.keys()

for i_provider,provider in enumerate(providers):
    if provider == '':
        continue
    tile_sets = cx.providers[provider].keys()    
    for tile_set in tile_sets:
        tile_set_name = '{}.{}'.format(provider,tile_set)
        tile_set_object = cx.providers[provider][tile_set]
        if not isinstance(tile_set_object,xyzservices.lib.TileProvider):
            continue
        tile_set_name_to_tile_set_object[tile_set_name] = tile_set_object
        
print('Enumerated {} tile sets'.format(len(tile_set_name_to_tile_set_object)))
    
preferred_tile_set_names = [
    'OpenStreetMap.Mapnik',
    'OpenStreetMap.HOT',
    'USGS.USTopo',
    'USGS.USImagery',
    'USGS.USImageryTopo',    
    'TopPlusOpen.Color',
    'TopPlusOpen.Grey',    
]

preferred_tile_set_name_to_tile_set_object = {}
for s in preferred_tile_set_names:
    assert s in tile_set_name_to_tile_set_object
    preferred_tile_set_name_to_tile_set_object[s] = tile_set_name_to_tile_set_object[s]

# plt.ioff()
n_colors = 20
colors = plt.cm.viridis(np.linspace(0,1,n_colors))

linestyle_str = [
     ('solid', 'solid'),      # Same as (0, ()) or '-'
     ('dotted', 'dotted'),    # Same as (0, (1, 1)) or ':'
     ('dashed', 'dashed'),    # Same as '--'
     ('dashdot', 'dashdot'),  # Same as '-.'
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('densely dashed',        (0, (5, 1))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),
     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))
]

linestyle_name_to_style = {}

for x in linestyle_str:
    linestyle_name_to_style[x[0]] = x[1]

def render_map(geometry,
               tile_set_name='OpenStreetMap.HOT',
               linewidth=2,
               edgecolor=colors[0],
               linestyle='solid',
               expansion=1.5):
    """
    Render the GeoDataFrame [geometry] onto a contextily basemap.
    """
    
    # plot_crs = 'EPSG:3857'
    plot_crs = 'EPSG:4326'

    # We'll assume this CRS if [geometry] has no specified CRS.
    default_geometry_crs = 'EPSG:4326'
    
    if geometry.crs is None:
        print('Warning: no CRS specified with geometry, assuming {}'.format(default_geometry_crs))
        geometry = geometry.set_crs(default_geometry_crs, allow_override=True)
        geometry_crs = default_geometry_crs
    else:
        geometry_crs = plot_crs
        
    if geometry.crs.srs != plot_crs:
        geometry = geometry.to_crs(plot_crs)
    
    # fig = plt.figure()
    fig= plt.figure(figsize=(10,10), dpi= 100, facecolor='w', edgecolor='k')
    
    ax = plt.subplot()
    geometry.plot(edgecolor=edgecolor, facecolor='none', linewidth=linewidth, linestyle=linestyle, ax=ax)
    
    if tile_set_name is None:
        tile_set_name = random.choice(preferred_tile_set_names)
        
    tile_set_object = tile_set_name_to_tile_set_object[tile_set_name]        
    
    # ax.axis('off')
    # plt.tight_layout()

    xlim_initial = ax.get_xlim()
    xcenter = (xlim_initial[0] + xlim_initial[1]) / 2.0
    xsize_initial = abs(xlim_initial[0] - xlim_initial[1])
    xsize_target = xsize_initial * expansion
    xlim_new = [xcenter-(xsize_target/2.0),xcenter+(xsize_target/2.0)]

    ylim_initial = ax.get_ylim()
    ycenter = (ylim_initial[0] + ylim_initial[1]) / 2.0
    ysize_initial = abs(ylim_initial[0] - ylim_initial[1])
    ysize_target = ysize_initial * expansion
    ylim_new = [ycenter-(ysize_target/2.0),ycenter+(ysize_target/2.0)]
    
    ax.set_xlim(xlim_new)
    ax.set_ylim(ylim_new)

    cx.add_basemap(ax, crs=geometry.crs.to_string(), source=tile_set_object, attribution=False)    
    
    # These are in the plot coordinate system
    xlims_rendered = ax.get_xlim()
    ylims_rendered = ax.get_ylim()
    
    min_coord_rendered = (min(xlims_rendered),min(ylims_rendered))
    max_coord_rendered = (max(xlims_rendered),max(ylims_rendered))
        
    plot_to_geometry_transformer = Transformer.from_crs(plot_crs,geometry_crs)    
    geometry_to_plot_transformer = Transformer.from_crs(geometry_crs,plot_crs)
    
    min_rendered_lat_lon = plot_to_geometry_transformer.transform(min_coord_rendered[0],min_coord_rendered[1])
    max_rendered_lat_lon = plot_to_geometry_transformer.transform(max_coord_rendered[0],max_coord_rendered[1])
    
    result = {}
    result['fig'] = fig
    result['ax'] = ax
    result['tile_set_name'] = tile_set_name    
    result['min_rendered_lat_lon'] = min_rendered_lat_lon
    result['max_rendered_lat_lon'] = max_rendered_lat_lon
    result['plot_to_geometry_transformer'] = plot_to_geometry_transformer
    result['geometry_to_plot_transformer'] = geometry_to_plot_transformer
    
    return result

## Render a sample file

In [None]:
# Pick a fixed index
i_record = 1

In [None]:
# ...or search by title
title_token = 'hawaiian hoary bat'
matching_i_record = None
for i_record,r in enumerate(dataset_records):
    if title_token in r['paper_title'].lower():  
        matching_i_record = i_record
        break
assert matching_i_record is not None
i_record = matching_i_record

In [None]:
r = dataset_records[i_record]
print('Paper title:\n{}\n\nFigure number: {}'.format(
    r['paper_title'],r['figure_number']))

geojson_file_abs = os.path.join(annotation_folder,r['geojson_filename'])
assert os.path.isfile(geojson_file_abs)

image_file_abs = os.path.join(image_folder,r['image_filename'])
assert os.path.isfile(image_file_abs)

geometry = gpd.read_file(geojson_file_abs, driver='GeoJSON')
plt.ion()
plot_result = render_map(geometry)