In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from unittest.mock import AsyncMock, Mock, patch

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from matplotlib.colors import ListedColormap

# Make sure your dotenv file has the following defined:
load_dotenv()
GIT_FOLDER = os.environ['GIT_FOLDER']
TITILER_URL = os.environ['TITILER_URL']
TITILER_API_KEY = os.environ['TITILER_API_KEY']
API_KEY = os.environ['API_KEY']
MODEL_PATH_LOCAL = os.environ['MODEL_PATH_LOCAL']

if not (GIT_FOLDER and TITILER_URL and TITILER_API_KEY and MODEL_PATH_LOCAL):
    print("ERRROR: Failed to find all the necessary environment variables!!!")
    # Note, you must restart the kernel if you want to load new environment variables

if GIT_FOLDER not in sys.path:
    sys.path.append(GIT_FOLDER)
print(sys.path)

In [None]:
from cerulean_cloud.models import get_model
from cerulean_cloud.tiling import TMS, offset_bounds_from_base_tiles
from cerulean_cloud.titiler_client import TitilerClient
from cerulean_cloud.cloud_run_orchestrator.clients import img_array_to_b64_image
from cerulean_cloud.cloud_run_orchestrator.schema import OrchestratorInput
from cerulean_cloud.cloud_run_orchestrator.handler import _orchestrate, get_tiler, get_titiler_client, get_roda_sentinelhub_client, get_database_engine
from cerulean_cloud.cloud_run_offset_tiles.schema import InferenceInput, PredictPayload
from cerulean_cloud.cloud_run_offset_tiles.handler import predict

In [None]:
fastaiunet = {
    "type": "FASTAIUNET",
    "file_path": "",#"experiments/2024_03_06_18_14_31_7cls_rn101_pr256_z9_fastai_baseline_noamb/tracing_cpu_model.pt",
    "layers": ["VV"],
    "cls_map": {
        0: "BACKGROUND",
        1: "INFRA",
        2: "NATURAL",
        3: "COIN_VESSEL",
        4: "REC_VESSEL",
        5: "OLD_VESSEL",
        6: "BACKGROUND"  # HITL AMBIGUOUS, should never be output by inference_idx
    },  # inference_idx maps to class table
    "name": "ResNet101 Baseline Noamb",
    "tile_width_m": 40844,  # Used to calculate zoom
    "tile_width_px": 256,  # Used to calculate scale
    "epochs": 80,
    "thresholds": {
        "pixel_nms_thresh": 0.4,
        "bbox_score_thresh": 0.1,
        "poly_score_thresh": 0.01, # JONA Is this working correctly???
        "pixel_score_thresh": 0.35,
        "groundtruth_dice_thresh": 0.0
    },
    "backbone_size": 101,
    "zoom_level":9,
    "scale":2,
    # "pixel_f1": 0.0,  # TODO CALCULATE
    # "instance_f1": 0.0  # TODO CALCULATE
}

maskrcnn = {
    "type": "MASKRCNN",
    "file_path": "",#"experiments/2023_10_05_02_22_46_4cls_rnxt101_pr512_px1024_680min_maskrcnn_wd01/scripting_cpu_model.pt",
    "layers": ['VV', 'INFRA', 'VESSEL'],
    "cls_map": {'0': 'BACKGROUND', '1': 'INFRA', '2': 'NATURAL', '3': 'VESSEL'},  # inference_idx maps to class table
    "name": "ResNext 101 hires56",
    "tile_width_m": 40844,  # Used to calculate zoom
    "tile_width_px": 512,  # Used to calculate scale
    "epochs": 122,
    "thresholds": {
        'pixel_nms_thresh': 0.4, 
        'bbox_score_thresh': 0.2, 
        'poly_score_thresh': 0.2,
        'pixel_score_thresh': 0.2, 
        'groundtruth_dice_thresh': 0.0},
    "backbone_size": 101,
    "zoom_level":9,
    "scale":2,
    # "pixel_f1": 0.461,  # TODO CALCULATE
    # "instance_f1": 0.47  # TODO CALCULATE
}


model_dict_predefined=fastaiunet

In [None]:
# sceneid = "S1A_IW_GRDH_1SDV_20240204T184243_20240204T184308_052413_0656A2_1B88"
# payload = OrchestratorInput(sceneid=sceneid)

# async def mock_post(_, url: str, **kwargs) -> AsyncMock:
#     response = Mock()
#     response.status_code = 200
#     response.json = Mock(return_value=predict(request=None, payload=PredictPayload(**kwargs['json']))[0].dict())
#     return response

# async def mock_get_db_model(_, model_name: str):
#     class MockModel:
#         __table__ = type('MockTable', (), {'columns': [type('MockColumn', (), {'name': name}) for name in model_dict_predefined.keys()]})
#         def __init__(self, model_dict):
#             for key, value in model_dict.items():
#                 setattr(self, key, value)
#     return MockModel(model_dict_predefined)

# async def mock_get_trigger(*args, **kwargs):
#     return "mock_trigger"

# class MockLayer:
#     def __init__(self, short_name):
#         self.short_name = short_name

# async def mock_get_layer(self, layer, **kwargs):
#     return MockLayer(short_name=layer)

# async def mock_get_sentinel1_grd(*args, **kwargs):
#     return "mock_sentinel1_grd"

# async def mock_deactivate_stale_slicks_from_scene_id(*args, **kwargs):
#     return 0

# class MockOrchestratorRun:
#     def __init__(self):
#         self.success = True

