# Imports

### Standards

In [None]:
from pathlib import Path

### Externals

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader

### Internal modules

In [None]:
from asm_mapping.data.planetscope_dataset import PlanetScopeDataset
from asm_mapping.data.dataset_mode import DatasetMode

# Configs

In [None]:
# seeds
RANDOM = 79
seed_everything(RANDOM, workers=True)

In [None]:
# folders
PS_DATA = "/mnt/guanabana/raid/home/pasan001/asm-mapping/data/split_0/planet/training_data"

In [None]:
# datasets
PAD = False
TRANSFORMS = False
IS_FUSION = False
IS_INFERENCE = False
MODE = DatasetMode.STANDALONE

# Datasets

## PlanetScope

In [None]:
ps_dataset = PlanetScopeDataset(data_dir=PS_DATA,
                              mode=MODE,
                              pad=PAD,
                              transforms=TRANSFORMS)

In [None]:
def plot_ps_band_hist(image):
    bands = ['Blue', 'Green', 'Red', 'NIR']
    plt.figure(figsize=(10, 8))

    for i, band in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        plt.hist(image[i].ravel(), bins=256, color='k', alpha=0.5)
        plt.title(f'{band} band histogram')
        plt.xlim([0, 1])
        plt.ylim([0, 8000])
    plt.tight_layout()
    plt.show()

In [None]:
def plot_ps_examples(dataset, indices=None, num_examples=3):
    if indices is None:
        indices = torch.randint(len(dataset), size=(num_examples,)).tolist()
    else:
        num_examples = len(indices)

    subplot_cols = 3
    fig, axs = plt.subplots(num_examples, subplot_cols, figsize=(12, num_examples * 4))

    for i, idx in enumerate(indices):
        img_tensor, gt_tensor = dataset[idx]

        # extract file name and index
        img_file_name = dataset.dataset[idx][0]
        img_index = img_file_name.split('_')[-1].split('.')[0]

        # check and convert data type
        img = img_tensor.numpy()

        # reorder bands from BGR to RGB
        img_rgb = img[[2, 1, 0], :, :]

        # reorder dimensions to (height, width, channels) as expected from matplotlib
        img_rgb = np.transpose(img_rgb, (1, 2, 0))
        
        # adjust image brightness
        factor = 1.5
        img_rgb = np.clip(img_rgb * factor, 0, 1)

        # ensure ground truth is uint8
        gt = gt_tensor.numpy()

        # extract  NDVI from the dataset
        ndvi = img_tensor[4, :, :].numpy()

        # make plot
        axs[i, 0].imshow(img_rgb)
        axs[i, 0].set_title(f"Image {img_index} - RGB")
        axs[i, 0].axis('off')

        axs[i, subplot_cols-2].imshow(ndvi, cmap='RdYlGn')
        axs[i, subplot_cols-2].set_title(f"NDVI")
        axs[i, subplot_cols-2].axis('off')

        axs[i, subplot_cols-1].imshow(gt, cmap='gray')
        axs[i, subplot_cols-1].set_title(f"Ground Truth {img_index}")
        axs[i, subplot_cols-1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
index = 111
plot_ps_band_hist(ps_dataset[index][0].numpy())

In [None]:
plot_ps_examples(ps_dataset, indices=[12, 99, 111])