# Plot data on a choropleth map

In [None]:
json_file = "demo/gfdl_cm3.json"  # path to the Mongo document to plot
shapefile_dir = "../graphs/shapefiles/"  # path to the directory of shapefiles
plot_width = 1200  # pixel width of the plot
plot_height = 800  # pixel height of the plot
projection = 4326  # coordinate reference system to use for plotting (also try 3085)

## import packages

In [None]:
import os
import numpy as np
from shapely.geometry.polygon import Polygon
from shapely.geometry.multipolygon import MultiPolygon
from geopandas import read_file
import pandas as pd
import json

from bokeh.io import show
from bokeh.models import LogColorMapper
from bokeh.plotting import figure, output_file, save

In [None]:
# color palettes

from bokeh.palettes import Purples256
from bokeh.palettes import Blues256
from bokeh.palettes import Greens256
from bokeh.palettes import Oranges256
from bokeh.palettes import Reds256
from bokeh.palettes import Greys256
from bokeh.palettes import Blues256
from bokeh.palettes import Inferno256
from bokeh.palettes import Magma256
from bokeh.palettes import Plasma256
from bokeh.palettes import Viridis256
from bokeh.palettes import Cividis256
from bokeh.palettes import Turbo256

from bokeh.palettes import PuOr11
from bokeh.palettes import BrBG11
from bokeh.palettes import PRGn11
from bokeh.palettes import PiYG11
from bokeh.palettes import RdBu11
from bokeh.palettes import RdGy11
from bokeh.palettes import RdYlBu11
from bokeh.palettes import Spectral11
from bokeh.palettes import RdYlGn11

from bokeh.palettes import YlGn9
from bokeh.palettes import YlGnBu9
from bokeh.palettes import GnBu9
from bokeh.palettes import BuGn9
from bokeh.palettes import PuBuGn9
from bokeh.palettes import PuBu9
from bokeh.palettes import BuPu9
from bokeh.palettes import RdPu9
from bokeh.palettes import PuRd9
from bokeh.palettes import OrRd9
from bokeh.palettes import YlOrRd9
from bokeh.palettes import YlOrBr9

## geopandas functions for getting coordinates
### references:

https://automating-gis-processes.github.io/2016/Lesson5-interactive-map-bokeh.html

https://discourse.bokeh.org/t/mapping-europe-with-bokeh-using-geopandas-and-handling-multipolygons/2571

In [None]:
def get_xy_coords(geometry, coord_type):
    """
    Returns either x or y coordinates from geometry coordinate sequence. Used with Polygon geometries.
    """
    if coord_type == 'x':
        return list(geometry.coords.xy[0])
    elif coord_type == 'y':
        return list(geometry.coords.xy[1])


def get_poly_coords(geometry, coord_type):
    """
    Returns Coordinates of Polygon using the Exterior of the Polygon
    """
    return get_xy_coords(geometry.exterior, coord_type)


def multi_geom_handler(multi_geometry, coord_type):
    """
    Function for handling MultiPolygon geometries.
    Returns a list of coordinates where all parts of Multi-geometries are merged into a single list.
    Individual geometries are separated with np.nan which is how Bokeh wants them.
    Bokeh documentation regarding the Multi-geometry issues can be found here (it is an open issue).
    https://github.com/bokeh/bokeh/issues/2321
    """
    all_poly_coords = [np.append(get_poly_coords(part, coord_type), np.nan) for part in multi_geometry]
    coord_arrays = np.concatenate(all_poly_coords)
    return coord_arrays


def get_coords(row, coord_type):
    """
    Returns the coordinates ('x' or 'y') of edges of a Polygon exterior
    """
    poly_type = type(row['geometry'])

    # get coords from a single polygon
    if poly_type == Polygon:
        return get_poly_coords(row['geometry'], coord_type)
    # get coords from multiple polygons
    elif poly_type == MultiPolygon:
        return multi_geom_handler(row['geometry'], coord_type)

## merge data with the shapefile
### references:

https://docs.bokeh.org/en/latest/docs/gallery/texas.html

In [None]:
def plot_mongo_doc(data, shapefile_dir=".", palette=Blues256.reverse(), projection=4326, plot_width=1200, plot_height=800, show_fig=False, save_fig=True):

    df = {}
    geographies = {}
    datasets = data['payload'].keys()

    for dataset in datasets:

        # get data
        
        granularity = data['payload'][dataset]['granularity']
        if not granularity:
            print(f"skipping {dataset} (does not have a granularity specified)")
            continue
        else:
            print(f"plotting {dataset} (granularity: {granularity})")
        instance_col_name = 'ID'
        year = data['year']

        df[dataset] = pd.DataFrame.from_dict(
            data['payload'][dataset]['data'],
            orient='index',
            columns=[f"{dataset}_value"],
        )
        df[dataset][instance_col_name] = df[dataset].index

        
        
        # merge data with the shapefile
        
        shapefile_path = f"{shapefile_dir}/{granularity}.shp"
        if os.path.exists(shapefile_path):
            geographies[dataset] = read_file(shapefile_path).to_crs(epsg=projection)
        else:
            print(f"{shapefile_path} not found, skipping")
            continue
        geographies[dataset] = geographies[dataset].merge(
            df[dataset], on=instance_col_name
        )
        geographies[dataset]['x'] = geographies[dataset].apply(
            get_coords, coord_type='x', axis=1
        )
        geographies[dataset]['y'] = geographies[dataset].apply(
            get_coords, coord_type='y', axis=1
        )
        
        
        # create figure

        plot_data = dict(
            x=geographies[dataset]['x'].tolist(),
            y=geographies[dataset]['y'].tolist(),
            name=geographies[dataset]['ID'].tolist(),
            value=geographies[dataset][f"{dataset}_value"].tolist(),
        )

        TOOLS = "pan,wheel_zoom,reset,hover,save,box_zoom"
        
        coords_tuple = (
            ("(Lat, Lon)", "($y, $x)")
            if projection == 4326
            else ("(x, y)", "($x, $y)")
        )
        
        fig = figure(
            title=f"USA {dataset} ({year})",
            tools=TOOLS,
            plot_width=plot_width,
            plot_height=plot_height,
            x_axis_location=None,
            y_axis_location=None,
            tooltips=[("Name", "@name"), ("Value", "@value"), coords_tuple],
        )
        fig.grid.grid_line_color = None
        fig.hover.point_policy = "follow_mouse"

        # reset the color palette
        color_mapper = LogColorMapper(palette=palette)
        
        fig.patches(
            'x',
            'y',
            source=plot_data,
            fill_color={'field': 'value', 'transform': color_mapper},
            fill_alpha=0.7,
            line_color="white",
            line_width=0.5,
        )

        if save_fig:
            output_file(f"{year}_{dataset}.html")
            save(fig)
        if show_fig:
            show(fig)

## load and plot the data

In [None]:
with open(json_file) as f:
    data = json.load(f)

palette = PuBu9
palette.reverse()
    
plot_mongo_doc(data, palette=palette, shapefile_dir=shapefile_dir, projection=projection, show_fig=True)