# 0) Import packages

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import numpy as np

from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader

# import custom libraries
from data.s1_dataset import Sentinel1Dataset
from data.planet_dataset import PlanetDataset
from data.s1_dataset_normalization import linear_norm_global_percentile as linear_norm_percentile_s1
from data.planet_dataset_normalization import linear_norm_global_percentile as linear_norm_percentile_planet
from data.fusion_dataset import FusionDataset

In [None]:
seed_everything(96, workers=True)

# 1) Initialize and inspect Sentinel-1 dataset

In [None]:
training_dir = '/mnt/guanabana/raid/home/pasan001/thesis/dataset/asm_dataset_split_0/s1/training_data'

training_dataset = Sentinel1Dataset(training_dir,
                                 pad=False,
                                 normalization=None)

In [None]:
# check if the padding worked out
training_dataset.__getitem__(10)

In [None]:
def plot_band_histograms(image):
    bands = ['VV', 'VH']
    plt.figure(figsize=(10, 8))

    for i, band in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        plt.hist(image[i].ravel(), bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([-30, 0])
        plt.ylim([0, 1000])
    plt.tight_layout()
    plt.show()

def plot_example(dataset, normalization, indices=None, num_examples=3):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)
    
    fig, axs = plt.subplots(num_examples, 3, figsize=(10, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # check and convert data type for visualisation
        img = img_tensor.numpy()

        if normalization is not None:
            img = normalization(img)

        gt = gt_tensor.numpy()

        # make plots
        axs[i, 0].imshow(img[0],  cmap='gray')
        axs[i, 0].set_title(f"Image {img_index} - VV")
        axs[i, 0].axis('off')

        axs[i, 1].imshow(img[1],  cmap='gray')
        axs[i, 1].set_title(f"Image {img_index} - VH")
        axs[i, 1].axis('off')

        axs[i, 2].imshow(gt, cmap='gray')
        axs[i, 2].set_title(f"Ground Truth {img_index}")
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
plot_band_histograms(training_dataset[0][0])

In [None]:
plot_example(training_dataset, indices=[10, 765, 458], normalization=linear_norm_percentile_s1)

# 2) Initialize and inspect Planet-NICFI dataset

In [None]:
training_dir = '/mnt/guanabana/raid/home/pasan001/thesis/dataset/asm_dataset_split_0/planet/binary/training_data'

training_dataset = PlanetDataset(training_dir,
                                    pad=False,
                                    normalization=None)

In [None]:
def plot_band_histograms(image):
    bands = ['Blue', 'Green', 'Red', 'NIR']
    plt.figure(figsize=(10, 8))

    for i, band in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        plt.hist(image[i].ravel(), bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([0, 4500])
        plt.ylim([0, 8000])
    plt.tight_layout()
    plt.show()

In [None]:
def plot_example(dataset, normalization, indices=None, num_examples=3, vi='ndvi'):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)

    fig, axs = plt.subplots(num_examples, 3, figsize=(12, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]
        # extract file name and index
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # check and convert data type
        img = img_tensor.numpy()

        # reorder bands from BGR to RGB
        img = img[[2, 1, 0], :, :]

        if normalization is not None:
            img = normalization(img)
        else:
            # scale for  visualisation
            img = (img / 10000.0) * 255
            # img = np.clip(img, 0, 255)
            img = img.astype(np.uint8)

        # reorder dimensions to (height, width, channels) as expected from matplotlib
        img = np.transpose(img, (1, 2, 0))

        # ensure ground truth is uint8
        # gt = gt_tensor.numpy().astype(np.uint8)
        gt = gt_tensor.numpy()

        # extract  NDVI from the dataset
        ndvi = img_tensor[4, :, :].numpy()
        savi = img_tensor[5, :, :].numpy()
        ndwi = img_tensor[6, :, :].numpy()

        # make plot
        axs[i, 0].imshow(img)
        axs[i, 0].set_title(f"Image {img_index} - RGB")
        axs[i, 0].axis('off')

        if vi == 'ndvi':
            axs[i, 1].imshow(ndvi, cmap='RdYlGn')
            axs[i, 1].set_title(f"NDVI")
            axs[i, 1].axis('off')
        elif vi == 'savi':
            axs[i, 1].imshow(savi, cmap='RdYlGn')
            axs[i, 1].set_title(f"SAVI")
            axs[i, 1].axis('off')
        elif vi == 'ndwi':
            axs[i, 1].imshow(ndwi, cmap='RdYlBu')
            axs[i, 1].set_title(f"NDWI")
            axs[i, 1].axis('off')

        axs[i, 2].imshow(gt, cmap='gray')
        axs[i, 2].set_title(f"Ground Truth {img_index}")
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# call the function using a non-normalized image tensor
plot_band_histograms(training_dataset[13][0].numpy())

In [None]:
# plot some examples choosing a normalization method
# wdvi is blank due to the range of this index [-1, 1]
plot_example(training_dataset, indices=[11, 13, 1],
             normalization=linear_norm_percentile_planet,
             vi='ndvi')

# 3) Initialize and inspect fusion dataset

In [None]:
data_dir = '/mnt/guanabana/raid/home/pasan001/thesis/dataset/asm_dataset_split_0/fusion'

In [None]:
late_fusion_dataset = FusionDataset(root_dir=data_dir, train=True)

data_loader = DataLoader(late_fusion_dataset, batch_size=2, shuffle=True)

# Fetch a batch of data
for planet_data, s1_data, gt in data_loader:
    print(s1_data.shape, planet_data.shape)
    break


In [None]:
def plot_fusion_examples(dataset, normalization, indices=None, num_examples=3):
    if indices is None:
        indices = np.random.randint(0, len(dataset), size=(num_examples,))
    
    fig, axs = plt.subplots(num_examples, 3, figsize=(15, num_examples * 5))

    for i, idx in enumerate(indices):
        planet_data, s1_data, gt = dataset[idx]

        # convert tensors to numpy for plotting
        planet_img = planet_data.numpy()
        planet_img = planet_img[[2, 1, 0], :, :]

        # apply normalization only on Planet images to improve visualization
        if normalization is not None:
            planet_img = normalization(planet_img)
        else:
            planet_img = (planet_img / 10000.0) * 255
            planet_img = planet_img.astype(np.uint8)
        
        # reorder dimensions to (height, width, channels) as expected from matplotlib
        planet_img = np.transpose(planet_img, (1, 2, 0))

        s1_img = s1_data.numpy()
        gt_img = gt.numpy()

        # plot Planet image
        axs[i, 0].imshow(planet_img)
        axs[i, 0].set_title(f'Planet Image {idx}')
        axs[i, 0].axis('off')

        # plot Sentinel-1 image - VV band
        axs[i, 1].imshow(s1_img[0], cmap='gray')
        axs[i, 1].set_title(f'Sentinel-1 Image {idx} - VV')
        axs[i, 1].axis('off')

        # plot Ground Truth
        axs[i, 2].imshow(gt_img, cmap='gray')
        axs[i, 2].set_title(f'Ground Truth {idx}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
plot_fusion_examples(late_fusion_dataset, normalization=linear_norm_percentile_planet, 
                        indices=[10, 144, 324], num_examples=3)