# CNN training example

This notebook illustrates how to use EDIT pipeline to train a simple CNN model on the ERA5 lowres dataset.

Make sure to set the `ERA5LOWRES` environment variable to make the ERA5 low-resolution archive foundable on your system.
Modify the following cell as follows:

- for NCI

```
%env ERA5LOWRES=/g/data/wb00/NCI-Weatherbench/5.625deg
```

- for NIWA

```
%env ERA5LOWRES=/nesi/nobackup/niwa00004/riom/weatherbench/5.625deg
```

In [None]:
%env ERA5LOWRES=/nesi/nobackup/niwa00004/riom/weatherbench/5.625deg

In [None]:
from pathlib import Path

import numpy as np
import xarray as xr
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import Trainer, LightningModule

import edit.data
import edit.tutorial  # NOQA
import edit.pipeline
import edit.training

In [None]:
train_folder = "cnn_training"
cache_folder = None

train_start = "2015-01-01T00"
train_end = "2015-01-12T00"  # "2015-12-31T00"
val_start = "2017-01-01T00"
val_end = "2017-01-12T00"

n_samples = 200
batch_size = 1
n_workers = 4

In [None]:
train_folder = Path(train_folder)
train_folder.mkdir(parents=True, exist_ok=True)

In [None]:
data_preparation = edit.pipeline.Pipeline(
    edit.data.archive.era5lowres(["u", "v", "geopotential", "vorticity"]),
    edit.pipeline.operations.xarray.Sort(
        ["msl", "10u", "10v", "2t", "geopotential", "vorticity"]
    ),
    edit.data.transforms.coordinates.standard_longitude(type="0-360"),
    edit.pipeline.operations.xarray.reshape.CoordinateFlatten("level"),
    # retrieve previous/next samples, dt = 1H
    edit.pipeline.modifications.TemporalRetrieval(
        concat=True, samples=((-1, 1), (1, 1, 1))
    ),
    edit.pipeline.operations.xarray.conversion.ToNumpy(),
    edit.pipeline.operations.numpy.reshape.Rearrange("c t h w -> t c h w"),
    edit.pipeline.operations.numpy.reshape.Squish(axis=0),
)
data_preparation

In [None]:
sample = data_preparation[train_start]
print(len(sample))
print(sample[0].shape)
print(sample[1].shape)

In [None]:
train_split = edit.pipeline.iterators.DateRange(train_start, train_end, interval="1h")
train_split = train_split.randomise(seed=42)
val_split = edit.pipeline.iterators.DateRange(val_start, val_end, interval="1h")

In [None]:
train_split[:5]

Let's precompute approximate mean and standard deviation using only few random samples, to rescale the input/output data to a reasonable range for model training.

In [None]:
samples = np.stack([data_preparation[train_split[i]][0] for i in range(n_samples)])
mean_approx = np.mean(samples, axis=0)
std_approx = np.std(samples, axis=0)

mean_path = train_folder / "mean.npy"
std_path = train_folder / "std.npy"
np.save(mean_path, mean_approx)
np.save(std_path, std_approx)

In [None]:
normaliser = edit.pipeline.operations.numpy.normalisation.Deviation(
    mean=mean_path, deviation=std_path
)
data_preparation_normed = edit.pipeline.Pipeline(data_preparation, normaliser)

if cache_folder is not None:
    data_preparation_normed = edit.pipeline.Pipeline(
        data_preparation_normed,
        edit.pipeline.modifications.Cache(
            cache_folder, pattern_kwargs={'extension': 'npy'}
        ),
    )

In [None]:
data_preparation_normed

In [None]:
data_module = edit.training.data.lightning.PipelineLightningDataModule(
    data_preparation_normed,
    train_split=train_split,
    valid_split=val_split,
    batch_size=batch_size,
    num_workers=n_workers,
)

In [None]:
data_module

In [None]:
class CNN(LightningModule):
    def __init__(
        self,
        *,
        n_features: int,
        layer_sizes: list[int],
        dropout: float,
        learning_rate: float,
    ):
        super().__init__()

        layer_sizes = (n_features,) + tuple(layer_sizes)
        layers = []
        for chan_in, chan_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers.extend(
                [
                    nn.Conv2d(chan_in, chan_out, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(),
                    nn.Dropout(p=dropout),
                ]
            )
        layers.append(
            nn.Conv2d(layer_sizes[-1], n_features, kernel_size=3, stride=1, padding=1)
        )
        self.cnn = nn.Sequential(*layers)

        self.learning_rate = learning_rate
        self.loss_function = F.l1_loss

    def forward(self, x):
        return self.cnn(x)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.loss_function(outputs, targets)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.loss_function(outputs, targets)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return {"optimizer": optimizer}

In [None]:
n_features = data_preparation_normed[train_start][0].shape[-3]
model = CNN(
    n_features=n_features, layer_sizes=[64, 64], dropout=0.6, learning_rate=1e-5
)

In [None]:
model

In [None]:
%env CUDA_VISIBLE_DEVICES=""

In [None]:
trainer = Trainer(max_epochs=1)
trainer.fit(model, datamodule=data_module)