In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# 2.explore

## to activate the gpu and address the args parser issue  https://stackoverflow.com/questions/48796169/how-to-fix-ipykernel-launcher-py-error-unrecognized-arguments-in-jupyter

In [2]:
import xarray

In [3]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable


def visualize(data, year, month, variable, save=False):
    if variable == "t2m":
        img = data.t2m[12*(year - 1950) + month - 1]
        name = "Surface Temperature"
        cmap = "RdYlBu_r"
    elif variable == "tp":
        img = data.tp[12*(year - 1950) + month - 1]
        name = "Total Precipitation"
        cmap = "BrBG"
    else:
        raise ValueError("Invalid variable.")
    fig, ax = plt.subplots(figsize=(8, 3.5))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    divnorm = colors.TwoSlopeNorm(vmin=img.min(), vcenter=img.mean(), vmax=img.max())
    mapping = ax.imshow(img, cmap=cmap, norm=divnorm)
    fig.colorbar(mapping, cax=cax)
    ax.set_title(f"ERA5 {name} ({month}/{year})")
    plt.tight_layout()
    if save:
        plt.savefig(f"figs/{variable}_{year}_{month}.png")
    plt.show()


def visualize_patches(data, year, month, variable, size=64, save=False):
    if variable == "t2m":
        img = data.t2m[12*(year - 1950) + month - 1]
        name = "Surface Temperature"
        cmap = "RdYlBu_r"
    elif variable == "tp":
        img = data.tp[12*(year - 1950) + month - 1]
        name = "Total Precipitation"
        cmap = "BrBG"
    else:
        raise ValueError("Invalid variable.")
    n_vertical, n_horizontal = img.shape[0] // size, img.shape[1] // size
    fig = plt.figure(figsize=(8, 4))
    gs = fig.add_gridspec(n_vertical, n_horizontal + 1, width_ratios=[1 for i in range(n_horizontal)] + [0.1])
    min_, mean_, max_ = img.min(), img.mean(), img.max()
    divnorm = colors.TwoSlopeNorm(vmin=min_, vcenter=mean_, vmax=max_)
    for i in range(n_vertical):
        for j in range(n_horizontal):
            ax = fig.add_subplot(gs[i, j])
            patch = img[i*size:(i+1)*size, j*size:(j+1)*size]
            mapping = ax.imshow(patch, norm=divnorm, cmap=cmap)
            ax.axis("off")
            ax.set_title(f"{i*size}:{(i+1)*size}, {j*size}:{(j+1)*size}", fontsize=6)
    cax = fig.add_subplot(gs[:, -1])
    fig.colorbar(mapping, cax=cax)
    fig.suptitle(f"ERA5 {name} ({month}/{year}) Patches {size}x{size}")
    plt.tight_layout()
    if save:
        plt.savefig(f"figs/{variable}_{year}_{month}_patches.png")
    plt.show()


def visualize_pool(data, year, month, variable, row, col, pool=4, size=64, save=False):

    if variable == "t2m":
        img = data.t2m[12*(year - 1950) + month - 1]
        name = "Surface Temperature"
        cmap = "RdYlBu_r"
    elif variable == "tp":
        img = data.tp[12*(year - 1950) + month - 1]
        name = "Total Precipitation"
        cmap = "BrBG"
    else:
        raise ValueError("Invalid variable.")

    # extract patch and set up figure
    img = img[row:row+size, col:col+size]
    fig = plt.figure(figsize=(8, 4))
    gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.1])
    min_, mean_, max_ = img.min(), img.mean(), img.max()
    divnorm = colors.TwoSlopeNorm(vmin=min_, vcenter=mean_, vmax=max_)

    # plot original patch
    ax = fig.add_subplot(gs[0])
    ax.axis("off")
    mapping = ax.imshow(img, norm=divnorm, cmap=cmap)

    # create downsampled patch of same dimension
    downsampled = img.values.reshape(size//pool, pool, size//pool, pool).mean(axis=(1, 3))
    upsampled = np.repeat(np.repeat(downsampled, pool, axis=0), pool, axis=1)

    # plot downsampled patch
    ax = fig.add_subplot(gs[1])
    ax.axis("off")
    ax.imshow(downsampled, norm=divnorm, cmap=cmap)

    # plot colorbar
    cax = fig.add_subplot(gs[2])
    fig.colorbar(mapping, cax=cax)
    fig.suptitle(f"ERA5 {name} ({month}/{year}) Patch {size}x{size} at ({row},{col}) Pooled {pool}x")
    plt.tight_layout()
    if save:
        plt.savefig(f"figs/{variable}_{year}_{month}_pool_{pool}_{row}_{col}.png")
    plt.show()



def visualize_nan(data, year, month, variable, row, col, pool=4, size=64, save=False):
    if variable == "t2m":
        img = data.t2m[12 * (year - 1950) + month - 1]
        name = "Surface Temperature"
        cmap = "RdYlBu_r"
    elif variable == "tp":
        img = data.tp[12 * (year - 1950) + month - 1]
        name = "Total Precipitation"
        cmap = "BrBG"
    else:
        raise ValueError("Invalid variable.")

        # extract patch and set up figure
    img = img[row:row + size, col:col + size]
    fig = plt.figure(figsize=(8, 4))
    gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.1])
    min_, mean_, max_ = img.min(), img.mean(), img.max()
    divnorm = colors.TwoSlopeNorm(vmin=min_, vcenter=mean_, vmax=max_)

    # plot original patch
    ax = fig.add_subplot(gs[0])
    ax.axis("off")
    mapping = ax.imshow(img, norm=divnorm, cmap=cmap)

    # create nan-replaced patch of same dimension
    replacement = np.nanmean(img)
    img = np.nan_to_num(img, nan=replacement)

    # plot nan-fixed patch
    ax = fig.add_subplot(gs[1])
    ax.axis("off")
    ax.imshow(img, norm=divnorm, cmap=cmap)

    # plot colorbar
    cax = fig.add_subplot(gs[2])
    fig.colorbar(mapping, cax=cax)
    fig.suptitle(f"ERA5 {name} ({month}/{year}) Patch {size}x{size} at ({row},{col}) Pooled {pool}x")
    plt.tight_layout()
    if save:
        if not os.path.exists(f"./figs"):
            os.makedirs(f"./figs")
        plt.savefig(f"./figs/nan_{variable}_{year}_{month}_pool_{pool}_{row}_{col}.png")
    plt.show()


