In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from unittest.mock import AsyncMock, Mock, patch

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from matplotlib.colors import ListedColormap

# Make sure your dotenv file has the following defined:
load_dotenv()
GIT_FOLDER = os.environ['GIT_FOLDER']
TITILER_URL = os.environ['TITILER_URL']
TITILER_API_KEY = os.environ['TITILER_API_KEY']
API_KEY = os.environ['API_KEY']
MODEL_PATH_LOCAL = os.environ['MODEL_PATH_LOCAL_UNET']
# MODEL_PATH_LOCAL = os.environ['MODEL_PATH_LOCAL_MASK']

if not (GIT_FOLDER and TITILER_URL and TITILER_API_KEY and MODEL_PATH_LOCAL):
    print("ERRROR: Failed to find all the necessary environment variables!!!")
    # Note, you must restart the kernel if you want to load new environment variables

if GIT_FOLDER not in sys.path:
    sys.path.append(GIT_FOLDER)
print(sys.path)

In [None]:
from cerulean_cloud.models import get_model, memfile_gtiff
from cerulean_cloud.tiling import TMS, offset_bounds_from_base_tiles
from cerulean_cloud.titiler_client import TitilerClient
from cerulean_cloud.cloud_run_orchestrator.clients import img_array_to_b64_image
from cerulean_cloud.cloud_run_orchestrator.schema import OrchestratorInput
from cerulean_cloud.cloud_run_orchestrator.handler import (
    _orchestrate,
    get_tiler,
    get_titiler_client,
    get_roda_sentinelhub_client,
    get_database_engine,
)
from cerulean_cloud.cloud_run_offset_tiles.schema import InferenceInput, PredictPayload
from cerulean_cloud.cloud_run_offset_tiles.handler import predict

In [None]:
fastaiunet = dict(
    type="FASTAIUNET",
    file_path="",  # "experiments/2024_03_06_18_14_31_7cls_rn101_pr256_z9_fastai_baseline_noamb/tracing_cpu_model.pt",
    layers=["VV"],
    cls_map={
        0: "BACKGROUND",
        1: "INFRA",
        2: "NATURAL",
        3: "VESSEL",
    },  # inference_idx maps to class table
    name="local test 46",
    tile_width_m=40844,  # Used to calculate zoom
    tile_width_px=512,  # Used to calculate scale
    epochs=500,
    thresholds={
        "poly_nms_thresh": 0.2,  # Minimum IoU between instances that will keep the higher scoring multipolygon
        "pixel_nms_thresh": 0.0,
        "bbox_score_thresh": 0.0001,  # Smallest bridge value that will connect polygons into a multipolygon
        "poly_score_thresh": 0.5,  # Determines the size of the outline of any given polygon
        "pixel_score_thresh": 0.9,  # Minimum pixel score that will be required to keep a multipolygon
        "groundtruth_dice_thresh": 0.0,
    },
    backbone_size=34,
    # pixel_f1=0.0, # TODO CALCULATE
    # instance_f1=0.0, # TODO CALCULATE
)

maskrcnn = dict(
    type="MASKRCNN",
    file_path="",  # "experiments/2023_10_05_02_22_46_4cls_rnxt101_pr512_px1024_680min_maskrcnn_wd01/scripting_cpu_model.pt",
    layers=["VV", "ALL_255", "VESSEL"],
    cls_map={
        0: "BACKGROUND",
        1: "INFRA",
        2: "NATURAL",
        3: "VESSEL",
    },  # inference_idx maps to class table
    name="ResNext 101 hires56",
    tile_width_m=40844,
    tile_width_px=512,
    epochs=122,
    thresholds={
        "poly_nms_thresh": 0.2,
        "pixel_nms_thresh": 0.4,
        "bbox_score_thresh": 0.3,
        "poly_score_thresh": 0.1,
        "pixel_score_thresh": 0.5,
        "groundtruth_dice_thresh": 0.0,
    },
    backbone_size=101,
    pixel_f1=0.461,
    instance_f1=0.47,
)

model_dict_predefined = maskrcnn if "maskrcnn" in MODEL_PATH_LOCAL else fastaiunet

In [None]:
async def get_titiler_client_and_offset_tiles(sentinel_scene, offset=0.33):
    payload = OrchestratorInput(**sentinel_scene)
    TitilerClient_url = os.getenv("TITILER_URL")
    titiler_client = TitilerClient(url=TitilerClient_url)
    scene_bounds = await titiler_client.get_bounds(payload.sceneid)
    tiler = TMS
    base_tiles = list(tiler.tiles(*scene_bounds, [payload.zoom], truncate=False))
    offset_tile_bounds = offset_bounds_from_base_tiles(base_tiles, offset_amount=offset)
    return titiler_client, offset_tile_bounds

