In [1]:
from types import SimpleNamespace

import wandb
import torch

from cloudcasting.constants import NUM_CHANNELS, DATA_INTERVAL_SPACING_MINUTES

from cloud_diffusion.dataset import CloudcastingDataset
from cloud_diffusion.utils import set_seed
from cloud_diffusion.ddpm import noisify_ddpm, ddim_sampler
from cloud_diffusion.models import UNet2D
from cloud_diffusion.vae import TemporalVAEAdapter


from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn.functional as F

from fastprogress import progress_bar

import matplotlib.pyplot as plt
import numpy as np

channel_names = ['IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120', 'IR_134',
       'VIS006', 'VIS008', 'WV_062', 'WV_073']

PROJECT_NAME = "nathan-test"
MERGE_CHANNELS = False
DEBUG = True
LOCAL = False

config = SimpleNamespace(
    img_size=256,
    epochs=50,  # number of epochs
    model_name="latent-diffusion",  # model name to save [unet_small, unet_big]
    strategy="ddpm",  # strategy to use [ddpm, simple_diffusion]
    noise_steps=1000,  # number of noise steps on the diffusion process
    sampler_steps=300,  # number of sampler steps on the diffusion process
    seed=42,  # random seed
    batch_size=2,  # batch size
    device="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
    num_workers=0 if DEBUG else 35,  # number of workers for dataloader
    num_frames=4,  # number of frames to use as input (includes noise frame)
    lr=5e-4,  # learning rate
    log_every_epoch=1,  # log every n epochs to wandb
    n_preds=8,  # number of predictions to make
    latent_dim=4,
    vae_lr=5e-5,
)

device = config.device


HISTORY_STEPS = config.num_frames - 1


config.model_params = dict(
    block_out_channels=(32, 64, 128, 256),  # number of channels for each block
    norm_num_groups=8,  # number of groups for the normalization layer
    in_channels=config.num_frames * config.latent_dim,  # number of input channels
    out_channels=config.latent_dim,  # number of output channels
)

set_seed(config.seed)

if LOCAL:
    TRAINING_DATA_PATH = VALIDATION_DATA_PATH = "/users/nsimpson/Code/climetrend/cloudcast/2020_training_nonhrv.zarr"
else:
    TRAINING_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2021_nonhrv.zarr"
    VALIDATION_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2022_training_nonhrv.zarr"
# Instantiate the torch dataset object
train_ds = CloudcastingDataset(
    config.img_size,
    valid=False,
    # strategy="resize",
    zarr_path=TRAINING_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)
# worth noting they do some sort of shuffling here; we don't for now
valid_ds = CloudcastingDataset(
    config.img_size,
    valid=True,
    # strategy="resize",
    zarr_path=VALIDATION_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)

train_dataloader = DataLoader(train_ds, config.batch_size, shuffle=True,  num_workers=config.num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_ds, config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)

ModuleNotFoundError: No module named 'fastprogress'

