In [None]:
from pathlib import Path
import pandas as pd
import ee
import geemap
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import GridspecLayout, Button, Layout, Label
from IPython.core.debugger import set_trace
from evaluate_results import load_histo_file

ee.Initialize()

In [None]:
from ipyleaflet import Marker, MarkerCluster

In [None]:
def get_comparison(pred_df, gold_df):
    rename_mapper = {
        "deforestation 2000-2018": "loss 2000-2018 p",
        "deforestation 2010-2018": "loss 2010-2018 p",
        "forest 2000": "forest 2000 p",
        "forest 2010": "forest 2010 p",
        "forest 2018": "forest 2018 p",
        "% of Forest": "forest 2018 g",
        "% Forest Loss 2000-2010": "loss 2000-2018 g",
        "% Forest Loss 2010-2018": "loss 2010-2018 g",
    }
    cols = ["plotID", "pl_plotid", "lon", "lat", "Sub-Categories if Naturally regenerated forest", "Sub-Categories if Planted forest", *list(rename_mapper)]
    results_with_labels = pd.merge(
        pred_df, gold_df, how="right", left_on="plotID", right_on="pl_plotid"
    )[cols]
    results_with_labels = results_with_labels.rename(columns=rename_mapper)
    return results_with_labels

def get_data(plot_id, df):

    # Get lon of particular plot
    condition = df["pl_plotid"] == plot_id
    lon = df[condition]["lon"].iloc[0]
    lat = df[condition]["lat"].iloc[0]
    cat1 = df[condition]["Sub-Categories if Planted forest"].iloc[0]
    cat2 = df[condition]["Sub-Categories if Naturally regenerated forest"].iloc[0]
    forest2018g = df[condition]["forest 2018 g"].iloc[0]
    forest2018p = df[condition]["forest 2018 p"].iloc[0]
    return lon, lat, cat1, cat2, forest2018g, forest2018p


def get_map(point, zoom=15):
    # Set up interactive map.
    Map = geemap.Map(lite_mode=False)
    Map.add_basemap("SATELLITE")
    Map.center_object(point, zoom)
    Map.layout.width="500px"
    Map.layout.height="500px"
    Map.scroll_wheel_zoom = False

    return Map

def get_jaxa_layer():
    fnf = jaxa.select("fnf")
    vis = {"min": 1.0, "max": 3.0, "palette": ["006400", "FEFF99", "0000FF"]}
    return fnf, vis

In [None]:
def get_composite_layer(point):
    region_around_point = point.buffer(500)
    # collection2000 = (
    #     ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
    #     .filterDate("2000-01-01", "2000-12-31")
    #     .filterBounds(point)
    # )
    # collection2010 = (
    #     ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
    #     .filterDate("2010-01-01", "2010-12-31")
    #     .filterBounds(point)
    # )
    collection2018 = (
        ee.ImageCollection("LANDSAT/LE07/C01/T1_RT")
        .filterDate("2018-01-01", "2018-12-31")
        .filterBounds(region_around_point)
    )

    # collection2000 = ee.Algorithms.Landsat.simpleComposite(collection2000)
    # collection2010 = ee.Algorithms.Landsat.simpleComposite(collection2010)
    collection2018 = ee.Algorithms.Landsat.simpleComposite(collection2018)

    # Map.addLayer(collection2000.clip(point), {bands:["B7","B5","B3"], max:128}, "real 2000-753")
    # Map.addLayer(collection2010.clip(point), {bands:["B7","B5","B3"], max:128}, "real 2010-753")

    return collection2018.clip(region_around_point)

In [None]:
jaxa = ee.ImageCollection("JAXA/ALOS/PALSAR/YEARLY/FNF").filterDate(
    "2017-01-01", "2017-12-31"
)
hexas = ee.FeatureCollection("users/jjaakko/Finland_776_Hexagons_")

# Get gold labels.
gold = pd.read_csv(Path("results/training_complete.csv"))

# Get predictions.
pred_file = "results/finland_subset1_samples_2.csv"
pred_df = load_histo_file(pred_file)
# Cast plotID to int.
pred_df["plotID"] = pred_df["plotID"].astype(int)

all = get_comparison(pred_df, gold)

# Hexa visualizer

Render satellite view, jaxa 2017 and landsat 7 composite from 2018 for the given hexagon.

Some plotids to try out:  
- 524872647
- 525122005
- 523291679 

In [None]:
button = widgets.Button(description="Show hexagon")
output = widgets.Output()
plotID = widgets.Text(
    value="524872647",
    placeholder="",
    description="Enter hexagon's plotid:",
    style=dict(description_width="initial"),
    disabled=False,
)
display(plotID, button)

In [None]:
details_output = widgets.Output()
display(details_output)

In [None]:
# display(output)

In [None]:
output2 = widgets.Output()
#display(output2)
output3 = widgets.Output()
with output3:
    display("Hello")

In [None]:
grid = GridspecLayout(1, 2, height='500px')
grid[0, 0] = output2
grid[0, 1] = output
grid

