In [None]:
from pathlib import Path
import math
import pandas as pd
import ee
import geemap
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import GridspecLayout, Button, Layout, Label, Dropdown
from IPython.core.debugger import set_trace
from evaluate_results import load_histo_file
from ipyleaflet import Marker, MarkerCluster, CircleMarker, LayerGroup
from common import load_results, get_config

In [None]:
# ee.Authenticate()
# This needs to be before ee.Initialize
Map = geemap.Map()
ee.Initialize()

In [None]:
def get_marker_color(pred, gold):
    if math.isnan(pred):
        return "darkgray"
    
    delta = abs(pred-gold)
    if  delta >= 75:
        marker_color = "red"
    elif delta >=50:
        marker_color = "pink"
    elif delta >=25:
        marker_color = "coral"
    else:
        marker_color = "blue"
        
    return marker_color

def get_marker_color_deforestation(deforest_p, deforest_g):
    if math.isnan(deforest_p):
        return "darkgray"
    
    if deforest_p == 0 and deforest_g == 0:
        marker_color = "blue"
    elif deforest_p > 0 and deforest_g > 0:
        marker_color = "green"
    elif deforest_p > 0 and deforest_g == 0:
        marker_color = "coral"
    else:
        marker_color = "red"
    
    return marker_color


def get_markers_as_layer_group(all, filtered_global_ids, on_map_show):
    bounded_area = all[all["pl_plotid"].isin(filtered_global_ids)]

    markers = []
    for _id, item in bounded_area.iterrows():
        #markers.append(Marker(location=[item["lat"],item["lon"]], title=str(item["pl_plotid"])))
        #marker_color = get_marker_color(item["forest 2018 p"], item["forest 2018 g"])
        if on_map_show == "Forest loss":
            marker_color = get_marker_color_deforestation(item["loss 2010-2018 p"], item["loss 2010-2018 g"])
        else:
            marker_color = get_marker_color(item["forest 2018 p"], item["forest 2018 g"])
        markers.append(CircleMarker(location=[item["lat"],item["lon"]], radius=5, weight=2, color=marker_color))
    marker_cluster = MarkerCluster(
        markers=markers,
        name = 'Markers'
    )
    layer_group = LayerGroup(layers=markers)
    return layer_group, bounded_area


def get_rename_mapper():
    rename_mapper = {
        "deforestation 2000-2010": "loss 2000-2010 p",
        "deforestation 2010-2018": "loss 2010-2018 p",
        "forest 2000": "forest 2000 p",
        "forest 2010": "forest 2010 p",
        "forest 2018": "forest 2018 p",
        "% of Forest": "forest 2018 g",
        "% Forest Loss 2000-2010": "loss 2000-2010 g",
        "% Forest Loss 2010-2018": "loss 2010-2018 g",
    }
    return rename_mapper


def get_cols():
    rename_mapper = get_rename_mapper()
    cols = ["plotID", "pl_plotid", "lon", "lat", "Sub-Categories if Naturally regenerated forest", "Sub-Categories if Planted forest", *list(rename_mapper)]
    return cols


def get_comparison(pred_df, gold_df):

    rename_mapper = get_rename_mapper()
    cols = get_cols()
    results_with_labels = pd.merge(
        pred_df, gold_df, how="right", left_on="plotID", right_on="pl_plotid"
    )[cols]
    results_with_labels = results_with_labels.rename(columns=rename_mapper)
    return results_with_labels

def get_data(plot_id, df):

    # Get lon of particular plot
    condition = df["pl_plotid"] == plot_id
    lon = df[condition]["lon"].iloc[0]
    lat = df[condition]["lat"].iloc[0]
    cat1 = df[condition]["Sub-Categories if Planted forest"].iloc[0]
    cat2 = df[condition]["Sub-Categories if Naturally regenerated forest"].iloc[0]
    forest2018g = df[condition]["forest 2018 g"].iloc[0]
    forest2018p = df[condition]["forest 2018 p"].iloc[0]
    return lon, lat, cat1, cat2, forest2018g, forest2018p


def get_map(point, width="500px", height="500px", zoom=15):
    # Set up interactive map.
    Map = geemap.Map(lite_mode=False)
    Map.add_basemap("SATELLITE")
    Map.center_object(point, zoom)
    Map.layout.width=width
    Map.layout.height=height
    Map.scroll_wheel_zoom = False

    return Map

def get_jaxa_layer():
    fnf = jaxa.select("fnf")
    vis = {"min": 1.0, "max": 3.0, "palette": ["006400", "FEFF99", "0000FF"]}
    return fnf, vis

