In [2]:
from types import SimpleNamespace

import wandb
import torch

from cloudcasting.constants import NUM_CHANNELS, DATA_INTERVAL_SPACING_MINUTES

from cloud_diffusion.dataset import CloudcastingDataset
from cloud_diffusion.utils import set_seed
from cloud_diffusion.ddpm import noisify_ddpm, ddim_sampler
from cloud_diffusion.models import UNet2D
from cloud_diffusion.vae import get_hacked_vae#, encode_frames, decode_frames


from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn.functional as F

from fastprogress import progress_bar

import matplotlib.pyplot as plt
import numpy as np

channel_names = ['IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120', 'IR_134',
       'VIS006', 'VIS008', 'WV_062', 'WV_073']

PROJECT_NAME = "nathan-test"
MERGE_CHANNELS = False
DEBUG = True
LOCAL = False

config = SimpleNamespace(
    img_size=256,
    epochs=50,  # number of epochs
    model_name="uvit-test",  # model name to save [unet_small, unet_big]
    strategy="ddpm",  # strategy to use [ddpm, simple_diffusion]
    noise_steps=1000,  # number of noise steps on the diffusion process
    sampler_steps=300,  # number of sampler steps on the diffusion process
    seed=42,  # random seed
    batch_size=2,  # batch size
    device="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
    num_workers=0 if DEBUG else 35,  # number of workers for dataloader
    num_frames=4,  # number of frames to use as input (includes noise frame)
    lr=5e-4,  # learning rate
    log_every_epoch=1,  # log every n epochs to wandb
    n_preds=8,  # number of predictions to make
    latent_dim=4,
    vae_lr=5e-5,
)

device = config.device


HISTORY_STEPS = config.num_frames - 1


config.model_params = dict(
    block_out_channels=(32, 64, 128, 256),  # number of channels for each block
    norm_num_groups=8,  # number of groups for the normalization layer
    in_channels=config.num_frames * config.latent_dim,  # number of input channels
    out_channels=config.latent_dim,  # number of output channels
)

set_seed(config.seed)

if LOCAL:
    TRAINING_DATA_PATH = VALIDATION_DATA_PATH = "/users/nsimpson/Code/climetrend/cloudcast/2020_training_nonhrv.zarr"
else:
    TRAINING_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2021_nonhrv.zarr"
    VALIDATION_DATA_PATH = "/bask/projects/v/vjgo8416-climate/shared/data/eumetsat/training/2022_training_nonhrv.zarr"
# Instantiate the torch dataset object
train_ds = CloudcastingDataset(
    config.img_size,
    valid=False,
    # strategy="resize",
    zarr_path=TRAINING_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)
# worth noting they do some sort of shuffling here; we don't for now
valid_ds = CloudcastingDataset(
    config.img_size,
    valid=True,
    # strategy="resize",
    zarr_path=VALIDATION_DATA_PATH,
    start_time=None,
    end_time=None,
    history_mins=(HISTORY_STEPS - 1) * DATA_INTERVAL_SPACING_MINUTES,
    forecast_mins=15,
    sample_freq_mins=15,
    nan_to_num=False,
    merge_channels=MERGE_CHANNELS,
)

train_dataloader = DataLoader(train_ds, config.batch_size, shuffle=True,  num_workers=config.num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_ds, config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

In [3]:
X = next(iter(train_dataloader))
X.shape

torch.Size([2, 11, 4, 256, 256])

In [None]:
def encode_frames(imgs, vae):
    # imgs shape: [batch, channels, time, height, width]
    B, C, T, H, W = imgs.shape
    
    # Reshape to process all frames at once
    imgs_2d = imgs.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)#.contiguous()
    
    latents = vae.encode(imgs_2d).latent_dist.sample()
    
    # Reshape back to include time dimension
    _, C_out, H_out, W_out = latents.shape
    return latents.reshape(B, T, C_out, H_out, W_out).permute(0, 2, 1, 3, 4)

def decode_frames(latents, vae):
    # latents shape: [batch, channels, time, height, width]
    B, C, T, H, W = latents.shape
    
    # Reshape to process all frames at once
    latents_2d = latents.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)#.contiguous()
    
    # Process through VAE
    decoded = vae.decode(latents_2d).sample
    
    # Reshape back to include time dimension
    _, C_out, H_out, W_out = decoded.shape
    decoded = decoded.reshape(B, T, C_out, H_out, W_out).permute(0, 2, 1, 3, 4)
    
    # Map channels
    return decoded

