In [1]:
from types import SimpleNamespace

import wandb
import torch

from cloudcasting.constants import NUM_CHANNELS, DATA_INTERVAL_SPACING_MINUTES

from cloud_diffusion.dataset import CloudcastingDataset
from cloud_diffusion.utils import set_seed
from cloud_diffusion.ddpm import noisify_ddpm, ddim_sampler
from cloud_diffusion.models import UNet2D
from cloud_diffusion.vae import VAEChannelAdapter


from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn.functional as F

from fastprogress import progress_bar

import matplotlib.pyplot as plt
import numpy as np

channel_names = ['IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120', 'IR_134',
       'VIS006', 'VIS008', 'WV_062', 'WV_073']

PROJECT_NAME = "nathan-test"
MERGE_CHANNELS = False
DEBUG = True
LOCAL = False

config = SimpleNamespace(
    img_size=256,
    epochs=50,  # number of epochs
    model_name="uvit-test",  # model name to save [unet_small, unet_big]
    strategy="ddpm",  # strategy to use [ddpm, simple_diffusion]
    noise_steps=1000,  # number of noise steps on the diffusion process
    sampler_steps=300,  # number of sampler steps on the diffusion process
    seed=42,  # random seed
    batch_size=2,  # batch size
    device="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
    num_workers=0 if DEBUG else 35,  # number of workers for dataloader
    num_frames=4,  # number of frames to use as input (includes noise frame)
    lr=5e-4,  # learning rate
    log_every_epoch=1,  # log every n epochs to wandb
    n_preds=8,  # number of predictions to make
    latent_dim=4,
    vae_lr=5e-5,
)

device = config.device


HISTORY_STEPS = config.num_frames - 1


config.model_params = dict(
    block_out_channels=(32, 64, 128, 256),  # number of channels for each block
    norm_num_groups=8,  # number of groups for the normalization layer
    in_channels=config.num_frames * config.latent_dim,  # number of input channels
    out_channels=config.latent_dim,  # number of output channels
)

set_seed(config.seed)

if LOCAL:
    TRAINING_DATA_PATH = VALIDATION_DATA_PATH = "/users/nsimpson/Code/climetrend/cloudcast/2020_training_nonhrv.zarr"
else:
    TRAINING_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2021_nonhrv.zarr"
    VALIDATION_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2022_training_nonhrv.zarr"
# Instantiate the torch dataset object
train_ds = CloudcastingDataset(
    config.img_size,
    valid=False,
    # strategy="resize",
    zarr_path=TRAINING_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)
# worth noting they do some sort of shuffling here; we don't for now
valid_ds = CloudcastingDataset(
    config.img_size,
    valid=True,
    # strategy="resize",
    zarr_path=VALIDATION_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)

train_dataloader = DataLoader(train_ds, config.batch_size, shuffle=True,  num_workers=config.num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_ds, config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
X = next(iter(train_dataloader))
X.shape

torch.Size([2, 11, 4, 256, 256])

In [3]:
def encode_frames(imgs, vae):
    # imgs shape: [batch, channels, time, height, width]
    B, C, T, H, W = imgs.shape
    
    # Reshape to process all frames at once
    imgs_2d = imgs.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)#.contiguous()
    
    latents = vae.encode(imgs_2d).latent_dist.sample()
    
    # Reshape back to include time dimension
    _, C_out, H_out, W_out = latents.shape
    return latents.reshape(B, T, C_out, H_out, W_out).permute(0, 2, 1, 3, 4)

def decode_frames(latents, vae):
    # latents shape: [batch, channels, time, height, width]
    B, C, T, H, W = latents.shape
    
    # Reshape to process all frames at once
    latents_2d = latents.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)#.contiguous()
    
    # Process through VAE
    decoded = vae.decode(latents_2d).sample
    
    # Reshape back to include time dimension
    _, C_out, H_out, W_out = decoded.shape
    decoded = decoded.reshape(B, T, C_out, H_out, W_out).permute(0, 2, 1, 3, 4)
    
    # Map channels
    return decoded

