# Visualize module

## Import

In [12]:
from bokeh.io import output_notebook, export_svgs, export_png
from bokeh.models import (ColumnDataSource, Label, LabelSet, Range1d, BasicTicker, LogTicker, 
                          HoverTool, ColorBar, Panel, Tabs, LinearColorMapper, Slope, Plot, Toggle,
                         DataRange1d, LinearAxis, Grid, Legend, LegendItem, Span, BoxAnnotation, ToolbarPanel,
                          Toolbar, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, WheelZoomTool, Text,
                         ToolbarBox, PanTool, BoxZoomTool, ResetTool, FactorRange)

from bokeh.plotting import ColumnDataSource, figure, output_file, show, save, from_networkx
from bokeh.transform import linear_cmap
from bokeh.models.mappers import LinearColorMapper, LogColorMapper
from bokeh.layouts import gridplot
from bokeh.transform import linear_cmap
from bokeh.util.hex import hexbin
from bokeh.models.graphs import NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4

from sklearn import datasets
from sklearn import tree
from sklearn.tree import _tree
from sklearn.linear_model import LinearRegression

import folium
import json
import numpy as np
import networkx as nx
import pandas as pd

In [4]:
output_notebook()

In [None]:
def hex_to_RGB(hex):
    ''' "#FFFFFF" -> [255,255,255] '''
    # Pass 16 to the integer function for change of base
    return [int(hex[i:i+2], 16) for i in range(1,6,2)]


def RGB_to_hex(RGB):
    ''' [255,255,255] -> "#FFFFFF" '''
    # Components need to be integers for hex to make sense
    RGB = [int(x) for x in RGB]
    return "#"+"".join(["0{0:x}".format(v) if v < 16 else
                        "{0:x}".format(v) for v in RGB])

def color_dict(gradient):
    ''' Takes in a list of RGB sub-lists and returns dictionary of
      colors in RGB and hex form for use in a graphing function
      defined later on '''
    return {"hex":[RGB_to_hex(RGB) for RGB in gradient],
            "r":[RGB[0] for RGB in gradient],
            "g":[RGB[1] for RGB in gradient],
            "b":[RGB[2] for RGB in gradient]}

def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10):
    ''' returns a gradient list of (n) colors between
      two hex colors. start_hex and finish_hex
      should be the full six-digit color string,
      inlcuding the number sign ("#FFFFFF") '''
    # Starting and ending colors in RGB form
    s = hex_to_RGB(start_hex)
    f = hex_to_RGB(finish_hex)
    # Initilize a list of the output colors with the starting color
    RGB_list = [s]
    # Calcuate a color at each evenly spaced value of t from 1 to n
    for t in range(1, n):
        # Interpolate RGB vector for color at the current value of t
        curr_vector = [
            int(s[j] + (float(t)/(n-1))*(f[j]-s[j]))
            for j in range(3)
        ]
        # Add it to our list of output colors
        RGB_list.append(curr_vector)

    return color_dict(RGB_list)

def polylinear_gradient(colors, n=256):
    ''' returns a list of colors forming linear gradients between
        all sequential pairs of colors. "n" specifies the total
        number of desired output colors '''
    # The number of colors per individual linear gradient
    n_out = int(float(n) / (len(colors) - 1))
    # returns dictionary defined by color_dict()
    gradient_dict = linear_gradient(colors[0], colors[1], n_out)

    if len(colors) > 1:
        for col in range(1, len(colors) - 1):
            next = linear_gradient(colors[col], colors[col+1], n_out)
            for k in ("hex", "r", "g", "b"):
                # Exclude first point to avoid duplicates
                gradient_dict[k] += next[k][1:]

    return gradient_dict

In [12]:
class Visualizer:
    
    def __init__(self, name=None, x=(None, None), y=(None, None), data=None, visualize=True, file_out='plots/',
                 colors=('#d0543d', '#ffffff', '#6b9ed6', '#808080', 'black', '#f09a3c'),
                tools='pan,wheel_zoom,box_zoom,reset', active_scroll='wheel_zoom', sizing_mode='stretch_both',
                x_range=None, y_range=None, x_axis_type=None, title=None):
        
        self.tools = tools
        self.active_scroll = active_scroll
        self.sizing_mode = sizing_mode
        self.x_range = x_range
        self.y_range = y_range
        self.x_axis_type = x_axis_type
        self.title = title

        self.name = name
        self.x, self.name_x = x
        self.y, self.name_y = y
        self.data = data

        if x_axis_type is not None and y_range is not None:
            self.plot = figure(title=None, margin=(5, 5, 5, 5), sizing_mode=self.sizing_mode, tools=self.tools,
                               active_scroll=self.active_scroll, y_range=self.y_range, x_axis_type=self.x_axis_type)
        elif x_axis_type is not None:
            self.plot = figure(title=self.title, margin=(5, 5, 5, 5), sizing_mode=self.sizing_mode, tools=self.tools,
                               active_scroll=self.active_scroll, x_range=self.x_range, x_axis_type=self.x_axis_type)
        elif y_range is not None:
            self.plot = figure(title=self.title, margin=(5, 5, 5, 5), sizing_mode=self.sizing_mode, tools=self.tools,
                               active_scroll=self.active_scroll, y_range=self.y_range)
        elif x_range is not None:
            self.plot = figure(title=self.title, margin=(5, 5, 5, 5), sizing_mode=self.sizing_mode, tools=self.tools,
                               active_scroll=self.active_scroll, x_range=self.x_range)
        else:
            self.plot = figure(title=self.title, margin=(5, 5, 5, 5), sizing_mode=self.sizing_mode, tools=self.tools,
                               active_scroll=self.active_scroll)
        
        self.visualize = visualize
        self.file_out = file_out + '/'
        self.colors = colors
        
        c1, c2, c3, c4, c5, c6 = self.colors
        
        self.custom_gradient = tuple(polylinear_gradient([c1, c2, c3])['hex'])
        
        self.fonter = '20px'
        self.plot.axis.major_label_text_font_size = self.fonter
