In [70]:
import random
import numpy as np
import torch
import rasterio as rs
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
import copy
import pandas as pd
import os
import glob
from torchvision.transforms import GaussianBlur
import re

class CogDataset_v41(Dataset):
    def __init__(self, num_s2_frames = 10):
        self.dataset_path = Path("/share/hariharan/cloud_removal/MultiSensor/dataset_temp_preprocessed/spatio_temporal")
        self.num_s2_frames = num_s2_frames
        self.load_spatio_temporal_info()
        self.mode = "MSI"

    def __len__(self):
        return 2048
    
    def __getitem__(self, idx):

        # randomly select a row in self.roi_spatio_temporal_info
        row = self.roi_spatio_temporal_info.iloc[random.randint(0, len(self.roi_spatio_temporal_info)-1)]
        roi = row["roi_id"]
        patch_id = row["patch_id"]
        day_counts = row["day_count"]
        dates = row["dates"]
        
        day_random_idx = random.randint(0, len(day_counts)-self.num_s2_frames)        
        FILE_PATH = os.path.join(self.dataset_path, f"{roi}_patch{patch_id}.cog")
        WINDOW = rs.windows.Window(0, day_random_idx * 256, 256, 256 * self.num_s2_frames)
        
        print(f"""{roi} | patch {patch_id} | day {day_random_idx} | latitude {row["latitude"]:.3f} | longtitude {row["longtitude"]:.3f}""")
        print(f"start date: {day_counts[day_random_idx]} | end date: {day_counts[day_random_idx+self.num_s2_frames]} ")
        
        with rs.open(FILE_PATH) as src:
            msi = torch.from_numpy(src.read(list(range(1, 18)), window=WINDOW))
        assert msi.shape == (17, 256 * self.num_s2_frames, 256)
        msi = msi.reshape(17, self.num_s2_frames, 256, 256)
        return msi

    def load_spatio_temporal_info(self):
        csv_list = glob.glob("/share/hariharan/cloud_removal/MultiSensor/dataset_temp_preprocessed/spatio_temporal/roi*.csv")
        self.roi_spatio_temporal_info = []
        for csv_file in csv_list:
            df = pd.read_csv(csv_file)
            if len(self.roi_spatio_temporal_info) == 0:
                df["day_count"] = df['day_count'].apply(lambda x: [int(num) for num in re.findall(r'\d+', x)])
                self.roi_spatio_temporal_info = df
            else:
                df["day_count"] = df['day_count'].apply(lambda x: [int(num) for num in re.findall(r'\d+', x)])
                self.roi_spatio_temporal_info = pd.concat([self.roi_spatio_temporal_info, df], ignore_index=True, axis=0)

batch_size = 4
dataset = CogDataset_v41()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
dataset.roi_spatio_temporal_info

In [71]:
for msi_buffer in dataloader: break

In [72]:
# msi_buffer.shape: torch.Size([4, 17, 10, 256, 256]) # [batch_size, num_bands, num_frames, height, width]
# Plot the images in the following manner
# for-loop over the four batches, for each batch do the following
# construct a subplots with 5 rows and 10 columns (aka num_frames)
# The first row is RGB, which corresponds to num_bands = 0,1,2
# The second row is NIR, which corresponds to num_bands = 3,4,5
# The third row is SWIR1, which corresponds to num_bands = 6,7,8
# The fourth row is SAR, which corresponds to num_bands = 15, 15, 16
# The fifth row is the cloud mask, which corresponds to num_bands = 10,11,13

def pp(img):
    return img.permute(1,2,0)

num_frames = 10

for batch_idx in range(4):
    
    # if batch_idx != 3: continue
    
    fig, axs = plt.subplots(5, num_frames, figsize=(num_frames*2, 10), sharex=True, sharey=True)
    for frame_idx in range(num_frames):
        
        # RGB
        img = pp(msi_buffer[batch_idx, 0:3, frame_idx]).numpy() / 3000
        axs[0, frame_idx].imshow(img)
        min_value = msi_buffer[batch_idx, 0, frame_idx].numpy().min()
        max_value = msi_buffer[batch_idx, 0, frame_idx].numpy().max()
        print(f"Batch {batch_idx} at frame {frame_idx}: Max_value {max_value} | Min_value {min_value}")

        img = pp(msi_buffer[batch_idx, 4:7, frame_idx]).numpy() / 5000
        axs[1, frame_idx].imshow(img)
        
        # FIR
        axs[2, frame_idx].imshow(msi_buffer[batch_idx, 0, frame_idx].numpy(), cmap='gray', vmin=0, vmax=10000)
        
        
        # SAR
        img = np.zeros_like(img)
        img[:,:,0] = msi_buffer[batch_idx, 10, frame_idx].numpy() / 25
        img[:,:,1] = msi_buffer[batch_idx, 11, frame_idx].numpy() / 32.5
        axs[3, frame_idx].imshow(img, vmin=0, vmax=32)
        
        img[:,:,0] = msi_buffer[batch_idx, 11, frame_idx].numpy() / 100
        img[:,:,1] = msi_buffer[batch_idx, 13, frame_idx].numpy()
        img[:,:,2] = msi_buffer[batch_idx, 15, frame_idx].numpy()
        axs[4, frame_idx].imshow(img, )
        
    fig.tight_layout()
    plt.pause(0.1)
    plt.close()

In [56]:
msi_buffer[:, 16, :].numpy().max()