# 0) Import packages

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import numpy as np

from pytorch_lightning import seed_everything

# import custom libraries
from data.s1_dataset import Sentinel1Dataset
from data.s1_dataset_normalization import linear_norm_global_minmax, linear_norm_global_percentile, global_standardization

In [None]:
seed_everything(96, workers=True)

# 1) Initialize and inspect dataset

In [None]:
training_dir = '/mnt/guanabana/raid/home/pasan001/thesis/dataset/asm_dataset_split_0/s1/training_data'

training_dataset = Sentinel1Dataset(training_dir,
                                 pad=False,
                                 normalization=linear_norm_global_minmax)

In [None]:
training_dataset[0][0]

In [None]:
def plot_band_histograms(image):
    bands = ['VV', 'VH']
    plt.figure(figsize=(10, 8))

    for i, band in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        plt.hist(image[i].ravel(), bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([-30, 0])
        plt.ylim([0, 1000])
    plt.tight_layout()
    plt.show()

def plot_example(dataset, normalization, indices=None, num_examples=3):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)
    
    fig, axs = plt.subplots(num_examples, 3, figsize=(10, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # check and convert data type for visualisation
        img = img_tensor.numpy()

        if normalization is not None:
            img = normalization(img)

        gt = gt_tensor.numpy()

        # make plots
        axs[i, 0].imshow(img[0],  cmap='gray')
        axs[i, 0].set_title(f"Image {img_index} - VV")
        axs[i, 0].axis('off')

        axs[i, 1].imshow(img[1],  cmap='gray')
        axs[i, 1].set_title(f"Image {img_index} - VH")
        axs[i, 1].axis('off')

        axs[i, 2].imshow(gt, cmap='gray')
        axs[i, 2].set_title(f"Ground Truth {img_index}")
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
plot_band_histograms(training_dataset[0][0])

In [None]:
plot_example(training_dataset, indices=[10, 765, 458], normalization=linear_norm_global_percentile)