In [1]:
from types import SimpleNamespace

import wandb
import torch

from cloudcasting.constants import NUM_CHANNELS, DATA_INTERVAL_SPACING_MINUTES

from cloud_diffusion.dataset import CloudcastingDataset
from cloud_diffusion.utils import set_seed
from cloud_diffusion.ddpm import noisify_ddpm, ddim_sampler
from cloud_diffusion.models import UNet2D
from cloud_diffusion.vae import get_hacked_vae, encode_frames, decode_frames


from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn.functional as F

from fastprogress import progress_bar

import matplotlib.pyplot as plt
import numpy as np

channel_names = ['IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120', 'IR_134',
       'VIS006', 'VIS008', 'WV_062', 'WV_073']

PROJECT_NAME = "nathan-test"
MERGE_CHANNELS = False
DEBUG = True
LOCAL = True

config = SimpleNamespace(
    img_size=256,
    epochs=50,  # number of epochs
    model_name="uvit-test",  # model name to save [unet_small, unet_big]
    strategy="ddpm",  # strategy to use [ddpm, simple_diffusion]
    noise_steps=1000,  # number of noise steps on the diffusion process
    sampler_steps=300,  # number of sampler steps on the diffusion process
    seed=42,  # random seed
    batch_size=2,  # batch size
    device="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
    num_workers=0 if DEBUG else 8,  # number of workers for dataloader
    num_frames=4,  # number of frames to use as input (includes noise frame)
    lr=5e-4,  # learning rate
    log_every_epoch=1,  # log every n epochs to wandb
    n_preds=8,  # number of predictions to make
    latent_dim=4,
    vae_lr=5e-5,
)

device = config.device


HISTORY_STEPS = config.num_frames - 1


config.model_params = dict(
    block_out_channels=(32, 64, 128, 256),  # number of channels for each block
    norm_num_groups=8,  # number of groups for the normalization layer
    in_channels=config.num_frames * config.latent_dim,  # number of input channels
    out_channels=config.latent_dim,  # number of output channels
)

set_seed(config.seed)

if LOCAL:
    TRAINING_DATA_PATH = VALIDATION_DATA_PATH = "/users/nsimpson/Code/climetrend/cloudcast/2020_training_nonhrv.zarr"
else:
    TRAINING_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2021_nonhrv.zarr"
    VALIDATION_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2022_training_nonhrv.zarr"
# Instantiate the torch dataset object
train_ds = CloudcastingDataset(
    config.img_size,
    valid=False,
    # strategy="resize",
    zarr_path=TRAINING_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)
# worth noting they do some sort of shuffling here; we don't for now
valid_ds = CloudcastingDataset(
    config.img_size,
    valid=True,
    # strategy="resize",
    zarr_path=VALIDATION_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)

train_dataloader = DataLoader(train_ds, config.batch_size, shuffle=True,  num_workers=config.num_workers)
valid_dataloader = DataLoader(valid_ds, config.batch_size, shuffle=False, num_workers=config.num_workers)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
X = next(iter(train_dataloader))
X.shape

torch.Size([2, 11, 4, 256, 256])

In [3]:
# model setup
unet = UNet2D(**config.model_params).to(device)
vae = get_hacked_vae().to(device)

# sampler
sampler = ddim_sampler(steps=config.sampler_steps)

# configure training
# wandb.config.update(config)
config.total_train_steps = config.epochs * len(train_dataloader)

# Create parameter groups with different learning rates
param_groups = [
    {
        'params': [p for p in vae.parameters() if p.requires_grad],
        'lr': config.vae_lr,
        'eps': 1e-5
    },
    {
        'params': unet.parameters(),
        'lr': config.lr,
        'eps': 1e-5
    }
]
optimizer = AdamW(param_groups)
scheduler = OneCycleLR(optimizer, max_lr=[config.vae_lr, config.lr], total_steps=config.total_train_steps)
scaler = torch.amp.GradScaler("cuda")

# get a validation batch for logging
val_batch = next(iter(valid_dataloader))[0:2].to(device)  # log first 2 predictions
print(val_batch.shape)

Unfreezing conv_in.weight
Unfreezing conv_in.bias
Unfreezing conv_out.weight
Unfreezing conv_out.bias
torch.Size([2, 11, 4, 256, 256])




In [4]:

vae.half()