In [None]:
def get_composite_layer(point):
    region_around_point = point.buffer(500)
    # collection2000 = (
    #     ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
    #     .filterDate("2000-01-01", "2000-12-31")
    #     .filterBounds(point)
    # )
    # collection2010 = (
    #     ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
    #     .filterDate("2010-01-01", "2010-12-31")
    #     .filterBounds(point)
    # )
    collection2018 = (
        ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
        .filterDate("2018-01-01", "2018-12-31")
        .filterBounds(region_around_point)
    )

    # collection2000 = ee.Algorithms.Landsat.simpleComposite(collection2000)
    # collection2010 = ee.Algorithms.Landsat.simpleComposite(collection2010)
    collection2018 = ee.Algorithms.Landsat.simpleComposite(collection2018)

    # Map.addLayer(collection2000.clip(point), {bands:["B7","B5","B3"], max:128}, "real 2000-753")
    # Map.addLayer(collection2010.clip(point), {bands:["B7","B5","B3"], max:128}, "real 2010-753")

    return collection2018.clip(region_around_point)

def button_v_clicked(b):
    with output_v:
        try:
            # Get bounding box of the map.
            ((min_y,min_x),(max_y,max_x)) = Map.bounds

            bounding_region = ee.Geometry.Rectangle((min_x, min_y, 
            max_x, max_y))

            # Get hexas residing inside the region.
            feature_collection = hexas.filterBounds(bounding_region).getInfo()
            filtered_global_ids = [int(feature["properties"]["global_id"]) for feature in feature_collection["features"]]

            bounded_area = all[all["pl_plotid"].isin(filtered_global_ids)]
            bounded_area["pl_plotid"] = bounded_area["pl_plotid"].astype(int)
            clear_output()
            display(bounded_area)
        except ValueError:
            # Map wasn't ready yet.
            pass
        return output_v
    
def button_v_clicked_forest(b):
    with output_v_forest:
        try:
            # Get bounding box of the map.
            ((min_y,min_x),(max_y,max_x)) = Map_forest.bounds

            bounding_region = ee.Geometry.Rectangle((min_x, min_y, 
            max_x, max_y))

            # Get hexas residing inside the region.
            feature_collection = hexas.filterBounds(bounding_region).getInfo()
            filtered_global_ids = [int(feature["properties"]["global_id"]) for feature in feature_collection["features"]]

            bounded_area = all[all["pl_plotid"].isin(filtered_global_ids)]
            bounded_area["pl_plotid"] = bounded_area["pl_plotid"].astype(int)
            clear_output()
            display(bounded_area)
        except ValueError:
            # Map wasn't ready yet.
            pass
        return output_v_forest

In [None]:
jaxa = ee.ImageCollection("JAXA/ALOS/PALSAR/YEARLY/FNF").filterDate(
    "2017-01-01", "2017-12-31"
)
#finland = ee.FeatureCollection("users/jjaakko/Finland_776_Hexagons_")
#hexas = finland.merge(ee.FeatureCollection('users/jjaakko/ParaguayHexagonsTrain'))
hexas = ee.FeatureCollection("users/jjaakko/Finland_776_Hexagons_")

# Get gold labels.
gold = pd.read_csv(Path("label_CSVs/validation_complete.csv", dtype={'pl_plotid': int}))

# Get predictions.
pred_df = load_results()

all = get_comparison(pred_df, gold)

# Hexa visualizer

This tools is intended for two purposes:  

- To visualize overall classification results along with the spatial information. Each hexagon is represented by a circle and the color of the circle indicates how accurate was the classification result compared to the visual interpretation data.
- To visually debug the classifier by having a look of individual hexagons to figure out where the model is doing well and where it is failing.

In [None]:
point = ee.Geometry.Point([29.0310810629922, 64.3901687169026])
Map = get_map(point, zoom=10)

try:
    ((min_y,min_x),(max_y,max_x)) = Map.bounds
except ValueError:
    # Map wasn't ready yet.
    #((min_y,min_x),(max_y,max_x)) = ((54.772375404880265, 5.30193328857422), (73.22521345877382, 49.247245788574226))
    
    #((min_y,min_x),(max_y,max_x)) = ((62, 22), (64, 24))
    ((min_y,min_x),(max_y,max_x)) = ((-55.522411831398216, -117.0867919921875), (72.94865294642922, 58.69445800781251))
    
bounding_region = ee.Geometry.Rectangle((min_x, min_y, max_x, max_y))

# Get hexas residing inside the region.
feature_collection = hexas.filterBounds(bounding_region).getInfo()
filtered_global_ids = [int(feature["properties"]["global_id"]) for feature in feature_collection["features"]]

layer_group, bounded_area = get_markers_as_layer_group(all, filtered_global_ids, "Forest loss")

Map.add_layer(layer_group)

Map.layout.width="500px"
Map.layout.height="500px"

button_v = widgets.Button(description="Filter based on visible area")
output_v = widgets.Output()
with output_v:
    display(bounded_area)

### Forest loss

On the below map you will see the hexas presented as circles. The color coding is as follows:  