#         self.plot.axis.major_label_text_font_size='20pt'

    def get_colormap(self, colors):
        
        return tuple(polylinear_gradient(colors)['hex'])
        
    def plot_to_html(self, plot_alt=None):
        simple_name = self.name.replace(' ', '_').lower()
        out_name = f'{self.file_out}{simple_name}.html'
        output_file(out_name, title=self.name)
        
        gridded = gridplot([self.plot], sizing_mode='stretch_both', 
                    ncols=1, toolbar_options=dict(logo=None))
        
        if self.visualize:
            show(gridded) if plot_alt is None else show(plot_alt)
        else:
            save(gridded) if plot_alt is None else save(plot_alt)      

        iframe = folium.IFrame(html=open(out_name, 'r').read(), width='100%', height='100%')
        iframe.save(out_name)
        
#         self.plot.output_backend = 'svg'
#         out_name = f'{self.file_out}{simple_name}.svg'
#         export_svgs(self.plot, filename=out_name)

#         out_name = f'{self.file_out}{simple_name}.png'
#         if plot_alt is None:
#             export_png(gridded, filename=out_name)
#         else:
#             export_png(gridded, filename=plot_alt)
        
    def simple(self):

        self.plot.line(self.x, self.y, line_width=2)

        self.plot_to_html()
    
    def cyclic(self, z):
        z, name_z = z
        
        custom_palette = tuple(polylinear_gradient([c3, c1], 256)['hex'])
        color_mapper = LinearColorMapper(palette=custom_palette)

        data=dict(x=x, y=y, z=z)
        
        self.plot.match_aspect = True
        self.plot.scatter('x', 'y', source=data, size={'field': 'z'}, line_color=None, 
                      color={'field': 'z', 'transform': color_mapper}, alpha=0.2)
        self.plot.toolbar.logo = None
        self.plot.xaxis.axis_label, self.plot.yaxis.axis_label = self.name_x, self.name_y
        self.plot.grid.grid_line_color = self.plot.axis.axis_line_color = None
        self.plot.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = 'Gray'

        mapper = linear_cmap(field_name='z', palette=custom_palette , low=min(z), high=max(z))

        color_bar = ColorBar(color_mapper=mapper['transform'], height=12, location=(0,0), title=name_z, title_text_font_size=self.fonter,
                            major_tick_line_color=None, minor_tick_line_color=None, orientation="horizontal",
                            major_label_text_font_size=self.fonter)
        self.plot.add_layout(color_bar, 'above')
        
        self.plot.axis.axis_label_text_font_size = self.fonter


        self.plot_to_html()
    
    def corr_matrix(self, squared=False):
        source_columns = self.data.columns.tolist()
        corr = self.data.dropna().corr()
        corr_matrix = corr.to_numpy()
        if squared:
            corr_matrix **= 2
            color_mapper = LinearColorMapper(palette=tuple(polylinear_gradient([c2, c3])['hex']), low=0, high=1)
        else:
            color_mapper = LinearColorMapper(palette=custom_gradient, low=-1, high=1)
        
        xname, yname = [], []
        for item1 in source_columns:
            for item2 in source_columns:
                yname.append(item1)
                xname.append(item2)

        data = dict(image=[corr_matrix], image_square=[corr_matrix ** 2], yname=[yname], xname=[xname])

        self.plot.image(source=data, image='image', x=0, y=0, dw=len(source_columns), dh=len(source_columns),
                        color_mapper=color_mapper)
        
        hover = self.plot.select(dict(type=HoverTool))
        hover.tooltips = [('r', "@image{1.111}"), ("r²", "@image_square{1.111}"), ("X", "@xname"), ("Y", "@yname")]
    
        self.plot.outline_line_color = None
        self.plot.x_range.range_padding = self.plot.y_range.range_padding = 0
        self.plot.grid.grid_line_color = self.plot.axis.axis_line_color = None
        self.plot.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = None

        axlist = [i / 10 for i in list(range(0, len(source_columns) * 10, 5)) if i % 2 != 0]
        axdict = dict(zip(axlist, source_columns))

        self.plot.axis.ticker = axlist
        self.plot.axis.major_label_overrides = axdict
        self.plot.xaxis.major_label_orientation = math.pi / 4

        color_bar = ColorBar(title='r²' if squared else 'r', color_mapper=color_mapper, ticker=BasicTicker(), title_text_font_style='bold',
                             major_tick_line_color=None, minor_tick_line_color=None, title_text_font_size=self.fonter, major_label_text_font_size=self.fonter,
                             title_text_align='left', label_standoff=12, border_line_color=None, location=(0,0))
        self.plot.add_layout(color_bar, 'right')
        