for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
    vae.train()
    unet.train()
    pbar = progress_bar(train_dataloader, leave=False)
    for batch in pbar:
        if torch.isnan(batch).all():
            continue

        batch = torch.as_tensor(batch, dtype=torch.float16)

        # (batch, channels, time, height, width)
        batch = torch.nan_to_num(batch, nan=0)
        img_batch = batch.to(device)
        with torch.autocast(device):

            # vae math
            latents = encode_frames(img_batch, vae)
           
            # diffusion math
            past_frames = latents[:, :, :-1]
            last_frame = latents[:, :, -1]  # intentionally remove time dim

            noised_img, t, noise = noisify_ddpm(last_frame)

            # flip the time and channel dimensions before merging them.
            # for time = t, channel = c, the order of the flattened input is now:
            # [t1c1, t1c2, ..., t2c1, t2c2, ...]
            # I don't know if this makes a difference, but it keeps things more consistent!
            past_frames = past_frames = past_frames.permute(0, 2, 1, 3, 4).reshape(latents.shape[0], -1, latents.shape[3], latents.shape[4])
            # concatenate on channel dim
            diffusion_input = torch.cat([past_frames, noised_img], dim=1)

            # diffusion loss calc
            predicted_noise = unet(diffusion_input, t)
            diffusion_loss = F.mse_loss(predicted_noise, noise)

            # also calculate a reconstruction loss for the vae on our encoded frames
            latents = latents.half()
            img_batch_hat = decode_frames(latents, vae)
            vae_loss = F.mse_loss(img_batch_hat, img_batch)

            # total loss (arbitrary weighting)
            loss = diffusion_loss + 0.1*vae_loss

            break

            # backprop 
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # wandb.log({"train_mse": loss.item(), "learning_rate": scheduler.get_last_lr()[0], "vae_loss": vae_loss.item(), "diffusion_loss": diffusion_loss.item()})
        pbar.comment = f"epoch={epoch}, vae_loss={vae_loss.item():2.3f}, diffusion_loss={diffusion_loss.item():2.3f}"

    with torch.no_grad():
        val_latents = encode_frames(val_batch.half(), vae)
        past_val_frames = val_latents[:, :, :-1]
        last_val_frame = val_latents[:, :, -1]  # intentionally remove time dim
        past_val_frames = past_val_frames.permute(0, 2, 1, 3, 4).reshape(val_latents.shape[0], -1, val_latents.shape[3], val_latents.shape[4]).float()
        samples = sampler(unet, past_frames=past_val_frames, num_channels=4)




    # # log predictions
    # if epoch % config.log_every_epoch == 0:
    #     samples = ...
    #     log_images(val_batch, samples)

# save_model(vae, config.model_name)




  hidden_states = F.scaled_dot_product_attention(


new_frame.shape=torch.Size([2, 4, 32, 32])
past_frames.shape=torch.Size([2, 12, 32, 32])


new_frame.shape=torch.Size([2, 4, 32, 32])
past_frames.shape=torch.Size([2, 12, 32, 32])


KeyboardInterrupt: 

In [None]:
def visualize_channels_over_time(images,  batch_idx=0, figsize=(12, 8), cmap='viridis'):
    """
    Visualize multi-channel images over time, handling both single and multiple timesteps.
    
    Args:
        images: Tensor of shape (batch, channels, time, height, width)
        channel_names: List of names for each channel
        batch_idx: Which batch element to visualize
        figsize: Size of the figure
        cmap: Colormap to use for visualization
    """
    n_channels = images.shape[1]
    n_timesteps = images.shape[2]
    

    
    # Create a grid of subplots
    if n_timesteps == 1:
        fig, axes = plt.subplots(n_channels, 1, figsize=figsize)
        axes = axes.reshape(-1, 1)  # Reshape to 2D array for consistent indexing
    else:
        fig, axes = plt.subplots(n_channels, n_timesteps, figsize=figsize)
    
    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.5)
    
    # Iterate through channels and timesteps
    for channel in range(n_channels):
        for timestep in range(n_timesteps):
            # Get the current image
            img = images[batch_idx, channel, timestep].numpy()
            
            # Normalize the image for better visualization
            # img_min = img.min()
            # img_max = img.max()
            # if img_max > img_min:
            #     img = (img - img_min) / (img_max - img_min)
            
            # Plot the image
            im = axes[channel, timestep].imshow(img, cmap=cmap, origin='lower')
            axes[channel, timestep].axis('off')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[channel, timestep], fraction=0.046, pad=0.04)
            
            # Add titles
            if channel == 0:
                axes[channel, timestep].set_title(f'frame {timestep-3}')
            if timestep == 0:
                axes[channel, timestep].text(-10, 32, channel_names[channel], 
                                          rotation=0, ha='right', va='center')

    return fig
