This section mainly tests the peak video memory required to input data of corresponding size under the same batch size for the local weather forecasting model and the global weather forecasting model, and the time required for a forward-backward training.

In [1]:
# Imports
import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import gc
from numba import cuda
from torch.utils.data import Dataset

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

Device: cuda:0


In [3]:
# Local weather forecasting model with pooling layers
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()


        if isinstance(kernel_size, int):
            padding = kernel_size // 2
        else:
            padding = (kernel_size[0] // 2, kernel_size[1] // 2) # kernel_size[0] is height, kernel_size[1] is width

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)

        self.norm1 = nn.GroupNorm(1, out_channels)
        self.norm2 = nn.GroupNorm(1, out_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        skip = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation(x)
        x = self.conv2(x)

        x = self.norm2(x)
        x = self.activation(x)
        x = x + skip
        return x

class ResNetPatch(nn.Module):
    def __init__(self, in_channels=2, out_channels=2, hidden_channels=64, kernel_size=(3, 3), patch_size=68, depth=4):
        super().__init__()
        self.lift = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)

        layers = []
        for i in range(depth):
            layers.append(ResidualBlock(hidden_channels, hidden_channels, kernel_size))
            if i < depth - 1:
                 layers.append(nn.MaxPool2d(kernel_size=2, stride=2))  # Maxpooling
        self.layers = nn.Sequential(*layers)

        final_size = patch_size// (2 ** (depth-1))
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(hidden_channels * final_size * final_size, out_channels * 4 * 4)

        self.out_channels = out_channels

    def forward(self, x):
        x = self.lift(x)
        x = self.layers(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = x.view(x.size(0), self.out_channels, 4, 4)
        return x

In [4]:
model = ResNetPatch(kernel_size=(3, 3), depth=4)
x = torch.randn(8, 2, 68, 68)
y = model(x)
print(y.shape)

torch.Size([8, 2, 4, 4])


In [5]:
# Global weather forecasting model
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()


        if isinstance(kernel_size, int):
            padding = kernel_size // 2
        else:
            padding = (kernel_size[0] // 2, kernel_size[1] // 2) # kernel_size[0] is height, kernel_size[1] is width

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)

        self.norm1 = nn.GroupNorm(1, out_channels)
        self.norm2 = nn.GroupNorm(1, out_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        skip = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activation(x)
        x = x + skip
        return x

class ResNet(nn.Module):
    def __init__(
        self,
        in_channels=2,
        out_channels=2,
        hidden_channels=64,
        kernel_size=(3, 3),
        depth=4
    ):
        super().__init__()


        if isinstance(kernel_size, int):
            pad = kernel_size // 2
        else:
            pad = (kernel_size[0] // 2, kernel_size[1] // 2)

        self.lift = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)

        layers = []
        for _ in range(depth):
            layers.append(ResidualBlock(hidden_channels, hidden_channels, kernel_size))
        self.layers = nn.ModuleList(layers)

        self.proj = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.lift(x)

        for layer in self.layers:
          x = layer(x)

        x = self.proj(x)
        return x

In [6]:
resnet = ResNet().to(device)

In [7]:
import time

def measure_performance(model, batch_size, input_shape, device='cuda'):
    model = model.to(device)
    model.train() # Generate random tensors, simulate input
    x = torch.randn(batch_size, *input_shape).to(device)

    torch.cuda.reset_peak_memory_stats(device) # Reset memory
    start = time.time()

    # Forward + Backward
    output = model(x)
    loss = output.sum()
    loss.backward()

    elapsed = time.time() - start
    mem = torch.cuda.max_memory_allocated(device) / 1024**2  # MB

    return mem, elapsed


In [8]:
batch_size = 16
full_shape = (2, 256, 512)
patch_shape = (2, 68, 68)

# Full-map model
full_mem, full_time = measure_performance(resnet, batch_size, full_shape, device = 'cuda')

# Local model
patch_mem, patch_time = measure_performance(model, batch_size, patch_shape, device = 'cuda')

# Per-pixel cost
full_pixels = batch_size * full_shape[1] * full_shape[2]
patch_pixels = batch_size * patch_shape[1] * patch_shape[2]

print(f"Full Model:  BS={batch_size}, Mem={full_mem:.2f}MB, Time={full_time:.4f}s, "
      f"Mem/px={(full_mem/full_pixels):.6f}MB, Time/px={(full_time/full_pixels):.6e}s")
print(f"Patch Model: BS={batch_size}, Mem={patch_mem:.2f}MB, Time={patch_time:.4f}s, "
      f"Mem/px={(patch_mem/patch_pixels):.6f}MB, Time/px={(patch_time/patch_pixels):.6e}s")

Full Model:  BS=16, Mem=13089.85MB, Time=1.2967s, Mem/px=0.006242MB, Time/px=6.183335e-07s
Patch Model: BS=16, Mem=212.45MB, Time=0.1014s, Mem/px=0.002872MB, Time/px=1.370491e-06s