#         xlen = 0.5
#         for j in range(len(source_columns)):
#             ylen = 0.18
#             for i in range(len(source_columns)):
#                 annot = Label(x=xlen, y=ylen, x_units='data', y_units='data', text_font_style='bold',
#                               text_alpha=(75 * abs(corr_matrix[j][i])) / 100 + 0.1,
#                              text='%.2f' % corr_matrix[j][i], text_font_size=self.fonter, render_mode='css', background_fill_color=None, 
#                                   text_align='center')
#                 self.plot.add_layout(annot)
#                 ylen += 1
#             xlen += 1

        self.plot_to_html()
        
        return corr

    def heatmap(self):
        color_mapper = LinearColorMapper(palette=tuple(polylinear_gradient([c3, '#f7f7f7', c1])['hex']), low=50, high=150)

        if type(self.data) != list:
            source_columns = [str(x) for x in self.data.columns.tolist()]
            source_rows = self.data.index.tolist()
            matrix = self.data.to_numpy().astype(int)

            xname, yname = [], []
            for item1 in source_rows:
                for item2 in source_columns:
                    yname.append(item1)
                    xname.append(item2)

            data = dict(image=[matrix], yname=[yname], xname=[xname])

            self.plot.image(source=data, image='image', x=0, y=0, dw=len(source_columns), dh=len(source_rows),
                            color_mapper=color_mapper)

            hover = self.plot.select(dict(type=HoverTool))
            hover.tooltips = [("Deviation", f"@image% of the average {pollutant}"), ("Month", "@xname"), ("Sensor", "@yname")]

            self.plot.outline_line_color = None
            self.plot.x_range.range_padding = self.plot.y_range.range_padding = 0
            self.plot.grid.grid_line_color = self.plot.axis.axis_line_color = None
            self.plot.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = None

            axlist = [i / 10 for i in list(range(0, len(source_columns) * 10, 5)) if i % 2 != 0]
            axdict = dict(zip(axlist, source_columns))
            self.plot.xaxis.ticker = axlist
            self.plot.xaxis.major_label_overrides = axdict

            axlist = [i / 10 for i in list(range(0, len(source_rows) * 10, 5)) if i % 2 != 0]
            axdict = dict(zip(axlist, source_rows))
            self.plot.yaxis.ticker = axlist
            self.plot.yaxis.major_label_overrides = axdict

    #         self.plot.xaxis.major_label_orientation = math.pi / 4

            color_bar = ColorBar(title=f'Δ% {pollutant}', color_mapper=color_mapper, ticker=BasicTicker(), title_text_font_style='bold',
                                 major_tick_line_color=None, minor_tick_line_color=None, title_text_font_size=self.fonter,
                                 title_text_align='left', label_standoff=8, border_line_color=None, location=(0,0),
                                major_label_text_font_size=self.fonter)
            self.plot.add_layout(color_bar, 'right')

            self.plot_to_html()
            
        else:
            grid_dfs = []
            for n, df in enumerate(self.data):
                source_columns = [str(x) for x in df.columns.tolist()]
                source_rows = df.index.tolist()
                matrix = df.to_numpy().astype(int)
                
                if n == 0:
                    p = figure(title=str(n + 2015), margin=(0, 3, 0, 3),
                               tools=self.tools, active_scroll=self.active_scroll,
                               toolbar_location='left', x_axis_location=None, y_axis_location=None,
                              match_aspect=False, title_location='above')
                    p.toolbar.logo = None
                else:
                    p = figure(toolbar_location=None, title=str(n + 2015), margin=(0, 3, 0, 3),
                               tools=self.tools, active_scroll=self.active_scroll,
                                   x_axis_location=None, y_axis_location=None,
                              match_aspect=False, title_location='above', x_range=grid_dfs[0].x_range,
                               y_range=grid_dfs[0].y_range)
                    
                
                p.image(image=[matrix], x=0, y=0, dw=len(source_columns), dh=len(source_rows), color_mapper=color_mapper)
                
                p.outline_line_color = None
                p.x_range.range_padding = p.y_range.range_padding = 0
                p.grid.grid_line_color = p.axis.axis_line_color = None
                p.axis.major_tick_line_color = p.axis.minor_tick_line_color = None
                p.axis.major_label_text_color = None
                p.title.text_font_size = self.fonter
                
                grid_dfs.append(p)

            print(grid_dfs)
            gp = gridplot(grid_dfs, ncols=len(grid_dfs), sizing_mode='stretch_both', merge_tools=True, toolbar_options=dict(logo=None))
            self.plot_to_html(gp)

    def vertical_bars(self):
        
        self.plot.vbar(x=self.x, top=self.y, width=0.85, color=c3, alpha=1, line_color=None, legend_label=f'Average {pollutant} (µg/m³)')
        self.plot.vbar(x=self.x, top=self.data[0], width=0.5, color='#e3e3e3', alpha=1, line_color=None, legend_label=self.name_y[0])
        self.plot.vbar(x=self.x, top=self.data[1], width=0.25, color=c1, alpha=1, line_color=None, legend_label=self.name_y[1])

