# Imports

### Externals

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
import os
import rasterio

from pathlib import Path
from pytorch_lightning import seed_everything

### Internal modules

In [None]:
from asm_mapping.data.planetscope_dataset import PlanetScopeDataset
from asm_mapping.data.sentinel1_dataset import Sentinel1Dataset
from asm_mapping.data.fusion_dataset import FusionDataset
from asm_mapping.data.dataset_mode import DatasetMode
from asm_mapping.models.lit_model_standalone import LitModelStandalone
from asm_mapping.models.lit_model_lf import LitModelLateFusion

## Configs

In [None]:
# seeds
RANDOM = 79
seed_everything(RANDOM, workers=True)

In [None]:
GPU_ID = 0

# Set device
if torch.cuda.is_available():
    device = torch.device(f"cuda:{GPU_ID}")
    print(f"Using GPU #{GPU_ID}: {torch.cuda.get_device_name(GPU_ID)}")
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU")
    
# comment out line below if you have GPU available but it's full and prefer
# to get the prediction examples using CPU instead 
# device = torch.device("cpu")

In [None]:
BASE_DIR = "/mnt/guanabana/raid/home/pasan001/asm-mapping"

In [None]:
PS_MODEL_PATH = f"{BASE_DIR}/old_checkpoints/split_0/ps_standalone_split_0_epoch=31_val_f1_score=0.000.ckpt"
S1_MODEL_PATH = f"{BASE_DIR}/checkpoints/split_2/s1_standalone_split_2_epoch=22_val_f1_score=0.000.ckpt"
# LF_MODEL_PATH = f"{BASE_DIR}/checkpoints/split_4/lf_conc_up_split_4_epoch=63_val_f1_score=0.000.ckpt"
LF_MODEL_PATH = f"{BASE_DIR}/old_checkpoints/split_4/lf_sum_down_split_4_epoch=76_val_f1_score=0.000.ckpt"



In [None]:
PS_DATA_PATH = f"{BASE_DIR}/data/ps_split/split_0/testing_set"
S1_DATA_PATH = f"{BASE_DIR}/data/s1_split/split_0/testing_set"
FUSION_DATA_PATH = f"{BASE_DIR}/data"
# fixed_indices = [87, 1, 111] # indices for 3 examples to generate
fixed_indices = [24, 76, 200]

In [None]:
# dataset parameters
PAD = False
TRANSFORMS = None
STANDALONE_MODE = DatasetMode.STANDALONE
FUSION_MODE = DatasetMode.FUSION

# Utils functions

In [None]:
def load_standalone_model(checkpoint_path, in_channels=6):
      torch.cuda.set_device(GPU_ID)
      checkpoint = torch.load(checkpoint_path, map_location=device)
      model = LitModelStandalone.load_from_checkpoint(checkpoint_path, 
                                                      in_channels=in_channels,
                                                      map_location=device)
      model.eval()
      return model

def load_fusion_model(checkpoint_path):
      torch.cuda.set_device(GPU_ID)
      checkpoint = torch.load(checkpoint_path, map_location=device)
      model = LitModelLateFusion.load_from_checkpoint(checkpoint_path,
                                                      map_location=device)
      model.eval()
      return model

def predict_standalone(model, img_tensor):
      with torch.no_grad():
            img_batch = img_tensor.unsqueeze(0)
            # ensure tensor is on the same device as model
            img_batch = img_batch.to(device)
            logits = model(img_batch)
            probs = torch.sigmoid(logits)
            pred = (probs > model.threshold).float().squeeze().cpu().numpy()
      return pred

def predict_fusion(model, planet_tensor, s1_tensor):
      with torch.no_grad():
            planet_batch = planet_tensor.unsqueeze(0)
            s1_batch = s1_tensor.unsqueeze(0)
            # ensure tensors are on the same device as model
            planet_batch = planet_batch.to(device)
            s1_batch = s1_batch.to(device)
            logits = model(planet_batch, s1_batch)
            probs = torch.sigmoid(logits)
            pred = (probs > model.threshold).float().squeeze().cpu().numpy()
      return pred

