# Imports

### Externals

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
import os
import rasterio

from pathlib import Path
from pytorch_lightning import seed_everything

### Internal modules

In [None]:
from asm_mapping.data.planetscope_dataset import PlanetScopeDataset
from asm_mapping.data.sentinel1_dataset import Sentinel1Dataset
from asm_mapping.data.fusion_dataset import FusionDataset
from asm_mapping.data.dataset_mode import DatasetMode
from asm_mapping.models.lit_model_standalone import LitModelStandalone
from asm_mapping.models.lit_model_lf import LitModelLateFusion

## Configs

In [None]:
# seeds
RANDOM = 79
seed_everything(RANDOM, workers=True)

In [None]:
GPU_ID = 3

# Set device
if torch.cuda.is_available():
    device = torch.device(f"cuda:{GPU_ID}")
    print(f"Using GPU #{GPU_ID}: {torch.cuda.get_device_name(GPU_ID)}")
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU")

In [None]:
BASE_DIR = "/mnt/guanabana/raid/home/pasan001/asm-mapping"

In [None]:
PS_MODEL_PATH = f"{BASE_DIR}/checkpoints/split_0/ps_standalone_split_0_epoch=31_val_f1_score=0.000.ckpt"

In [None]:
PS_DATA_PATH = f"{BASE_DIR}/data/ps_split/split_0/testing_set" 

In [None]:
# dataset parameters
PAD = False
TRANSFORMS = None
STANDALONE_MODE = DatasetMode.STANDALONE
FUSION_MODE = DatasetMode.FUSION

In [None]:
def load_standalone_model(checkpoint_path, in_channels=6):
      torch.cuda.set_device(GPU_ID)
      checkpoint = torch.load(checkpoint_path, map_location=device)
      model = LitModelStandalone.load_from_checkpoint(checkpoint_path, 
                                                      in_channels=in_channels,
                                                      map_location=device)
      model.eval()
      return model

def load_fusion_model(checkpoint_path):
      torch.cuda.set_device(GPU_ID)
      checkpoint = torch.load(checkpoint_path, map_location=device)
      model = LitModelLateFusion.load_from_checkpoint(checkpoint_path,
                                                      map_location=device)
      model.eval()
      return model

def predict_standalone(model, img_tensor):
      with torch.no_grad():
            img_batch = img_tensor.unsqueeze(0)
            # ensure tensor is on the same device as model
            img_batch = img_batch.to(device)
            logits = model(img_batch)
            probs = torch.sigmoid(logits)
            pred = (probs > model.threshold).float().squeeze().cpu().numpy()
      return pred

def predict_fusion(model, planet_tensor, s1_tensor):
      with torch.no_grad():
            planet_batch = planet_tensor.unsqueeze(0)
            s1_batch = s1_tensor.unsqueeze(0)
            # ensure tensors are on the same device as model
            planet_batch = planet_batch.to(device)
            s1_batch = s1_batch.to(device)
            logits = model(planet_batch, s1_batch)
            probs = torch.sigmoid(logits)
            pred = (probs > model.threshold).float().squeeze().cpu().numpy()
      return pred

In [None]:
def get_raw_rgb(dataset, idx):
      img_path = dataset.dataset[idx][0]
      
      with rasterio.open(img_path, 'r') as src:
            img = src.read().astype(np.float32)
      
      rgb = np.zeros((img.shape[1], img.shape[2], 3))
      for i, band_idx in enumerate([2, 1, 0]):
            band = img[band_idx]
            band_min, band_max = band.min(), band.max()
            if band_max > band_min: 
                  rgb[:,:,i] = np.clip((band - band_min) / (band_max - band_min), 0, 1)
            else:
                  rgb[:,:,i] = 0
      
      return rgb

def plot_ps_predictions(model, dataset, indices=None, num_examples=3):
      if indices is None:
            indices = random.sample(range(len(dataset)), num_examples)
      
      fig, axs = plt.subplots(num_examples, 3, figsize=(15, num_examples * 4))
      
      for i, idx in enumerate(indices):
            img_tensor, gt_tensor = dataset[idx]
            pred = predict_standalone(model, img_tensor)
            
            # get RGB for visualization
            rgb = get_raw_rgb(dataset, idx)
            
            # make plots
            axs[i, 0].imshow(rgb)
            axs[i, 0].set_title(f"PlanetScope RGB - Example {i+1}")
            axs[i, 0].axis('off')
            
            axs[i, 1].imshow(pred, cmap='gray')
            axs[i, 1].set_title(f"Model Prediction")
            axs[i, 1].axis('off')
            
            axs[i, 2].imshow(gt_tensor.numpy(), cmap='gray')
            axs[i, 2].set_title(f"Ground Truth")
            axs[i, 2].axis('off')
      
      plt.tight_layout()
      plt.show()

In [None]:
ps_dataset = PlanetScopeDataset(
    data_dir=PS_DATA_PATH,
    mode=STANDALONE_MODE,
    pad=PAD,
    transforms=TRANSFORMS,
    split="split_0" 
)

try:
    ps_model = load_standalone_model(PS_MODEL_PATH, in_channels=6)
    print("PlanetScope model loaded successfully")
except Exception as e:
    print(f"Error loading PlanetScope model: {e}")
    ps_model = None

In [None]:
if ps_model is not None:
    print("\n## PlanetScope Standalone model predictions")
    fixed_indices = [87, 1, 111]
    plot_ps_predictions(ps_model, ps_dataset, indices=fixed_indices)