In [None]:
# EXPLORE TILES FROM A GIVEN SCENE_ID
offset = 0.33 * 2
titler_client, tile_bounds = await get_titiler_client_and_offset_tiles(
    test_scene, offset=offset
)
for i, tile in enumerate(tile_bounds):
    plt.imshow(
        (
            await titler_client.get_offset_tile(scene_id, *tile, height=512, width=512)
        ).transpose(2, 0, 1)[0],
        cmap="gray",
    )
    plt.title(str(i))
    plt.show()

In [None]:
# breaking?
scene_id = "S1A_IW_GRDH_1SDV_20240814T171401_20240814T171426_055212_06BAD3_5001"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[1]],
    0.33: [[1]],
    0.66: [[1]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20210523T005625_20210523T005651_038008_047C68_FE94"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[86, 92, 93]],
    0.33: [[79, 86, 93]],
    0.66: [[79, 86]],
}

In [None]:
# fragmentation
scene_id = "S1A_IW_GRDH_1SDV_20200802T174315_20200802T174340_033731_03E8D7_B2B2"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[86, 92, 93]],
    0.33: [[79, 86, 93]],
    0.66: [[79, 86]],
}

In [None]:
# fragmentation
scene_id = "S1B_IW_GRDH_1SDV_20200807T041503_20200807T041528_022812_02B4BF_B5B6"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0:[[29, 30]],
    0.33: [[30]],
    0.66: [[23, 30]],
}

In [None]:
# fragmentation
scene_id = "S1A_IW_GRDH_1SDV_20210610T105641_20210610T105706_038277_048454_EF6F"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[33, 41]],
    0.33: [[33, 41]],
    0.66: [[33, 34]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20230618T232014_20230618T232039_049047_05E5E0_718C"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[17, 25]],
    0.33: [[18, 26]],
    0.66: [[18, 19]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20230523T224049_20230523T224114_048667_05DA7A_91D1"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[44, 45, 51, 52, 53]],
    0.33: [[37, 38, 44, 45, 46, 52, 53, 54]],
    0.66: [[37, 38, 45, 46, 47]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20200729T095401_20200729T095430_033668_03E6EE_2611"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[18, 19, 25, 26, 32, 33]],
    0.33: [[18, 19, 25, 26, 33]],
    0.66: [[11, 12, 18, 19, 25, 26]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20230318T175405_20230318T175430_047702_05BAED_22A9"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[26, 28, 34, 35, 36, 43]],
    # .33: [[],]
    # .66: [[],]
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20201114T034910_20201114T034935_035239_041D79_AA81"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[34, 42, 51, 59, 67, 75]],
    0.33: [[34, 35, 43, 51, 59, 60, 67, 68]],
    0.66: [[35, 43, 51, 60, 68]],
}

In [None]:
scene_id = "S1B_IW_GRDH_1SDV_20201023T170409_20201023T170433_023943_02D81C_C8C1"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[18]],  # , 25, 26
    0.33: [[18, 19]],
    0.66: [[12, 19]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20240204T184243_20240204T184308_052413_0656A2_1B88"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[44, 52]],
    0.33: [[37, 45]],
    0.66: [[37, 45]],
}

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20201009T201613_20201009T201642_034724_040B9A_67CD"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[44, 45]],
    0.33: [[37, 45]],
    0.66: [[37, 45]],
}

In [None]:
# Edge Effect Scenes
# scene_id = "S1A_IW_GRDH_1SDV_20240429T215331_20240429T215359_053654_068417_58C8"
# scene_id = "S1A_IW_GRDH_1SDV_20240920T215330_20240920T215358_055754_06CF90_3DCE"
scene_id = "S1A_IW_GRDH_1SDV_20240710T215329_20240710T215357_054704_06A8FA_6F62"
test_scene = {"sceneid": scene_id, "zoom": 9, "scale": 2}

slick_tiles = {
    0: [[13, 21, 29]],
    0.33: [[13, 22, 30]],
    0.66: [[14, 22, 30]],
}

ensemble = [tiling1, tiling2, tiling3, ...]

tiling = [stack1, stack2, stack3, ...]

stack = [tile1, tile2, tile2, ...]

ensemble_bounds[tiling_idx][stack_idx][tile_idx] >>> 4 float bound

None of these are fixed length.

In [None]:
model = get_model(model_dict_predefined, model_path_local=MODEL_PATH_LOCAL)

ensemble_out = []
ensemble_bounds = []
ensemble = slick_tiles

for tiling_offset in ensemble.keys():
    titler_client, tile_bounds = await get_titiler_client_and_offset_tiles(
        test_scene, offset=tiling_offset
    )
    tiling_out = []
    tiling_bounds = []
    for s, stack in enumerate(ensemble[tiling_offset]):
        stack_in = []
        vv_stack = []
        stack_bounds = []
        for tile_idx in stack:
            vv = (
                await titler_client.get_offset_tile(
                    scene_id, *tile_bounds[tile_idx], height=512, width=512
                )
            ).transpose(2, 0, 1)[0]
            vv_stack.append(vv)
            input = InferenceInput(
                image=img_array_to_b64_image(
                    np.array([vv] * len(model_dict_predefined["layers"]))
                )
            )
            stack_in.append(input)
            stack_bounds.append(tile_bounds[tile_idx])
        stack_out = model.predict(stack_in)
        tiling_out.append(stack_out)
        tiling_bounds.append(stack_bounds)
    ensemble_bounds.append(tiling_bounds)
    ensemble_out.append(tiling_out)


# ensemble_bounds[tiling_idx][stack_idx][tile_idx] >>> 4 float bound
# ensemble_out[tiling_idx][stack_idx].stack[tile_idx].json_data >>> str

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20, 5))  # Create 1 row of 4 subplots
for i, cls in enumerate(
    model.deserialize(ensemble_out[0][0].stack[0].json_data).detach().numpy()
):
    axs[i].imshow(
        cls, cmap="jet", vmin=0, vmax=1
    )  # Use vmin and vmax to scale the images
    axs[i].axis("off")  # Turn off axis for cleaner look