- Blue (there were no deforestation and model predicted none)
- Green (there was deforestation and the model predicted deforestation)
- Orange (there were no deforestation but the model predicted deforestation)
- Red (there was deforestation but the model didn't predict deforestation)

You can move on the map and zoom on a specific area. When you click the `Filter` button, information from within that area gets shown in the ara to the right of the map.

In [None]:
display(button_v)
button_v.on_click(button_v_clicked)

In [None]:
Map.layout.margin = "0px 20px 0px 0px"
output_v.layout.overflow="scroll"
grid = GridspecLayout(1, 2, height='500px')
grid[0, 0] = Map
grid[0, 1] = output_v
grid

In [None]:
Map_forest = get_map(point, zoom=10)

layer_group_forest, bounded_area = get_markers_as_layer_group(all, filtered_global_ids, "Forest")

Map_forest.add_layer(layer_group_forest)

Map_forest.layout.width="500px"
Map_forest.layout.height="500px"

button_v_forest = widgets.Button(description="Filter based on visible area")
output_v_forest = widgets.Output()
with output_v_forest:
    display(bounded_area)

### Forest

On the below map you will see the hexas presented as circles. The color coding is as follows:  

- Blue (predicted and actual amount differ less than 25%)
- Orange (predicted and actual amount differ 25%-50%)
- Pink (predicted and actual amount differ 50%-75%)
- Red (predicted and actual amount differ more than 75%)

You can move on the map and zoom on a specific area. When you click the `Filter` button, information from within that area gets shown in the ara to the right of the map.

In [None]:
display(button_v_forest)
button_v_forest.on_click(button_v_clicked_forest)

Map_forest.layout.margin = "0px 20px 0px 0px"
output_v_forest.layout.overflow="scroll"
grid_forest = GridspecLayout(1, 2, height='500px')
grid_forest[0, 0] = Map_forest
grid_forest[0, 1] = output_v_forest
grid_forest

By clicking the `Show hexagon` button, two maps are loaded. The leftmost contains a small map that you can use to see the overall location near around the selected hexagon. The map on the right displays a detailed view of the hexagon with multiple layers.

In [None]:
button = widgets.Button(description="Show hexagon")
output = widgets.Output()
plotID = widgets.Text(
    value="524872647",
    placeholder="",
    description="Enter hexagon's plotid:",
    style=dict(description_width="initial"),
    disabled=False,
)
display(plotID, button)

In [None]:
details_output = widgets.Output()
display(details_output)

In [None]:
# display(output)

In [None]:
output2 = widgets.Output()

In [None]:
grid2 = GridspecLayout(1, 5, height='800px')
grid2[0, 0] = output2
grid2[0, 1:] = output
grid2

In [None]:
def button_clicked(b):
    try:
        with details_output:
            lon, lat, cat1, cat2, forest2018g, forest2018p = get_data(int(plotID.value), all)
            clear_output()
            display(HTML(f"Cat1: {cat1}"))
            display(HTML(f"Cat2: {cat2}"))
            display(HTML(f"Forest 2018 (gold): {forest2018g}"))
            display(HTML(f"Forest 2018 (pred): {forest2018p}"))
        
        with output:
            #lon, lat, cat1, cat2, forest2018g, forest2018p = get_data(int(plotID.value), all)
            point = ee.Geometry.Point([lon, lat])
            Map = get_map(point, "840px", "600px", zoom=14)

            fnf, vis = get_jaxa_layer()
            # Show jaxa forest vs non-forest.
            Map.addLayer(fnf, vis, "fnf", True)

            composite = get_composite_layer(point)
            Map.addLayer(
                composite,
                {"bands": ["B4", "B5", "B3"], "max": 258},
                "real 2018",
            )
            # Show hexagons on top.
            Map.addLayer(hexas, {"color": "blue"}, "hexas", True, 0.3)
            clear_output()
            display(HTML(f"Displaying {plotID.value} !!"))
            display(Map)

        with output2:
            #lon, lat, cat1, cat2, _forest2018g, _forest2018p = get_data(int(plotID.value), all)
            point = ee.Geometry.Point([lon, lat])
            Map2 = get_map(point, "200px", "200px", zoom=4)
            #composite2 = get_composite_layer(point)
            #Map2.addLayer(
            #    composite2,
            #    {"bands": ["B7", "B5", "B3"], "max": 258},
            #    "real 2018-753", False
            #)
            # Show hexagons on top.
            Map2.addLayer(hexas, {"color": "blue"}, "hexas", True, 0.3)
            markers = [Marker(location=[lat,lon])]
            marker_cluster = MarkerCluster(
                markers=markers,
                name = 'Markers')
            Map2.add_layer(marker_cluster)
            clear_output()
            display(HTML(f"Displaying {plotID.value}"))
            display(Map2)
        return output, output2
            
    except IndexError:
        clear_output()
        display(HTML(f"Hexagon with plotID of {plotID.value} not found."))
    except ValueError:
        clear_output()
        display(HTML(f"PlotID has to be an integer value."))

button.on_click(button_clicked)