In [2]:
import os
import re
import torch
import numpy as np
import tifffile as tiff
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.transforms.v2.functional import normalize

In [3]:
class SatImage_Dataloader(Dataset):
    """
    Mono-temporal Sen2-MTC dataset with patch extraction.

    For each time index n:
        cloudy[n] has multiple *.tif files
        cloudless[n] has one  *.tif file

    One cloudy sample for each n -> mono-temporal.
    If patch_size is provided:
        If center_crop True : return crop from the center of image.
        If center_crop False: return a random crop on the image.

    If stride is provided:
        Enumerate all patches of size patch_size with given stride.
        (This multiplies dataset size with deterministic patches.)
    """

    def __init__(self, route, patch_size=128, stride=128, center_crop=False, transform=None):
        super().__init__()
        self.root_dir = route
        self.patch_size = patch_size
        self.center_crop = center_crop
        self.transform = transform
        self.stride = stride  # NEW

        # List of samples:
        # If stride is None:       (cloudy_path, clean_path)
        # If stride is not None:   (cloudy_path, clean_path, top, left)
        self.samples = []

        cloudy_pattern    = r"(.+?)_(\d+)_(\d+)\.tif"       # matches n_k (cloud)
        cloudless_pattern = r"(.+?)_(\d+)\.tif"            # matches n (clean)

        for tile_name in sorted(os.listdir(route)):
            tile_path = os.path.join(route, tile_name)
            cloud_dir = os.path.join(tile_path, "cloud")
            clean_dir = os.path.join(tile_path, "cloudless")

            if not (os.path.isdir(cloud_dir) and os.path.isdir(clean_dir)):
                continue

            # 1 — Parse cloudy files grouped by time index n
            cloudy_by_n = {}
            for fname in sorted(os.listdir(cloud_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudy_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                path = os.path.join(cloud_dir, fname)
                cloudy_by_n.setdefault(n, []).append(path)

            # 2 — Parse cloudless (clean) files by time index n
            clean_by_n = {}
            for fname in sorted(os.listdir(clean_dir)):
                if not fname.endswith(".tif"):
                    continue
                m = re.match(cloudless_pattern, fname)
                if not m:
                    continue

                n = int(m.group(2))
                clean_by_n[n] = os.path.join(clean_dir, fname)

            # 3 — For each n, pick *one* cloudy and match with clean[n]
            for n in sorted(cloudy_by_n.keys()):
                if n not in clean_by_n:
                    print(f"[WARNING] Tile {tile_name}: time {n} has cloudy but no clean.")
                    continue

                cloudy_path = cloudy_by_n[n][0]   # mono-temporal choose first
                clean_path = clean_by_n[n]

                # If stride is not provided → old behavior: 1 sample per image
                if self.stride is None:
                    self.samples.append((cloudy_path, clean_path, None, None))
                    continue

                # Otherwise enumerate patches using stride
                # We must load image shape
                tmp = tiff.imread(cloudy_path)  # (H, W, C)
                H, W, _ = tmp.shape

                ps = self.patch_size
                st = self.stride

                for top in range(0, H - ps + 1, st):
                    for left in range(0, W - ps + 1, st):
                        self.samples.append((cloudy_path, clean_path, top, left))

        print(f"[Sen2MTC loaded] Total samples (including patches): {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def load_tif(self, path):
        arr = tiff.imread(path)  # (H, W, C)
        arr = np.array(arr, dtype=np.float32)
        return arr

    # Patch extraction helper
    def extract_patch(self, img, size):
        """
        img: numpy array shape (C,H,W)
        size: int, patch size
        returns: (C, size, size)
        """
        _, H, W = img.shape
        if size > H or size > W:
            raise ValueError(f"Patch size {size} > image size {(H,W)}")

        if self.center_crop:
            top = (H - size) // 2
            left = (W - size) // 2
        else:
            top = np.random.randint(0, H - size + 1)
            left = np.random.randint(0, W - size + 1)

        patch = img[:, top:top+size, left:left+size]
        return patch

    def __getitem__(self, idx):
        cloudy_path, clean_path, top, left = self.samples[idx]

        cloudy = self.load_tif(cloudy_path)
        clean  = self.load_tif(clean_path)

        cloudy = torch.from_numpy(cloudy.transpose(2,0,1))
        clean  = torch.from_numpy(clean.transpose(2,0,1))

        # Patch extraction
        if self.patch_size is not None:
            if self.stride is not None:
                # predetermined patch from (top, left)
                ps = self.patch_size
                cloudy = cloudy[:, top:top+ps, left:left+ps]
                clean  = clean[:, top:top+ps, left:left+ps]
            else:
                # center or random crop
                cloudy = self.extract_patch(cloudy, self.patch_size)
                clean  = self.extract_patch(clean, self.patch_size)

        sample = {
            "cloudy": cloudy,
            "clean": clean,
            "cloudy_path": cloudy_path,
            "clean_path": clean_path
        }

        if self.transform:
            sample = self.transform(sample)

        return sample


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
route = 'D:\Desktop\Duke 2025 Fall\CS 372\Final Project\372 Data\CTGAN\Sen2_MTC\dataset\Sen2_MTC'    # route to data
size = 128              # if size = patch_size = stride -> no overlapping sampling
batch_size = 16

In [17]:
def compute_global_stats(dataset, batch_size=32):
    """
    Compute global per-channel mean and std over a dataset.
    Returns:
        mean, std
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    channel_sum = None
    channel_sq_sum = None
    total_pixels = 0

    for batch in loader:
        x = batch["cloudy"].double()   # (B,C,H,W)
        B, C, H, W = x.shape

        if channel_sum is None:
            channel_sum = torch.zeros(C, dtype=torch.float64)
            channel_sq_sum = torch.zeros(C, dtype=torch.float64)

        # Sum over batch and spatial dims
        channel_sum += x.sum(dim=[0, 2, 3])
        channel_sq_sum += (x ** 2).sum(dim=[0, 2, 3])

        total_pixels += B * H * W

    mean = channel_sum / total_pixels
    std = torch.sqrt(channel_sq_sum / total_pixels - mean**2)

    print("Global mean:", mean)
    print("Global std:", std)

    return mean.float(), std.float()

In [18]:
class Normalization:
    """
    Normalize cloudy and clean patches using precomputed mean/std.
    """

    def __init__(self, mean, std):
        self.mean = mean.reshape(-1, 1, 1)   # (C,1,1)
        self.std  = std.reshape(-1, 1, 1)    # (C,1,1)

    def __call__(self, sample):
        cloudy = sample["cloudy"]
        clean  = sample["clean"]

        cloudy_n = (cloudy - self.mean) / (self.std + 1e-6)
        clean_n  = (clean  - self.mean) / (self.std + 1e-6)

        return {
            "cloudy": cloudy_n,
            "clean": clean_n,
            "cloudy_path": sample["cloudy_path"],
            "clean_path": sample["clean_path"],
        }

In [19]:
dataset_raw = SatImage_Dataloader(route=route, patch_size=size, stride=size)
mean, std = compute_global_stats(dataset_raw, batch_size=batch_size)
normalized = Normalization(mean, std)

dataset = SatImage_Dataloader(route=route, patch_size=size, stride=size, transform=normalized)
train_ratio=.7; val_ratio=.15; test_ratio=.15

total       = len(dataset)
train_len   = int(train_ratio*total)
val_len     = int(val_ratio*total)+1
test_len    = int(test_ratio*total)

generator = torch.Generator().manual_seed(2025)
train_set, val_set, test_set = random_split(dataset, 
                                            [train_len, val_len, test_len],
                                            generator=generator) #reproducability

train_loader    = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader      = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader     = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
#expect ~40s run

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: 'D:\\Desktop\\Duke 2025 Fall\\CS 372\\Final Projectú Data\\CTGAN\\Sen2_MTC\\dataset\\Sen2_MTC'

In [22]:
route = "D:/Desktop/Duke 2025 Fall/CS 372/Final Project Data/CTGAN/Sen2_MTC/dataset/Sen2_MTC"

print(repr(route))
print(os.path.exists(route))

'D:/Desktop/Duke 2025 Fall/CS 372/Final Project Data/CTGAN/Sen2_MTC/dataset/Sen2_MTC'
False


In [7]:
for batch in train_loader:
    print(batch['cloudy'].shape)
    print(batch['clean'].shape)
    x = batch['cloudy']
    print(x.mean(), x.std())
    break
#[batch_size, channels, height, width] -> [batch_size, channels, patch_size, patch_size]

torch.Size([16, 4, 128, 128])
torch.Size([16, 4, 128, 128])
tensor(-0.0600) tensor(0.9697)


In [8]:
import torch.nn as nn
import torch.nn.functional as F

# structure ok
class ResidualBlock(nn.Module):
    """
    Small residual conv block:
    in -> Conv -> GN -> SiLU -> Conv -> GN -> +skip
    """
    def __init__(self, in_channels, out_channels, num_groups=8):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.gn1   = nn.GroupNorm(num_groups=min(num_groups, out_channels), num_channels=out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gn2   = nn.GroupNorm(num_groups=min(num_groups, out_channels), num_channels=out_channels)
        self.silu   = nn.SiLU(inplace=True)

        # if channel dims change, use 1*1 conv for skip
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.skip = nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.conv1(x)
        out = self.gn1(out)
        out = self.silu(out)
        out = self.conv2(out)
        out = self.gn2(out)
        out = out + identity
        out = self.silu(out)
        return out

class CloudEncoder(nn.Module):
    """
    Cloud encoder based on CNN.

    Input:
        x: (B, in_channels, H, W), normalize required

    Output:
        z: (B, latent_dim) cloud embeddings

    Structure:
        - conv
        - #num_stages residual + downsample stages
        - Final residual block
        - Global average pooling
        - Linear projection to latent_dim
    """
    def __init__(
        self,
        in_channels = 4,
        base_channels = 32,
        num_stages = 3,
        latent_dim = 128,
        num_groups = 8,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.base_channels = base_channels
        self.num_stages = num_stages
        self.latent_dim = latent_dim

        layers = []

        # Initial conv to get to base_channels
        layers.append(
            nn.Sequential(
                nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1),
                nn.GroupNorm(num_groups=min(num_groups, base_channels), num_channels=base_channels),
                nn.SiLU(inplace=True),
            )
        )

        in_ch = base_channels
        channels = [base_channels * (2 ** i) for i in range(num_stages)]

        # Residual + downsample stages
        self.down_blocks = nn.ModuleList()
        self.downsamples = nn.ModuleList()

        for out_ch in channels:
            self.down_blocks.append(ResidualBlock(in_ch, out_ch, num_groups=num_groups))
            # stride-2 conv for downsampling
            self.downsamples.append(
                nn.Conv2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)
            )
            in_ch = out_ch

        # A final residual block at lowest resolution
        self.final_block = ResidualBlock(in_ch, in_ch, num_groups=num_groups)

        # Register first stem as single module for clarity
        self.stem = layers[0]
        
        # Projection to latent space
        self.proj = nn.Linear(in_ch, latent_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, H, W)
        returns:
            z: (B, latent_dim)
        """
        # Stem
        h = self.stem(x)  # (B, base_channels, H, W)

        # Downsampling stages
        for block, down in zip(self.down_blocks, self.downsamples):
            h = block(h)
            h = down(h)  # spatial size halves each time

        # Final block
        h = self.final_block(h)

        # Global average pooling over spatial dims: (B, C, 1, 1) -> (B, C)
        h = F.adaptive_avg_pool2d(h, output_size=1).squeeze(-1).squeeze(-1)

        # Project to latent_dim
        z = self.proj(h)  # (B, latent_dim)

        return z

In [9]:
def make_noise_schedule(T=500, sigma_min=1e-4, sigma_max=.02):
    sigmas = torch.linspace(sigma_min, sigma_max, T)
    return sigmas

def sample_timesteps(batch_size, T, device):
    out = torch.randint(low=1, high=T, size=(batch_size,)).to(device)   
    return out

def forward_diffusion(clean, cloudy, t, sigmas):
    """
    Mean-Reverting Diffusion (MRDM) forward process:
        x_t = cloudy + sigma_t * eps
        cloudy as mean, drive noise schedule to approximate cloudy mean

    Args:
        clean:   (B, C, H, W) clean images (only needed for generating eps)
        cloudy:  (B, C, H, W) cloudy images (used as drift center μ_t)
        t:       (B,) integer timesteps
        sigmas:  (T,) schedule of σ_t

    Returns:
        x_t:   noisy image
        eps:   noise used in generation (target for diffusion loss)
        mu_t:  drift mean (cloudy)
    """

    B = clean.shape[0]
    # Sample Gaussian noise
    eps = torch.randn_like(clean)

    sigma_t = sigmas[t].view(B, 1, 1, 1)
    mu_t = cloudy
    x_t = mu_t + sigma_t * eps

    return x_t, eps, mu_t

class ForwardDiffusion(nn.Module):
    """
    Wrapper.

    Usage:
        forwarder = ForwardDiffusionMRDM(T=1000)
        x_t, eps, mu_t = forwarder(clean, cloudy, t)
    """
    def __init__(self, T=1000, sigma_min=0.0001, sigma_max=0.02):
        super().__init__()
        self.T = T

        # register sigmas as buffer so they move with .to(device)
        sigmas = make_noise_schedule(T=T, sigma_min=sigma_min, sigma_max=sigma_max)
        self.register_buffer("sigmas", sigmas)

    def sample_t(self, batch_size, device):
        return sample_timesteps(batch_size, self.T, device)

    def forward(self, clean, cloudy, t):
        return forward_diffusion(clean, cloudy, t, self.sigmas)

In [10]:
from diffusers import UNet2DModel

class CloudConditionedUNet(nn.Module):
    """
    Wrapper around Hugging Face UNet2DModel with:
      - in_channels = 4, out_channels = 4
      - conditioning on cloud embedding z_cloud via a per-channel bias
      - internal timestep embedding (no t_emb needed externally)

    Forward signature:
        eps_pred = model(x_t, t, z_cloud)
    """

    def __init__(
        self,
        base_model_name: str = "google/ddpm-cifar10-32",
        in_channels: int = 4,
        out_channels: int = 4,
        latent_dim: int = 128,
    ):
        super().__init__()

        # 1. Load pretrained UNet (3->3 by default)
        base_unet = UNet2DModel.from_pretrained(base_model_name)

        # 2. Copy config and modify channels
        config = base_unet.config
        config.in_channels = in_channels
        config.out_channels = out_channels

        # 3. Rebuild UNet with same architecture, but new channel counts
        self.unet = UNet2DModel.from_config(config)

        # 4. Optionally, copy overlapping weights for conv_in/conv_out (3 -> 4 channels)
        with torch.no_grad():
            # conv_in: (C_out, C_in, k, k)
            old_w = base_unet.conv_in.weight    # (C_out, 3, k, k)
            new_w = self.unet.conv_in.weight    # (C_out, 4, k, k)
            new_w[:, :3, :, :] = old_w
            # last channel (index 3) stays random-initialized

            self.unet.conv_in.bias.data.copy_(base_unet.conv_in.bias.data)

            # conv_out: (4, C_in, k, k) vs old (3, C_in, k, k)
            old_w_out = base_unet.conv_out.weight   # (3, C_in, k, k)
            new_w_out = self.unet.conv_out.weight   # (4, C_in, k, k)
            new_w_out[:3, :, :, :] = old_w_out
            # 4th channel random-init
            self.unet.conv_out.bias.data[:3] = base_unet.conv_out.bias.data

        # 5. Conditioning MLP: z_cloud -> per-channel bias on input
        self.cond_proj = nn.Linear(latent_dim, in_channels)

    def forward(self, x_t: torch.Tensor, t: torch.Tensor, z_cloud: torch.Tensor):
        """
        x_t:      (B, 4, H, W)  noisy sample at timestep t
        t:        (B,) or scalar int/float timestep; UNet2DModel embeds this internally
        z_cloud:  (B, latent_dim) from your CloudEncoder

        Returns:
            eps_pred: (B, 4, H, W) predicted noise
        """

        # project cloud embedding to per-channel bias
        cond = self.cond_proj(z_cloud)         # (B, 4)
        cond = cond.unsqueeze(-1).unsqueeze(-1)  # (B, 4, 1, 1)

        # inject conditioning as bias on input channels
        x_cond = x_t + cond

        # HuggingFace UNet2DModel handles timestep embedding internally
        out = self.unet(sample=x_cond, timestep=t)

        # out is a UNet2DOutput; main tensor is .sample
        eps_pred = out.sample
        return eps_pred
    


  from .autonotebook import tqdm as notebook_tqdm


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------
# 1. Simple sinusoidal timestep embedding
# ---------------------------------------------------------------------
def timestep_embedding(t, dim):
    half = dim // 2
    freqs = torch.exp(
        torch.arange(half, dtype=torch.float32, device=t.device)
        * (-torch.log(torch.tensor(10000.0)) / (half - 1))
    )
    args = t[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    return emb


# ---------------------------------------------------------------------
# 2. Basic ResNet block (Conv-Norm-Act-Conv-Norm-Act)
# ---------------------------------------------------------------------
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_dense = nn.Linear(time_emb_dim, out_ch)

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.act = nn.SiLU()

        # shortcut if needed
        self.shortcut = (
            nn.Conv2d(in_ch, out_ch, 1)
            if in_ch != out_ch
            else nn.Identity()
        )

    def forward(self, x, t_emb):
        """
        x:     (B, C, H, W)
        t_emb: (B, time_emb_dim)
        """
        h = self.conv1(x)
        h = self.norm1(h)
        h = self.act(h)

        # add time embedding
        t_added = self.time_dense(t_emb)[:, :, None, None]
        h = h + t_added

        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act(h)

        return h + self.shortcut(x)


# ---------------------------------------------------------------------
# 3. Minimal UNet for 4 channel diffusion
# ---------------------------------------------------------------------
class SmallUNet4C(nn.Module):
    def __init__(
        self,
        in_channels=4,
        out_channels=4,
        base_channels=64,
        time_emb_dim=256,
    ):
        super().__init__()

        # time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # ---------- Downsampling ----------
        self.conv_in = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        self.down1 = ResBlock(base_channels, base_channels * 2, time_emb_dim)
        self.down2 = ResBlock(base_channels * 2, base_channels * 4, time_emb_dim)

        self.pool = nn.MaxPool2d(2)

        # ---------- Middle ----------
        self.mid = ResBlock(base_channels * 4, base_channels * 4, time_emb_dim)

        # ---------- Upsampling ----------
        self.up1 = ResBlock(base_channels * 4 + base_channels * 4, base_channels * 2, time_emb_dim)
        self.up2 = ResBlock(base_channels * 2 + base_channels * 2, base_channels, time_emb_dim)

        self.conv_out = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def forward(self, x, t):
        # t: integer or float timestep
        if len(t.shape) == 0:
            t = t.unsqueeze(0)
        t_emb = timestep_embedding(t, self.time_mlp[0].in_features)
        t_emb = self.time_mlp(t_emb)

        # Down
        x1 = self.conv_in(x)            # (B, 64, H, W)
        x2 = self.down1(x1, t_emb)      # (B, 128, H, W)
        x3 = self.pool(x2)
        x3 = self.down2(x3, t_emb)      # (B, 256, H/2, W/2)

        # Middle
        xm = self.mid(x3, t_emb)

        # Up
        u1 = torch.cat([xm, x3], dim=1)
        u1 = self.up1(u1, t_emb)        # (B, 128, H/2, W/2)

        u2 = F.interpolate(u1, scale_factor=2, mode="nearest")
        u2 = torch.cat([u2, x2], dim=1)
        u2 = self.up2(u2, t_emb)        # (B, 64, H, W)

        out = self.conv_out(u2)
        return out

class CloudConditionedUNet_4C(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, latent_dim=128):
        super().__init__()

        self.unet = SmallUNet4C(
            in_channels=in_channels,
            out_channels=out_channels,
        )

        self.cond_proj = nn.Linear(latent_dim, in_channels)

    def forward(self, x_t, t, z_cloud):
        cond = self.cond_proj(z_cloud).unsqueeze(-1).unsqueeze(-1)
        x_cond = x_t + cond
        eps_pred = self.unet(x_cond, t)
        return eps_pred


In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

forwarder = ForwardDiffusion(T=500).to(device)
cloud_encoder = CloudEncoder(
    in_channels=4,
    base_channels=32,
    num_stages=2,
    latent_dim=128
).to(device)
"""
denoiser = CloudConditionedUNet(
    base_model_name="google/ddpm-cifar10-32",
    in_channels=4,
    out_channels=4,
    latent_dim=128
).to(device)
"""
denoiser = CloudConditionedUNet_4C(
    in_channels=4,
    out_channels=4,
    latent_dim=128
).to(device)


params = list(cloud_encoder.parameters()) + list(denoiser.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
epochs = 1

def forward_trainer(epochs, train_loader, optimizer, cloud_encoder, denoiser, device):
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch in train_loader:
            cloudy = batch['cloudy'].to(device)
            clean  = batch['clean'].to(device)
            B = cloudy.shape[0]

            t = forwarder.sample_t(B, device=device)
            x_t, eps, mu_t = forwarder(clean, cloudy, t)
            z_cloud = cloud_encoder(cloudy)
            eps_pred = denoiser(x_t, t, z_cloud)
        
            loss = torch.mean((eps_pred - eps)**2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        avg_loss = epoch_loss / num_batches
        print(f"[Epoch {epoch}] loss = {avg_loss:.6f}")

temp_train = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=0)
forward_trainer(epochs, temp_train, optimizer, cloud_encoder, denoiser, device)

batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete
batch complete


KeyboardInterrupt: 