plt.show()

In [None]:
# model_dict_predefined["thresholds"] = {
#     "poly_nms_thresh": 0.2,  # Minimum IoU between instances that will keep the higher scoring multipolygon
#     "bbox_score_thresh": 0.0001,  # Smallest bridge value that will connect polygons into a multipolygon
#     "poly_score_thresh": 0.5,  # Determines the size of the outline of any given polygon
#     "pixel_score_thresh": 0.9,  # Minimum pixel score that will be required to keep a multipolygon
# }

In [None]:
model = get_model(model_dict_predefined, model_path_local=MODEL_PATH_LOCAL)

fc_stack = []
for tiling_out, tiling_bounds in zip(ensemble_out, ensemble_bounds):
    fc = model.postprocess_tileset(tiling_out, tiling_bounds)
    fc_stack.append(fc)

# Ensemble the FCs together
fc_f = model.nms_feature_reduction(
    features=fc_stack, min_overlaps_to_keep=(1 if len(slick_tiles) > 1 else 0)
)

In [None]:
import geopandas as gpd

# PLOT EVERYTHING
gdf_f = gpd.GeoDataFrame.from_features(fc_f["features"])

fig, ax = plt.subplots(figsize=(20, 20))
colors = plt.cm.get_cmap("jet", len(gdf_f))
for tile_idx, g in gdf_f.iterrows():
    gdf_f.loc[[tile_idx]].plot(ax=ax, color=colors(tile_idx), alpha=0.75)
    ax.plot(
        [],
        [],
        color=colors(tile_idx),
        label=f"{tile_idx} {model_dict_predefined['cls_map'][g['inf_idx']]}: {round(g['machine_confidence'],2)}",
    )
ax.legend(title="Machine Confidence", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.show()

gdf_f

In [None]:
from rasterio.plot import show
from rasterio.merge import merge as rio_merge

ds_tiles = [
    memfile_gtiff(nparray=array, bounds=bounds).open()
    for array, bounds in zip(vv_stack, tiling_bounds[-1])
]

# Merge datasets
scene_array, transform = rio_merge(ds_tiles)

# Plotting the merged image
fig, ax = plt.subplots(figsize=(20, 20))  # You can adjust the size as needed
show(
    scene_array[0], transform=transform, ax=ax
)  # Assuming merged array is the first band
gdf_f.plot(
    ax=ax, alpha=0.3, edgecolor="red", facecolor="none"
)  # Adjust alpha for transparency
ax.set_title("Stitched Image")
plt.show()