# Imports

### Standards

In [None]:
from pathlib import Path

### Externals

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

from pytorch_lightning import seed_everything

### Internal modules

In [None]:
from asm_mapping.data.planetscope_dataset import PlanetScopeDataset
from asm_mapping.data.sentinel1_dataset import Sentinel1Dataset
from asm_mapping.data.fusion_dataset import FusionDataset
from asm_mapping.data.dataset_mode import DatasetMode
from asm_mapping.data.fusion_dataset import ResampleStrategy

# Configs

In [None]:
# seeds
RANDOM = 79
seed_everything(RANDOM, workers=True)

In [None]:
# folders
PS_DATA = "/mnt/guanabana/raid/home/pasan001/asm-mapping/data/ps_split/split_0/training_set"
S1_DATA = "/mnt/guanabana/raid/home/pasan001/asm-mapping/data/s1_split/split_0/training_set"
FUSION_DATA = "/mnt/guanabana/raid/home/pasan001/asm-mapping/data/"

In [None]:
# datasets
PAD = False
TRANSFORMS = None
STANDALONE_MODE = DatasetMode.STANDALONE
FUSION_MODE = DatasetMode.FUSION

# Datasets

## PlanetScope

In [None]:
ps_dataset = PlanetScopeDataset(data_dir=PS_DATA,
                              mode=STANDALONE_MODE,
                              pad=PAD,
                              transforms=TRANSFORMS)