visualize_channels_over_time(X);

In [None]:
def match_sizes(ground_truth, prediction, method='crop'):
    """
    Match the sizes of ground truth and prediction arrays.
    
    Args:
        ground_truth: Tensor of shape (batch, channels, height1, width1)
        prediction: Tensor of shape (batch, channels, height2, width2)
        method: One of ['crop', 'pad', 'resize']
    
    Returns:
        Tuple of tensors with matched sizes
    """
    import torch.nn.functional as F
    
    if method == 'crop':
        # Crop ground truth to prediction size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_start = h_diff // 2
        w_start = w_diff // 2
        
        gt_matched = ground_truth[:, :, 
                                h_start:h_start + prediction.shape[2],
                                w_start:w_start + prediction.shape[3]]
        pred_matched = prediction
        
    elif method == 'pad':
        # Pad prediction to ground truth size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_pad = (h_diff // 2, h_diff - h_diff // 2)
        w_pad = (w_diff // 2, w_diff - w_diff // 2)
        
        pred_matched = F.pad(prediction, (w_pad[0], w_pad[1], h_pad[0], h_pad[1]))
        gt_matched = ground_truth
        
    elif method == 'resize':
        # Resize prediction to ground truth size using bilinear interpolation
        pred_matched = F.interpolate(prediction, 
                                   size=(ground_truth.shape[2], ground_truth.shape[3]),
                                   mode='bilinear',
                                   align_corners=False)
        gt_matched = ground_truth
        
    else:
        raise ValueError("method must be one of ['crop', 'pad', 'resize']")
        
    return gt_matched, pred_matched

def visualize_prediction_comparison(ground_truth, prediction, batch_idx=0, figsize=(20, 8), 
                                  cmap='viridis', diff_cmap='RdBu_r', size_match='crop'):
    """
    Visualize ground truth, prediction, and their differences side by side for each channel.
    Maintains original data ranges for proper visualization.
    """
    # Match sizes
    gt_matched, pred_matched = match_sizes(ground_truth, prediction, method=size_match)
    n_channels = ground_truth.shape[1]

    pred_matched[torch.isnan(gt_matched)] = torch.nan

    # Create a grid of subplots
    fig, axes = plt.subplots(n_channels, 3, figsize=figsize)

    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.8)
    
    # Iterate through channels
    for channel in range(n_channels):
        # Get images
        gt_img = gt_matched[batch_idx, channel].numpy()
        pred_img = pred_matched[batch_idx, channel].numpy()
        diff_img = pred_img - gt_img
        
        # For difference plot, use symmetric limits around zero
        diff_bound = max(abs(diff_img.min()), abs(diff_img.max()))
        diff_bound = max(diff_bound, 0.1)  # Ensure some minimum range for visualization
        
        # Plot ground truth with original range
        im_gt = axes[channel, 0].imshow(gt_img, cmap=cmap, origin='lower')
        axes[channel, 0].axis('off')
        plt.colorbar(im_gt, ax=axes[channel, 0], fraction=0.046, pad=0.04)
        
        # Plot prediction with same range as ground truth
        im_pred = axes[channel, 1].imshow(pred_img, cmap=cmap, origin='lower')
        axes[channel, 1].axis('off')
        plt.colorbar(im_pred, ax=axes[channel, 1], fraction=0.046, pad=0.04)
        
        # Plot difference with symmetric range
        im_diff = axes[channel, 2].imshow(diff_img, cmap=diff_cmap, origin='lower',
                                        vmin=-diff_bound, vmax=diff_bound)
        axes[channel, 2].axis('off')
        plt.colorbar(im_diff, ax=axes[channel, 2], fraction=0.046, pad=0.04)
        
        # Add channel name on the left
        axes[channel, 0].text(-10, 32, channel_names[channel],
                            rotation=0, ha='right', va='center')
    
    # Add column headers
    axes[0, 0].set_title(f'Ground Truth ({gt_matched.shape[2]}x{gt_matched.shape[3]})')
    axes[0, 1].set_title(f'Prediction ({pred_matched.shape[2]}x{pred_matched.shape[3]})')
    axes[0, 2].set_title('Difference (Pred - GT)')
    return fig

visualize_prediction_comparison(img_batch[:, :,  :, :], res[:, :, :, :], size_match="resize");