In [1]:
import os
import re
import torch
import numpy as np
import tifffile as tiff
from tqdm import tqdm
from torch.utils.data import Dataset, random_split, DataLoader

In [2]:
class SatImage_Dataloader_1v1(Dataset):
    def __init__(self, route, patch_size=128, stride=128, center_crop=False, transform=None):
        super().__init__()
        self.root_dir = route
        self.patch_size = patch_size
        self.center_crop = center_crop
        self.transform = transform
        self.stride = stride

        # List of samples:
        # If stride is None:       (cloudy_path, clean_path)
        # If stride is not None:   (cloudy_path, clean_path, top, left)
        self.samples = []

        cloudy_pattern    = r"(.+?)_(\d+)_(\d+)\.tif"       # matches n_k (cloud)
        cloudless_pattern = r"(.+?)_(\d+)\.tif"            # matches n (clean)

        for tile_name in sorted(os.listdir(route)):
            tile_path = os.path.join(route, tile_name)
            cloud_dir = os.path.join(tile_path, "cloud")
            clean_dir = os.path.join(tile_path, "cloudless")

            if not (os.path.isdir(cloud_dir) and os.path.isdir(clean_dir)):
                continue

            # 1 — Parse cloudy files grouped by time index n
            cloudy_by_n = {}
            for fname in sorted(os.listdir(cloud_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudy_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                path = os.path.join(cloud_dir, fname)
                cloudy_by_n.setdefault(n, []).append(path)

            # 2 — Parse cloudless (clean) files by time index n
            clean_by_n = {}
            for fname in sorted(os.listdir(clean_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudless_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                clean_by_n[n] = os.path.join(clean_dir, fname)

            # 3 — For each n, pick *one* cloudy and match with clean[n]
            for n in sorted(cloudy_by_n.keys()):
                if n not in clean_by_n:
                    print(f"[WARNING] Tile {tile_name}: time {n} has cloudy but no clean.")
                    continue

                cloudy_path = cloudy_by_n[n][0]   # mono-temporal choose first
                clean_path = clean_by_n[n]

                # If stride is not provided → old behavior: 1 sample per image
                if self.stride is None:
                    self.samples.append((cloudy_path, clean_path, None, None))
                    continue

                # Otherwise enumerate patches using stride
                tmp = tiff.imread(cloudy_path)  # (H, W, C)
                H, W, _ = tmp.shape

                ps = self.patch_size
                st = self.stride

                for top in range(0, H - ps + 1, st):
                    for left in range(0, W - ps + 1, st):
                        self.samples.append((cloudy_path, clean_path, top, left))

        print(f"[Sen2MTC loaded] Total samples (including patches): {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def load_tif(self, path):
        arr = tiff.imread(path)  # (H, W, C)
        arr = np.array(arr, dtype=np.float32)
        return arr

    # Patch extraction helper
    def extract_patch(self, img, size):
        _, H, W = img.shape
        if size > H or size > W:
            raise ValueError(f"Patch size {size} > image size {(H,W)}")

        if self.center_crop:
            top = (H - size) // 2
            left = (W - size) // 2
        else:
            top = np.random.randint(0, H - size + 1)
            left = np.random.randint(0, W - size + 1)

        patch = img[:, top:top+size, left:left+size]
        return patch

    def __getitem__(self, idx):
        cloudy_path, clean_path, top, left = self.samples[idx]

        cloudy = self.load_tif(cloudy_path)
        clean  = self.load_tif(clean_path)

        cloudy = torch.from_numpy(cloudy.transpose(2,0,1))
        clean  = torch.from_numpy(clean.transpose(2,0,1))

        # Patch extraction
        if self.patch_size is not None:
            if self.stride is not None:
                # predetermined patch from (top, left)
                ps = self.patch_size
                cloudy = cloudy[:, top:top+ps, left:left+ps]
                clean  = clean[:, top:top+ps, left:left+ps]
            else:
                # center or random crop
                cloudy = self.extract_patch(cloudy, self.patch_size)
                clean  = self.extract_patch(clean, self.patch_size)

        sample = {
            "cloudy": cloudy,
            "clean": clean,
            "cloudy_path": cloudy_path,
            "clean_path": clean_path
        }

        if self.transform:
            sample = self.transform(sample)

        return sample


In [5]:
def compute_global_stats(dataset, batch_size=32):
    """
    Compute global per-channel mean and std over a dataset.
    Returns:
        mean, std
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    channel_sum = None
    channel_sq_sum = None
    total_pixels = 0

    for batch in loader:
        x = batch["cloud"].double()   # (B,C,H,W)
        B, C, H, W = x.shape

        if channel_sum is None:
            channel_sum = torch.zeros(C, dtype=torch.float64)
            channel_sq_sum = torch.zeros(C, dtype=torch.float64)

        # Sum over batch and spatial dims
        channel_sum += x.sum(dim=[0, 2, 3])
        channel_sq_sum += (x ** 2).sum(dim=[0, 2, 3])

        total_pixels += B * H * W

    mean = channel_sum / total_pixels
    std = torch.sqrt(channel_sq_sum / total_pixels - mean**2)

    print("Global mean:", mean)
    print("Global std:", std)

    return mean.float(), std.float()

class Normalization_1v1:
    """
    Normalize cloudy and clean patches using precomputed mean/std.
    """

    def __init__(self, mean, std):
        self.mean = mean.reshape(-1, 1, 1)   # (C,1,1)
        self.std  = std.reshape(-1, 1, 1)    # (C,1,1)

    def __call__(self, sample):
        cloudy = sample["cloudy"]
        clean  = sample["clean"]

        cloudy_n = (cloudy - self.mean) / (self.std + 1e-6)
        clean_n  = (clean  - self.mean) / (self.std + 1e-6)

        return {
            "cloudy": cloudy_n,
            "clean": clean_n,
            "cloudy_path": sample["cloudy_path"],
            "clean_path": sample["clean_path"],
        }

In [2]:
class SatImage_Dataloader_3v1(Dataset):
    def __init__(self, route, patch_size=128, stride=128,
                 center_crop=False, transform=None):

        super().__init__()
        self.root_dir = route
        self.patch_size = patch_size
        self.center_crop = center_crop
        self.transform = transform
        self.stride = stride
        self.samples = []

        cloudy_pattern = r"(.+?)_(\d+)_(\d+)\.tif"
        clean_pattern  = r"(.+?)_(\d+)\.tif"

        for tile_name in sorted(os.listdir(route)):
            tile_path = os.path.join(route, tile_name)
            cloud_dir = os.path.join(tile_path, "cloud")
            clean_dir = os.path.join(tile_path, "cloudless")

            if not (os.path.isdir(cloud_dir) and os.path.isdir(clean_dir)):
                continue

            cloudy_by_n = {}
            for fname in sorted(os.listdir(cloud_dir)):
                if fname.endswith(".tif"):
                    m = re.match(cloudy_pattern, fname)
                    if m:
                        n = int(m.group(2))
                        cloudy_by_n.setdefault(n, []).append(
                            os.path.join(cloud_dir, fname)
                        )

            clean_by_n = {}
            for fname in sorted(os.listdir(clean_dir)):
                if fname.endswith(".tif"):
                    m = re.match(clean_pattern, fname)
                    if m:
                        n = int(m.group(2))
                        clean_by_n[n] = os.path.join(clean_dir, fname)

            for n in sorted(cloudy_by_n.keys()):
                if n not in clean_by_n: 
                    continue

                cloudy_paths = sorted(cloudy_by_n[n])
                if len(cloudy_paths) < 3:
                    continue
                if len(cloudy_paths) > 3:
                    cloudy_paths = cloudy_paths[:3]

                clean_path = clean_by_n[n]

                # patch enumeration
                if self.stride is None:
                    self.samples.append((cloudy_paths, clean_path, None, None))
                else:
                    tmp = tiff.imread(cloudy_paths[0])
                    H, W, _ = tmp.shape

                    ps = self.patch_size
                    st = self.stride

                    for top in range(0, H-ps+1, st):
                        for left in range(0, W-ps+1, st):
                            self.samples.append((cloudy_paths, clean_path, top, left))

        print(f"[Loaded minimal MT dataset] {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def load_tif(self, path):
        arr = tiff.imread(path).astype(np.float32)
        return arr  # (H,W,C)

    def __getitem__(self, idx):
        cloudy_paths, clean_path, top, left = self.samples[idx]

        # clean
        clean_np = self.load_tif(clean_path).transpose(2,0,1)
        clean = torch.from_numpy(clean_np)  # (C,H,W)

        # cloudy sequence
        seq_list = []
        for p in cloudy_paths:
            c_np = self.load_tif(p).transpose(2,0,1)
            seq_list.append(torch.from_numpy(c_np))
        seq = torch.stack(seq_list, dim=0)  # (T,C,H,W)

        # patching
        if self.patch_size is not None:
            ps = self.patch_size

            if self.stride is not None:
                clean = clean[:, top:top+ps, left:left+ps]
                seq   = seq[:, :, top:top+ps, left:left+ps]
            else:
                _, H, W = clean.shape
                if self.center_crop:
                    top = (H-ps)//2
                    left = (W-ps)//2
                else:
                    top = np.random.randint(0, H-ps+1)
                    left = np.random.randint(0, W-ps+1)
                clean = clean[:, top:top+ps, left:left+ps]
                seq   = seq[:, :, top:top+ps, left:left+ps]

        sample = {"clean": clean, "cloudy_seq": seq}

        if self.transform:
            sample = self.transform(sample)

        return sample

class Normalization_3v1:
    def __init__(self, mean, std):
        self.mean = mean.reshape(-1,1,1)
        self.std  = std.reshape(-1,1,1)

    def __call__(self, sample):
        clean = sample["clean"]                # (C,H,W)
        seq   = sample["cloudy_seq"]           # (T,C,H,W)

        clean_n = (clean - self.mean) / (self.std + 1e-6)
        seq_n   = (seq   - self.mean) / (self.std + 1e-6)

        return {
            "clean": clean_n,
            "cloudy_seq": seq_n
        }

def compute_global_stats_3v1(root):
    """
    Compute global mean/std over ALL raw TIFF images inside:
        <tile>/cloud/*.tif
        <tile>/cloudless/*.tif
    
    Returns torch tensors:  mean (C,), std (C,)
    """

    sum_channels = None
    sum_sq_channels = None
    count_pixels = 0

    # list all tiles
    tile_names = sorted(os.listdir(root))

    all_images = []

    for tile in tile_names:
        tile_dir = os.path.join(root, tile)
        cloud_dir = os.path.join(tile_dir, "cloud")
        clean_dir = os.path.join(tile_dir, "cloudless")

        if not (os.path.isdir(cloud_dir) and os.path.isdir(clean_dir)):
            continue

        # collect all tif paths
        for subdir in [cloud_dir, clean_dir]:
            for fname in os.listdir(subdir):
                if fname.endswith(".tif"):
                    all_images.append(os.path.join(subdir, fname))

    print(f"[compute_global_stats] Found {len(all_images)} images.")

    # iterate through images
    for img_path in tqdm(all_images, desc="Computing global mean/std"):
        arr = tiff.imread(img_path).astype(np.float32)   # (H,W,C)
        C = arr.shape[2]

        if sum_channels is None:
            sum_channels = np.zeros(C, dtype=np.float64)
            sum_sq_channels = np.zeros(C, dtype=np.float64)

        # flatten H,W dims
        flat = arr.reshape(-1, C)   # (H*W, C)
        sum_channels += flat.sum(axis=0)
        sum_sq_channels += (flat ** 2).sum(axis=0)
        count_pixels += flat.shape[0]

    # compute mean/std
    mean = sum_channels / count_pixels
    var = (sum_sq_channels / count_pixels) - (mean ** 2)
    std = np.sqrt(var)

    mean = torch.tensor(mean, dtype=torch.float32)
    std  = torch.tensor(std, dtype=torch.float32)

    print("Global mean:", mean)
    print("Global std :", std)

    return mean, std


In [3]:
route = "./Sen2_MTC/dataset/Sen2_MTC"
ps = 128
st = 128
batch_size = 16

dataset_raw = SatImage_Dataloader_3v1(
    route=route,
    patch_size=ps,
    stride=st,
    transform=None
)

mean, std = compute_global_stats_3v1(route)

dataset_norm = SatImage_Dataloader_3v1(
    route=route,
    patch_size=ps,
    stride=st,
    transform=Normalization_3v1(mean, std)
)

train_ratio = 0.7
val_ratio   = 0.15
test_ratio  = 0.15

total       = len(dataset_norm)
train_len   = int(total * train_ratio)
val_len     = int(total * val_ratio)
test_len    = total - train_len - val_len

train_set, val_set, test_set = torch.utils.data.random_split(
    dataset_norm,
    [train_len, val_len, test_len],
    generator=torch.Generator().manual_seed(42)
)

def pack_split(split):
    """Convert a Subset into a list of {cloudy_seq, clean} dicts."""
    out = []
    for i in range(len(split)):
        item = split[i]
        out.append({
            "cloudy_seq": item["cloudy_seq"],
            "clean": item["clean"]
        })
    return out


data = {
    "train": pack_split(train_set),
    "val":   pack_split(val_set),
    "test":  pack_split(test_set)
}

torch.save(data, "./Ckpts/Sen2MTC_FULL_3v1_norm.pt")
print("Saved → Sen2MTC_all.pt")


[Loaded minimal MT dataset] 13668 samples
[compute_global_stats] Found 13669 images.


Computing global mean/std: 100%|██████████| 13669/13669 [01:07<00:00, 202.09it/s]


Global mean: tensor([1711.6575, 1720.1365, 1585.7429, 3117.7529])
Global std : tensor([1862.0779, 1880.4310, 2019.5642, 1677.4390])
[Loaded minimal MT dataset] 13668 samples
Saved → Sen2MTC_all.pt


In [3]:
data = torch.load("./Ckpts/Sen2MTC_Mini_3v1_norm.pt", map_location="cpu")

In [4]:
class PrecomputedSen2MTC(Dataset):
    def __init__(self, data_list):
        # data_list is a list of dicts:
        # [{"cloudy_seq": tensor, "clean": tensor}, ...]
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "cloudy_seq": item["cloudy_seq"],
            "clean": item["clean"]
        }

def collate_mt(batch):
    # batch: list of {"cloudy_seq":..., "clean":...}

    cloudy = [b["cloudy_seq"] for b in batch]   # list of (T,C,H,W)
    clean  = [b["clean"]      for b in batch]   # list of (C,H,W)

    cloudy = torch.stack(cloudy, dim=0)    # (B,T,C,H,W)
    clean  = torch.stack(clean,  dim=0)    # (B,C,H,W)

    return {
        "cloudy_seq": cloudy,
        "clean": clean
    }


In [5]:
train_set = PrecomputedSen2MTC(data["train"])
val_set   = PrecomputedSen2MTC(data["val"])
test_set  = PrecomputedSen2MTC(data["test"])

train_loader = DataLoader(train_set, batch_size=16,
                          shuffle=True, collate_fn=collate_mt)

val_loader   = DataLoader(val_set,   batch_size=16,
                          shuffle=False, collate_fn=collate_mt)

test_loader  = DataLoader(test_set,  batch_size=16,
                          shuffle=False, collate_fn=collate_mt)


In [6]:
batch = next(iter(train_loader))

print(batch["cloudy_seq"].shape)   # (B,T,C,H,W)
print(batch["clean"].shape)        # (B,C,H,W)


torch.Size([16, 3, 4, 128, 128])
torch.Size([16, 4, 128, 128])


In [7]:
print(train_set[0])

{'cloudy_seq': tensor([[[[ 1.7398e+00,  1.5795e+00,  1.3616e+00,  ...,  1.6607e+00,
            1.5304e+00,  1.3295e+00],
          [ 1.6714e+00,  1.5710e+00,  1.4855e+00,  ...,  1.6223e+00,
            1.4812e+00,  1.3413e+00],
          [ 1.5432e+00,  1.5005e+00,  1.5689e+00,  ...,  1.6479e+00,
            1.5176e+00,  1.3680e+00],
          ...,
          [-3.6172e-01, -3.8950e-01, -3.5264e-01,  ...,  3.0068e-01,
            3.3166e-01,  3.4555e-01],
          [-3.2647e-01, -3.0029e-01, -2.2817e-01,  ...,  2.4833e-01,
            2.7076e-01,  3.4555e-01],
          [-2.7412e-01, -2.2924e-01, -1.8544e-01,  ...,  2.7183e-01,
            3.1670e-01,  4.1927e-01]],

         [[ 1.8813e+00,  1.8304e+00,  1.6671e+00,  ...,  1.8049e+00,
            1.6056e+00,  1.3722e+00],
          [ 1.8198e+00,  1.6628e+00,  1.5080e+00,  ...,  1.8071e+00,
            1.5737e+00,  1.3128e+00],
          [ 1.7286e+00,  1.5037e+00,  1.4040e+00,  ...,  1.7498e+00,
            1.6034e+00,  1.3552e+00],
     