In [5]:
import os
import re
import torch
import numpy as np
import tifffile as tiff
from tqdm import tqdm
from torch.utils.data import Dataset, random_split, DataLoader

In [2]:
class SatImage_Dataloader(Dataset):
    """
    Mono-temporal Sen2-MTC dataset with patch extraction.

    For each time index n:
        cloudy[n] has multiple *.tif files
        cloudless[n] has one  *.tif file

    One cloudy sample for each n -> mono-temporal.
    If patch_size is provided:
        If center_crop True : return crop from the center of image.
        If center_crop False: return a random crop on the image.

    If stride is provided:
        Enumerate all patches of size patch_size with given stride.
        (This multiplies dataset size with deterministic patches.)
    """

    def __init__(self, route, patch_size=128, stride=128, center_crop=False, transform=None):
        super().__init__()
        self.root_dir = route
        self.patch_size = patch_size
        self.center_crop = center_crop
        self.transform = transform
        self.stride = stride  # NEW

        # List of samples:
        # If stride is None:       (cloudy_path, clean_path)
        # If stride is not None:   (cloudy_path, clean_path, top, left)
        self.samples = []

        cloudy_pattern    = r"(.+?)_(\d+)_(\d+)\.tif"       # matches n_k (cloud)
        cloudless_pattern = r"(.+?)_(\d+)\.tif"            # matches n (clean)

        for tile_name in sorted(os.listdir(route)):
            tile_path = os.path.join(route, tile_name)
            cloud_dir = os.path.join(tile_path, "cloud")
            clean_dir = os.path.join(tile_path, "cloudless")

            if not (os.path.isdir(cloud_dir) and os.path.isdir(clean_dir)):
                continue

            # 1 — Parse cloudy files grouped by time index n
            cloudy_by_n = {}
            for fname in sorted(os.listdir(cloud_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudy_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                path = os.path.join(cloud_dir, fname)
                cloudy_by_n.setdefault(n, []).append(path)

            # 2 — Parse cloudless (clean) files by time index n
            clean_by_n = {}
            for fname in sorted(os.listdir(clean_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudless_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                clean_by_n[n] = os.path.join(clean_dir, fname)

            # 3 — For each n, pick *one* cloudy and match with clean[n]
            for n in sorted(cloudy_by_n.keys()):
                if n not in clean_by_n:
                    print(f"[WARNING] Tile {tile_name}: time {n} has cloudy but no clean.")
                    continue

                cloudy_path = cloudy_by_n[n][0]   # mono-temporal choose first
                clean_path = clean_by_n[n]

                # If stride is not provided → old behavior: 1 sample per image
                if self.stride is None:
                    self.samples.append((cloudy_path, clean_path, None, None))
                    continue

                # Otherwise enumerate patches using stride
                # We must load image shape
                tmp = tiff.imread(cloudy_path)  # (H, W, C)
                H, W, _ = tmp.shape

                ps = self.patch_size
                st = self.stride

                for top in range(0, H - ps + 1, st):
                    for left in range(0, W - ps + 1, st):
                        self.samples.append((cloudy_path, clean_path, top, left))

        print(f"[Sen2MTC loaded] Total samples (including patches): {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def load_tif(self, path):
        arr = tiff.imread(path)  # (H, W, C)
        arr = np.array(arr, dtype=np.float32)
        return arr

    # Patch extraction helper
    def extract_patch(self, img, size):
        """
        img: numpy array shape (C,H,W)
        size: int, patch size
        returns: (C, size, size)
        """
        _, H, W = img.shape
        if size > H or size > W:
            raise ValueError(f"Patch size {size} > image size {(H,W)}")

        if self.center_crop:
            top = (H - size) // 2
            left = (W - size) // 2
        else:
            top = np.random.randint(0, H - size + 1)
            left = np.random.randint(0, W - size + 1)

        patch = img[:, top:top+size, left:left+size]
        return patch

    def __getitem__(self, idx):
        cloudy_path, clean_path, top, left = self.samples[idx]

        cloudy = self.load_tif(cloudy_path)
        clean  = self.load_tif(clean_path)

        cloudy = torch.from_numpy(cloudy.transpose(2,0,1))
        clean  = torch.from_numpy(clean.transpose(2,0,1))

        # Patch extraction
        if self.patch_size is not None:
            if self.stride is not None:
                # predetermined patch from (top, left)
                ps = self.patch_size
                cloudy = cloudy[:, top:top+ps, left:left+ps]
                clean  = clean[:, top:top+ps, left:left+ps]
            else:
                # center or random crop
                cloudy = self.extract_patch(cloudy, self.patch_size)
                clean  = self.extract_patch(clean, self.patch_size)

        sample = {
            "cloudy": cloudy,
            "clean": clean,
            "cloudy_path": cloudy_path,
            "clean_path": clean_path
        }

        if self.transform:
            sample = self.transform(sample)

        return sample


In [3]:
def compute_global_stats(dataset, batch_size=32):
    """
    Compute global per-channel mean and std over a dataset.
    Returns:
        mean, std
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    channel_sum = None
    channel_sq_sum = None
    total_pixels = 0

    for batch in loader:
        x = batch["cloudy"].double()   # (B,C,H,W)
        B, C, H, W = x.shape

        if channel_sum is None:
            channel_sum = torch.zeros(C, dtype=torch.float64)
            channel_sq_sum = torch.zeros(C, dtype=torch.float64)

        # Sum over batch and spatial dims
        channel_sum += x.sum(dim=[0, 2, 3])
        channel_sq_sum += (x ** 2).sum(dim=[0, 2, 3])

        total_pixels += B * H * W

    mean = channel_sum / total_pixels
    std = torch.sqrt(channel_sq_sum / total_pixels - mean**2)

    print("Global mean:", mean)
    print("Global std:", std)

    return mean.float(), std.float()

class Normalization:
    """
    Normalize cloudy and clean patches using precomputed mean/std.
    """

    def __init__(self, mean, std):
        self.mean = mean.reshape(-1, 1, 1)   # (C,1,1)
        self.std  = std.reshape(-1, 1, 1)    # (C,1,1)

    def __call__(self, sample):
        cloudy = sample["cloudy"]
        clean  = sample["clean"]

        cloudy_n = (cloudy - self.mean) / (self.std + 1e-6)
        clean_n  = (clean  - self.mean) / (self.std + 1e-6)

        return {
            "cloudy": cloudy_n,
            "clean": clean_n,
            "cloudy_path": sample["cloudy_path"],
            "clean_path": sample["clean_path"],
        }

In [6]:


############################################################
# 1. Your dataset + compute_global_stats + Normalization
############################################################

# --- paste your ENTIRE SatImage_Dataloader class here ---
# --- paste compute_global_stats here ---
# --- paste Normalization class here ---

###############################################
# 2. Precompute everything
###############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

route = "./Sen2_MTC_Mini"
patch_size = 128
stride = 128
batch_size = 16

print("STEP 1: Create RAW dataset...")
dataset_raw = SatImage_Dataloader(route=route,
                                  patch_size=patch_size,
                                  stride=stride,
                                  transform=None)

print("STEP 2: Compute global mean/std...")
mean, std = compute_global_stats(dataset_raw, batch_size=batch_size)

print("STEP 3: Create normalized dataset...")
normalized_tf = Normalization(mean, std)
dataset_norm = SatImage_Dataloader(route=route,
                                   patch_size=patch_size,
                                   stride=stride,
                                   transform=normalized_tf)

############################################################
# 3. MATERIALIZE THE ENTIRE DATASET INTO MEMORY
############################################################

print("STEP 4: Materialize (load all samples into CPU RAM) ...")

all_cloudy = []
all_clean  = []
all_cloudy_path = []
all_clean_path  = []

loader = DataLoader(dataset_norm, batch_size=1, shuffle=False)

for sample in tqdm(loader):
    # sample["cloudy"] has shape (1,C,H,W) → squeeze B dim
    all_cloudy.append(sample["cloudy"].squeeze(0))
    all_clean.append(sample["clean"].squeeze(0))
    all_cloudy_path.append(sample["cloudy_path"][0])
    all_clean_path.append(sample["clean_path"][0])

# Stack into tensors or keep list (list is safer for large dataset)
# all_cloudy = torch.stack(all_cloudy)
# all_clean  = torch.stack(all_clean)

print("TOTAL PATCHES:", len(all_cloudy))

############################################################
# 4. Train/Val/Test Split  (deterministic)
############################################################

train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

N = len(all_cloudy)
train_len = int(train_ratio * N)
val_len   = int(val_ratio   * N)
test_len  = N - train_len - val_len

generator = torch.Generator().manual_seed(2025)
indices = torch.randperm(N, generator=generator)

train_idx = indices[:train_len]
val_idx   = indices[train_len:train_len+val_len]
test_idx  = indices[train_len+val_len:]

def subset(idx_list):
    return {
        "cloudy": [all_cloudy[i] for i in idx_list],
        "clean" : [all_clean[i]  for i in idx_list],
        "cloudy_path": [all_cloudy_path[i] for i in idx_list],
        "clean_path" : [all_clean_path[i]  for i in idx_list],
    }

train_set = subset(train_idx)
val_set   = subset(val_idx)
test_set  = subset(test_idx)

############################################################
# 5. Save everything to .pt
############################################################

save_dict = {
    "mean": mean,
    "std": std,
    "train_set": train_set,
    "val_set": val_set,
    "test_set": test_set,
}

torch.save(save_dict, "sen2mtc_precomputed.pt")
print("Saved sen2mtc_precomputed.pt")


STEP 1: Create RAW dataset...
[Sen2MTC loaded] Total samples (including patches): 4200
STEP 2: Compute global mean/std...
Global mean: tensor([1950.8348, 2058.9039, 1980.1758, 3509.9704], dtype=torch.float64)
Global std: tensor([2049.4984, 2072.6882, 2218.3611, 1820.3944], dtype=torch.float64)
STEP 3: Create normalized dataset...
[Sen2MTC loaded] Total samples (including patches): 4200
STEP 4: Materialize (load all samples into CPU RAM) ...


100%|██████████| 4200/4200 [00:12<00:00, 348.46it/s]


TOTAL PATCHES: 4200
Saved sen2mtc_precomputed.pt


In [7]:
import torch
from torch.utils.data import Dataset, DataLoader

##############################################
# Minimal Dataset wrapper for precomputed data
##############################################
class PrecomputedDataset(Dataset):
    def __init__(self, bag):
        self.cloudy = bag["cloudy"]
        self.clean  = bag["clean"]

    def __len__(self):
        return len(self.cloudy)

    def __getitem__(self, i):
        return {
            "cloudy": self.cloudy[i].float(),
            "clean":  self.clean[i].float()
        }

##############################################
# LOAD
##############################################
data = torch.load("sen2mtc_precomputed.pt")

train_set = PrecomputedDataset(data["train_set"])
val_set   = PrecomputedDataset(data["val_set"])
test_set  = PrecomputedDataset(data["test_set"])

batch_size = 16

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False)


In [8]:
for batch in train_loader:
    print(batch['cloudy'].shape)
    print(batch['clean'].shape)
    x = batch['cloudy']
    print(x.mean(), x.std())
    break
#[batch_size, channels, height, width] -> [batch_size, channels, patch_size, patch_size]

torch.Size([16, 4, 128, 128])
torch.Size([16, 4, 128, 128])
tensor(-0.0722) tensor(0.7882)