In [4]:
# model setup
unet = UNet2D(**config.model_params).to(device)
# vae = get_hacked_vae().to(device)#.float()
from diffusers.models import AutoencoderKL
vae = VAEChannelAdapter(AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)).to(device)

# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).float()
# for param in vae.parameters():
#     param.requires_grad = False

# sampler
sampler = ddim_sampler(steps=config.sampler_steps)

# configure training
# wandb.config.update(config)
config.total_train_steps = config.epochs * len(train_dataloader)

# Create parameter groups with different learning rates
param_groups = [
    {
        'params': [p for p in vae.parameters() if p.requires_grad],
        'lr': config.vae_lr,
        'eps': 1e-5
    },
    {
        'params': unet.parameters(),
        'lr': config.lr,
        'eps': 1e-5
    }
]
optimizer = AdamW(param_groups)
scheduler = OneCycleLR(optimizer, max_lr=[config.vae_lr, config.lr], total_steps=config.total_train_steps)
scaler = torch.amp.GradScaler("cuda")

# get a validation batch for logging
val_batch = next(iter(valid_dataloader))[0:2].to(device)  # log first 2 predictions
print(val_batch.shape)

torch.Size([2, 11, 4, 256, 256])


In [9]:
from torch import nn

class VAEChannelAdapter(nn.Module):
    """
    Complete VAE adapter that handles:
    1. Channel adaptation (many channels <-> 3 channels)
    2. VAE encoding/decoding
    3. Latent scaling
    
    All scaling factors are integrated into forward/backward passes
    for cleaner usage.
    """
    def __init__(self, vae, channels=11):
        super().__init__()
        self.vae = vae
        self.channels = channels
        # Store scaling factor from VAE config
        self.scaling_factor = vae.config.scaling_factor
        
        # Input adapter: channels -> 3
        self.in_adapter = nn.Sequential(
            nn.Conv2d(channels, 32, 3, padding=1),
            nn.GroupNorm(8, 32),
            nn.SiLU(),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.GroupNorm(4, 16),
            nn.SiLU(),
            nn.Conv2d(16, 3, 3, padding=1),
            nn.Tanh()
        )
        
        # Output adapter: 3 -> channels
        self.out_adapter = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.GroupNorm(8, 32),
            nn.SiLU(),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.GroupNorm(4, 16),
            nn.SiLU(),
            nn.Conv2d(16, channels, 3, padding=1),
            nn.Tanh()
        )
        
        self._init_weights()
        
        # Freeze VAE parameters
        for param in self.vae.parameters():
            param.requires_grad = False
    
    def _init_weights(self):
        """Initialize adapter weights for stable training"""
        for module in [self.in_adapter, self.out_adapter]:
            for m in module.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, a=0.1, mode='fan_out', nonlinearity='linear')
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
    
    def encode(self, x):
        """
        Encode and scale in one step
        
        Args:
            x: [B, C, H, W] input tensor where C = self.channels
        Returns:
            Scaled latents ready for training/inference
        """
        # Adapt channels and encode
        x = self.in_adapter(x)
        latents = self.vae.encode(x).latent_dist.sample()
        # Scale the latents
        return latents * self.scaling_factor
    
    def decode(self, z):
        """
        Descale and decode in one step
        
        Args:
            z: Scaled latent tensor from training/inference
        Returns:
            Multi-channel output tensor
        """
        # Descale the latents
        z = z / self.scaling_factor
        # Decode and adapt channels
        decoded = self.vae.decode(z).sample
        return self.out_adapter(decoded)