# async def mock_add_orchestrator(*args, **kwargs):
#     return MockOrchestratorRun()

# with\
#     patch('cerulean_cloud.database_client.DatabaseClient.get_db_model', new=mock_get_db_model), \
#     patch('httpx.AsyncClient.post', new=mock_post), \
#     patch('cerulean_cloud.database_client.DatabaseClient.get_trigger', new=mock_get_trigger), \
#     patch('cerulean_cloud.database_client.DatabaseClient.get_layer', new=mock_get_layer), \
#     patch('cerulean_cloud.database_client.DatabaseClient.get_sentinel1_grd', new=mock_get_sentinel1_grd), \
#     patch('cerulean_cloud.database_client.DatabaseClient.deactivate_stale_slicks_from_scene_id', new=mock_deactivate_stale_slicks_from_scene_id), \
#     patch('cerulean_cloud.database_client.DatabaseClient.add_orchestrator', new=mock_add_orchestrator):

#     response = await _orchestrate(
#         payload, 
#         get_tiler(), 
#         get_titiler_client(), 
#         get_roda_sentinelhub_client(), 
#         get_database_engine()
#     )
# print(response)


In [None]:
async def get_titiler_client_and_offset_tiles(sentinel_scene, offset=.33):
    payload = OrchestratorInput(**sentinel_scene)
    TitilerClient_url = os.getenv('TITILER_URL')
    titiler_client = TitilerClient(url=TitilerClient_url)
    scene_bounds = await titiler_client.get_bounds(payload.sceneid)
    tiler = TMS
    base_tiles = list(tiler.tiles(*scene_bounds, [payload.zoom], truncate=False))
    offset_tile_bounds = offset_bounds_from_base_tiles(base_tiles, offset_amount=offset)
    return titiler_client, offset_tile_bounds

In [None]:
scene_id = "S1A_IW_GRDH_1SDV_20240204T184243_20240204T184308_052413_0656A2_1B88"
test_scene = {"sceneid": scene_id , "zoom":9, "scale":2}

titler_client , tile_bounds =  await get_titiler_client_and_offset_tiles(test_scene,offset=.66)

example_tile_37 = tile_bounds[37] # 37 and 45
example_tile_45 = tile_bounds[45] # 37 and 45

vv_37 = (await titler_client.get_offset_tile(scene_id, *example_tile_37,height=512,width=512)).transpose(2,0,1)[0]
vv_45 = (await titler_client.get_offset_tile(scene_id, *example_tile_45,height=512,width=512)).transpose(2,0,1)[0]

plt.imshow(vv_37, cmap="gray")

In [None]:
inf_stack = [
    InferenceInput(image=img_array_to_b64_image(np.array([vv_37]*len(model_dict_predefined["layers"]))), bounds=example_tile_37), 
    InferenceInput(image=img_array_to_b64_image(np.array([vv_45]*len(model_dict_predefined["layers"]))), bounds=example_tile_45),
    ]

model = get_model(model_dict_predefined,model_path_local=MODEL_PATH_LOCAL)

In [None]:
out = model.predict(inf_stack)

In [None]:
# if len(out)==2:
#     # i.e. If you have edited models.py to return inference_result_stack, raw_preds

#     inference_result_stack, raw_preds = out

#     # Take softmax along the channel dimension (dim=1)
#     softmax_tensor = F.softmax(raw_preds, dim=1)
#     # Take the argmax along the channel dimension to get the predicted classes
#     argmax_tensor = torch.argmax(softmax_tensor, dim=1)
#     # Convert tensor to numpy array for visualization
#     argmax_numpy = argmax_tensor.squeeze().numpy()

#     cls_map = model_dict_predefined["cls_map"]
#     # Generate colormap and norm for visualization
#     colors = plt.cm.tab20(np.linspace(0, 1, len(cls_map)))
#     cmap, norm = ListedColormap(colors), plt.Normalize(vmin=-0.5, vmax=len(cls_map)-0.5)

#     # Plot the argmax results as an image
#     plt.figure(figsize=(10, 10))
#     plt.imshow(argmax_numpy, cmap=cmap, norm=norm)

#     # Create and display custom legend
#     handles = [mpatches.Patch(color=colors[i], label=cls_map[i]) for i in cls_map]
#     plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize='large')
#     plt.title("Class Prediction", fontsize=16)
#     plt.show()


In [None]:
model = get_model(model_dict_predefined,model_path_local=MODEL_PATH_LOCAL)

In [None]:
import geopandas as gpd
gj = model.postprocess_tileset([out])
gdf = gpd.GeoDataFrame.from_features(gj["features"])
gdf.plot()

In [None]:
import matplotlib.pyplot as plt

# Assuming 'gdf' is your GeoDataFrame and 'vv_37' is your image array
plt.figure(figsize=(10, 10))  # Set the size of the figure (adjust as needed)

# Loop through each feature in the GeoDataFrame
for index, row in gdf.iterrows():
    plt.imshow(vv_37, cmap='gray')  # Display the image
    # Plot the feature on top of the image
    gdf.iloc[[index]].plot(ax=plt.gca(), alpha=0.5, edgecolor='red', facecolor='none')
    plt.title(f'Class: {model_dict_predefined["cls_map"][row["inf_idx"]]}')
    plt.show()

In [None]:
a, t = model.stitch(out)
argmax_indices = np.argmax(a, axis=0)

# Plot the argmax indices
plt.imshow(argmax_indices, cmap='viridis')