In [None]:
def plot_ps_band_hist(image):
    bands = ['Blue', 'Green', 'Red', 'NIR']
    plt.figure(figsize=(10, 8))

    for i, band in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        plt.hist(image[i].ravel(), bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([0, 1])
        plt.ylim([0, 8000])
    plt.tight_layout()
    plt.show()

In [None]:
def plot_ps_examples(dataset, indices=None, num_examples=3):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)

    subplot_cols = 3
    fig, axs = plt.subplots(num_examples, subplot_cols, figsize=(12, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]

        # extract file name and index
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # check and convert data type
        img = img_tensor.numpy()

        # reorder bands from BGR to RGB
        img_rgb = img[[2, 1, 0], :, :]

        # reorder dimensions to (height, width, channels) as expected from matplotlib
        img_rgb = np.transpose(img_rgb, (1, 2, 0))
        
        # adjust image brightness
        factor = 1.5
        img_rgb = np.clip(img_rgb * factor, 0, 1)

        # ensure ground truth is uint8
        gt = gt_tensor.numpy()

        # extract  NDVI from the dataset
        ndvi = img_tensor[4, :, :].numpy()

        # make plot
        axs[i, 0].imshow(img_rgb)
        axs[i, 0].set_title(f"Image {img_index} - RGB")
        axs[i, 0].axis('off')

        axs[i, subplot_cols-2].imshow(ndvi, cmap='RdYlGn')
        axs[i, subplot_cols-2].set_title(f"NDVI")
        axs[i, subplot_cols-2].axis('off')

        axs[i, subplot_cols-1].imshow(gt, cmap='gray')
        axs[i, subplot_cols-1].set_title(f"Ground Truth {img_index}")
        axs[i, subplot_cols-1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
index = 111
plot_ps_band_hist(ps_dataset[index][0].numpy())

In [None]:
plot_ps_examples(ps_dataset, indices=[87, 1, 111])

## Sentinel-1

In [None]:
s1_dataset = Sentinel1Dataset(data_dir=S1_DATA,
                              mode=STANDALONE_MODE,
                              pad=PAD,
                              transforms=TRANSFORMS,
                              split="split_0")

In [None]:
len(s1_dataset)

In [None]:
def plot_s1_band_hist(image):
    bands = ['VV', 'VH', 'VV/VH']
    plt.figure(figsize=(15, 5))

    for i, band in enumerate(bands):
        plt.subplot(1, 3, i + 1)
        data = image[i].ravel()
        print(f"{band} stats: min={data.min():.4f}, max={data.max():.4f}, mean={data.mean():.4f}, std={data.std():.4f}")
        plt.hist(data, bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([0, 1])
        plt.ylim([0, 1000])
    plt.tight_layout()
    plt.show()

In [None]:
def plot_s1_examples(dataset, indices=None, num_examples=3):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)

    subplot_cols = 4
    fig, axs = plt.subplots(num_examples, subplot_cols, figsize=(16, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]

        # extract file name and index
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # convert to numpy
        img = img_tensor.numpy()

        vv = img[0, :, :]
        vh = img[1, :, :]
        ratio = img[2, :, :]
            
        # make RGB stack
        img_rgb = np.stack([vv, vh, ratio], axis=-1)

        # ensure ground truth is uint8
        gt = gt_tensor.numpy()

        # make plot
        axs[i, 0].imshow(vv, cmap='gray')
        axs[i, 0].set_title(f"Image {img_index} - VV")
        axs[i, 0].axis('off')

        axs[i, 1].imshow(vh, cmap='gray')
        axs[i, 1].set_title(f"Image {img_index} - VH")
        axs[i, 1].axis('off')

        axs[i, 2].imshow(img_rgb, cmap='plasma')
        axs[i, 2].set_title(f"Image {img_index} - RGB (VV, VH, VV/VH)")
        axs[i, 2].axis('off')

        axs[i, 3].imshow(gt, cmap='gray')
        axs[i, 3].set_title(f"Ground Truth {img_index}")
        axs[i, 3].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
index = 111
plot_s1_band_hist(s1_dataset[index][0].numpy())

In [None]:
plot_s1_examples(s1_dataset, indices=[87, 1, 111])

## Fusion

In [None]:
fusion_dataset = FusionDataset(
    data_dir=FUSION_DATA,
    split=0,
    transforms=False,
    pad=True,
    resample_strategy=ResampleStrategy("upsample_s1")
)

In [None]:
def plot_fusion_examples(dataset, indices=None, num_examples=3):
      if indices is None:
            indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
      else:
            num_examples = len(indices)

      subplot_cols = 3
      fig, axs = plt.subplots(num_examples, subplot_cols, figsize=(15, num_examples * 4))

      for i, idx in enumerate(indices):
            planet_tensor, s1_tensor, gt_tensor = dataset[idx]

            # normalize Planet RGB
            planet_img = planet_tensor.numpy()
            planet_rgb = planet_img[[2, 1, 0], :, :]  # Reorder to RGB
            planet_rgb = np.transpose(planet_rgb, (1, 2, 0))
            planet_rgb = np.clip((planet_rgb - planet_rgb.min()) / (planet_rgb.max() - planet_rgb.min()), 0, 1)

            # normalize S1 RGB composite with VV, VH, and VV/VH ratio
            s1_img = s1_tensor.numpy()
            s1_composite = np.zeros((s1_img.shape[1], s1_img.shape[2], 3))
            for j, band_idx in enumerate([0, 1, 2]):  # VV, VH, ratio
                  band = s1_img[band_idx]
                  s1_composite[:,:,j] = np.clip((band - band.min()) / (band.max() - band.min()), 0, 1)

            # Make plots
            axs[i, 0].imshow(planet_rgb)
            axs[i, 0].set_title(f"Planet RGB {idx}")
            axs[i, 0].axis('off')

            axs[i, 1].imshow(s1_composite)
            axs[i, 1].set_title(f"S1 RGB composite {idx}")
            axs[i, 1].axis('off')

            axs[i, 2].imshow(gt_tensor.numpy(), cmap='gray')
            axs[i, 2].set_title(f"Ground Truth {idx}")
            axs[i, 2].axis('off')

      plt.tight_layout()
      plt.show()

In [None]:
plot_fusion_examples(fusion_dataset, indices=[87, 1, 111])