In [None]:
vae = TemporalVAEAdapter(AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae.load_state_dict('f/bask/projects/v/vjgo8416-climate/users/gmmg6904/cloud_diffusion/models/omf5mrig_cloud-finetune--vae.pth

In [2]:
# model setup
unet = UNet2D(**config.model_params).to(device)
# vae = get_hacked_vae().to(device)#.float()
from diffusers.models import AutoencoderKL
vae = TemporalVAEAdapter(AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)).to(device)

# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).float()
# for param in vae.parameters():
#     param.requires_grad = False

# sampler
sampler = ddim_sampler(steps=config.sampler_steps)

# configure training
# wandb.config.update(config)
config.total_train_steps = config.epochs * len(train_dataloader)

# Create parameter groups with different learning rates
param_groups = [
    {
        'params': [p for p in vae.parameters() if p.requires_grad],
        'lr': config.vae_lr,
        'eps': 1e-5
    },
    {
        'params': unet.parameters(),
        'lr': config.lr,
        'eps': 1e-5
    }
]
optimizer = AdamW(param_groups)
scheduler = OneCycleLR(optimizer, max_lr=[config.vae_lr, config.lr], total_steps=config.total_train_steps)
scaler = torch.amp.GradScaler("cuda")

# get a validation batch for logging
val_batch = next(iter(valid_dataloader))[0:2].to(device)  # log first 2 predictions
print(val_batch.shape)

torch.Size([2, 11, 4, 256, 256])


In [3]:
def check_tensor(tensor, name, print_stats=False):
    """Utility function to check tensor for NaN/Inf values and optionally print statistics"""
    if torch.isnan(tensor).any():
        print(f"NaN detected in {name}")
        return False
    if torch.isinf(tensor).any():
        print(f"Inf detected in {name}")
        return False
    if print_stats:
        print(f"{name} stats: min={tensor.min().item():.3f}, max={tensor.max().item():.3f}, "
              f"mean={tensor.mean().item():.3f}, std={tensor.std().item():.3f}")
    return True

In [4]:
assert torch.isnan(val_batch).sum() == 0

In [10]:
def visualize_channels_over_time(images,  batch_idx=0, figsize=(12, 8), cmap='viridis'):
    """
    Visualize multi-channel images over time, handling both single and multiple timesteps.
    
    Args:
        images: Tensor of shape (batch, channels, time, height, width)
        channel_names: List of names for each channel
        batch_idx: Which batch element to visualize
        figsize: Size of the figure
        cmap: Colormap to use for visualization
    """
    n_channels = images.shape[1]
    n_timesteps = images.shape[2]
    

    
    # Create a grid of subplots
    if n_timesteps == 1:
        fig, axes = plt.subplots(n_channels, 1, figsize=figsize)
        axes = axes.reshape(-1, 1)  # Reshape to 2D array for consistent indexing
    else:
        fig, axes = plt.subplots(n_channels, n_timesteps, figsize=figsize)
    
    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.5)
    
    # Iterate through channels and timesteps
    for channel in range(n_channels):
        for timestep in range(n_timesteps):
            # Get the current image
            img = images[batch_idx, channel, timestep].numpy()
            
            # Normalize the image for better visualization
            # img_min = img.min()
            # img_max = img.max()
            # if img_max > img_min:
            #     img = (img - img_min) / (img_max - img_min)
            
            # Plot the image
            im = axes[channel, timestep].imshow(img, cmap=cmap, origin='lower')
            axes[channel, timestep].axis('off')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[channel, timestep], fraction=0.046, pad=0.04)
            
            # Add titles
            if channel == 0:
                extra = ' (predicted)' if (timestep-3) == 0 else ''
                axes[channel, timestep].set_title(f'frame {timestep-3}' + extra)
            if timestep == 0:
                axes[channel, timestep].text(-10, 32, channel_names[channel], 
                                          rotation=0, ha='right', va='center')

    return fig


In [None]:
import torch
import torch.nn.functional as F
from cloud_diffusion.wandb import save_model

vae_loss_scale = 1
print_stats = False

# Modified training loop with checks
for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
    unet.train()
    pbar = progress_bar(train_dataloader, leave=False)
    for batch in pbar:
        if torch.isnan(batch).all():
            continue
            
        batch = torch.nan_to_num(batch, nan=0)
        img_batch = batch.to(device)
        
    # with torch.autocast(device):   # we want this but NaNs happen in the encoder :(
        latents = vae.encode_frames(img_batch)
            
        past_frames = latents[:, :, :-1]
        last_frame = latents[:, :, -1]
        
        noised_img, t, noise = noisify_ddpm(last_frame)
        
        past_frames = past_frames.permute(0, 2, 1, 3, 4)
        past_frames = past_frames.reshape(latents.shape[0], -1, latents.shape[3], latents.shape[4])
        
        diffusion_input = torch.cat([past_frames, noised_img], dim=1)
        predicted_noise = unet(diffusion_input, t)
        diffusion_loss = F.mse_loss(predicted_noise, noise)
        
        img_batch_hat = vae.decode_frames(latents)
        vae_loss = F.mse_loss(img_batch_hat, img_batch)
        
        loss = diffusion_loss + vae_loss * vae_loss_scale
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        pbar.comment = f"epoch={epoch}, vae_loss={vae_loss.item():2.3f}, diffusion_loss={diffusion_loss.item():2.3f}"

    if epoch % config.log_every_epoch == 0:
        with torch.no_grad():
            val_latents = vae.encode_frames(val_batch)
            check_tensor(val_latents, 'val_latents', print_stats=True)
            past_val_frames = val_latents[:, :, :-1]
            last_val_frame = val_latents[:, :, -1]  # intentionally remove time dim
            past_val_frames = past_val_frames.permute(0, 2, 1, 3, 4).reshape(val_latents.shape[0], -1, val_latents.shape[3], val_latents.shape[4])
            samples = sampler(unet, past_frames=past_val_frames, num_channels=4)
            samples = samples.unsqueeze(dim=2)
            check_tensor(samples, 'samples', print_stats=True)
            decoded = vae.decode_frames(samples).cpu()

        valid_plot = visualize_channels_over_time(torch.cat((val_batch[:,:,:-1].detach().cpu(), decoded), dim=2));
        wandb.log({"all-channels": valid_plot})

save_model(vae, config.model_name + '-unet')
save_model(unet, config.model_name + '-vae')