# Example usage is now much cleaner:
def test_vae_adapter(adapter, device='cuda'):
    """Test the adapter's full pipeline with integrated scaling"""
    x = torch.randn(1, adapter.channels, 64, 64).to(device)
    x = torch.tanh(x)  # Normalize input
    
    print(f"Input shape: {x.shape}, range: [{x.min():.3f}, {x.max():.3f}]")
    
    # Encode (scaling included)
    latents = adapter.encode(x)
    print(f"Latents shape: {latents.shape}, range: [{latents.min():.3f}, {latents.max():.3f}]")
    
    # Decode (descaling included)
    decoded = adapter.decode(latents)
    print(f"Output shape: {decoded.shape}, range: [{decoded.min():.3f}, {decoded.max():.3f}]")
    
    return decoded


# Usage in training loop is now simpler:
adapted_vae = VAEChannelAdapter(AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device), channels=11).to(device)
def analyze_latent_distribution(latents):
    """Check if latents match expected distribution"""
    print(f"Latent stats:")
    print(f"Mean: {latents.mean():.3f}")
    print(f"Std: {latents.std():.3f}")
    print(f"Range: [{latents.min():.3f}, {latents.max():.3f}]")
    return latents.std().item()
# During training
latents = adapted_vae.encode(X[:,:,0,:,:].to(device))  # Includes scaling
analyze_latent_distribution(latents)
decoded = adapted_vae.decode(latents)  # Includes descaling

# No need for separate scaling calls!

test_vae_adapter(adapted_vae)

Latent stats:
Mean: -0.014
Std: 1.105
Range: [-2.958, 2.487]
Input shape: torch.Size([1, 11, 64, 64]), range: [-1.000, 1.000]
Latents shape: torch.Size([1, 4, 8, 8]), range: [-1.519, 1.915]
Output shape: torch.Size([1, 11, 64, 64]), range: [-0.995, 0.992]