#         self.plot.line(self.x, self.data[0], color=c1, alpha=1, line_width=3, line_cap='round', legend_label=self.name_y[0])
#         self.plot.line(self.x, self.data[1], color=c3, alpha=1, line_width=3, line_cap='round', line_dash='dashed', legend_label=self.name_y[1])

        self.plot.y_range.start = 0
        self.plot.x_range.range_padding = 0.1
        self.plot.xaxis.major_label_orientation = -np.pi/2
        self.plot.xgrid.grid_line_color = None
        self.plot.axis.axis_line_color = None
        self.plot.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = 'lightgrey'
        self.plot.axis.major_tick_line_color = None
        self.plot.legend.label_text_font_size = self.fonter
        self.plot.axis.axis_label_text_font_size = self.fonter
        self.plot.axis.major_label_text_font_size = self.fonter
        
        self.plot_to_html()
        
    def sorted_bars(self, title=None, xs=None, ys=None, title2=None):

        if xs is not None:
            custom_gradient = tuple(polylinear_gradient([c2, c3], len(self.y) + 2)['hex'])[2:]
        else:
            if len(self.x) % 2 == 0:
                custom_gradient = tuple(polylinear_gradient([c1, '#f2f2f2', c3], len(self.y) + 2)['hex'])[1:]
            else:
                custom_gradient = tuple(polylinear_gradient([c1, '#f2f2f2', c3], len(self.y) + 2)['hex'])

            
        source = dict(x=self.x, y=self.y, color=list(custom_gradient))

        self.plot.plot_height = 500
        self.plot.hbar_stack(y='y', stackers='x', source=source, color='color', line_width=self.plot.plot_height/len(self.y) * 0.7)

        t = title.split('$')
        start_high = len(t) * 30 + 10
        for i in t:
            annot = Label(x=max(self.x), y=start_high, x_units='data', y_units='screen', text_font_style='bold',
                     text=i, text_font_size=self.fonter, render_mode='css', background_fill_color='white', 
                          background_fill_alpha=0.5, text_align='right')
            self.plot.add_layout(annot)
            start_high -= 30

        if xs is not None:
            custom_gradient = tuple(polylinear_gradient([c2, c1], len(ys) + 2)['hex'])[2:]
            source2=dict(x=-xs, y=ys, color=list(custom_gradient))

            t = title2.split('$')
            start_high = len(t) * 30 + 10
            for i in t:
                annot = Label(x=-max(xs) + 0.42, y=start_high, x_units='data', y_units='screen', text_font_style='bold',
                         text=i, text_font_size=self.fonter, render_mode='css', background_fill_color='white', 
                              background_fill_alpha=0.5, text_align='left')
                self.plot.add_layout(annot)
                start_high -= 30

            axlist = [(i / 100) for i in list(range(-100, 100, 20))]
            axdict = dict(zip(axlist, [str(abs(i)) for i in axlist]))

            self.plot.xaxis.ticker = axlist
            self.plot.xaxis.major_label_overrides = axdict

            self.plot.hbar_stack(y='y', stackers='x', source=source2, color='color', line_width=self.plot.plot_height/len(self.y) * 0.7)

        self.plot.ygrid.grid_line_color = None
        self.plot.axis.axis_line_color = None
        self.plot.yaxis.major_tick_line_color = self.plot.yaxis.minor_tick_line_color = None
        self.plot.xaxis.major_tick_line_color = self.plot.xaxis.minor_tick_line_color = 'gray'

        self.plot_to_html()
        
    def linked_histograms(self, color1, color2, title=None):
        p = figure(tools='pan,wheel_zoom,box_zoom,reset', sizing_mode='stretch_both', min_border=0,
                   match_aspect=False, toolbar_location="left", x_axis_location=None, y_axis_location=None,
                   title=title, title_location='below', active_scroll='wheel_zoom')
        p.toolbar.logo = None

        r = p.scatter(self.x, self.y, size=6, color=color1, alpha=0.3, line_color=None, line_width=None)

        model = LinearRegression().fit(np.array(self.x).reshape(-1, 1), self.y)

        # Find the slope and intercept from the model
        slope = model.coef_[0] # Takes the first element of the array
        intercept = model.intercept_

        # Make the regression line
        regression_line = Slope(gradient=slope, y_intercept=intercept, line_color="red", line_width=3, line_alpha=0.25, line_dash='dashed')
        p.add_layout(regression_line)
        
        model = LinearRegression().fit(np.array(np.linspace(np.min(self.x), np.max(self.x))).reshape(-1, 1), np.linspace(np.min(self.y), np.max(self.y)))
        slope = model.coef_[0] # Takes the first element of the array
        intercept = model.intercept_
        regression_line = Slope(gradient=slope, y_intercept=intercept, line_color='black', line_width=3)
        p.add_layout(regression_line)

        # create the horizontal histogram
        hhist, hedges = np.histogram(self.x, bins=50)
        hzeros = np.zeros(len(hedges)-1)
        hmax = max(hhist)*1.1

        ph = figure(toolbar_location=None, sizing_mode='stretch_width', plot_height=int(p.plot_height * 0.35), x_range=p.x_range,
                    y_range=(0, hmax), min_border=0, min_border_bottom=5, x_axis_location='above', y_axis_location="left", title=self.name_x)
        ph.ygrid.grid_line_color = None
        ph.yaxis.major_label_orientation = -np.pi/4

        ph.axis.axis_line_color = None
        ph.axis.major_tick_line_color = 'gray'
        ph.axis.minor_tick_line_color = None

        ph.quad(bottom=0, left=hedges[:-1], right=hedges[1:], top=hhist, color=color1, line_color='white')

        # create the vertical histogram
        vhist, vedges = np.histogram(self.y, bins=50)
        vzeros = np.zeros(len(vedges)-1)
        vmax = max(vhist)*1.1

        pv = figure(toolbar_location=None, plot_width=int(p.plot_width * 0.35), sizing_mode='stretch_height', x_range=(0, vmax),
                    y_range=p.y_range, min_border=0, min_border_left=5, y_axis_location="right", title=self.name_y,
                   title_location='right')
        pv.xgrid.grid_line_color = None
        pv.xaxis.major_label_orientation = -np.pi/4

        pv.axis.axis_line_color = None
        pv.axis.major_tick_line_color = 'gray'
        pv.axis.minor_tick_line_color = None

        pv.quad(left=0, bottom=vedges[:-1], top=vedges[1:], right=vhist, color=color2, line_color='white')
        
        for i in [p, pv, ph]:
            i.title.text_font_size = self.fonter
            i.axis.axis_label_text_font_size = self.fonter
            i.axis.major_label_text_font_size = self.fonter

        return gridplot([[ph, None], [p, pv]], sizing_mode='stretch_both', merge_tools=False, toolbar_options=dict(logo=None))
        
    def time_series(self, category=None, date_range=None, pred_var='Predicted', satellite=False,
                    multiple=None, now_out=False, more=None):
        
        sat_extra = None

        self.plot.grid.grid_line_alpha = 0.3

        if category == 'aq':
            self.plot.line(self.x, self.y, color=color, alpha=1 if sat_extra is None else 0.5,
                           line_width=3 if sat_extra is None else 2, line_cap='round', legend_label=self.name_y)
            y = self.data[pred_var]
            self.plot.line(self.x, y, color=c5, line_width=3, alpha=1 if sat_extra is None else 0.75, line_cap='round',
                           legend_label=self.name_x if sat_extra is None else 'Predicted (using weather stations)')
            
#             y = sat_extra['Predicted Sat']
#             self.plot.line(self.x, y, color=c5, line_width=3, line_dash='dashed', line_cap='round', legend_label='Predicted (using satellites / forecasting)')
            self.plot.legend.location = "bottom_right"
            self.plot.legend.click_policy="hide"
        else:
            if multiple is not None:
#                 for i in multiple[:-1]:
#                     self.plot.line(self.x, self.data[i], color='lightgray', alpha=1, line_width=3, line_cap='round', legend_label=str(i))
                
                self.plot.line(self.x, self.data[multiple[-1]], color=color, alpha=1, line_width=3, line_cap='round', line_dash='dashed',
                               legend_label=multiple[-1] + ' ' + pollutant + ' before ' + self.name_y)
                self.plot.line(self.x, self.y, color=color, alpha=1, line_width=3, line_cap='round', legend_label=pollutant + ' during ' + self.name_y)
                self.plot.xaxis.axis_label = self.name_x
                self.plot.axis.axis_label_text_font_size = self.fonter
                self.plot.legend.click_policy="hide"
            elif more is not None:
                for i in more:
                    datay, namey, colory = i
                    self.plot.line(self.x, datay, color=colory, alpha=1, line_width=3, line_cap='round', legend_label=namey)
            else:
                self.plot.line(self.x, self.y, color=color, alpha=1, line_width=3, line_cap='round')

            if satellite and category == 'sat':

                y = self.data[self.name_y + ' raw']
                self.plot.scatter(self.x, y, color=color, alpha=1, size=5, line_color=None, legend_label='Satellite sensor raw')
                
                y = self.data[self.title]
                self.plot.line(self.x, y, color=color, alpha=0.5, line_width=2, line_cap='round', legend_label='Ground sensor')
        
                y = self.data[self.title + ' raw']
                self.plot.scatter(self.x, y, color=color, alpha=0.25, size=3, line_color=None, legend_label='Ground sensor raw')
                
                self.plot.legend.click_policy="hide"

#         def roller(dataframe, columns, roll_size=6, periods=1):
#             df2 = dataframe.copy()
#             for i in columns:
#                 df2[i] = df2[i].rolling(min_periods=periods, window=roll_size, center=True).mean()
#             return df2
        
#         rolled = roller(self.data, self.data.columns.tolist())
#         for i in range(1):
#             rolled = roller(rolled, ['Wind gusts', 'Wind direction'], 5)
            
#         for i in range(3):
#             rolled = roller(rolled, ['Wind gusts', 'Wind direction'], 6)
#         rolled[['Wind gusts', 'Wind direction']] = rolled[['Wind gusts', 'Wind direction']].shift(-2)
            
#         y = rolled[self.name_y]
#         self.plot.line(self.x, y, color=color, alpha=1, line_width=3, line_cap='round', legend_label=self.name_y + ' smooth')
        if date_range is not None:
            vline = Span(location=date_range[0], dimension='height', line_color='gray', line_width=2, line_alpha=0.25, line_dash='dashed')
            self.plot.renderers.extend([vline])
        
        if more is not None or multiple is not None:
            self.plot.legend.background_fill_alpha=0.75
            self.plot.legend.location = "bottom_left"
            self.plot.legend.label_text_font_size = self.fonter
        self.plot.title.text_font_size = self.fonter

        self.plot.xgrid.band_hatch_pattern = "/"
        self.plot.xgrid.band_hatch_alpha = 0.6
        self.plot.xgrid.band_hatch_color = "lightgrey"
        self.plot.xgrid.band_hatch_weight = 0.5
        self.plot.xgrid.band_hatch_scale = 10
        self.plot.grid.grid_line_color = "white"
        self.plot.toolbar.logo = self.plot.axis.axis_line_color = None
        self.plot.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = 'lightgrey'
        
        
        if now_out:
            self.plot_to_html()
        else:
            return self.plot
    
    def ridgelines(self, raw=True):
        
        def ridge(category, data, scale=1):
            return list(zip([category]*len(data), scale*data))

        palette = tuple(polylinear_gradient([c3, c1], len(self.y))['hex'])
        
        start_date_1 = self.x[0] - timedelta(days=35*(len(self.y)-1))
        end_year = self.x[-1] - timedelta(days=35*(len(self.y)-1))
        for year in range(start_date_1.year, end_year.year + 1):
            end_date_1 = start_date_1 + timedelta(days=35*(len(self.y)-1))
            self.plot.segment(x0=[start_date_1], x1=[end_date_1], y0=0.5, y1=len(self.y) - 0.5, line_width=1, line_cap='round', color='gray', alpha=0.25)
            
            annot = Label(x=start_date_1, y=0, y_offset=-17.5,
                  x_units='data', y_units='data', text_font_style='bold', text_align='center',
                  text=str(year), text_font_size='10px', render_mode='css', text_alpha=0.8,
                  background_fill_color='white', background_fill_alpha=0.5)
            self.plot.add_layout(annot)
                        
            start_date_1 += timedelta(days=364 if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0 else 365)
        
#         NL_measures = datetime(year=2020, month=3, day=9) - timedelta(hours=1000)
#         CHINA_measures = datetime(year=2020, month=1, day=23) - timedelta(hours=1000)
#         start_date_1 = NL_measures
#         end_date_1 = start_date_1 + timedelta(days=35*(len(self.y)-1))
#         self.plot.segment(x0=[start_date_1], x1=[end_date_1], y0=0.5, y1=len(self.y) - 0.5, line_width=2, line_cap='round', color='gray', alpha=0.5)

#         annot = Label(x=start_date_1, y=0, y_offset=-30,
#               x_units='data', y_units='data', text_font_style='bold', text_align='center',
#               text='First COVID-19 measures (2020-03-09)', text_font_size='10px', render_mode='css', text_alpha=0.8,
#               background_fill_color='white', background_fill_alpha=0.5)
#         self.plot.add_layout(annot)
            

        if raw:
            yvals = len(self.y) + 0.5
        else:
            yvals = len(self.y)
            
        for i, station in enumerate(reversed(self.y)):
            rolling, harmonic = self.data[i]
            new_df = pd.DataFrame({'datetime': self.x, 'rolling': rolling.tolist(),
                                   'harmonic': harmonic.tolist()})

            new_df[['rolling']] = new_df[['rolling']].rolling(min_periods=1, window=7 * 4 * 3 * 20, center=True).mean()

            if pollutant in station and 'S5P' not in station:
                new_df['rolling'] *= 1/120
                new_df['harmonic'] *= 1/120
#                 new_df['trend'] *= 1/25
            else:
                adder = 0 if 'S5P' in station else 0.05
                new_df['normalizer'] = new_df['rolling']
                new_df['rolling'] = (new_df['rolling']-new_df['rolling'].min())/(new_df['rolling'].max()-new_df['rolling'].min())
                new_df['harmonic'] = (new_df['harmonic']-new_df['normalizer'].min())/(new_df['normalizer'].max()-new_df['normalizer'].min())
                new_df['rolling'] = new_df['rolling'] / 4 + adder
                new_df['harmonic'] = new_df['harmonic'] / 4 + adder
#                 new_df['trend'] *= 1/new_df['trend'].max()


            new_df['trend'] = new_df['rolling'] - new_df['harmonic']
            
            new_df = new_df.iloc[::100] # 1060
            
            new_df['datetime'] = new_df['datetime'] - timedelta(days=i*35)
                                    
            yvals -= 1
            
            color_set = palette[i] if i > 4 else '#abacae'

            if raw:
                y = new_df['rolling'].multiply(6).add(yvals)
                y0 = new_df['rolling'].subtract(new_df['rolling']).multiply(6).add(yvals)
                self.plot.varea(new_df['datetime'], y1=y0, y2=y, alpha=0.2, fill_color=color_set)
                self.plot.line(new_df['datetime'], y, alpha=0.8, line_color=color_set,
                               line_width=2, line_cap='round', legend_label=self.name_x)
                
                y2 = new_df['harmonic'].multiply(6).add(yvals)
                self.plot.line(new_df['datetime'], y2, alpha=1, line_color=color_set,
                               line_width=2, line_dash='dashed', line_cap='round', legend_label=self.name_y)
                
            else:
                trend_vals = new_df['trend'].multiply(10).add(0.5).add(yvals)
                
                mean_trend = trend_vals.mean()
                y0 = [mean_trend for i in trend_vals]
                y = trend_vals

                self.plot.varea(new_df['datetime'].tolist(), y1=y0, y2=y, alpha=0.4, fill_color=color_set)

#                 self.plot.line(new_df['datetime'].tolist(), y, alpha=0.5, line_color=color_set, line_dash='dashed',
#                                line_width=1, line_cap='round')
                
                self.plot.line(new_df['datetime'].tolist(), y, alpha=0.8, line_color=color_set,
                               line_width=2, line_cap='round', legend_label=self.name_x)
                

           
            annot = Label(x=new_df['datetime'].tolist()[-1], y=yvals, x_offset=10,
                  x_units='data', y_units='data', text_font_style='bold',
                  text=station, text_font_size='9px', render_mode='css', text_alpha=0.8 if i > 4 else 0.5,
                  background_fill_color='white', background_fill_alpha=0.5)
            self.plot.add_layout(annot)


        if raw:
            self.plot.legend.click_policy="hide"
        self.plot.legend.location="bottom_left"
        self.plot.legend.background_fill_alpha=0.75
        self.plot.legend.orientation = "horizontal"
        self.plot.legend.margin = 2
        self.plot.legend.padding = 5
        
        self.plot.outline_line_color = None
        self.plot.background_fill_color = 'white'

        self.plot.ygrid.grid_line_color = None
        self.plot.xgrid.grid_line_color = None
        self.plot.xgrid.ticker = self.plot.xaxis.ticker

        self.plot.axis.minor_tick_line_color = None
        self.plot.axis.major_tick_line_color = None
        self.plot.axis.axis_line_color = None
        
        self.plot.yaxis.visible = False
        self.plot.xaxis.visible = False

        self.plot.y_range.range_padding = 0.23
#         self.plot.x_range.range_padding = 0.1
        
        self.plot.x_range=Range1d(datetime(year=2014, month=12, day=1), datetime(year=2024, month=3, day=1))
        
        
#         .ray(x=[1, 2, 3], y=[1, 2, 3], length=45, angle=[30, 45, 60],
#       angle_units="deg", color="#FB8072", line_width=2)

#         self.plot.padding_right = 150
#         self.plot.padding_top = 100

        self.plot_to_html()
    
    def decision_tree(self):

        columns = self.x.columns.to_list()
        yname = self.y.columns.to_list()
        # rf = RandomForestRegressor(n_estimators=5, random_state=1, max_depth=5, n_jobs=-1)
        dt = tree.DecisionTreeRegressor(random_state=1, max_depth=4)
        dt.fit(self.x, self.y)
        # rf.fit(features_train, labels_train)
        # estimator = rf.estimators_[3]
        # dt = estimator
        
        def tree_to_code(tree, feature_names):
            tree_ = tree.tree_
            feature_name = [
                feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
                for i in tree_.feature
            ]
        #     print("def tree({}):".format(", ".join(feature_names)))

            g = nx.DiGraph()

            def recurse(node, depth, g):
                indent = "  " * depth
                if tree_.feature[node] != _tree.TREE_UNDEFINED:
                    name = feature_name[node]
                    threshold = tree_.threshold[node]
                    node_name = "{}\n≤ {:.2f}".format(name, threshold) + ' NODE {}'.format(node)
                    g.add_node(node_name)
        #             print("{}if {} <= {}:".format(indent, name, threshold))
                    cl_name = recurse(tree_.children_left[node], depth + 1, g)
                    g.add_edge(node_name, cl_name, name='yes')
        #             print("{}else:".format(indent, name, threshold))
                    cr_name = recurse(tree_.children_right[node], depth + 1, g)
                    g.add_edge(node_name, cr_name, name='no')
                else:
                    node_name = "{}\n≈ {:.2f}".format(yname[0], tree_.value[node][0][0]) + 'NODE{}'.format(node)
                    g.add_node(node_name)
        #             print("{}return {}".format(indent, tree_.value[node]))
                return node_name
            recurse(0, 1, g)
            relabel_dict = {}
            order_dict = {}
            for n in g.nodes:
                relabel_dict[n], order = n.split('NODE')
                order_dict[relabel_dict[n]] = int(order)
            #g = nx.relabel_nodes(g, relabel_dict)

            return g, order_dict
        
        g, order_dict = tree_to_code(dt, columns)
        
        def get_root(g):
            root = [node for node, deg in g.degree() if deg == 2]
            if len(root) != 1:
                raise Exception('something wrong')
            else:
                return root[0]

        def set_pos_dict(g, parent, node, pos_dict, dx=1, dy=1, root_coord=(0, 1), eps=0.5):
            if parent is None:
                node = get_root(g)
                x, y = root_coord
            else:
                x, y = pos_dict[parent]
                y = y - dy
                edge = g.get_edge_data(parent, node)
                if edge['name'] == 'yes':
                    x = x + dx
                else:
                    x = x - dx
            pos_dict[node] = np.array((x, y))

            children = [dest for orig, dest in g.edges if orig == node]
            for child in children:
                set_pos_dict(g, node, child, pos_dict, dx=dx*eps)
        
        pos = {}
        set_pos_dict(g, None, None, pos, dx=50, dy=3)
                
        def fun_layout(g, scale=None, center=None, dim=None):
            xy = pos.values()
            xy = np.array(list(xy))
            mean = xy.mean(axis=0)
            max_ = np.abs(xy).max(axis=0)
            xy = (xy - mean + center)*scale/max_
            i = 0
            for k, v in pos.items():
                pos[k] = xy[i]
                i += 1
            return pos
        
        G = g
        p = figure(tools='pan,wheel_zoom,box_zoom,reset', active_scroll='wheel_zoom', sizing_mode='stretch_width', plot_height=500,
                   x_range=Range1d(-1.1, 1.1), y_range=Range1d(-0.5, 1.2), x_axis_location=None, y_axis_location=None,
                   margin=(5, 5, 5, 5))
        # p.title.text = "Decision tree visualization"
        # hover = HoverTool(tooltips=[("Node", "@name")])
        # p.add_tools(hover)

        p.grid.grid_line_color = self.plot.axis.axis_line_color = None
        p.axis.major_tick_line_color = self.plot.axis.minor_tick_line_color = None
        p.toolbar.logo = None

        graph_renderer = from_networkx(G, fun_layout, scale=1, center=(0,0))

        # graph_renderer.node_renderer.glyph = Circle(size=1, line_color=None, fill_color=c3)
        graph_renderer.node_renderer.glyph = Text(text='name', text_font_size='10px', text_font_style='bold', text_align='center', y_offset=12)

        # graph_renderer.node_renderer.selection_glyph = Circle(size=15, fill_color=c1)
        # graph_renderer.node_renderer.hover_glyph = Circle(size=15, fill_color=c1)
        graph_renderer.node_renderer.data_source.data['name'] = [e.split('NODE')[0] for e in list(g.nodes)]


        graph_renderer.edge_renderer.glyph = MultiLine(line_color="#CCCCCC", line_alpha=0.6, line_width=2)
        # graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=5)
        # graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)

        # graph_renderer.selection_policy = NodesAndLinkedEdges()
        graph_renderer.inspection_policy = NodesAndLinkedEdges()

        p.renderers.append(graph_renderer)

        self.plot_to_html(p)

