In [25]:
import sys
sys.path.append("..") #Parent 
sys.path.append("../..") #grandparent
import torch
from models.components.ldm.denoiser import UNetModel
import random
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import json
import config
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import json



In [26]:
from models.components.unet import DownscalingUnetLightning 
from models.ae_module import AutoencoderKL
from models.components.ae import SimpleConvEncoder, SimpleConvDecoder
from models.components.ldm.denoiser.ddim import DDIMSampler
from models.ldm_module import LatentDiffusion
from DownscalingDataModule import DownscalingDataModule

Instantiating UNet

In [27]:
ckpt_unet = "trained_ckpts/Training_LDM.models.components.unet.DownscalingUnetLightning_checkpoint.ckpt"

model_UNet = DownscalingUnetLightning(
    in_ch=5,  # 4 vars + elevation
    out_ch=4,  # 4 output variables
    features=[64, 128, 256, 512],
    channel_names=["precip", "temp", "temp_min", "temp_max"]
)

In [28]:
unet_state_dict = torch.load(ckpt_unet, map_location="cpu")["state_dict"]
model_UNet.load_state_dict(unet_state_dict, strict=False)
model_UNet.eval()

DownscalingUnetLightning(
  (unet): DownscalingUnet(
    (e1): EncoderBlock(
      (conv): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(5, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
        )
      )
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (e2): EncoderBlock(
      (conv): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(128, 128, kernel_size=(3,

Instantiating VAE for residual encoding 

In [31]:
ckpt_vae = "trained_ckpts/Training_LDM.models.ae_module.AutoencoderKL_checkpoint.ckpt"


In [34]:
# Debug: Load checkpoint to see the original configuration
checkpoint = torch.load(ckpt_vae, map_location="cpu")

In [36]:
encoder = SimpleConvEncoder(in_dim=4, levels=1, min_ch=64, ch_mult=1)
decoder = SimpleConvDecoder(in_dim=64, levels=1, min_ch=16)  # Changed from levels=2 to levels=1
model_VAE = AutoencoderKL.load_from_checkpoint(
    checkpoint,
    encoder=encoder,
    decoder=decoder,
    kl_weight=0.01,
    strict=False 
)
model_VAE.eval()

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

Latent denoising 

Denoising in latent space 

In [13]:
ckpt_ldm = "trained_ckpts/LDM_checkpoint.ckpt"

# Denoiser for LDM 
denoiser = UNetModel(
    in_channels=32,  # VAE latent dim
    out_channels=32,
    model_channels=64,
    num_res_blocks=2,
    attention_resolutions=[1, 2, 4],
    context_ch=None,
    channel_mult=[1, 2, 4, 4],
    conv_resample=True,
    dims=2,
    use_fp16=False,
    num_heads=4
)

model_LDM = LatentDiffusion(
    denoiser=denoiser,
    autoencoder=model_VAE,
    timesteps=1000,
    beta_schedule="linear",
    loss_type="l2",
    use_ema=True,
    lr=1e-4,
    lr_warmup=0,
    linear_start=1e-4,
    linear_end=2e-2,
    cosine_s=8e-3,
    parameterization="eps"
)

remapping of keys for overcoming the error 

In [16]:
ldm_ckpt = torch.load(ckpt_ldm, map_location="cpu")
remapped_ldm_state_dict = {}
for k, v in ldm_ckpt["state_dict"].items():
    if k.startswith("autoencoder.unet_regr.unet."):
        new_key = "autoencoder.unet." + k[len("autoencoder.unet_regr.unet."):]
    elif k.startswith("autoencoder.unet_regr."):
        new_key = "autoencoder.unet." + k[len("autoencoder.unet_regr."):]
    else:
        new_key = k
    remapped_ldm_state_dict[new_key] = v


In [17]:
model_LDM.load_state_dict(remapped_ldm_state_dict, strict=False)
model_LDM.eval()

LatentDiffusion(
  (denoiser): UNetModel(
    (time_embed): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
    (input_blocks): ModuleList(
      (0): TimestepEmbedSequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): Identity()
            (1): SiLU()
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=256, out_features=64, bias=True)
          )
          (out_layers): Sequential(
            (0): Identity()
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
      

#Inference goes from Unet----VAE----denosiing within VAE

In [None]:
def pipeline(input_sample, target_sample=None):  
    with torch.no_grad():
        # Mean pred with unet
        unet_prediction = model_UNet(input_sample)
        
        #Residuals for denoising
        if target_sample is not None:
            residuals = target_sample - unet_prediction  # (1, 4, H, W)
        
        #VAE encoding
        mean, log_var = model_VAE.encode(residuals)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        latent = mean + eps * std  # (1, 64, H', W')
        
        #Denoising using LDM
        sampler = DDIMSampler(model_LDM)
        shape = latent.shape[1:]
        
        denoised_latent, _ = sampler.sample(
            S=25,
            batch_size=1,
            shape=shape,
            x_T=latent,  # encoded residual latent
            eta=0.2,     # stochasticity
            verbose=False,
            progbar=True
        )
        
        # Decoding to pixel space
        refined_residuals = model_VAE.decode(denoised_latent)  # (1, 4, H, W)
        
        # Unet plus refined residuals
        final_prediction = unet_prediction + refined_residuals
        
        return {
            'unet_prediction': unet_prediction,
            'original_residuals': residuals,
            'latent_encoded': latent,
            'denoised_latent': denoised_latent,
            'refined_residuals': refined_residuals,
            'final_prediction': final_prediction
        }

Datasets for dtaamodule

In [None]:
train_input_paths = {
    'precip': f'{config.DATASETS_TRAINING_DIR}/RhiresD_input_test_chronological_scaled.nc',
    'temp': f'{config.DATASETS_TRAINING_DIR}/TabsD_input_test_chronological_scaled.nc',
    'temp_min': f'{config.DATASETS_TRAINING_DIR}/TminD_input_test_chronological_scaled.nc',
    'temp_max': f'{config.DATASETS_TRAINING_DIR}/TmaxD_input_test_chronological_scaled.nc'
}

train_target_paths = {
    'precip': f'{config.DATASETS_TRAINING_DIR}/RhiresD_target_test_chronological_scaled.nc',
    'temp': f'{config.DATASETS_TRAINING_DIR}/TabsD_target_test_chronological_scaled.nc',
    'temp_min': f'{config.DATASETS_TRAINING_DIR}/TminD_target_test_chronological_scaled.nc',
    'temp_max': f'{config.DATASETS_TRAINING_DIR}/TmaxD_target_test_chronological_scaled.nc'
}

elevation_path = f'{config.BASE_DIR}/sasthana/Downscaling/Downscaling_Models/elevation.tif'

dm = DownscalingDataModule(
    train_input=train_input_paths,
    train_target=train_target_paths,
    elevation=elevation_path,
    batch_size=32,
    num_workers=4,
    preprocessing={
        'variables': {
            'input': {
                'precip': 'RhiresD',
                'temp': 'TabsD', 
                'temp_min': 'TminD',
                'temp_max': 'TmaxD'
            },
            'target': {
                'precip': 'RhiresD',
                'temp': 'TabsD',
                'temp_min': 'TminD', 
                'temp_max': 'TmaxD'
            }
        },
        'preprocessing': {
            'nan_to_num': True,
            'nan_value': 0.0
        }
    }
)

# Setup the data module
dm.setup('fit')


In [21]:
train_loader = dm.train_dataloader()
train_batch = next(iter(train_loader))
train_inputs, train_targets = train_batch

In [23]:
idx = 20
input_sample = train_inputs[idx].unsqueeze(0)  # (1, 5, H, W)
target_sample = train_targets[idx].unsqueeze(0)  # (1, 4, H, W)

results = pipeline(input_sample, target_sample)

Data shape for DDIM sampling is (1, 32, 120, 185), eta 0.2
Running DDIM Sampling with 25 timesteps


DDIM Sampler: 100%|██████████| 25/25 [26:13<00:00, 62.95s/it]


RuntimeError: The size of tensor a (370) must match the size of tensor b (740) at non-singleton dimension 3

Denormalisation and plotting  : 


In [None]:
#Denorm function

In [24]:
with open(f'{config.DATASETS_TRAINING_DIR}/RhiresD_scaling_params_chronological.json', 'r') as f:
    pr_params = json.load(f)
with open(f'{config.DATASETS_TRAINING_DIR}/TabsD_scaling_params_chronological.json', 'r') as f:
    temp_params = json.load(f)
with open(f'{config.DATASETS_TRAINING_DIR}/TminD_scaling_params_chronological.json', 'r') as f:
    temp_min_params = json.load(f)
with open(f'{config.DATASETS_TRAINING_DIR}/TmaxD_scaling_params_chronological.json', 'r') as f:
    temp_max_params = json.load(f)

def denorm_pr(x):

    return x * (pr_params['max'] - pr_params['min']) + pr_params['min']

def denorm_temp(x, params):

    return x * params['std'] + params['mean']


denorm and plotting 

In [None]:
def denorm_plot(results, input_sample, target_sample):    
    input_np = input_sample[0, :4].cpu().numpy() #Elevation removed, was not needed for plotting 
    target_np = target_sample[0].cpu().numpy()
    unet_np = results['unet_prediction'][0].cpu().numpy()
    final_np = results['final_prediction'][0].cpu().numpy()
    residuals_np = results['refined_residuals'][0].cpu().numpy()
    
    input_denorm = np.empty_like(input_np)
    target_denorm = np.empty_like(target_np)
    unet_denorm = np.empty_like(unet_np)
    final_denorm = np.empty_like(final_np)
    residuals_denorm = np.empty_like(residuals_np)
    
    for i, (var, params) in enumerate([
        ("precip", pr_params),
        ("temp", temp_params), 
        ("temp_min", temp_min_params),
        ("temp_max", temp_max_params)
    ]):
        if var == "precip":
            input_denorm[i] = denorm_pr(input_np[i])
            target_denorm[i] = denorm_pr(target_np[i])
            unet_denorm[i] = denorm_pr(unet_np[i])
            final_denorm[i] = denorm_pr(final_np[i])
            residuals_denorm[i] = denorm_pr(residuals_np[i])
        else:
            input_denorm[i] = denorm_temp(input_np[i], params)
            target_denorm[i] = denorm_temp(target_np[i], params)
            unet_denorm[i] = denorm_temp(unet_np[i], params)
            final_denorm[i] = denorm_temp(final_np[i], params)
            residuals_denorm[i] = denorm_temp(residuals_np[i], params)

    channel_names = ["Precip", "Temp", "Min Temp", "Max Temp"]
    fig, axes = plt.subplots(4, 5, figsize=(25, 20))
    
    for i in range(4): #5 samples
        # bicubic IP
        axes[i, 0].imshow(input_denorm[i], cmap='coolwarm')
        axes[i, 0].set_title(f"Input - {channel_names[i]}")
        
        # UNet
        axes[i, 1].imshow(unet_denorm[i], cmap='coolwarm')
        axes[i, 1].set_title(f"UNet Mean - {channel_names[i]}")
        
        # Refined Residuals
        axes[i, 2].imshow(residuals_denorm[i], cmap='RdBu_r')
        axes[i, 2].set_title(f"Refined Residuals - {channel_names[i]}")
        
        # Final pred by Adding Unet mean
        axes[i, 3].imshow(final_denorm[i], cmap='coolwarm')
        axes[i, 3].set_title(f"Final Prediction - {channel_names[i]}")
        
        # ground truth
        axes[i, 4].imshow(target_denorm[i], cmap='coolwarm')
        axes[i, 4].set_title(f"Ground Truth - {channel_names[i]}")
    
    plt.tight_layout()
    plt.show()

denorm_plot(results, input_sample, target_sample)