This repo is based on DiffScaler, https://github.com/DSIP-FBK/DiffScaler/tree/main/src

You are already acquainted with Pytorch now as a fanstastic library for writing, modifying, testing and scaling up the code. In this notebook, let us learn a cool wrapper library for Pytorch known as Pytorch Lightning

A great resource to learn about the differences between Pytorch and Pytorch Lightning and how Pytorch Lightning makes your life easier : https://www.geeksforgeeks.org/deep-learning/pytorch-vs-pytorch-lightning/

Please note that according to this TA, you can`t really appreciate the resurcefulness of Pytorch lightning without learning the basics of Pytorch. Hence, take this tutorial as a 10,000 ft overview of what Pytorch and Pytorch lightning actually do. 

In [8]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
import torch.nn.functional as F
from lightning import LightningModule, LightningDataModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
import hydra
import random
import zstandard
import os
import glob
import io
import xarray as xr
import rasterio
import rioxarray
import json
import yaml

In [2]:
from preprocessing import load_and_normalise, decompress_zst_pt, get_file_list, collate_fn

Training the hierarchy

In [3]:
cfg=OmegaConf.load("conf/config_experiments.yaml")

In [4]:
from short_exercise import DownscalingDataModule

In [5]:
data_module = DownscalingDataModule(
    batch_size=1,
    val_frac=cfg.data.val_split,
    test_frac=cfg.data.test_split,
    num_workers=cfg.data.num_workers,
    static_dir=cfg.paths.static_dir,
    save_stats_json=os.path.join(cfg.paths.output_dir, "stats.json")
)
data_module.setup()
test_loader = data_module.test_dataloader()

# Get a single test sample (input and ground truth)
test_iter = iter(test_loader)
input_sample, ground_truth = next(test_iter)
input_sample = input_sample[0] 
ground_truth = ground_truth[0]

Using 48 samples from 2020.
Using 48 samples from 2020.


Writing the YAML file

In [6]:
def inference(cfg, input_sample, ground_truth=None):
    stats_path = os.path.join(cfg.paths.output_dir, "stats.json")
    with open(stats_path, "r") as f:
        stats = json.load(f)
    low_2mt_mean = stats["low_2mt_mean"]
    low_2mt_std = stats["low_2mt_std"]

    def denormalise(arr):
        return arr * low_2mt_std + low_2mt_mean

    # Load ckpts for UNet, VAE, LDM
    checkpoint_dir = cfg.paths.checkpoint_dir
    unet_module = UNetLitModule.load_from_checkpoint(
        os.path.join(checkpoint_dir, "unet", sorted(os.listdir(checkpoint_dir + "/unet"))[-1])
    )
    vae_module = VAELitModule.load_from_checkpoint(
        os.path.join(checkpoint_dir, "vae", sorted(os.listdir(checkpoint_dir + "/vae"))[-1])
    )
    ldm_module = LDMLitModule.load_from_checkpoint(
        os.path.join(checkpoint_dir, "ldm", sorted(os.listdir(checkpoint_dir + "/ldm"))[-1]),
        vae=vae_module
    )

    unet_module.eval()
    vae_module.eval()
    ldm_module.eval()

    with torch.no_grad():
        device = next(unet_module.parameters()).device
        fuzzy_input = input_sample.unsqueeze(0).to(device)
        unet_pred = unet_module.net(fuzzy_input)

        input_height = vae_module.hparams.input_height // 2
        input_width = vae_module.hparams.input_width // 2
        latent_shape = (1, 1, input_height, input_width)

        # Downsample unet_pred to latent spatial size
        unet_pred_ds = F.interpolate(unet_pred, size=(input_height, input_width), mode='bilinear', align_corners=False)

        # Generate 3 samples
        generated_latents = []
        for _ in range(3):
            z_sample = ldm_module.sample(latent_shape, unet_pred_ds)
            z_sample_flat = z_sample.view(z_sample.size(0), -1)
            generated_latents.append(z_sample_flat)

        generated_latents = torch.cat(generated_latents, dim=0)

        # Decode latents to residuals
        generated_residuals = vae_module.model.decode(generated_latents, unet_pred)

        final_reconstructions = unet_pred + generated_residuals

    # Prepare images for plotting
    all_imgs = [
        denormalise(fuzzy_input[0, 0].cpu().numpy()),  # ERA5 predictor
        denormalise(final_reconstructions[0, 0].cpu().numpy()),  # Sample 1
        denormalise(final_reconstructions[1, 0].cpu().numpy()),  # Sample 2
        denormalise(final_reconstructions[2, 0].cpu().numpy()),  # Sample 3
    ]
    titles = [
        "ERA5 2m input",
        "Sample 1",
        "Sample 2",
        "Sample 3",
    ]

    # Add ground truth if provided
    if ground_truth is not None:
        all_imgs.append(denormalise(ground_truth.cpu().numpy()))
        titles.append("Ground Truth")

    vmin = min(img.min() for img in all_imgs)
    vmax = max(img.max() for img in all_imgs)

    fig, axes = plt.subplots(1, len(all_imgs), figsize=(6 * len(all_imgs), 6), constrained_layout=True)
    images = []

    for i, ax in enumerate(axes):
        im = ax.imshow(all_imgs[i], cmap='coolwarm', vmin=vmin, vmax=vmax, origin='lower')
        ax.set_title(titles[i], fontsize=15, fontweight='bold')
        ax.set_xlabel("Longitude", fontsize=13)
        ax.set_ylabel("Latitude", fontsize=13)
        ax.tick_params(axis='both', which='major', labelsize=11)
        ax.grid(False, which='both')
        images.append(im)

    cbar = fig.colorbar(images[0], ax=axes, orientation='horizontal', fraction=0.04, pad=0.08)
    cbar.set_label("2m Temperature (C)", fontsize=15, fontweight='bold')
    cbar.ax.tick_params(labelsize=13)

    plt.suptitle("ERA5 low res + 3 samples + Ground Truth (conditional inference)", fontsize=20, y=1.05, fontweight='bold')
    plt.show()

In [9]:
#Now that the models are trained,,, how will you perform inference using the trained hierarchy?
inference(cfg, input_sample, ground_truth)

NameError: name 'UNetLitModule' is not defined