In [5]:
# model setup
unet = UNet2D(**config.model_params).to(device)
vae = get_hacked_vae().to(device)#.float()
from diffusers.models import AutoencoderKL
# vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device).float()
# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).float()
# for param in vae.parameters():
#     param.requires_grad = False

# sampler
sampler = ddim_sampler(steps=config.sampler_steps)

# configure training
# wandb.config.update(config)
config.total_train_steps = config.epochs * len(train_dataloader)

# Create parameter groups with different learning rates
param_groups = [
    {
        'params': [p for p in vae.parameters() if p.requires_grad],
        'lr': config.vae_lr,
        'eps': 1e-5
    },
    {
        'params': unet.parameters(),
        'lr': config.lr,
        'eps': 1e-5
    }
]
optimizer = AdamW(param_groups)
scheduler = OneCycleLR(optimizer, max_lr=[config.vae_lr, config.lr], total_steps=config.total_train_steps)
scaler = torch.amp.GradScaler("cuda")

# get a validation batch for logging
val_batch = next(iter(valid_dataloader))[0:2].to(device)  # log first 2 predictions
print(val_batch.shape)

Unfreezing conv_in.weight
Unfreezing conv_in.bias
Unfreezing conv_out.weight
Unfreezing conv_out.bias
torch.Size([2, 11, 4, 256, 256])


In [6]:
X.shape

torch.Size([2, 11, 4, 256, 256])

In [None]:
with torch.autocast(device):
    with torch.no_grad():
        X_latents = encode_frames(X.to(device), vae)

check_tensor(X_latents, "X_latents", print_stats=True)

In [None]:
with torch.autocast(device):
    with torch.no_grad():
        latents = vae.encode(X[:,:,0,:,:].to(device)).latent_dist.sample()

check_tensor(latents, "lat")

In [None]:
check_tensor(X, "x", print_stats=True)

In [None]:
with torch.no_grad():
    latents = vae.half().encode(X[:,:,0,:,:].to(device).half()).latent_dist.sample()

check_tensor(latents, "lat")

In [None]:
import torch
import torch.nn.functional as F

def check_tensor(tensor, name, print_stats=False):
    """Utility function to check tensor for NaN/Inf values and optionally print statistics"""
    if torch.isnan(tensor).any():
        print(f"NaN detected in {name}")
        return False
    if torch.isinf(tensor).any():
        print(f"Inf detected in {name}")
        return False
    if print_stats:
        print(f"{name} stats: min={tensor.min().item():.3f}, max={tensor.max().item():.3f}, "
              f"mean={tensor.mean().item():.3f}, std={tensor.std().item():.3f}")
    return True

