In [1]:
import xarray as xr
import numpy as np
import os

In [2]:
import lightning as L

In [3]:
from torch.utils.data import Dataset, DataLoader

In [4]:
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.nn import functional as F
import torch

Contents
---

- Dataset: initialize with lazy loading, load data only when needed.
- Dataloader: collect data for forming batch, consider num_workers keyword for optimal performance
- Model

In [5]:
class DemoDataset(Dataset):
    def __init__(self, resolution='1H', start_time='1960-01-01', end_time='2000-01-01'):
        '''
        Initializes the dataset.

        Note that data is not yet loaded into memory here.
        
        Arguments:
        resolution (str): Temporal resolution (default: 1D - daily)
        start_time (str): Beginning of time slice (default: 1960-01-01)
        end_time (str): End of time slice (default: 2000-01-01)
        '''
        self.resolution = resolution
        self.data_root = f'/pool/data/ERA5/E5/sf/an/{self.resolution}/167/'
        # lazy loading
        self.dataset = xr.open_mfdataset(os.path.join(self.data_root, '*.grb'), engine='cfgrib', backend_kwargs={'indexpath':''})
        # time slice
        self.dataset = self.dataset.sel(time=slice(np.datetime64(start_time), np.datetime64(end_time)))

        print(f'Size of variable t2m (float32) with resolution {self.resolution}: {np.prod(self.dataset["t2m"].shape) * 4 / 1024**3:.2f} GB')
        print(f'Total number of samples: {len(self.dataset.time)}')
        
    def __len__(self):
        '''Returns number of samples in the dataset'''
        return len(self.dataset.time)

    def __getitem__(self, idx):
        '''
        Retrieve a (sample, label) pair from the dataset.

        This is where the data is actually loaded into memory.

        Arguments:
        idx (int): index in dataset

        Returns:
        torch.Tensor: 2D sample (2m temperature, reshaped to lat/lon grid)
        torch.Tensor: label (month)
        '''
        # this actually loads the data
        feature = torch.from_numpy(self.dataset['t2m'][idx].values.reshape(-1, 640))
        feature = feature.unsqueeze(0) # add channel dimension

        # month 0 ... 11
        label = int(np.datetime_as_string(self.dataset['valid_time'][idx].values)[5:7]) - 1
        #label = torch.Tensor([label])
        
        return feature, label

In [6]:
class DemoModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2d(1, 6, kernel_size=5)
        self.maxpool1 = MaxPool2d(2, 2)
        self.conv2 = Conv2d(6, 6, kernel_size=5)
        self.maxpool2 = MaxPool2d(2, 2)
        self.conv3 = Conv2d(6, 6, kernel_size=5)
        self.maxpool3 = MaxPool2d(2, 2)
        self.conv4 = Conv2d(6, 6, kernel_size=5)
        self.maxpool4 = MaxPool2d(2, 2)
        self.flatten = Flatten()
        self.fc1 = Linear(10584, 128)
        self.fc2 = Linear(128, 64)
        self.fc3 = Linear(64, 12)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.maxpool1(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool2(x)
        x = F.relu(self.conv3(x))
        x = self.maxpool3(x)
        x = F.relu(self.conv4(x))
        x = self.maxpool4(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def training_step(self, batch, batchidx):
        x = batch[0]
        y = batch[1]
        
        yhat = self(x) # logits

        loss = F.cross_entropy(yhat, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

        

In [7]:
%%time
dataset = DemoDataset(resolution='1D')

ERROR 1: PROJ: proj_create_from_database: Open of /sw/spack-levante/mambaforge-22.9.0-2-Linux-x86_64-kptncg/share/proj failed


Size of variable t2m (float32) with resolution 1D: 29.50 GB
Total number of samples: 14610
CPU times: user 1min 5s, sys: 29.7 s, total: 1min 35s
Wall time: 3min 22s


Load a single sample directly.

In [8]:
%%time
dataset.__getitem__(100)

CPU times: user 62.7 ms, sys: 5.42 ms, total: 68.1 ms
Wall time: 67 ms


(tensor([[[244.8626, 244.8450, 244.8274,  ..., 248.8665, 249.1946, 249.0052],
          [248.9583, 248.8606, 248.7669,  ..., 251.5423, 251.3353, 251.3157],
          [251.3470, 250.6204, 250.4720,  ..., 245.1224, 244.7532, 244.3567],
          ...,
          [223.8880, 224.1438, 224.2591,  ..., 236.8587, 236.0130, 235.1341],
          [234.1341, 232.8919, 231.6009,  ..., 231.1907, 229.8196, 228.3802],
          [226.9603, 225.5169, 224.6263,  ..., 226.7220, 226.7239, 226.7395]]]),
 3)

Create a Pytorch dataloader. This yields batches from the dataset, loads data on the fly, and feeds them into the model.
Try increasing / decreasing the num_workers parameter for optimal performance.

In [9]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

Create the model and trainer

In [10]:
model = DemoModel()

In [11]:
x = dataset.__getitem__(0)[0].unsqueeze(0)

In [12]:
%%time
model(x)

CPU times: user 22.3 s, sys: 84.2 ms, total: 22.3 s
Wall time: 199 ms


tensor([[ 0.3440,  0.5120,  0.0274, -0.7017, -0.1596, -0.8801, -0.2594, -1.2573,
          0.8689, -0.7828, -0.3442,  0.5506]], grad_fn=<AddmmBackward0>)

In [13]:
trainer = L.Trainer(max_epochs=1)

/home/k/k202141/.local/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /sw/spack-levante/mambaforge-22.9.0-2-Linux-x86_64-k ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model, train_dataloaders=dataloader)

/home/k/k202141/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /work/ka1176/caroline/gitlab/tutorial-large-datasets/lightning_logs/version_12243349/checkpoints exists and is not empty.

   | Name     | Type      | Params | Mode 
------------------------------------------------
0  | conv1    | Conv2d    | 156    | train
1  | maxpool1 | MaxPool2d | 0      | train
2  | conv2    | Conv2d    | 906    | train
3  | maxpool2 | MaxPool2d | 0      | train
4  | conv3    | Conv2d    | 906    | train
5  | maxpool3 | MaxPool2d | 0      | train
6  | conv4    | Conv2d    | 906    | train
7  | maxpool4 | MaxPool2d | 0      | train
8  | flatten  | Flatten   | 0      | train
9  | fc1      | Linear    | 1.4 M  | train
10 | fc2      | Linear    | 8.3 K  | train
11 | fc3      | Linear    | 780    | train
------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.467     Tot

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined