In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
from src.models import geograph, visualisation
from src.data_loading import test_data

test_gdf = test_data.get_polygon_gdf("chernobyl_squares_touching")
test_gdf['class_label'] = 0
test_gdf = test_gdf[['geometry', 'class_label']]

In [None]:
test_gdf2 = test_data.get_polygon_gdf("chernobyl_squares_apart")
test_gdf2['class_label'] = 0
test_gdf2 = test_gdf2[['geometry', 'class_label']]

In [None]:
graph = geograph.GeoGraph(test_gdf)

In [None]:
graph2 = geograph.GeoGraph(test_gdf2)

In [None]:
#widgets.link((m, 'layers'), (viewer, 'layers'))

In [None]:
from traitlets import HasTraits, Integer, observe, Dict, Tuple
import ipywidgets as widgets
from ipywidgets import HTML
import ipyleaflet
from ipyleaflet import TileLayer, GeoData, basemaps
import copy
import markdown 
import pandas as pd

from src.constants import CHERNOBYL_COORDS_WGS84, WGS84

class GeoGraphViewer(ipyleaflet.Map):
    #layers = Tuple()
    log_tab = widgets.Output(layout={'border': '1px solid black'})
    
    def __init__(self):
        super().__init__(center=CHERNOBYL_COORDS_WGS84, zoom = 7, scroll_wheel_zoom=True)
        self.map = ipyleaflet.Map(center=CHERNOBYL_COORDS_WGS84, zoom = 7, scroll_wheel_zoom=True)
        self.layer_dict = dict(Basemap=dict(layer=TileLayer(base=True, max_zoom=19, min_zoom=4),active=True, layer_type='map'))
        #self.create_traitlets_links()
        self.custom_style = dict(style={'color': 'black', 'fillColor': '#3366cc'},
                          hover_style={'fillColor': 'red' , 'fillOpacity': 0.2},
                          point_style={'radius': 10, 'color': 'red', 'fillOpacity': 0.8, 'fillColor': 'blue', 'weight': 3})
    
    @log_tab.capture()     
    def add_graph(self,graph, name = 'Graph'):
        nodes, edges = visualisation.create_node_edge_geometries(graph.graph)
        graph_geometries = pd.concat([nodes,edges]).reset_index()
        graph_geo_data = ipyleaflet.GeoData(geo_dataframe = graph_geometries.to_crs(WGS84),
                          style ={'color': 'black', 'fillColor': '#3366cc'},
                          hover_style={'fillColor': 'red' , 'fillOpacity': 0.2},
                          point_style={'radius': 10, 'color': 'red', 'fillOpacity': 0.8, 'fillColor': 'blue', 'weight': 3},
                          name = name)
        self.layer_dict[name] = dict(layer=graph_geo_data, active=True, layer_type='graph')
        self.layer_update()
    
    @log_tab.capture() 
    def layer_update(self):
        self.layers = tuple([entry['layer'] for entry in self.layer_dict.values() if entry['active']])
        
    @log_tab.capture()    
    def set_graph_style(self, radius):
        for name, entry in self.layer_dict.items():
            if entry['layer_type'] == "graph":
                layer = entry['layer']
                #layer.point_style['radius'] = radius   #doesn't work because traitlet change not observed
                self.custom_style['point_style']['radius'] = radius
                layer = ipyleaflet.GeoData(geo_dataframe = layer.geo_dataframe, name = layer.name, **self.custom_style)
                self.layer_dict[name]['layer'] = layer
        self.layer_update()
            
    
    @log_tab.capture()    
    def remove_graphs(self, button):
        for idx, layer in enumerate(self.layers):
            if layer.name[0:5] == "Graph":
                self.remove_layer(layer)
                
    @log_tab.capture()    
    def checkbox_layer_switch(self, change_dict):
        if change_dict['name'] == 'value':
            layer_name = change_dict['owner'].description
            self.layer_dict[layer_name]['active'] = change_dict['new']
            
        self.layer_update()
    
    
    @log_tab.capture()    
    def create_habitat_tab(self):
        checkboxes = []
        for name in self.layer_dict.keys():
            checkbox = widgets.Checkbox(
                    value=True,
                    description=name,
                    disabled=False,
                    indent=False
                )
            checkbox.observe(self.checkbox_layer_switch)
            checkboxes.append(checkbox)
            
        self.add_habitat_button = widgets.Button(
            description='Remove graph',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Click me',
            icon='plus' # (FontAwesome names without the `fa-` prefix)
        )
        self.add_habitat_button.on_click(self.remove_graphs)
        
        habitat_accordion = widgets.Accordion(children=[self.add_habitat_button])
        habitat_accordion.set_title(0, 'Habitats')
        
        checkboxes.append(habitat_accordion)
        
        habitat_tab = widgets.VBox(checkboxes)
            
        return habitat_tab
    
    @log_tab.capture()    
    def create_diff_tab(self):
        time_slider1 = widgets.IntSlider(min=1960, max=2021, step=1, value=1990, description="Start time:")
        time_slider2 = widgets.IntSlider(min=1960, max=2021, step=1, value=2010, description="End time:")
        
        compute_node_button = widgets.Button(
            description='Compute node diff',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='This computes the differences between of the nodes in the graph at start time and the graph at end time.',
            icon='' # (FontAwesome names without the `fa-` prefix)
        )
        
        compute_pgon_button = widgets.Button(
            description='Compute polygon diff',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='This computes the differences between of the polygons in the graph at start time and the graph at end time.',
            icon='' # (FontAwesome names without the `fa-` prefix)
        )
        
        diff_tab = widgets.VBox([time_slider1, time_slider2, compute_node_button, compute_pgon_button])
        
        return diff_tab
    
    @log_tab.capture()    
    def create_settings_tab(self):
        radius_slider = widgets.FloatSlider(min=0.01, max=100.0, step=0.005, value=5.0, description="Node radius:")
        w = widgets.interactive(self.set_graph_style, radius=radius_slider)
        
        zoom_slider = widgets.FloatSlider(description='Zoom level:', min=0, max=15, value=7)
        widgets.jslink((zoom_slider, 'value'), (self, 'zoom'))
        
        settings_tab = widgets.VBox([zoom_slider, radius_slider, radius_slider, radius_slider, radius_slider, radius_slider])
        
        return settings_tab
        
    @log_tab.capture()    
    def add_widgets(self):
        self.add_settings_widget()
        self.add_control(ipyleaflet.FullScreenControl())
            
    @log_tab.capture()
    def add_settings_widget(self):
        habitats_tab = self.create_habitat_tab()
        diff_tab = self.create_diff_tab()
        settings_tab = self.create_settings_tab()
        
        tab_nest = widgets.Tab()
        tab_nest.children = [habitats_tab, settings_tab, diff_tab, settings_tab, self.log_tab]
        for i, title in enumerate(["Habitats","Metrics","Diff","Settings","Log"]):
            tab_nest.set_title(i, title)
            
        header = widgets.HTML(markdown.markdown("""&nbsp;&nbsp;GeoGraph&nbsp;&nbsp;"""))
        widget_control2 = ipyleaflet.WidgetControl(widget=header, position='bottomright')
        self.add_control(widget_control2)
        
        self.control = ipyleaflet.LayersControl(position='topleft')
        self.add_control(self.control)
        
        widget_panel = widgets.VBox([tab_nest])
        widget_control = ipyleaflet.WidgetControl(widget=widget_panel, position='topright')
        self.add_control(widget_control)
        


        
viewer = GeoGraphViewer()
viewer.add_graph(graph)
viewer.add_graph(graph2, name = 'Graph 2')
viewer.create_habitat_tab()
viewer.add_widgets()
viewer

In [None]:
viewer.log_tab