# PlanetScope standalone model predictions

In [None]:
def get_raw_rgb(dataset, idx):
      img_path = dataset.dataset[idx][0]
      
      with rasterio.open(img_path, 'r') as src:
            img = src.read().astype(np.float32)
      
      rgb = np.zeros((img.shape[1], img.shape[2], 3))
      for i, band_idx in enumerate([2, 1, 0]):
            band = img[band_idx]
            band_min, band_max = band.min(), band.max()
            if band_max > band_min: 
                  rgb[:,:,i] = np.clip((band - band_min) / (band_max - band_min), 0, 1)
            else:
                  rgb[:,:,i] = 0
      
      return rgb

def plot_ps_predictions(model, dataset, indices=None, num_examples=3):
      if indices is None:
            indices = random.sample(range(len(dataset)), num_examples)
      
      fig, axs = plt.subplots(num_examples, 3, figsize=(15, num_examples * 4))
      
      for i, idx in enumerate(indices):
            img_tensor, gt_tensor = dataset[idx]
            pred = predict_standalone(model, img_tensor)
            
            # get RGB for visualization
            rgb = get_raw_rgb(dataset, idx)
            
            # make plots
            axs[i, 0].imshow(rgb)
            axs[i, 0].set_title(f"PlanetScope RGB - Example {i+1}")
            axs[i, 0].axis('off')
            
            axs[i, 1].imshow(pred, cmap='gray')
            axs[i, 1].set_title(f"Model Prediction")
            axs[i, 1].axis('off')
            
            axs[i, 2].imshow(gt_tensor.numpy(), cmap='gray')
            axs[i, 2].set_title(f"Ground Truth")
            axs[i, 2].axis('off')
      
      plt.tight_layout()
      plt.show()

In [None]:
ps_dataset = PlanetScopeDataset(
    data_dir=PS_DATA_PATH,
    mode=STANDALONE_MODE,
    pad=PAD,
    transforms=TRANSFORMS,
    split="split_0" 
)

try:
    ps_model = load_standalone_model(PS_MODEL_PATH, in_channels=6)
    print("PlanetScope model loaded successfully")
except Exception as e:
    print(f"Error loading PlanetScope model: {e}")
    ps_model = None

In [None]:
if ps_model is not None:
    print("\n## PlanetScope Standalone model predictions")
    plot_ps_predictions(ps_model, ps_dataset, indices=fixed_indices)

# Sentinel-1 standalone model prediction examples

In [None]:
def get_s1_rgb(dataset, idx):
      img_path = dataset.dataset[idx][0]
      
      with rasterio.open(img_path, 'r') as src:
            img = src.read().astype(np.float32)
      
      # Create a false color composite using VV, VH, and ratio
      rgb = np.zeros((img.shape[1], img.shape[2], 3))
      vv = img[0]
      vh = img[1]
      ratio = vv - vh  # VV/VH ratio in dB scale
      
      # Normalize each band for visualization
      for i, band in enumerate([vv, vh, ratio]):
            band_min, band_max = band.min(), band.max()
            if band_max > band_min:
                  rgb[:,:,i] = np.clip((band - band_min) / (band_max - band_min), 0, 1)
            else:
                  rgb[:,:,i] = 0
      
      return rgb