if __name__ == "__main__":
    data = xr.open_dataset(f"../input/era5-dataset/era5.nc")
    # visualize(data, 2020, 4, "t2m", save=True)
    # visualize(data, 2020, 4, "tp", save=True)
    # visualize_patches(data, 2020, 4, "t2m", size=64, save=True)
    # visualize_patches(data, 2020, 4, "tp", size=64, save=True)
    # visualize_pool(data, 2020, 4, "t2m", row=64, col=128, size=64, pool=4, save=True)
    # visualize_pool(data, 2020, 4, "tp", row=64, col=128, size=64, pool=4, save=True)
    visualize_nan(data, 2020, 4, "t2m", row=0, col=0, size=64, pool=4, save=True)
    visualize_nan(data, 2020, 4, "tp", row=0, col=0, size=64, pool=4, save=True)

# 3. NN MODEL


In [4]:
import numpy as np
import pytorch_lightning as pl
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch.nn as nn
import torch.nn.functional as F
import torch


class BaseModel(pl.LightningModule):
    def __init__(
            self,
            input_channels: list = [0, 1],      # indices of tensor input channels to consider (0=t2m, 1=tp)
            output_channels: list = [0, 1],     # indices of tensor target channels to predict (0=t2m, 1=tp)
            lr: float = 1e-3,
            decayRate = 1
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = lr
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.input_dim = len(input_channels)
        self.output_dim = len(output_channels)
        self.exp_decay_rate = decayRate

    def training_step(self, batch, batch_idx):
        x = batch['x'][:, self.input_channels, :, :]
        y = batch['y'][:, self.output_channels, :, :]
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch['x'][:, self.input_channels, :, :]
        y = batch['y'][:, self.output_channels, :, :]
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x = batch['x'][:, self.input_channels, :, :]
        y = batch['y'][:, self.output_channels, :, :]
        y_hat = self(x)

        def _ssim_trans(x):
            return x.detach().cpu().permute(0, 2, 3, 1).numpy()

        def _psnr_trans(x):
            x_ = x.detach().cpu().numpy()
            min_ = np.amin(x_)
            max_ = np.amax(x_)
            return (x_ - min_) / (max_ - min_)

        self.log_dict({
            'MSE': F.mse_loss(y_hat, y),
            'SSIM': ssim(_ssim_trans(y_hat), _ssim_trans(y), multichannel=True),
            'PSNR': psnr(image_true=_psnr_trans(y), image_test=_psnr_trans(y_hat), data_range=1)
        })

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return {
            'optimizer': opt,
            'lr_scheduler': torch.optim.lr_scheduler.ExponentialLR(opt, self.exp_decay_rate)
        }



class SRCNN(BaseModel):
    """
    Image Super-Resolution Using Deep Convolutional Networks
    Chao Dong, Chen Change Loy, Kaiming He, and Xiaoou Tang
    https://arxiv.org/pdf/1501.00092.pdf
    """
    def __init__(
            self,
            hidden_1: int = 64,     # n_1 in the paper
            hidden_2: int = 32,     # n_2 in the paper
            kernel_1: int = 9,      # f_1
            kernel_2: int = 1,      # f_2
            kernel_3: int = 5,      # f_3
            padding: bool = True,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.hidden_1 = hidden_1
        self.hidden_2 = hidden_2
        self.kernel_1 = kernel_1
        self.kernel_2 = kernel_2
        self.kernel_3 = kernel_3
        self.padding = padding

        extra_args = {}
        if self.padding:
            extra_args["padding"] = "same"
            extra_args["padding_mode"] = "replicate"

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=self.input_dim, out_channels=self.hidden_1, kernel_size=self.kernel_1, **extra_args),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.hidden_1, out_channels=self.hidden_2, kernel_size=self.kernel_2, **extra_args),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.hidden_2, out_channels=self.output_dim, kernel_size=self.kernel_3, **extra_args),
        )

    def training_step(self, batch, batch_idx):
        x = batch['x'][:, self.input_channels, :, :]
        y = batch['y'][:, self.output_channels, :, :]
        y_hat = self(x)    
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss
        
    def forward(self, x):
        return self.layers(x)


