In [1]:
import os

import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T

from ipywidgets import interact_manual, Layout
import ipywidgets as widgets
from glob import glob
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from tqdm import tqdm
from src.datasets import load_dataset_JIF, SatelliteDataset, DictDataset, make_transforms_JIF
from src.datasources import S2_ALL_12BANDS
from src.plot import showtensor
from multiprocessing import Manager
from src.datasources import *
transforms = make_transforms_JIF(lr_bands_to_use='true_color', radiometry_depth=12)
multiprocessing_manager = Manager()

dataset_root = 'dataset/'
hr_dataset_folder = 'hr/'
lr_dataset_folder = 'lr/'

In [2]:
def ensure_tensor_has_four_dimensions(tensor):
    """ Ensure that the tensor has four dimensions.
    If it doesn't pad the missing dimensions with empty dimensions.

    Parameters
    ----------
    tensor : Tensor
        The tensor to pad.

    Returns
    -------
    Tensor
        A tensor with four dimensions.
    """
    if tensor.ndim == 4:
        pass
    elif tensor.ndim == 3:
        tensor = tensor[None]
    elif tensor.ndim == 2:
        tensor = tensor[None, None]
    return tensor


In [None]:
for folder in tqdm(os.listdir('dataset/lr/')):
    lr = SatelliteDataset(
    root=os.path.join(dataset_root, lr_dataset_folder, folder, "L2A", ""),
    file_postfix="-L2A_data.tiff",
    transform=transforms["lr"],
    number_of_revisits=8,
    bands_to_read=S2_ALL_12BANDS["true_color"],
    multiprocessing_manager=multiprocessing_manager,
    )

    lrc = SatelliteDataset(
        root=os.path.join(dataset_root, lr_dataset_folder, folder, "L2A", ""),
        file_postfix="-CLM.tiff",
        transform=transforms["lrc"],
        number_of_revisits=8,
        multiprocessing_manager=multiprocessing_manager
    )

    hr_panchromatic = SatelliteDataset(
        root=os.path.join(dataset_root, hr_dataset_folder, folder),
        file_postfix="_pan.tiff",
        transform=transforms["hr_pan"],
        number_of_revisits=1,
        multiprocessing_manager=multiprocessing_manager
    )

    hr_rgb = SatelliteDataset(
        root=os.path.join(dataset_root, hr_dataset_folder, folder),
        file_postfix="_rgbn.tiff",
        transform=transforms["hr"],
        bands_to_read=SPOT_RGB_BANDS,
        number_of_revisits=1,
        multiprocessing_manager=multiprocessing_manager
    )

    hr_pansharpened = SatelliteDataset(
        root=os.path.join(dataset_root, hr_dataset_folder, folder),
        file_postfix="_ps.tiff",
        transform=transforms["hr"],
        bands_to_read=SPOT_RGB_BANDS,
        number_of_revisits=1,
        multiprocessing_manager=multiprocessing_manager
    )

    dataset = DictDataset(
        **{
            "lr": lr,
            "lrc": lrc,
            "hr": hr_rgb,
            "hr_pan": hr_panchromatic,
            "hr_pansharpened": hr_pansharpened,
        }
    )
    item = dataset[0]
    lr, lrc, hr_rgbn, hr_pan, ps = item['lr'], item['lrc'], item['hr'], item['hr_pan'], item['hr_pansharpened']
    tensor = make_grid([hr_rgbn[0]], normalize=True, scale_each=True)
    tensor = ensure_tensor_has_four_dimensions(tensor)

    axis = True
    figsize =10
    images_or_revisits, channels, height, width = tensor.shape
    nrows, ncols = 1, images_or_revisits

    image = 0

    x = tensor[image]
    # Convert images to float
    if x.dtype == torch.uint8:
        x = x / 255.0
    if x.is_floating_point():
        if x.ndim == 3 and x.shape[0] >= 3:
            # Channels first to channels last [1,2,3] -> [3,2,1]
            transform = T.ToPILImage()
            img = transform(x)
            img.save("dataset/simple_hr/" + folder +"_hr.png")
            x = x.permute(1, 2, 0)

        else:
            x = x[0]
            transform = T.ToPILImage()
            img.save("dataset/simple_hr/" + folder +"_hr.png")

    
    tensor = make_grid(lr[0], normalize=True, scale_each=True, nrow=4)
    tensor = ensure_tensor_has_four_dimensions(tensor)

    axis = True
    figsize =10
    images_or_revisits, channels, height, width = tensor.shape
    nrows, ncols = 1, images_or_revisits

    image = 0

    x = tensor[image]
    # Convert images to float
    if x.is_floating_point():
        if x.ndim == 3 and x.shape[0] >= 3:
            # Channels first to channels last [1,2,3] -> [3,2,1]
            transform = T.ToPILImage()
            img = transform(x)
            img.save("dataset/simple_lr/" + folder +"_lr.png")
            x = x.permute(1, 2, 0)

        else:
            x = x[0]
            transform = T.ToPILImage()
            img.save("dataset/simple_lr/" + folder +"_lr.png")
            

 34%|██████████████████████████████████████████████████████████████████▉                                                                                                                                 | 1341/3927 [16:50<32:47,  1.31it/s]