def plot_s1_predictions(model, dataset, indices=None, num_examples=3):
      if indices is None:
            indices = random.sample(range(len(dataset)), num_examples)
      
      fig, axs = plt.subplots(num_examples, 3, figsize=(15, num_examples * 4))
      
      for i, idx in enumerate(indices):
            img_tensor, gt_tensor = dataset[idx]
            pred = predict_standalone(model, img_tensor)
            
            # get RGB for visualization
            rgb = get_s1_rgb(dataset, idx)
            
            # make plots
            axs[i, 0].imshow(rgb)
            axs[i, 0].set_title(f"Sentinel-1 RGB - Example {i+1}")
            axs[i, 0].axis('off')
            
            axs[i, 1].imshow(pred, cmap='gray')
            axs[i, 1].set_title(f"Model prediction")
            axs[i, 1].axis('off')
            
            axs[i, 2].imshow(gt_tensor.numpy(), cmap='gray')
            axs[i, 2].set_title(f"Ground truth")
            axs[i, 2].axis('off')
      
      plt.tight_layout()
      plt.show()

In [None]:
s1_dataset = Sentinel1Dataset(
      data_dir=S1_DATA_PATH,
      mode=STANDALONE_MODE,
      pad=PAD,
      transforms=TRANSFORMS,
      split="split_0"
)


try:
      s1_model = load_standalone_model(S1_MODEL_PATH, in_channels=3)
      print("Sentinel-1 model loaded successfully")
except Exception as e:
      print(f"Error loading Sentinel-1 model: {e}")
      s1_model = None

In [None]:
if s1_model is not None:
    print("\n## Sentinel-1 Standalone model predictions")
    plot_s1_predictions(s1_model, s1_dataset, indices=fixed_indices)

# Late Fusion model prediction examples

In [None]:
def plot_fusion_predictions(model, dataset, indices=None, num_examples=3):
    if indices is None:
        indices = random.sample(range(len(dataset)), num_examples)
    
    fig, axs = plt.subplots(num_examples, 4, figsize=(20, num_examples * 4))
    
    for i, idx in enumerate(indices):
        planet_tensor, s1_tensor, gt_tensor = dataset[idx]
        pred = predict_fusion(model, planet_tensor, s1_tensor)
        
        # Get Planet RGB
        planet_img = planet_tensor.numpy()
        planet_rgb = np.zeros((planet_img.shape[1], planet_img.shape[2], 3))
        for j, band_idx in enumerate([2, 1, 0]):  # RGB bands
            band = planet_img[band_idx]
            planet_rgb[:,:,j] = np.clip((band - band.min()) / (band.max() - band.min()), 0, 1)
        
        # Get S1 composite
        s1_img = s1_tensor.numpy()
        s1_composite = np.zeros((s1_img.shape[1], s1_img.shape[2], 3))
        for j, band_idx in enumerate([0, 1, 2]):  # VV, VH, ratio
            band = s1_img[band_idx]
            s1_composite[:,:,j] = np.clip((band - band.min()) / (band.max() - band.min()), 0, 1)
        
        # Make plots
        axs[i, 0].imshow(planet_rgb)
        axs[i, 0].set_title(f"PlanetScope RGB - Example {i+1}")
        axs[i, 0].axis('off')
        
        axs[i, 1].imshow(s1_composite)
        axs[i, 1].set_title(f"Sentinel-1 RGB")
        axs[i, 1].axis('off')
        
        axs[i, 2].imshow(pred, cmap='gray')
        axs[i, 2].set_title(f"Model prediction")
        axs[i, 2].axis('off')
        
        axs[i, 3].imshow(gt_tensor.numpy(), cmap='gray')
        axs[i, 3].set_title(f"Ground truth")
        axs[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
fusion_dataset = FusionDataset(
    data_dir=FUSION_DATA_PATH,
    split=0,
    transforms=TRANSFORMS,
    pad=PAD,
    is_test=True
)

try:
    lf_model = load_fusion_model(LF_MODEL_PATH)
    print("Late Fusion model loaded successfully")
except Exception as e:
    print(f"Error loading Late Fusion model: {e}")
    lf_model = None


In [None]:
if lf_model is not None:
    print("\n## Late Fusion model predictions")
    plot_fusion_predictions(lf_model, fusion_dataset, indices=fixed_indices)