In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from pathlib import Path
import geopandas as gpd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from gelos.gelosdataset import GELOSDataSet
from gelos.gelosdatamodule import GELOSDataModule


In [6]:
data_root = Path('/app/data/raw/v0.40')

In [None]:
gelos_datamodule = GELOSDataModule(data_root=data_root, batch_size=8, num_workers=16)

In [None]:
datamodule = gelos_datamodule
datamodule.setup()
loader = datamodule.predict_dataloader()
dataset = loader.dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

modalities = dataset.bands.keys()
band_names = dataset.all_band_names


Using device: cuda


In [13]:
# initialize accumulators
sums = {m: 0 for m in modalities}
sum_squares = {m: 0 for m in modalities}
pixel_counts = {m: 0 for m in modalities}
first_batch = True

In [14]:
for batch in tqdm(loader, total=len(loader)):
    # The dataset returns a dict with an 'image' key, which itself is a dict of modalities
    image_dict = batch["image"]
    
    # Process each modality in the batch
    for modality, tensor in image_dict.items():
        # Move tensor to the designated device (GPU/CPU)
        tensor = tensor.to(device)

        # On the first batch, initialize accumulators with correct tensor shapes
        if first_batch:
            # Shape is (B, T, C, H, W). We want to sum over B, T, H, W.
            # Resulting shape should be (C,)
            num_channels = tensor.shape[2]
            sums[modality] = torch.zeros(num_channels, device=device)
            sum_squares[modality] = torch.zeros(num_channels, device=device)
            
            # Get shape for pixel count calculation
            _, t, _, h, w = tensor.shape
            pixel_counts[modality] = len(dataset) * t * h * w

        # Sum over all dimensions except the channel dimension (C)
        sums[modality] += torch.sum(tensor, dim=(0, 1, 3, 4))
        sum_squares[modality] += torch.sum(tensor.pow(2), dim=(0, 1, 3, 4))

    # After processing the first batch, set the flag to false
    if first_batch:
        first_batch = False


100%|██████████| 9603/9603 [06:06<00:00, 26.18it/s]


In [15]:
means = {m: sums[m] / pixel_counts[m] for m in modalities}
stds = {m: torch.sqrt(sum_squares[m] / pixel_counts[m] - means[m].pow(2)) for m in modalities}

print("Means:", means)
print("Stds:", stds)

Means: {'S1RTC': tensor([0.1445, 0.0290], device='cuda:0'), 'S2L2A': tensor([1852.9951, 2046.7385, 2346.2803, 2593.0386, 2900.8289, 3365.5979,
        3576.1414, 3657.3047, 3703.0908, 3709.9336, 3543.1648, 3048.2400],
       device='cuda:0'), 'landsat': tensor([0.0817, 0.0960, 0.1316, 0.1531, 0.2622, 0.2377, 0.1811],
       device='cuda:0'), 'dem': tensor([642.7003], device='cuda:0')}
Stds: {'S1RTC': tensor([2.6007, 0.2677], device='cuda:0'), 'S2L2A': tensor([1201.8008, 1267.0759, 1316.0233, 1520.8367, 1518.5592, 1419.7736,
        1442.8787, 1476.5182, 1437.5333, 1440.6731, 1588.9490, 1524.4882],
       device='cuda:0'), 'landsat': tensor([0.1597, 0.1609, 0.1554, 0.1681, 0.1539, 0.1463, 0.1311],
       device='cuda:0'), 'dem': tensor([783.0748], device='cuda:0')}


In [16]:
import json

# Get the band names for each modality from the dataset class
bands_per_modality = GELOSDataSet.all_band_names

# Create new dictionaries to store the formatted results
formatted_means = {}
formatted_stds = {}

# Iterate through each modality
for modality, mean_tensor in means.items():
    # Detach tensor from GPU, move to CPU, and convert to a list of floats
    mean_values = mean_tensor.cpu().tolist()
    std_values = stds[modality].cpu().tolist()
    
    # Get the list of band names for the current modality
    band_names = bands_per_modality[modality]
    
    # Create a dictionary mapping band names to their calculated mean and std
    formatted_means[modality] = {band: mean for band, mean in zip(band_names, mean_values)}
    formatted_stds[modality] = {band: std for band, std in zip(band_names, std_values)}

# Print the formatted dictionaries as pretty JSON
print("MEANS = ", json.dumps(formatted_means, indent=4))
print("\nSTDS = ", json.dumps(formatted_stds, indent=4))


MEANS =  {
    "S1RTC": {
        "VV": 0.14450763165950775,
        "VH": 0.029020152986049652
    },
    "S2L2A": {
        "COASTAL_AEROSOL": 1852.9951171875,
        "BLUE": 2046.738525390625,
        "GREEN": 2346.2802734375,
        "RED": 2593.03857421875,
        "RED_EDGE_1": 2900.828857421875,
        "RED_EDGE_2": 3365.597900390625,
        "RED_EDGE_3": 3576.141357421875,
        "NIR_BROAD": 3657.3046875,
        "NIR_NARROW": 3703.0908203125,
        "SWIR_1": 3709.93359375,
        "SWIR_2": 3543.164794921875,
        "CIRRUS": 3048.239990234375
    },
    "landsat": {
        "coastal": 0.08165209740400314,
        "blue": 0.09596806019544601,
        "green": 0.1315794140100479,
        "red": 0.1531316637992859,
        "nir08": 0.2621993124485016,
        "swir16": 0.23768098652362823,
        "swir22": 0.18106447160243988
    },
    "dem": {
        "dem": 642.7003173828125
    }
}

STDS =  {
    "S1RTC": {
        "VV": 2.600670576095581,
        "VH": 0.2677262127