# Model Training and Evaluation for Weather Forecasting

ClimateLearn provides a variety of baseline models to perform forecasting and [downscaling](https://uaf-snap.org/how-do-we-do-it/downscaling). In this tutorial, we'll see how to train a [ResNet model](https://en.wikipedia.org/wiki/Residual_neural_network) to do both. This tutorial is intended for use in Google Colab. Before starting, ensure that you are on a GPU runtime.

## Google Colab setup
You might need to restart the kernel after installing ClimateLearn so that your Colab environment knows to use the correct package versions.

In [None]:
!pip install climate-learn

In [None]:
from google.colab import drive
drive.mount("/content/drive")

## Data preparation
The second cell of this section can be skipped if the data is already downloaded to your Drive. See the "1a-Data Processing" notebook for more details.

In [None]:
root = "/content/drive/MyDrive/ClimateLearn"
source = "weatherbench"
dataset = "era5"
resolution = "5.625"
variable = "2m_temperature"
years = range(1979, 2018)

In [None]:
from climate_learn.data import download
download(root=root, source=source, dataset=dataset, resolution=resolution, variable=variable)

In [None]:
from climate_learn.data import DataModuleArgs, DataModule
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.tasks.args import ForecastingArgs

data_args = ERA5Args(
    root_dir=f"{root}/data/{source}/{dataset}/{resolution}/",
    variables=[variable],
    years=years
)

forecasting_args = ForecastingArgs(
    dataset_args=data_args,
    in_vars=[variable],
    out_vars=[variable],
    pred_range=3*24,
    subsample=6
)

data_module_args = DataModuleArgs(
    task_args=forecasting_args,
    train_start_year=1979,
    val_start_year=2015,
    test_start_year=2017,
    end_year=2018
)

data_module = DataModule(
    data_module_args=data_module_args,
    batch_size=128,
    num_workers=1
)

## Model initialization
First, we load the ResNet model. The number of input and output channels is 1 since we are concerned only with `2m_temperature`. 

In [None]:
from climate_learn.models import load_model
from torch.optim import AdamW

model_kwargs = {
    "in_channels": 1,
    "out_channels": 1,
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
    "optimizer": AdamW
}

model_module = load_model(
    name="resnet",
    task="forecasting",
    model_kwargs=model_kwargs,
    optim_kwargs=optim_kwargs
)

Although it is tedious to write out, we also could have specified the number of channels with the following.
```python
"in_channels": len(data_module.hparams.data_module_args.train_task_args.in_vars),
"out_channels": len(data_module.hparams.data_module_args.train_task_args.out_vars),
```

Next, we set the [climatology](https://www.sciencedirect.com/topics/earth-and-planetary-sciences/climatology), which is the average value of our variables over the data period.

In [None]:
from climate_learn.models import set_climatology
set_climatology(model_module, data_module)

## Training

First, we fit a linear regression baseline to compare the ResNet model against in evaluation.

In [None]:
from climate_learn.models import fit_lin_reg_baseline
fit_lin_reg_baseline(model_module, data_module, reg_hparam=0.0)

Then, we fit the ResNet model. To use the GPU accelerator on Google Colab, you must be on a GPU runtime.

In [None]:
from climate_learn.training import Trainer

trainer = Trainer(
    seed=0,
    accelerator="gpu",
    precision=16,
    max_epochs=1
)

trainer.fit(model_module, data_module)

## Evaluation 

In [None]:
trainer.test(model_module, data_module)

Ideally, the model's predictions have a strong correlation with the ground truth, which would be indicated by a high [anomaly correlation coefficient](https://climatelearn.readthedocs.io/en/latest/user-guide/metrics.html#anomaly-correlation-coefficient) value. We also want our model to achieve a smaller [latitude-weighted root mean square error](https://climatelearn.readthedocs.io/en/latest/user-guide/metrics.html#anomaly-correlation-coefficient) than the climatological forecast.