# Modified training loop with checks
use_channel_mixer = False
for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
    vae.train() if not use_channel_mixer else vae.eval()
    unet.train()
    pbar = progress_bar(train_dataloader, leave=False)
    for batch in pbar:
        if torch.isnan(batch).all():
            continue
            
        # Check input batch
        batch = torch.nan_to_num(batch, nan=0)
        img_batch = batch.to(device)
        check_tensor(img_batch, "img_batch", print_stats=True)
        
        with torch.autocast(device):
            # VAE encoding check
            latents = encode_frames(img_batch, vae, use_channel_mixer) * 0.18215
            if not check_tensor(latents, "latents", print_stats=True):
                print("NaN detected after VAE encoding")
                break
                
            # Diffusion preparation checks
            past_frames = latents[:, :, :-1]
            last_frame = latents[:, :, -1]
            check_tensor(past_frames, "past_frames")
            check_tensor(last_frame, "last_frame")
            
            # Noise addition check
            noised_img, t, noise = noisify_ddpm(last_frame)
            check_tensor(noised_img, "noised_img")
            check_tensor(noise, "noise")
            
            # Check reshape operations
            past_frames = past_frames.permute(0, 2, 1, 3, 4)
            past_frames = past_frames.reshape(latents.shape[0], -1, latents.shape[3], latents.shape[4])
            check_tensor(past_frames, "reshaped_past_frames")
            
            # Check concatenated input
            diffusion_input = torch.cat([past_frames, noised_img], dim=1)
            check_tensor(diffusion_input, "diffusion_input")
            
            # Check UNet output
            predicted_noise = unet(diffusion_input, t)
            if not check_tensor(predicted_noise, "predicted_noise", print_stats=True):
                print("NaN detected in UNet output")
                break
            
            # Loss calculation checks
            diffusion_loss = F.mse_loss(predicted_noise, noise)
            check_tensor(diffusion_loss, "diffusion_loss")
            
            img_batch_hat = decode_frames(latents * (1 / 0.18215), vae, use_channel_mixer)
            check_tensor(img_batch_hat, "decoded_frames")
            
            vae_loss = F.mse_loss(img_batch_hat, img_batch)
            check_tensor(vae_loss, "vae_loss")
            
            loss = diffusion_loss  # +0.1*vae_loss
            if not check_tensor(loss, "total_loss"):
                print("NaN detected in final loss")
                break
        
        # Gradient checking
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        
        # Check gradients before optimizer step
        for name, param in unet.named_parameters():
            if param.grad is not None:
                if not check_tensor(param.grad, f"gradient_{name}"):
                    print(f"NaN gradient detected in {name}")
                    break
        
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        pbar.comment = f"epoch={epoch}, vae_loss={vae_loss.item():2.3f}, diffusion_loss={diffusion_loss.item():2.3f}"

    # with torch.no_grad():
    #     val_latents = encode_frames(val_batch, vae)
    #     past_val_frames = val_latents[:, :, :-1]
    #     last_val_frame = val_latents[:, :, -1]  # intentionally remove time dim
    #     past_val_frames = past_val_frames.permute(0, 2, 1, 3, 4).reshape(val_latents.shape[0], -1, val_latents.shape[3], val_latents.shape[4]).float()
    #     samples = sampler(unet, past_frames=past_val_frames, num_channels=4)




    # # log predictions
    # if epoch % config.log_every_epoch == 0:
    #     samples = ...
    #     log_images(val_batch, samples)

# save_model(vae, config.model_name)




In [9]:
def debug_vae_distribution(vae, input_tensor):
    with torch.no_grad():
        # Get encoder output
        encoder_out = vae.encoder(input_tensor)
        print("\nEncoder output stats:")
        print(f"Shape: {encoder_out.shape}")
        print(f"Range: [{encoder_out.min():.3f}, {encoder_out.max():.3f}]")
        print(f"Mean: {encoder_out.mean():.3f}")
        
        # Debug the moment computation
        # In most VAEs, this is where mean and logvar are computed
        try:
            # Assuming the VAE has these components - adjust based on your modification
            moments = vae.encoder.conv_out(encoder_out)
            print("\nMoments output stats:")
            print(f"Shape: {moments.shape}")
            print(f"Range: [{moments.min():.3f}, {moments.max():.3f}]")
            print(f"Mean: {moments.mean():.3f}")
            
            # Split into mean and logvar
            mean, logvar = torch.chunk(moments, 2, dim=1)
            print("\nMean stats:")
            print(f"Range: [{mean.min():.3f}, {mean.max():.3f}]")
            print(f"Has NaN: {torch.isnan(mean).any()}")
            
            print("\nLogvar stats:")
            print(f"Range: [{logvar.min():.3f}, {logvar.max():.3f}]")
            print(f"Has NaN: {torch.isnan(logvar).any()}")
            
            # Check exp(logvar/2) operation
            std = torch.exp(0.5 * logvar)
            print("\nStd stats:")
            print(f"Range: [{std.min():.3f}, {std.max():.3f}]")
            print(f"Has NaN: {torch.isnan(std).any()}")
            
        except Exception as e:
            print(f"Error in moment computation: {str(e)}")
            import traceback
            print(traceback.format_exc())

vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
# Test with normalized input
test_input = torch.randn(1, 11, 64, 64).to(device)
test_input = torch.tanh(test_input)  # Ensure [-1, 1] range
with torch.autocast(device):
    debug_vae_distribution(vae, test_input)

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 11, 64, 64] to have 3 channels, but got 11 channels instead

In [14]:
from torch import nn