class VDSR(BaseModel):
    """
    Accurate Image Super-Resolution Using Very Deep Convolutional Networks
    Jiwon Kim, Jung Kwon Lee and Kyoung Mu Lee
    https://arxiv.org/pdf/1511.04587.pdf
    """
    def __init__(
            self,
            d: int = 20,                # d=20 in the paper
            kernel: int = 3,            # k=3 in the paper
            hidden_dim: int = 64,       # hidden_dim=64 in the paper
            pre_post_kernel: int = 9,   # not in paper, but we draw inspiration from SRCNN and SRResNet
            **kwargs
    ):
        super().__init__(**kwargs)
        self.d = d
        self.kernel = kernel
        self.hidden_dim = hidden_dim
        self.pre_post_kernel = pre_post_kernel

        # construct layer before blocks
        pre_layers = [
            nn.Conv2d(in_channels=self.input_dim, out_channels=self.hidden_dim, kernel_size=self.pre_post_kernel, padding="same", padding_mode="replicate"),
            nn.ReLU()
        ]
        self.pre_layers = nn.Sequential(*pre_layers)

        # construct main set of blocks
        blocks = []
        for _ in range(d-2):
            blocks.append(nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"))
            blocks.append(nn.ReLU())
        self.blocks = nn.Sequential(*blocks)

        # construct layer after blocks and residual connection
        post_layers = [
            nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.output_dim, kernel_size=self.pre_post_kernel, padding="same", padding_mode="replicate"),
        ]
        self.post_layers = nn.Sequential(*post_layers)

    def forward(self, x):
        x = self.pre_layers(x)  # preprocess input to hidden_dim x h x w
        x = x + self.blocks(x)  # apply post-block skip connection
        x = self.post_layers(x)     # postprocess back to c x h x w
        return x


class SRResNet(BaseModel):
    """
    Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
    Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, ´
    Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, Wenzhe Shi
    https://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf
    """
    def __init__(
            self,
            n_blocks: int = 16,         # n_blocks=16 in the paper
            kernel: int = 3,            # k=3 in the paper
            pre_post_kernel: int = 9,   # pre_post_kernel=9 in the paper
            hidden_dim: int = 64,       # hidden_dim=64 in the paper
            **kwargs
    ):
        super().__init__(**kwargs)
        self.n_blocks = n_blocks
        self.kernel = kernel
        self.pre_post_kernel = pre_post_kernel
        self.hidden_dim = hidden_dim

        # construct layers before residual blocks
        pre_layers = [
            nn.Conv2d(in_channels=self.input_dim, out_channels=self.hidden_dim, kernel_size=self.pre_post_kernel, padding="same", padding_mode="replicate"),
            nn.PReLU()
        ]
        self.pre_layers = nn.Sequential(*pre_layers)

        # construct residual blocks
        blocks = []
        for _ in range(n_blocks):
            block = [
                nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"),
                nn.BatchNorm2d(num_features=self.hidden_dim),
                nn.PReLU(),
                nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"),
                nn.BatchNorm2d(num_features=self.hidden_dim)
            ]
            block = nn.Sequential(*block)
            blocks.append(block)
        self.blocks = nn.ModuleList(blocks)

        # construct layers after residual blocks, before last residual connection
        post_layers_1 = [
            nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"),
            nn.BatchNorm2d(num_features=self.hidden_dim)
        ]
        self.post_layers_1 = nn.Sequential(*post_layers_1)

        # construct layers after residual blocks, after last residual connection
        post_layers_2 = [
            nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"),
            nn.PReLU(),
            nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.hidden_dim, kernel_size=self.kernel, padding="same", padding_mode="replicate"),
            nn.PReLU(),
            nn.Conv2d(in_channels=self.hidden_dim, out_channels=self.output_dim, kernel_size=self.pre_post_kernel, padding="same", padding_mode="replicate"),
        ]
        self.post_layers_2 = nn.Sequential(*post_layers_2)

    def forward(self, x):
        x = self.pre_layers(x)                  # preprocess input to hidden_dim x h x w
        x_pre = x                               # save for residual connections later
        for block in self.blocks:               # apply residual blocks
            x = x + block(x)
        x = x_pre + self.post_layers_1(x)       # apply post-block skip connection
        x = self.post_layers_2(x)               # postprocess back to c x h x w
        return x