tensor([[[[-0.3131,  0.2267,  0.1221,  ..., -0.1308,  0.0740,  0.3229],
          [-0.2024, -0.0539, -0.1153,  ...,  0.0662,  0.2216,  0.3749],
          [-0.3863,  0.1248,  0.1053,  ...,  0.4131,  0.4177,  0.3517],
          ...,
          [ 0.2698, -0.3137, -0.5546,  ..., -0.6310, -0.4948,  0.3405],
          [ 0.1188, -0.4660, -0.7697,  ...,  0.1032,  0.2107,  0.3803],
          [ 0.2109,  0.6104,  0.0839,  ...,  0.4946,  0.2015,  0.3172]],

         [[ 0.1906,  0.0820,  0.1424,  ...,  0.0575,  0.0832, -0.0485],
          [ 0.0566, -0.3328, -0.1398,  ...,  0.2047,  0.2020, -0.1119],
          [-0.0758, -0.2198, -0.4947,  ...,  0.2811,  0.0979, -0.0565],
          ...,
          [ 0.3398, -0.2138,  0.1189,  ..., -0.1691,  0.5578,  0.0621],
          [-0.0920, -0.0780,  0.0608,  ...,  0.0527, -0.0773, -0.0125],
          [-0.3634, -0.0598, -0.2975,  ..., -0.4819, -0.0839,  0.2312]],

         [[-0.2784, -0.3097,  0.1206,  ..., -0.5178, -0.4380, -0.3055],
          [-0.5114, -0.6199, -

In [5]:
import torch
import torch.nn.functional as F

def check_tensor(tensor, name, print_stats=False):
    """Utility function to check tensor for NaN/Inf values and optionally print statistics"""
    if torch.isnan(tensor).any():
        print(f"NaN detected in {name}")
        return False
    if torch.isinf(tensor).any():
        print(f"Inf detected in {name}")
        return False
    if print_stats:
        print(f"{name} stats: min={tensor.min().item():.3f}, max={tensor.max().item():.3f}, "
              f"mean={tensor.mean().item():.3f}, std={tensor.std().item():.3f}")
    return True

# Modified training loop with checks
for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
    unet.train()
    pbar = progress_bar(train_dataloader, leave=False)
    for batch in pbar:
        if torch.isnan(batch).all():
            continue
            
        # Check input batch
        batch = torch.nan_to_num(batch, nan=0)
        img_batch = batch.to(device)
        check_tensor(img_batch, "img_batch", print_stats=True)
        
        with torch.autocast(device):
            # VAE encoding check
            latents = encode_frames(img_batch, vae) 
            latents = vae.scale_latents(latents, encode=True)
            if not check_tensor(latents, "latents", print_stats=True):
                print("NaN detected after VAE encoding")
                break
                
            # Diffusion preparation checks
            past_frames = latents[:, :, :-1]
            last_frame = latents[:, :, -1]
            check_tensor(past_frames, "past_frames")
            check_tensor(last_frame, "last_frame")
            
            # Noise addition check
            noised_img, t, noise = noisify_ddpm(last_frame)
            check_tensor(noised_img, "noised_img")
            check_tensor(noise, "noise")
            
            # Check reshape operations
            past_frames = past_frames.permute(0, 2, 1, 3, 4)
            past_frames = past_frames.reshape(latents.shape[0], -1, latents.shape[3], latents.shape[4])
            check_tensor(past_frames, "reshaped_past_frames")
            
            # Check concatenated input
            diffusion_input = torch.cat([past_frames, noised_img], dim=1)
            check_tensor(diffusion_input, "diffusion_input")
            
            # Check UNet output
            predicted_noise = unet(diffusion_input, t)
            if not check_tensor(predicted_noise, "predicted_noise", print_stats=True):
                print("NaN detected in UNet output")
                break
            
            # Loss calculation checks
            diffusion_loss = F.mse_loss(predicted_noise, noise)
            check_tensor(diffusion_loss, "diffusion_loss")
            
            img_batch_hat = decode_frames(latents * (1 / 0.18215), vae, use_channel_mixer)
            check_tensor(img_batch_hat, "decoded_frames")
            
            vae_loss = F.mse_loss(img_batch_hat, img_batch)
            check_tensor(vae_loss, "vae_loss")
            
            loss = diffusion_loss  # +0.1*vae_loss
            if not check_tensor(loss, "total_loss"):
                print("NaN detected in final loss")
                break
        
        # Gradient checking
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        
        # Check gradients before optimizer step
        for name, param in unet.named_parameters():
            if param.grad is not None:
                if not check_tensor(param.grad, f"gradient_{name}"):
                    print(f"NaN gradient detected in {name}")
                    break
        
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        pbar.comment = f"epoch={epoch}, vae_loss={vae_loss.item():2.3f}, diffusion_loss={diffusion_loss.item():2.3f}"

    # with torch.no_grad():
    #     val_latents = encode_frames(val_batch, vae)
    #     past_val_frames = val_latents[:, :, :-1]
    #     last_val_frame = val_latents[:, :, -1]  # intentionally remove time dim
    #     past_val_frames = past_val_frames.permute(0, 2, 1, 3, 4).reshape(val_latents.shape[0], -1, val_latents.shape[3], val_latents.shape[4]).float()
    #     samples = sampler(unet, past_frames=past_val_frames, num_channels=4)




    # # log predictions
    # if epoch % config.log_every_epoch == 0:
    #     samples = ...
    #     log_images(val_batch, samples)

# save_model(vae, config.model_name)




img_batch stats: min=-0.940, max=0.930, mean=0.281, std=0.457
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.957, max=0.930, mean=0.190, std=0.533
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=1.000, mean=0.168, std=0.572
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.916, mean=0.090, std=0.594
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.893, mean=0.081, std=0.603
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.925, mean=0.139, std=0.594
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.927, mean=0.124, std=0.617
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.934, mean=0.073, std=0.678
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.893, mean=0.059, std=0.626
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.934, mean=0.094, std=0.595
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.992, max=0.904, mean=0.115, std=0.607
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.904, mean=0.090, std=0.637
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.971, max=0.905, mean=0.079, std=0.565
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.915, mean=0.049, std=0.667
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.935, mean=0.144, std=0.574
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.921, mean=0.098, std=0.616
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.941, max=0.938, mean=0.178, std=0.579
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.919, mean=0.117, std=0.671
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.902, mean=0.066, std=0.660
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.942, max=0.935, mean=0.147, std=0.526
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=1.000, mean=0.124, std=0.637
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.940, max=0.932, mean=0.128, std=0.519
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.885, mean=0.002, std=0.648
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.971, max=0.935, mean=0.114, std=0.549
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.955, max=0.951, mean=0.180, std=0.584
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.962, max=0.912, mean=0.142, std=0.605
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.961, max=0.974, mean=0.182, std=0.550
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.914, mean=0.064, std=0.653
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.943, max=0.962, mean=0.103, std=0.548
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.936, max=0.930, mean=0.218, std=0.505
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.948, max=0.935, mean=0.221, std=0.491
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.909, mean=0.056, std=0.644
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.897, mean=0.040, std=0.659
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.919, mean=0.076, std=0.673
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.923, mean=0.088, std=0.591
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.929, mean=0.148, std=0.609
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.943, max=0.942, mean=0.217, std=0.540
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.978, max=0.922, mean=0.157, std=0.583
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.914, mean=0.079, std=0.639
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.944, max=0.955, mean=0.175, std=0.553
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.935, max=0.947, mean=0.249, std=0.483
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.940, mean=0.136, std=0.570
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.973, max=0.934, mean=0.223, std=0.556
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.916, mean=0.110, std=0.591
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.913, mean=0.047, std=0.662
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.930, mean=0.102, std=0.643
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.915, mean=0.090, std=0.653
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.916, mean=0.049, std=0.646
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.977, max=0.910, mean=0.099, std=0.621
NaN detected in latents
NaN detected after VAE encoding


img_batch stats: min=-0.979, max=0.941, mean=0.127, std=0.645
NaN detected in latents
NaN detected after VAE encoding


In [6]:
def debug_vae_distribution(vae, input_tensor):
    with torch.no_grad():
        # Get encoder output
        encoder_out = vae.encoder(input_tensor)
        print("\nEncoder output stats:")
        print(f"Shape: {encoder_out.shape}")
        print(f"Range: [{encoder_out.min():.3f}, {encoder_out.max():.3f}]")
        print(f"Mean: {encoder_out.mean():.3f}")
        
        # Debug the moment computation
        # In most VAEs, this is where mean and logvar are computed
        try:
            # Assuming the VAE has these components - adjust based on your modification
            moments = vae.encoder.conv_out(encoder_out)
            print("\nMoments output stats:")
            print(f"Shape: {moments.shape}")
            print(f"Range: [{moments.min():.3f}, {moments.max():.3f}]")
            print(f"Mean: {moments.mean():.3f}")
            
            # Split into mean and logvar
            mean, logvar = torch.chunk(moments, 2, dim=1)
            print("\nMean stats:")
            print(f"Range: [{mean.min():.3f}, {mean.max():.3f}]")
            print(f"Has NaN: {torch.isnan(mean).any()}")
            
            print("\nLogvar stats:")
            print(f"Range: [{logvar.min():.3f}, {logvar.max():.3f}]")
            print(f"Has NaN: {torch.isnan(logvar).any()}")
            
            # Check exp(logvar/2) operation
            std = torch.exp(0.5 * logvar)
            print("\nStd stats:")
            print(f"Range: [{std.min():.3f}, {std.max():.3f}]")
            print(f"Has NaN: {torch.isnan(std).any()}")
            
        except Exception as e:
            print(f"Error in moment computation: {str(e)}")
            import traceback
            print(traceback.format_exc())

vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
# Test with normalized input
test_input = torch.randn(1, 11, 64, 64).to(device)
test_input = torch.tanh(test_input)  # Ensure [-1, 1] range
with torch.autocast(device):
    debug_vae_distribution(vae, test_input)

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 11, 64, 64] to have 3 channels, but got 11 channels instead

In [None]:
from torch import nn

class VAEChannelAdapter(nn.Module):
    """
    Adapter module to convert arbitrary channel inputs to work with SDXL VAE.
    
    Key design principles:
    1. Gradual channel reduction to preserve information
    2. Careful normalization for training stability
    3. Bounded outputs to match VAE expectations
    4. Separation of concerns between adaptation and encoding
    """
    def __init__(self, vae, in_channels=11):
        super().__init__()
        self.vae = vae
        
        # Input adapter network: converts in_channels -> 3 channels
        # Architecture designed for stability and information preservation
        self.in_adapter = nn.Sequential(
            # Layer 1: Initial dimension expansion and processing
            # - Expand to 32 channels to preserve information capacity
            # - 3x3 conv maintains spatial context
            # - Padding=1 preserves spatial dimensions
            nn.Conv2d(in_channels, 32, 3, padding=1),
            # GroupNorm with 8 groups (4 channels per group)
            # - Batch-size independent normalization
            # - More stable than BatchNorm or LayerNorm for image data
            nn.GroupNorm(8, 32),
            # SiLU activation
            # - Smooth gradients
            # - No vanishing gradient issues
            # - Better performance than ReLU for vision tasks
            nn.SiLU(),
            
            # Layer 2: Intermediate processing
            # - Reduce channels gradually (32 -> 16)
            # - Maintain spatial dimensions
            nn.Conv2d(32, 16, 3, padding=1),
            # GroupNorm with 4 groups (4 channels per group)
            # - Groups reduced to maintain consistent channels per group
            nn.GroupNorm(4, 16),
            nn.SiLU(),
            
            # Layer 3: Final mapping to RGB
            # - Convert to 3 channels for VAE input
            # - Maintain spatial dimensions
            nn.Conv2d(16, 3, 3, padding=1),
            # Tanh activation
            # - Forces output to [-1, 1] range
            # - Matches VAE's expected input distribution
            # - Prevents extreme values
            nn.Tanh()
        )
        
        # Freeze VAE parameters
        # - Prevents modification of pretrained weights
        # - Ensures stability of latent space
        # - Reduces training complexity
        for param in self.vae.parameters():
            param.requires_grad = False
    
    def encode(self, x):
        """
        Convert input to 3 channels and encode with VAE.
        
        Process:
        1. Transform input channels to RGB-like space
        2. Use pretrained VAE encoder
        
        Returns VAE's latent distribution for sampling
        """
        x = self.in_adapter(x)  # Convert to 3 channels
        return self.vae.encode(x)  # Use original VAE encoder
    
    def decode(self, z):
        """
        Decode latents using VAE decoder.
        
        Note: No modification needed here since:
        - Decoder operates in latent space
        - Output is already in desired format
        """
        return self.vae.decode(z)
    
    @staticmethod
    def scale_latents(latents, vae, encode=True):
        """
        Scale latents by VAE's scaling factor.
        
        Args:
            latents: Tensor to scale
            vae: VAE model containing scaling factor
            encode: If True, scale for encoding (multiply)
                   If False, scale for decoding (divide)
        """
        scaling_factor = vae.config.scaling_factor
        if encode:
            return latents * scaling_factor
        return latents / scaling_factor

# Example usage:
"""
# Create and move to device
adapted_vae = VAEChannelAdapter(original_vae, in_channels=11).to(device)

# Encoding
x = torch.randn(1, 11, 64, 64).to(device)  # 11-channel input
encoded = adapted_vae.encode(x)
latents = encoded.latent_dist.sample()
latents = VAEChannelAdapter.scale_latents(latents, original_vae, encode=True)

# Decoding
scaled_latents = VAEChannelAdapter.scale_latents(latents, original_vae, encode=False)
decoded = adapted_vae.decode(scaled_latents).sample
"""
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
# Create the adapted VAE
adapted_vae = VAEChannelAdapter(vae, in_channels=11).to(device)

# Encoding
x = torch.randn(1, 11, 64, 64).to(device)  # 11-channel input
encoded = adapted_vae.encode(x)
latents = encoded.latent_dist.sample()
latents = VAEChannelAdapter.scale_latents(latents, vae, encode=True)

# Decoding
scaled_latents = VAEChannelAdapter.scale_latents(latents, vae, encode=False)
decoded = adapted_vae.decode(scaled_latents).sample

In [None]:
decoded

In [None]:

# Test function
def test_adapted_vae(adapted_vae, input_tensor):
    print(f"\nInput shape: {input_tensor.shape}")
    print(f"Input range: [{input_tensor.min():.3f}, {input_tensor.max():.3f}]")
    
    with torch.no_grad():
        try:
            # Test the whole encode path
            encoded = adapted_vae.encode(input_tensor)
            latents = encoded.latent_dist.sample() * vae.config.scaling_factor
            
            print("\nLatents:")
            print(f"Shape: {latents.shape}")
            print(f"Range: [{latents.min():.3f}, {latents.max():.3f}]")
            print(f"Has NaN: {torch.isnan(latents).any()}")
            
            # Test decoding too
            decoded = adapted_vae.decode(latents * 1 / vae.config.scaling_factor).sample
            print("\nDecoded:")
            print(f"Shape: {decoded.shape}")
            print(f"Range: [{decoded.min():.3f}, {decoded.max():.3f}]")
            print(f"Has NaN: {torch.isnan(decoded).any()}")
            
        except Exception as e:
            print(f"Error: {str(e)}")
            import traceback
            print(traceback.format_exc())

# Test it
test_input = torch.randn(1, 11, 64, 64).to(device)
test_input = torch.tanh(test_input)  # Ensure reasonable input range
test_adapted_vae(adapted_vae, test_input)

In [None]:


visualize_channels_over_time(X_latents.cpu())

In [None]:
def visualize_channels_over_time(images,  batch_idx=0, figsize=(12, 8), cmap='viridis'):
    """
    Visualize multi-channel images over time, handling both single and multiple timesteps.
    
    Args:
        images: Tensor of shape (batch, channels, time, height, width)
        channel_names: List of names for each channel
        batch_idx: Which batch element to visualize
        figsize: Size of the figure
        cmap: Colormap to use for visualization
    """
    n_channels = images.shape[1]
    n_timesteps = images.shape[2]
    

    
    # Create a grid of subplots
    if n_timesteps == 1:
        fig, axes = plt.subplots(n_channels, 1, figsize=figsize)
        axes = axes.reshape(-1, 1)  # Reshape to 2D array for consistent indexing
    else:
        fig, axes = plt.subplots(n_channels, n_timesteps, figsize=figsize)
    
    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.5)
    
    # Iterate through channels and timesteps
    for channel in range(n_channels):
        for timestep in range(n_timesteps):
            # Get the current image
            img = images[batch_idx, channel, timestep].numpy()
            
            # Normalize the image for better visualization
            # img_min = img.min()
            # img_max = img.max()
            # if img_max > img_min:
            #     img = (img - img_min) / (img_max - img_min)
            
            # Plot the image
            im = axes[channel, timestep].imshow(img, cmap=cmap, origin='lower')
            axes[channel, timestep].axis('off')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[channel, timestep], fraction=0.046, pad=0.04)
            
            # Add titles
            if channel == 0:
                axes[channel, timestep].set_title(f'frame {timestep-3}')
            if timestep == 0:
                axes[channel, timestep].text(-10, 32, channel_names[channel], 
                                          rotation=0, ha='right', va='center')

    return fig
visualize_channels_over_time(X);

In [None]:
def match_sizes(ground_truth, prediction, method='crop'):
    """
    Match the sizes of ground truth and prediction arrays.
    
    Args:
        ground_truth: Tensor of shape (batch, channels, height1, width1)
        prediction: Tensor of shape (batch, channels, height2, width2)
        method: One of ['crop', 'pad', 'resize']
    
    Returns:
        Tuple of tensors with matched sizes
    """
    import torch.nn.functional as F
    
    if method == 'crop':
        # Crop ground truth to prediction size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_start = h_diff // 2
        w_start = w_diff // 2
        
        gt_matched = ground_truth[:, :, 
                                h_start:h_start + prediction.shape[2],
                                w_start:w_start + prediction.shape[3]]
        pred_matched = prediction
        
    elif method == 'pad':
        # Pad prediction to ground truth size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_pad = (h_diff // 2, h_diff - h_diff // 2)
        w_pad = (w_diff // 2, w_diff - w_diff // 2)
        
        pred_matched = F.pad(prediction, (w_pad[0], w_pad[1], h_pad[0], h_pad[1]))
        gt_matched = ground_truth
        
    elif method == 'resize':
        # Resize prediction to ground truth size using bilinear interpolation
        pred_matched = F.interpolate(prediction, 
                                   size=(ground_truth.shape[2], ground_truth.shape[3]),
                                   mode='bilinear',
                                   align_corners=False)
        gt_matched = ground_truth
        
    else:
        raise ValueError("method must be one of ['crop', 'pad', 'resize']")
        
    return gt_matched, pred_matched

def visualize_prediction_comparison(ground_truth, prediction, batch_idx=0, figsize=(20, 8), 
                                  cmap='viridis', diff_cmap='RdBu_r', size_match='crop'):
    """
    Visualize ground truth, prediction, and their differences side by side for each channel.
    Maintains original data ranges for proper visualization.
    """
    # Match sizes
    gt_matched, pred_matched = match_sizes(ground_truth, prediction, method=size_match)
    n_channels = ground_truth.shape[1]

    pred_matched[torch.isnan(gt_matched)] = torch.nan

    # Create a grid of subplots
    fig, axes = plt.subplots(n_channels, 3, figsize=figsize)

    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.8)
    
    # Iterate through channels
    for channel in range(n_channels):
        # Get images
        gt_img = gt_matched[batch_idx, channel].numpy()
        pred_img = pred_matched[batch_idx, channel].numpy()
        diff_img = pred_img - gt_img
        
        # For difference plot, use symmetric limits around zero
        diff_bound = max(abs(diff_img.min()), abs(diff_img.max()))
        diff_bound = max(diff_bound, 0.1)  # Ensure some minimum range for visualization
        
        # Plot ground truth with original range
        im_gt = axes[channel, 0].imshow(gt_img, cmap=cmap, origin='lower')
        axes[channel, 0].axis('off')
        plt.colorbar(im_gt, ax=axes[channel, 0], fraction=0.046, pad=0.04)
        
        # Plot prediction with same range as ground truth
        im_pred = axes[channel, 1].imshow(pred_img, cmap=cmap, origin='lower')
        axes[channel, 1].axis('off')
        plt.colorbar(im_pred, ax=axes[channel, 1], fraction=0.046, pad=0.04)
        
        # Plot difference with symmetric range
        im_diff = axes[channel, 2].imshow(diff_img, cmap=diff_cmap, origin='lower',
                                        vmin=-diff_bound, vmax=diff_bound)
        axes[channel, 2].axis('off')
        plt.colorbar(im_diff, ax=axes[channel, 2], fraction=0.046, pad=0.04)
        
        # Add channel name on the left
        axes[channel, 0].text(-10, 32, channel_names[channel],
                            rotation=0, ha='right', va='center')
    
    # Add column headers
    axes[0, 0].set_title(f'Ground Truth ({gt_matched.shape[2]}x{gt_matched.shape[3]})')
    axes[0, 1].set_title(f'Prediction ({pred_matched.shape[2]}x{pred_matched.shape[3]})')
    axes[0, 2].set_title('Difference (Pred - GT)')
    return fig

visualize_prediction_comparison(img_batch[:, :,  :, :], res[:, :, :, :], size_match="resize");