class VAEChannelAdapter(nn.Module):
    """
    Adapter module to convert arbitrary channel inputs to work with SDXL VAE.
    
    Key design principles:
    1. Gradual channel reduction to preserve information
    2. Careful normalization for training stability
    3. Bounded outputs to match VAE expectations
    4. Separation of concerns between adaptation and encoding
    """
    def __init__(self, vae, in_channels=11):
        super().__init__()
        self.vae = vae
        
        # Input adapter network: converts in_channels -> 3 channels
        # Architecture designed for stability and information preservation
        self.in_adapter = nn.Sequential(
            # Layer 1: Initial dimension expansion and processing
            # - Expand to 32 channels to preserve information capacity
            # - 3x3 conv maintains spatial context
            # - Padding=1 preserves spatial dimensions
            nn.Conv2d(in_channels, 32, 3, padding=1),
            # GroupNorm with 8 groups (4 channels per group)
            # - Batch-size independent normalization
            # - More stable than BatchNorm or LayerNorm for image data
            nn.GroupNorm(8, 32),
            # SiLU activation
            # - Smooth gradients
            # - No vanishing gradient issues
            # - Better performance than ReLU for vision tasks
            nn.SiLU(),
            
            # Layer 2: Intermediate processing
            # - Reduce channels gradually (32 -> 16)
            # - Maintain spatial dimensions
            nn.Conv2d(32, 16, 3, padding=1),
            # GroupNorm with 4 groups (4 channels per group)
            # - Groups reduced to maintain consistent channels per group
            nn.GroupNorm(4, 16),
            nn.SiLU(),
            
            # Layer 3: Final mapping to RGB
            # - Convert to 3 channels for VAE input
            # - Maintain spatial dimensions
            nn.Conv2d(16, 3, 3, padding=1),
            # Tanh activation
            # - Forces output to [-1, 1] range
            # - Matches VAE's expected input distribution
            # - Prevents extreme values
            nn.Tanh()
        )
        
        # Freeze VAE parameters
        # - Prevents modification of pretrained weights
        # - Ensures stability of latent space
        # - Reduces training complexity
        for param in self.vae.parameters():
            param.requires_grad = False
    
    def encode(self, x):
        """
        Convert input to 3 channels and encode with VAE.
        
        Process:
        1. Transform input channels to RGB-like space
        2. Use pretrained VAE encoder
        
        Returns VAE's latent distribution for sampling
        """
        x = self.in_adapter(x)  # Convert to 3 channels
        return self.vae.encode(x)  # Use original VAE encoder
    
    def decode(self, z):
        """
        Decode latents using VAE decoder.
        
        Note: No modification needed here since:
        - Decoder operates in latent space
        - Output is already in desired format
        """
        return self.vae.decode(z)
    
    @staticmethod
    def scale_latents(latents, vae, encode=True):
        """
        Scale latents by VAE's scaling factor.
        
        Args:
            latents: Tensor to scale
            vae: VAE model containing scaling factor
            encode: If True, scale for encoding (multiply)
                   If False, scale for decoding (divide)
        """
        scaling_factor = vae.config.scaling_factor
        if encode:
            return latents * scaling_factor
        return latents / scaling_factor

# Example usage:
"""
# Create and move to device
adapted_vae = VAEChannelAdapter(original_vae, in_channels=11).to(device)

# Encoding
x = torch.randn(1, 11, 64, 64).to(device)  # 11-channel input
encoded = adapted_vae.encode(x)
latents = encoded.latent_dist.sample()
latents = VAEChannelAdapter.scale_latents(latents, original_vae, encode=True)

# Decoding
scaled_latents = VAEChannelAdapter.scale_latents(latents, original_vae, encode=False)
decoded = adapted_vae.decode(scaled_latents).sample
"""
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
# Create the adapted VAE
adapted_vae = VAEChannelAdapter(vae, in_channels=11).to(device)

# Encoding
x = torch.randn(1, 11, 64, 64).to(device)  # 11-channel input
encoded = adapted_vae.encode(x)
latents = encoded.latent_dist.sample()
latents = VAEChannelAdapter.scale_latents(latents, vae, encode=True)

# Decoding
scaled_latents = VAEChannelAdapter.scale_latents(latents, vae, encode=False)
decoded = adapted_vae.decode(scaled_latents).sample

In [15]:
decoded