class Nearest(BaseModel):
    """
    Baseline: apply nearest-neighbor upscaling from LR to HR. Because our dataloaders automatically apply
    nearest-neighbor upscaling to ensure LR and HR are of same dimension, this model implements the identity function.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x):
        return x


class Bilinear(BaseModel):
    """
    Baseline: apply bilinear interpolation-based upscaling from LR to HR. Because our dataloaders automatically apply
    nearest-neighbor upscaling to ensure LR and HR are of same dimension, we must first downscale then upscale
    to use prebuilt PyTorch code.
    """
    def __init__(self, pool_size, **kwargs):
        super().__init__(**kwargs)
        self.pool_size = pool_size
        self.pool = nn.AvgPool2d(kernel_size=self.pool_size, stride=self.pool_size)

    def forward(self, x):
        size = x.shape[-1]
        x = self.pool(x)    # note that x is already pooled and repeated; this changes size but does not change content
        x = F.interpolate(x, size=size, mode="bilinear")
        return x


class Bicubic(BaseModel):
    """
    Baseline: apply bicubic interpolation-based upscaling from LR to HR. Because our dataloaders automatically apply
    nearest-neighbor upscaling to ensure LR and HR are of same dimension, we must first downscale then upscale
    to use prebuilt PyTorch code.
    """
    def __init__(self, pool_size, **kwargs):
        super().__init__(**kwargs)
        self.pool_size = pool_size
        self.pool = nn.AvgPool2d(kernel_size=self.pool_size, stride=self.pool_size)

    def forward(self, x):
        size = x.shape[-1]
        x = self.pool(x)  # note that x is already pooled and repeated; this changes size but does not change content
        x = F.interpolate(x, size=size, mode="bicubic")
        return x



# SRGAN Model-(keras)

In [None]:
class SRGAN(Basemodel):
    def __init__(
            self,
            n_blocks: int = 16,         # n_blocks=16 in the paper
            kernel: int = 3,            # k=3 in the paper
            pre_post_kernel: int = 9,   # pre_post_kernel=9 in the paper
            hidden_dim: int = 64,       # hidden_dim=64 in the paper
            **kwargs
    ):
    
    def res_block_gen(model, kernal_size, filters, strides):
    
    gen = model
    
    model =nn.Conv2D(in_channels= filters,out_channels=self.hidden_dim, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = BatchNormalization(momentum = 0.5)(model)
    # Using Parametric ReLU
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = BatchNormalization(momentum = 0.5)(model)
        
    model = add([gen, model])
    
    return model
    
    
def up_sampling_block(model, kernal_size, filters, strides):
    
    # In place of Conv2D and UpSampling2D we can also use Conv2DTranspose (Both are used for Deconvolution)
    # Even we can have our own function for deconvolution (i.e one made in Utils.py)
    #model = Conv2DTranspose(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
    model = UpSampling2D(size = 2)(model)
    model = LeakyReLU(alpha = 0.2)(model)
    
    return model


def discriminator_block(model, filters, kernel_size, strides):
    
    model = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same")(model)
    model = BatchNormalization(momentum = 0.5)(model)
    model = LeakyReLU(alpha = 0.2)(model)
    
    return model

# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf


def __init__(self, noise_shape):

    self.noise_shape = noise_shape

def generator(self):

    gen_input = Input(shape = self.noise_shape)

    model = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)

    gen_model = model

    # Using 16 Residual Blocks
    for index in range(16):
        model = res_block_gen(model, 3, 64, 1)

    model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model)
    model = BatchNormalization(momentum = 0.5)(model)
    model = add([gen_model, model])

    # Using 2 UpSampling Blocks
    for index in range(2):
        model = up_sampling_block(model, 3, 256, 1)

    model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(model)
    model = Activation('tanh')(model)

    generator_model = Model(inputs = gen_input, outputs = model)

    return generator_model

# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
def Discriminator(object):

def __init__(self, image_shape):

    self.image_shape = image_shape

def discriminator(self):

    dis_input = Input(shape = self.image_shape)

    model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(dis_input)
    model = LeakyReLU(alpha = 0.2)(model)

    model = discriminator_block(model, 64, 3, 2)
    model = discriminator_block(model, 128, 3, 1)
    model = discriminator_block(model, 128, 3, 2)
    model = discriminator_block(model, 256, 3, 1)
    model = discriminator_block(model, 256, 3, 2)
    model = discriminator_block(model, 512, 3, 1)
    model = discriminator_block(model, 512, 3, 2)

    model = Flatten()(model)
    model = Dense(1024)(model)
    model = LeakyReLU(alpha = 0.2)(model)

    model = Dense(1)(model)
    model = Activation('sigmoid')(model) 

    discriminator_model = Model(inputs = dis_input, outputs = model)

    return discriminator_model


# 1. Data Acquired

In [None]:
from typing import List
import numpy as np
import matplotlib.pyplot as plt
import torch
import xarray as xr
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from matplotlib import cm


class ERA5Data(Dataset):

    def __init__(self, datasets: List[xr.Dataset], patch_size: int, pool_size: int, pool_type: str):
        # each element of data is an xr.Dataset representing a different physical variable
        # in our case, data = [t2m, tp] = [temp @ 2 meters, total precipitation]
        # we can think of each element of data as representing a different image channel
        # we merge these channels into a single c x h x w tensor in __getitem__
        self.datasets = datasets
        self.patch_size = patch_size
        self.pool_size = pool_size
        self.pool_type = pool_type
        self.n_channels = len(self.datasets)
        self.n_months = datasets[0].shape[0]
        self.n_vertical = datasets[0].shape[1] // patch_size
        self.n_horizontal = datasets[0].shape[2] // patch_size
        self.mus, self.stds = [], []
        for dataset in self.datasets:
            mu = np.nanmean(dataset.values)
            s = np.nanstd(dataset.values)
            dataset.values = (dataset.values - mu) / s
            self.mus.append(mu)
            self.stds.append(s)
        self.mus = np.array(self.mus).reshape((1, -1, 1, 1))
        self.stds = np.array(self.stds).reshape((1, -1, 1, 1))


    def __len__(self):
        return self.n_months * self.n_vertical * self.n_horizontal

    def __getitem__(self, i: int):
        # e.g., if we have 100 patches per month, then i=527 corresponds to 5th month and 527//100 = 5
        month = i // (self.n_vertical * self.n_horizontal)

        # e.g., i=527 corresponds to 27th patch in image and 527 % 100 = 27
        patch = i % (self.n_vertical * self.n_horizontal)

        # e.g., if each image is (h, w) = (5, 20) then patch 27 corresponds to 2nd row and 27 // 20 = 1
        row = patch // self.n_horizontal

        # e.g., if each image is (h, w) = (5, 20) then patch 27 corresponds to 8th col and 27 % 20 = 7
        col = patch % self.n_horizontal

        # extract patch for this month, this vertical offset, and this horizontal offset by collating channels
        ps = self.patch_size
        input_channels = []
        target_channels = []
        for c in range(self.n_channels):
            # extract one channel at a time from self.datasets list of xr.Datasets
            channel = self.datasets[c][month, row*ps:(row+1)*ps, col*ps:(col+1)*ps].values

            # if more than half of values are nan, skip this region and return None
            # None will be handled properly by collate_fn of DataLoader
            if np.sum(np.isnan(channel)) > ps ** 2 / 2:
                return None
            # otherwise replace remaining nans with mean of region
            else:
                replacement = np.nanmean(channel)
                channel = np.nan_to_num(channel, nan=replacement)

            # this original full-resolution channel is the target
            target_channels.append(channel)

            # downsample and upsample to produce low-resolution input channel
            # https://stackoverflow.com/a/42463514
            if self.pool_type == "mean":
                downsampled = channel.reshape(ps // self.pool_size, self.pool_size,
                                              ps // self.pool_size, self.pool_size).mean(axis=(1, 3))
            elif self.pool_type == "max":
                downsampled = channel.reshape(ps // self.pool_size, self.pool_size,
                                              ps // self.pool_size, self.pool_size).max(axis=(1, 3))
            else:
                raise ValueError("Invalid pooling type.")
            upsampled = np.repeat(np.repeat(downsampled, self.pool_size, axis=0), self.pool_size, axis=1)
            input_channels.append(upsampled)

        # return input x and output y for batch collation
        input = torch.from_numpy(np.array(input_channels))
        target = torch.from_numpy(np.array(target_channels))

        return {"x": input, "y": target}

    def _get_batch(self, years_per_batch, start_month=0, stop_month=None):
        # deprecated: gets whole-geographic batch for a given number of years
        if stop_month is None:
            stop_month = self.n_months
        inputs, targets = [], []
        for month in range(start_month, stop_month):
            for row in range(self.n_vertical):
                for col in range(self.n_horizontal):
                    ps = self.patch_size
                    input_channels = []
                    target_channels = []
                    for c in range(self.n_channels):
                        # extract one channel at a time from self.datasets list of xr.Datasets
                        channel = self.datasets[c][month, row*ps:(row+1)*ps, col*ps:(col+1)*ps].values

                        # if more than half of values are nan, skip this region and return None
                        # None will be handled properly by collate_fn of DataLoader
                        if np.sum(np.isnan(channel)) > ps**2 / 2:
                            continue
                        # otherwise replace remaining nans with mean of region
                        else:
                            replacement = np.nanmean(channel)
                            channel = np.nan_to_num(channel, nan=replacement)

                        # this original full-resolution channel is the target
                        target_channels.append(channel)

                        # downsample and upsample to produce low-resolution input channel
                        # https://stackoverflow.com/a/42463514
                        if self.pool_type == "mean":
                            downsampled = channel.reshape(ps // self.pool_size, self.pool_size,
                                                          ps // self.pool_size, self.pool_size).mean(axis=(1, 3))
                        elif self.pool_type == "max":
                            downsampled = channel.reshape(ps // self.pool_size, self.pool_size,
                                                          ps // self.pool_size, self.pool_size).max(axis=(1, 3))
                        else:
                            raise ValueError("Invalid pooling type.")
                        upsampled = np.repeat(np.repeat(downsampled, self.pool_size, axis=0), self.pool_size, axis=1)
                        input_channels.append(upsampled)

                    # save input x and output y for batch if valid
                    if len(input_channels) > 0:
                        inputs.append(np.array(input_channels))
                        targets.append(np.array(target_channels))

            if (month // 12) % years_per_batch == 0:
                # return a batch with a few years
                input_tensor = torch.from_numpy(np.array(inputs))
                target_tensor = torch.from_numpy(np.array(targets))
                yield {"x": input_tensor, "y": target_tensor}
                inputs, targets = [], []


class ERA5DataModule(pl.LightningDataModule):

    def __init__(self, args):
        # setup construction parameters
        self.patch_size = args.get("patch_size", 64)
        self.pool_size = args.get("pool_size", 2)
        self.pool_type = args.get("pool_type", "mean")

        # setup data
        self.data = xr.open_dataset(f"../input/era5-dataset/era5.nc")
        self.train_start = args.get("train_start", 1950)
        self.train_end = args.get("train_end", 2000)
        self.val_start = args.get("val_start", 2000)
        self.val_end = args.get("val_end", 2010)
        self.test_start = args.get("test_start", 2010)
        self.test_end = args.get("test_end", 2020)
        train_data = [getattr(self.data, x)[12*(self.train_start - 1950):12*(self.train_end - 1950)] for x in ["t2m", "tp"]]
        val_data = [getattr(self.data, x)[12*(self.val_start - 1950):12*(self.val_end - 1950)] for x in ["t2m", "tp"]]
        test_data = [getattr(self.data, x)[12*(self.test_start - 1950):12*(self.test_end - 1950)] for x in ["t2m", "tp"]]
        self.train_data = ERA5Data(train_data, self.patch_size, self.pool_size, self.pool_type)
        self.val_data = ERA5Data(val_data, self.patch_size, self.pool_size, self.pool_type)
        self.test_data = ERA5Data(test_data, self.patch_size, self.pool_size, self.pool_type)

        # setup loader parameters
        self.batch_size = args.get("batch_size", 32)

    def collate_fn(self, batch):
        # get rid of None in minibatch arising from edges of dataset
        # https://discuss.pytorch.org/t/questions-about-dataloader-and-dataset/806/7
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def eval_dataloader(self, pool_size):
        # experiment with performance on non-native pool size
        eval_data = [getattr(self.data, x)[12 * (self.test_start - 1950):12 * (self.test_end - 1950)] for x in ["t2m", "tp"]]
        eval_data = ERA5Data(eval_data, self.patch_size, pool_size, self.pool_type)
        return DataLoader(eval_data, batch_size=self.batch_size, collate_fn=self.collate_fn)


if __name__ == "__main__":

    # test entire DataModule with __getitem__ indexing
    datamodule = ERA5DataModule(args={"pool_size": 4, "batch_size": 32})
    dataloader = datamodule.train_dataloader()
    fig, ax = plt.subplots(4, 4, figsize=(12, 12))
    batch = next(iter(dataloader))
    x = batch["x"].detach().cpu().numpy()
    y = batch["y"].detach().cpu().numpy()
    for i in range(4):
        ax[i, 0].imshow(x[i][0], cmap=cm.RdYlBu_r); ax[i, 0].set_title("T2M @ LR")
        ax[i, 1].imshow(y[i][0], cmap=cm.RdYlBu_r); ax[i, 1].set_title("T2M @ HR")
        ax[i, 2].imshow(x[i][1], cmap=cm.BrBG); ax[i, 2].set_title("TP @ LR")
        ax[i, 3].imshow(y[i][1], cmap=cm.BrBG); ax[i, 3].set_title("TP @ HR")
    plt.tight_layout()
    plt.show()

    # test Dataset with deprecated _get_batch indexing
    # dataset = datamodule.train_data
    # fig, ax = plt.subplots(4, 4, figsize=(12, 12))
    # x = next(dataset._get_batch(1))["x"].detach().cpu().numpy()
    # y = next(dataset._get_batch(1))["y"].detach().cpu().numpy()
    # for i in range(4):
    #     ax[i, 0].imshow(x[i][0], cmap=cm.RdYlBu_r); ax[i, 0].set_title("T2M @ LR")
    #     ax[i, 1].imshow(y[i][0], cmap=cm.RdYlBu_r); ax[i, 1].set_title("T2M @ HR")
    #     ax[i, 2].imshow(x[i][1], cmap=cm.BrBG); ax[i, 2].set_title("TP @ LR")
    #     ax[i, 3].imshow(y[i][1], cmap=cm.BrBG); ax[i, 3].set_title("TP @ HR")
    # plt.tight_layout()
    # plt.show()

# 4 TRAIN.py

In [None]:
import os
from argparse import ArgumentParser
# from models import SRCNN, VDSR, SRResNet
# from data import ERA5DataModule
import pytorch_lightning as pl
import wandb

# from utils import ImageVisCallback

wandb.init(project='cv-proj', entity="cv803f21-superres")


def main(args):

    # configure data module
    e = ERA5DataModule(args={
        "pool_size": args.pool_size,
        "batch_size": args.batch_size,
        "patch_size": args.patch_size
    })
    train_dl, val_dl = e.train_dataloader(), e.val_dataloader()

    # input channels controls which channels we use as predictors
    # output channels controls which channels we use as targets, i.e., loss signal
    # channel 0 corresponds to t2m and channel 1 corresponds to tp
    # e.g., input_channels=[0, 1], output_channels=[1] predicts tp @ HR using t2m AND tp @ LR
    # e.g., input_channels=[1],    output_channels=[1] predicts tp @ HR using ONLY tp @ LR
    # ...etc.
    args.model = args.model if hasattr(args, "model") else "SRCNN"
    if args.model.lower() == "vdsr":
        print("Constructing VDSR")
        model = VDSR(input_channels=[0, 1], output_channels=[0, 1], lr=args.lr, decayRate=args.decay_Rate)
    elif args.model.lower() == "srresnet":
        print("Constructing SRResNet")
        model = SRResNet(input_channels=[0, 1], output_channels=[0, 1], lr=args.lr)
    elif args.model.lower() == "srcnn":
        print("Constructing SRCNN")
        model = SRCNN(input_channels=[0, 1], output_channels=[0, 1], lr=args.lr)
    else:
        raise ValueError("Invalid model architecture.")

    # Wandb logging
    wandb_logger = pl.loggers.WandbLogger(project='cv-proj')
    wandb_logger.watch(model, log_freq=500)

    trainer: pl.Trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger = wandb_logger
    trainer.callbacks.append(ImageVisCallback(val_dl))

    trainer.fit(model, train_dl, val_dl)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--model', default="SRCNN", type=str, help="Model to train")
    parser.add_argument('--batch_size', default=128, type=int, help="Batch size to train with")
    parser.add_argument('--pool_size', default=4, type=int, help="Super-resolution factor")
    parser.add_argument('--patch_size', default=64, type=int, help="Image patch size to super-resolve")
    parser.add_argument('--lr', default=1e-3, type=float, help="Learning rate")
    parser.add_argument("--decay_Rate", default=1, type=float, help="Exponential decay rate")
    args, unknown = parser.parse_known_args()
    args.gpus=1
    main(args)

# 5.test

In [None]:
import argparse
import pytorch_lightning as pl
# from data import ERA5DataModule
# from models import SRCNN, VDSR, SRResNet, Nearest, Bilinear, Bicubic

def main(args):
    e = ERA5DataModule(args={
        "pool_size": args.pool_size,
        "batch_size": args.batch_size,
        "patch_size": args.patch_size
    })
    test_dl = e.test_dataloader()

    if args.model.lower() == 'srcnn':
        print("Testing SRCNN")
        model = SRCNN
        model = model.load_from_checkpoint(args.checkpoint)
    elif args.model.lower() == 'srresnet':
        print("Testing SRResNet")
        model = SRResNet
        model = model.load_from_checkpoint(args.checkpoint)
    elif args.model.lower() == 'vdsr':
        print("Testing VDSR")
        model = VDSR
        model = model.load_from_checkpoint(args.checkpoint)
    elif args.model.lower() == 'nearest':
        print("Testing Nearest")
        model = Nearest()
    elif args.model.lower() == 'bilinear':
        print("Testing Bilinear")
        model = Bilinear(pool_size=args.pool_size)
    elif args.model.lower() == 'bicubic':
        print("Testing Bicubic")
        model = Bicubic(pool_size=args.pool_size)
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    # Wandb logging
    wandb_logger = pl.loggers.WandbLogger(project='cv-proj')
    wandb_logger.watch(model, log_freq=500)

    trainer: pl.Trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger = wandb_logger

    trainer.test(model, test_dl)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--model', default="SRCNN", type=str, help="Model to test")
    parser.add_argument('--checkpoint', type=str, help="Checkpoint file (.ckpt)")
    parser.add_argument('--batch_size', default=16, type=int, help="Batch size to train with")
    parser.add_argument('--pool_size', default=4, type=int, help="Super-resolution factor")
    parser.add_argument('--patch_size', default=64, type=int, help="Image patch size to super-resolve")

    args, unknown = parser.parse_known_args()
    args.gpus = 1
    # args.model = "SRCNN"
    # args.checkpoint = "cv-proj/SRCNN-lr4-lyric-tree-103-epoch=433-step=73345.ckpt"
    # args.model = "VDSR"
    # args.checkpoint = "cv-proj/VDSR-lr4-balmy-sound-102-epoch=86-step=14702.ckpt"
    # args.model = "SRResNet"
    # args.checkpoint = "cv-proj/SRResNet-lr4-robust-capybara-101-epoch=45-step=15547.ckpt"

    print(f"Loading checkpoint {args.checkpoint}")
    main(args)

# utils

In [None]:

import os
import torch
import wandb
import pytorch_lightning as pl


class ImageVisCallback(pl.Callback):
    def __init__(self, val_Dataloader, max_samples=10):
        super().__init__()

        self.valLoader = val_Dataloader
        self.max_samples = max_samples

    def on_validation_end(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
        '''imgsA = self.val_imgs.to(device=model.device).unsqueeze(0)
        imgsA = imgsA[:, model.output_channels, :, :]
        imgsY = self.val_y.to(device=model.device).unsqueeze(0)
        imgsY = imgsY[:, model.output_channels, :, :]
        '''

        val_dl = self.valLoader
        dataiter = iter(val_dl)
        for i in range(self.max_samples):
            test = dataiter.next()

            imgs = test['x'][0].to(device=model.device).unsqueeze(0)
            imgs = imgs[:, model.output_channels, :, :]

            imgsY = test['y'][0].to(device=model.device).unsqueeze(0)
            imgsY = imgsY[:, model.output_channels, :, :]

            upresed = model(imgs)

            mosaics = torch.cat([imgs, upresed, imgsY], dim=-2)
            caption = "Image {}: Top: Low Res, Middle: High Res Prediction, Bottom: High Res Truth".format(i)

            logname = "val/examples{}".format(i) if os.name != "nt" else "val\examples{}".format(i)
            trainer.logger.experiment.log({
                logname: [wandb.Image(mosaic, caption) for mosaic in mosaics],
            })

        trainer.logger.experiment.log({
            "global_step": trainer.global_step  # This will make sure wandb gets the epoch/step right
        })

# 6. eval

In [None]:

import argparse
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from matplotlib import colors
# from data import ERA5DataModule
# from models import SRCNN, VDSR, SRResNet, Nearest, Bilinear, Bicubic


def visualize_preds(dl, models, suptitle="Predictions", file_name="preds", save=True):

    batch = list(dl.dataset._get_batch(years_per_batch=1, start_month=108))[3]
    x = batch["x"].detach().cpu().numpy() * dl.dataset.stds + dl.dataset.mus
    y = batch["y"].detach().cpu().numpy() * dl.dataset.stds + dl.dataset.mus
    y_hats = {}
    for model_name, model in models.items():
        y_hats[model_name] = model(batch["x"]).detach().cpu().numpy() * dl.dataset.stds + dl.dataset.mus

    # extract patch indices we care about
    patches = {
        "WA": 0,
        "MI": 6,
        "UT": 10,
        "PA": 15,
        "TX": 25
    }

    fig = plt.figure(figsize=(20, 2*len(y_hats) + 4))
    gs = fig.add_gridspec(len(y_hats) + 2, 12, width_ratios=[1]*10 + [0.1]*2)
    
    t2m_mins, t2m_means, t2m_maxs = [], [], []
    tp_mins, tp_means, tp_maxs = [], [], []
    for z in [x, y] + list(y_hats.values()):
        t2m_mins.append(z[:, 0, :, :].min())
        t2m_means.append(z[:, 0, :, :].mean())
        t2m_maxs.append(z[:, 0, :, :].max())
        tp_mins.append(z[:, 1, :, :].min())
        tp_means.append(z[:, 1, :, :].mean())
        tp_maxs.append(z[:, 1, :, :].max())
    divnorm_t2m = colors.TwoSlopeNorm(vmin=min(t2m_mins), vcenter=np.mean(t2m_means), vmax=max(t2m_maxs))
    divnorm_tp = colors.TwoSlopeNorm(vmin=min(tp_mins), vcenter=np.mean(tp_means), vmax=max(tp_maxs))

    # plot images
    all_axes = []
    for i, (state, patch) in enumerate(patches.items()):
        x_t2m, y_t2m = x[patch, 0], y[patch, 0]
        y_hats_t2m = [y_hat[patch, 0] for y_hat in y_hats.values()]
        x_tp, y_tp = x[patch, 1], y[patch, 1]
        y_hats_tp = [y_hat[patch, 1] for y_hat in y_hats.values()]

        # t2m
        imgs = [x_t2m] + y_hats_t2m + [y_t2m]
        axes = []
        for j in range(len(imgs)):
            ax = fig.add_subplot(gs[j, 2*i])
            ax.set_xticks([])
            ax.set_yticks([])
            mapping_t2m = ax.imshow(imgs[j], norm=divnorm_t2m, cmap="RdYlBu_r")
            axes.append(ax)
        all_axes.append(axes)

        # tp
        imgs = [x_tp] + y_hats_tp + [y_tp]
        axes = []
        for j in range(len(imgs)):
            ax = fig.add_subplot(gs[j, 2*i + 1])
            ax.set_xticks([])
            ax.set_yticks([])
            mapping_tp = ax.imshow(imgs[j], norm=divnorm_tp, cmap="BrBG")
            axes.append(ax)
        all_axes.append(axes)

    # set up titles
    all_axes = np.array(all_axes).T
    all_axes[0, 0].set_ylabel("LR Input")
    for i, model_name in enumerate(models.keys()):
        all_axes[i+1, 0].set_ylabel(model_name)
    all_axes[-1, 0].set_ylabel("HR Truth")

    variables = ["T2M", "TP"]
    for i, (state, patch) in enumerate(patches.items()):
        all_axes[0, 2*i].set_title(f"{state} ({variables[0]})")
        all_axes[0, 2*i+1].set_title(f"{state} ({variables[1]})")

    # set up colorbars
    t2m_cax = fig.add_subplot(gs[:, -2])
    tp_cax = fig.add_subplot(gs[:, -1])
    fig.colorbar(mapping_t2m, cax=t2m_cax)
    fig.colorbar(mapping_tp, cax=tp_cax)

    plt.suptitle(suptitle)
    plt.tight_layout()
    if save:
        plt.savefig(f"figs/{file_name}.png")
    plt.show()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--batch_size', default=128, type=int, help="Batch size to train with")
    parser.add_argument('--pool_size', default=4, type=int, help="Super-resolution factor")
    parser.add_argument('--patch_size', default=64, type=int, help="Image patch size to super-resolve")
    args = parser.parse_args()
    args.gpus = 1
    models = {
        "Nearest": Nearest(),
        "Bilinear": Bilinear(pool_size=4),
        "Bicubic": Bicubic(pool_size=4),
        "SRCNN": SRCNN.load_from_checkpoint("cv-proj/SRCNN-lr4-lyric-tree-103-epoch=433-step=73345.ckpt"),
        "VDSR": VDSR.load_from_checkpoint("cv-proj/VDSR-lr4-balmy-sound-102-epoch=86-step=14702.ckpt"),
        "SRResNet": SRResNet.load_from_checkpoint("cv-proj/SRResNet-lr4-robust-capybara-101-epoch=45-step=15547.ckpt")
    }
    e = ERA5DataModule(args={
        "pool_size": args.pool_size,
        "batch_size": args.batch_size,
        "patch_size": args.patch_size
    })
    test_dl = e.test_dataloader()

    # test on usual 4x for April 2020
    visualize_preds(test_dl, models, suptitle=f"Predictions 4x (4/2020)", file_name=f"preds_2020_4_pool_4.png")

    # test also on 8x for April 2020
    eval_dl_8x = e.eval_dataloader(pool_size=8)
    models["Bilinear"] = Bilinear(pool_size=8)
    models["Bicubic"] = Bilinear(pool_size=8)
    visualize_preds(eval_dl_8x, models, suptitle=f"Predictions 8x (4/2020)", file_name=f"preds_2020_4_pool_8.png")

    # test also on 16x for April 2020
    eval_dl_16x = e.eval_dataloader(pool_size=16)
    models["Bilinear"] = Bilinear(pool_size=16)
    models["Bicubic"] = Bilinear(pool_size=16)
    visualize_preds(eval_dl_16x, models, suptitle=f"Predictions 16x (4/2020)", file_name=f"preds_2020_4_pool_16.png")