In [68]:
class Mapper:
    
    def __init__(self, name, file_out='plots/'):
        self.name = name
        self.file_out = file_out
        pass
    
    def map_sensors(self, station_focus=None, df_aq_stations=None, weather_stations=None, traffic_points=None, aq_station_plot='time_series_of_aq'):

        f = folium.Figure(width=1000, height=500)
        # [52.1326, 5.2913]
        m = folium.Map(
            location=[df_aq_stations.loc[station_focus[0]]['lat'],
                      df_aq_stations.loc[station_focus[0]]['lon']], # [df_aq_stations.loc[station_focus[0]]['lat'], df_aq_stations.loc[station_focus[0]]['lon']],
            zoom_start=10, attributionControl=False, # 12 zoom
            control_scale=True
        ).add_to(f)

        folium.WmsTileLayer(url = 'https://geodata.nationaalgeoregister.nl/luchtfoto/rgb/wms?', 
                                    layers='Actueel_ortho25', name='Aerial photograph', fmt='image/png', 
                                    attr='Luchtfoto', transparent=True, show=False, version='1.3.0').add_to(m)
        
        if gee_user:
            img = folium.raster_layers.ImageOverlay(
                name='S5P',
                image='images/gee/png_out/Air quality_ NO2 S5P 2020-03-13 12_59_49.png',
    #             [[[2.395, 54.068], [2.395, 50.243],
    #               [8.042, 50.243], [8.042, 54.068]]]

                bounds=[[50.243, 2.395], [54.068, 8.042]], # [[lat_min, lon_min], [lat_max, lon_max]]
                opacity=0.9,
                interactive=True,
                cross_origin=False,
                zindex=100,
                overlay=True,
                show=False
            )

            img.add_to(m)

        popup_width, popup_height = 350, 300

        f1 = folium.FeatureGroup(name='Air quality sensors').add_to(m)
        
        for n, i in enumerate(df_aq_stations.index.to_list()):
            alt_pop = '''<h6>Air quality sensor</h6><h5><b>{}</b></h5>
                              <table style="width:100%">
                                  <tr>
                                      <td><b>number</b></td>
                                      <td>{}</td>
                                  </tr>
                                  <tr>
                                      <td><b>lat, lon</b></td>
                                      <td>{}, {}</td>
                                  </tr>
                                  <tr>
                                      <td><b>components</b></td>
                                      <td>{}</td>
                                  </tr>
                              </table>
                              '''.format(
                              i,
                              df_aq_stations.loc[i]['number'],
                              df_aq_stations.loc[i]['lat'],
                              df_aq_stations.loc[i]['lon'],
                              ', '.join(df_aq_stations.loc[i]['components'].split('|')),
                              df_aq_stations.loc[i]['id'])
#             if i in station_focus:
#                 iframe = folium.IFrame(html=open(f'{self.file_out}/{i}/{aq_station_plot}.html', 'r').read(), width=popup_width, height=popup_height)
#                 popup = folium.Popup(iframe, max_width=2650)
            folium.CircleMarker([df_aq_stations.loc[i]['lat'], df_aq_stations.loc[i]['lon']],
                          radius=5,
                          color=c5 if i in station_focus else '#A0A0A0',
                          weight=3,
                          fill=True,
                          fillColor='#A0A0A0',
                          fillOpacity=0.5,
                          popup=None, # popup if i in station_focus else None
                          tooltip=alt_pop).add_to(m).add_to(f1)
            
        if weather_stations is not None:
            f2 = folium.FeatureGroup(name='Weather stations').add_to(m)
            for n, i in enumerate(weather_stations.index.to_list()):
#                 iframe = folium.IFrame(html=open(f'{self.file_out}/{station_focus[0]}/time_series_of_whtr.html', 'r').read(), width=popup_width, height=popup_height)
#                 popup = folium.Popup(iframe, max_width=2650)
                folium.RegularPolygonMarker([weather_stations.loc[i]['lat'], weather_stations.loc[i]['lon']],
                              radius=7,
                              color='blue',
                              weight=3,
                              fill=True,
                              fillColor=c3,
                              fillOpacity=0.5,
                              popup=None,
                              tooltip=f'<h6>Weather station</h6><h5><b>{i}</b></h5>').add_to(m).add_to(f2)
        
        if traffic_points is not None:
            f3 = folium.FeatureGroup(name='Road traffic sensors').add_to(m)
            for n, i in enumerate(traffic_points.index.to_list()):
#                 iframe = folium.IFrame(html=open(f'{self.file_out}/{station_focus[0]}/time_series_of_trfc.html', 'r').read(), width=popup_width, height=popup_height)
#                 popup = folium.Popup(iframe, max_width=2650)
                folium.CircleMarker([traffic_points.loc[i]['lat'], traffic_points.loc[i]['lon']],
                              radius=3,
                              color='red',
                              weight=3,
                              fill=True,
                              fillColor=c1,
                              fillOpacity=0.5,
                              popup=None,
                              tooltip=f'<h6>Road traffic sensor</h6><h5><b>{i}</b></h5>').add_to(m).add_to(f3)

        folium.LayerControl().add_to(m)
        # m.add_child(folium.LatLngPopup())
        m.save(f'{self.file_out}/{self.name}.html')

        iframe = folium.IFrame(html=open(f'{self.file_out}/{self.name}.html', 'r').read(), width='100%', height='100%')
        iframe.save(f'{self.file_out}/{self.name}.html')
        
        return m