tensor([[[[ 0.0332,  0.0755,  0.0393,  ..., -0.1426, -0.0624,  0.0515],
          [-0.0526,  0.0987,  0.1303,  ..., -0.0270, -0.0626, -0.0603],
          [ 0.0886, -0.1302,  0.1758,  ...,  0.0319,  0.2005,  0.0095],
          ...,
          [-0.2033, -0.2689, -0.0618,  ..., -0.1438,  0.1000, -0.3867],
          [-0.1533, -0.0489,  0.0122,  ...,  0.6241, -0.0550, -0.2087],
          [-0.2078, -0.1929,  0.0056,  ...,  0.2452,  0.0038, -0.1768]],

         [[-0.0942,  0.0204, -0.0234,  ..., -0.0860, -0.0347, -0.1457],
          [ 0.0750,  0.1678,  0.1047,  ...,  0.2891,  0.1999, -0.0823],
          [ 0.2409,  0.0357,  0.1196,  ...,  0.2031,  0.4529,  0.0879],
          ...,
          [ 0.0755,  0.3906,  0.0666,  ...,  0.2908,  0.3753, -0.0581],
          [ 0.1437,  0.3112,  0.1630,  ...,  0.5809, -0.0872, -0.0296],
          [-0.0011, -0.0323, -0.0610,  ...,  0.4977,  0.1834,  0.0057]],

         [[-0.3682, -0.2437, -0.3088,  ..., -0.2186, -0.2511, -0.3449],
          [-0.2915, -0.2141, -

In [None]:

# Test function
def test_adapted_vae(adapted_vae, input_tensor):
    print(f"\nInput shape: {input_tensor.shape}")
    print(f"Input range: [{input_tensor.min():.3f}, {input_tensor.max():.3f}]")
    
    with torch.no_grad():
        try:
            # Test the whole encode path
            encoded = adapted_vae.encode(input_tensor)
            latents = encoded.latent_dist.sample() * vae.config.scaling_factor
            
            print("\nLatents:")
            print(f"Shape: {latents.shape}")
            print(f"Range: [{latents.min():.3f}, {latents.max():.3f}]")
            print(f"Has NaN: {torch.isnan(latents).any()}")
            
            # Test decoding too
            decoded = adapted_vae.decode(latents * 1 / vae.config.scaling_factor).sample
            print("\nDecoded:")
            print(f"Shape: {decoded.shape}")
            print(f"Range: [{decoded.min():.3f}, {decoded.max():.3f}]")
            print(f"Has NaN: {torch.isnan(decoded).any()}")
            
        except Exception as e:
            print(f"Error: {str(e)}")
            import traceback
            print(traceback.format_exc())

# Test it
test_input = torch.randn(1, 11, 64, 64).to(device)
test_input = torch.tanh(test_input)  # Ensure reasonable input range
test_adapted_vae(adapted_vae, test_input)

In [None]:


visualize_channels_over_time(X_latents.cpu())

In [None]:
def visualize_channels_over_time(images,  batch_idx=0, figsize=(12, 8), cmap='viridis'):
    """
    Visualize multi-channel images over time, handling both single and multiple timesteps.
    
    Args:
        images: Tensor of shape (batch, channels, time, height, width)
        channel_names: List of names for each channel
        batch_idx: Which batch element to visualize
        figsize: Size of the figure
        cmap: Colormap to use for visualization
    """
    n_channels = images.shape[1]
    n_timesteps = images.shape[2]
    

    
    # Create a grid of subplots
    if n_timesteps == 1:
        fig, axes = plt.subplots(n_channels, 1, figsize=figsize)
        axes = axes.reshape(-1, 1)  # Reshape to 2D array for consistent indexing
    else:
        fig, axes = plt.subplots(n_channels, n_timesteps, figsize=figsize)
    
    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.5)
    
    # Iterate through channels and timesteps
    for channel in range(n_channels):
        for timestep in range(n_timesteps):
            # Get the current image
            img = images[batch_idx, channel, timestep].numpy()
            
            # Normalize the image for better visualization
            # img_min = img.min()
            # img_max = img.max()
            # if img_max > img_min:
            #     img = (img - img_min) / (img_max - img_min)
            
            # Plot the image
            im = axes[channel, timestep].imshow(img, cmap=cmap, origin='lower')
            axes[channel, timestep].axis('off')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[channel, timestep], fraction=0.046, pad=0.04)
            
            # Add titles
            if channel == 0:
                axes[channel, timestep].set_title(f'frame {timestep-3}')
            if timestep == 0:
                axes[channel, timestep].text(-10, 32, channel_names[channel], 
                                          rotation=0, ha='right', va='center')

    return fig
visualize_channels_over_time(X);

In [None]:
def match_sizes(ground_truth, prediction, method='crop'):
    """
    Match the sizes of ground truth and prediction arrays.
    
    Args:
        ground_truth: Tensor of shape (batch, channels, height1, width1)
        prediction: Tensor of shape (batch, channels, height2, width2)
        method: One of ['crop', 'pad', 'resize']
    
    Returns:
        Tuple of tensors with matched sizes
    """
    import torch.nn.functional as F
    
    if method == 'crop':
        # Crop ground truth to prediction size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_start = h_diff // 2
        w_start = w_diff // 2
        
        gt_matched = ground_truth[:, :, 
                                h_start:h_start + prediction.shape[2],
                                w_start:w_start + prediction.shape[3]]
        pred_matched = prediction
        
    elif method == 'pad':
        # Pad prediction to ground truth size
        h_diff = ground_truth.shape[2] - prediction.shape[2]
        w_diff = ground_truth.shape[3] - prediction.shape[3]
        
        h_pad = (h_diff // 2, h_diff - h_diff // 2)
        w_pad = (w_diff // 2, w_diff - w_diff // 2)
        
        pred_matched = F.pad(prediction, (w_pad[0], w_pad[1], h_pad[0], h_pad[1]))
        gt_matched = ground_truth
        
    elif method == 'resize':
        # Resize prediction to ground truth size using bilinear interpolation
        pred_matched = F.interpolate(prediction, 
                                   size=(ground_truth.shape[2], ground_truth.shape[3]),
                                   mode='bilinear',
                                   align_corners=False)
        gt_matched = ground_truth
        
    else:
        raise ValueError("method must be one of ['crop', 'pad', 'resize']")
        
    return gt_matched, pred_matched

def visualize_prediction_comparison(ground_truth, prediction, batch_idx=0, figsize=(20, 8), 
                                  cmap='viridis', diff_cmap='RdBu_r', size_match='crop'):
    """
    Visualize ground truth, prediction, and their differences side by side for each channel.
    Maintains original data ranges for proper visualization.
    """
    # Match sizes
    gt_matched, pred_matched = match_sizes(ground_truth, prediction, method=size_match)
    n_channels = ground_truth.shape[1]

    pred_matched[torch.isnan(gt_matched)] = torch.nan

    # Create a grid of subplots
    fig, axes = plt.subplots(n_channels, 3, figsize=figsize)

    # Set the spacing between subplots
    plt.subplots_adjust(hspace=0.2, wspace=-0.8)
    
    # Iterate through channels
    for channel in range(n_channels):
        # Get images
        gt_img = gt_matched[batch_idx, channel].numpy()
        pred_img = pred_matched[batch_idx, channel].numpy()
        diff_img = pred_img - gt_img
        
        # For difference plot, use symmetric limits around zero
        diff_bound = max(abs(diff_img.min()), abs(diff_img.max()))
        diff_bound = max(diff_bound, 0.1)  # Ensure some minimum range for visualization
        
        # Plot ground truth with original range
        im_gt = axes[channel, 0].imshow(gt_img, cmap=cmap, origin='lower')
        axes[channel, 0].axis('off')
        plt.colorbar(im_gt, ax=axes[channel, 0], fraction=0.046, pad=0.04)
        
        # Plot prediction with same range as ground truth
        im_pred = axes[channel, 1].imshow(pred_img, cmap=cmap, origin='lower')
        axes[channel, 1].axis('off')
        plt.colorbar(im_pred, ax=axes[channel, 1], fraction=0.046, pad=0.04)
        
        # Plot difference with symmetric range
        im_diff = axes[channel, 2].imshow(diff_img, cmap=diff_cmap, origin='lower',
                                        vmin=-diff_bound, vmax=diff_bound)
        axes[channel, 2].axis('off')
        plt.colorbar(im_diff, ax=axes[channel, 2], fraction=0.046, pad=0.04)
        
        # Add channel name on the left
        axes[channel, 0].text(-10, 32, channel_names[channel],
                            rotation=0, ha='right', va='center')
    
    # Add column headers
    axes[0, 0].set_title(f'Ground Truth ({gt_matched.shape[2]}x{gt_matched.shape[3]})')
    axes[0, 1].set_title(f'Prediction ({pred_matched.shape[2]}x{pred_matched.shape[3]})')
    axes[0, 2].set_title('Difference (Pred - GT)')
    return fig

visualize_prediction_comparison(img_batch[:, :,  :, :], res[:, :, :, :], size_match="resize");