In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from diffusers import AutoencoderKL, UNet2DModel, DDPMScheduler
#from torchvision.transforms import Compose, ToTensor, Normalize
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch.utils.data import Dataset

class SSTDatasetWithMask(Dataset):
    def __init__(self, sst_data, masks, sequence_length=6):
        """
        SST Dataset with masks to generate (context frames, target frame, mask).

        Args:
            sst_data (torch.Tensor): SST data of shape (time, lat, lon).
            masks (torch.Tensor): Masks of the same shape as sst_data.
            sequence_length (int): Number of frames to condition on.
        """
        
        self.sst_data = sst_data
        self.masks = masks
        self.sequence_length = sequence_length

    def __len__(self):
        return self.sst_data.size(0) - self.sequence_length

    def __getitem__(self, idx):
        frames = self.sst_data[idx:idx + self.sequence_length + 1]
        mask = self.masks[idx + self.sequence_length].unsqueeze(0)
        return frames, mask

In [None]:
class LatentDiffusionTrainer:
    def __init__(self, vae, unet, scheduler, optimizer):
        self.vae = vae
        self.unet = unet
        self.scheduler = scheduler
        self.optimizer = optimizer

    def train(self, dataloader, num_epochs=5, device="cuda"):
        print("Start train", flush = True)
        self.vae.to(device).eval()  # VAE stays in eval mode
        self.unet.to(device)

        for epoch in range(num_epochs):
            for context_frames, mask in dataloader:
                print(context_frames.shape, flush = True)

                # Move data to device
                context_frames = context_frames.to(device)  # Shape: (B, 4, H, W)
                target_frame = context_frames[:, 3:, :, :]
                context_frames = context_frames[:, :3, :, :]
                
                
                print(context_frames.shape, flush = True)
                print(target_frame.shape, flush = True)

                # Encode images into latent space
                with torch.no_grad():
                    latent_contexts = self.vae.encode(context_frames).latent_dist.sample()
                    print(latent_contexts.shape, flush = True)
                    latent_target = self.vae.encode(target_frame).latent_dist.sample()
                    print(latent_target.shape, flush = True)
                
                # Add noise to target latent
                noise = torch.randn_like(latent_target)
                timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (target_frame.size(0),), device=device)
                noisy_latent_target = self.scheduler.add_noise(latent_target, noise, timesteps)


                input_val = torch.cat([latent_contexts, noisy_latent_target], dim = 1)
                
                print(input_val.shape, flush = True)

                # Predict noise using the UNet
                outputs = self.unet(input_val, timesteps).sample

                # Compute loss
                loss = (nn.MSELoss(reduction="none")(outputs, noise)* mask).mean()
                
                # Backpropagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")


In [None]:
def main():
    # Load processed SST data and masks
    sst_data = np.load('processed_sst_data.npy')  # Shape: (time, lat, lon)
    masks = np.load('sst_masks.npy')  # Shape: (time, lat, lon)

    # Ensure the data and masks match
    assert sst_data.shape == masks.shape, "Data and masks must have the same shape."

    sst_data_tensor = torch.tensor(sst_data, dtype=torch.float32)
    masks_tensor = torch.tensor(masks, dtype=torch.float32)
    
    # Dataset and DataLoader
    sequence_length = 3  # Conditioning on the last 6 frames
    dataset = SSTDatasetWithMask(sst_data_tensor, masks_tensor, sequence_length=sequence_length)

    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

    # Load modeling components
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

    
    unet = UNet2DModel(
        sample_size=64,  # Latent space resolution
        in_channels=4,   # Channel size for the noisy frame
        out_channels=1   # Channel size for the predicted clean frame
    )
    
    # Diffusion Scheduler
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    
    # Optimizer
    optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

    # Trainer
    trainer = LatentDiffusionTrainer(vae, unet, scheduler, optimizer)

    # Train the model
    trainer.train(dataloader, num_epochs=10, device="cuda" if torch.cuda.is_available() else "cpu")


print("Start", flush = True)
main()

Start


: 