In [None]:
def button_clicked(b):
    try:
        with details_output:
            lon, lat, cat1, cat2, forest2018g, forest2018p = get_data(int(plotID.value), all)
            clear_output()
            display(HTML(f"Cat1: {cat1}"))
            display(HTML(f"Cat2: {cat2}"))
            display(HTML(f"Forest 2018 (gold): {forest2018g}"))
            display(HTML(f"Forest 2018 (pred): {forest2018p}"))
        
        with output:
            #lon, lat, cat1, cat2, forest2018g, forest2018p = get_data(int(plotID.value), all)
            point = ee.Geometry.Point([lon, lat])
            Map = get_map(point)

            fnf, vis = get_jaxa_layer()
            # Show jaxa forest vs non-forest.
            Map.addLayer(fnf, vis, "fnf", True)

            composite = get_composite_layer(point)
            Map.addLayer(
                composite,
                {"bands": ["B7", "B5", "B3"], "max": 128},
                "real 2018-753",
            )
            # Show hexagons on top.
            Map.addLayer(hexas, {"color": "blue"}, "hexas", True, 0.3)
            clear_output()
            display(HTML(f"Displaying {plotID.value} !!"))
            display(Map)

        with output2:
            #lon, lat, cat1, cat2, _forest2018g, _forest2018p = get_data(int(plotID.value), all)
            point = ee.Geometry.Point([lon, lat])
            Map2 = get_map(point, zoom=6)
            composite2 = get_composite_layer(point)
            Map2.addLayer(
                composite2,
                {"bands": ["B7", "B5", "B3"], "max": 128},
                "real 2018-753", False
            )
            # Show hexagons on top.
            Map2.addLayer(hexas, {"color": "blue"}, "hexas", True, 0.3)
            markers = [Marker(location=[lat,lon])]
            marker_cluster = MarkerCluster(
                markers=markers,
                name = 'Markers')
            Map2.add_layer(marker_cluster)
            clear_output()
            display(HTML(f"Displaying {plotID.value}"))
            display(Map2)
        return output, output2
            
    except IndexError:
        clear_output()
        display(HTML(f"Hexagon with plotID of {plotID.value} not found."))
    except ValueError:
        clear_output()
        display(HTML(f"PlotID has to be an integer value."))

button.on_click(button_clicked)

In [None]:
cols = ["pl_plotid", "forest 2018 g", "forest 2018 p", "Sub-Categories if Naturally regenerated forest", "Sub-Categories if Planted forest"]
point = ee.Geometry.Point([29.0310810629922, 64.3901687169026])
Map = get_map(point, zoom=10)

try:
    ((min_y,min_x),(max_y,max_x)) = Map.bounds
except ValueError:
    # Map wasn't ready yet.
    ((min_y,min_x),(max_y,max_x)) = ((54.772375404880265, 5.30193328857422), (73.22521345877382, 49.247245788574226))
    
bounding_region = ee.Geometry.Rectangle((min_x, min_y, 
max_x, max_y))

# Get hexas residing inside the region.
feature_collection = hexas.filterBounds(bounding_region).getInfo()
filtered_global_ids = [int(feature["properties"]["global_id"]) for feature in feature_collection["features"]]

bounded_area = all[all["pl_plotid"].isin(filtered_global_ids)]

markers = []
for _id, item in bounded_area.iterrows():
    markers.append(Marker(location=[item["lat"],item["lon"]], title=str(item["pl_plotid"])))
marker_cluster = MarkerCluster(
    markers=markers,
    name = 'Markers'
)
Map.add_layer(marker_cluster)

marker_cluster = MarkerCluster(
    markers=markers,
    name = 'Markers'
)
Map.add_layer(marker_cluster)
Map.layout.width="500px"
Map.layout.height="500px"

In [None]:
def button_v_clicked(b):
   with output_v:
            
        try:
            # Get bounding box of the map.
            ((min_y,min_x),(max_y,max_x)) = Map.bounds

            bounding_region = ee.Geometry.Rectangle((min_x, min_y, 
            max_x, max_y))

            # Get hexas residing inside the region.
            feature_collection = hexas.filterBounds(bounding_region).getInfo()
            filtered_global_ids = [int(feature["properties"]["global_id"]) for feature in feature_collection["features"]]

            bounded_area = all[all["pl_plotid"].isin(filtered_global_ids)]
            clear_output()
            display(bounded_area[cols])
        except ValueError:
            # Map wasn't ready yet.
            pass
        return output_v

In [None]:
button_v = widgets.Button(description="Filter based on visible area")
output_v = widgets.Output()
with output_v:
    display(bounded_area[cols])
display(button_v)
button_v.on_click(button_v_clicked)

In [None]:
from ipywidgets import GridspecLayout, Button, Layout

def create_expanded_button(description, button_style):
    return Button(description=description, button_style=button_style, layout=Layout(height='auto', width='auto'))

Map.layout.margin = "0px 20px 0px 0px"
output_v.layout.overflow="scroll"
grid = GridspecLayout(2, 2, height='500px')
grid[0, :] = Label("Hello")
grid[1, 0] = Map
grid[1, 1] = output_v
grid