# Model Training and Evaluation

ClimateLearn provides a variety of baseline models to perform forecasting and [downscaling](https://uaf-snap.org/how-do-we-do-it/downscaling). In this tutorial, we'll see how to train a [ResNet model](https://en.wikipedia.org/wiki/Residual_neural_network) to do both. This tutorial is intended for use in Google Colab. Before starting, ensure that you are on a GPU runtime.

## Google Colab setup
You might need to restart the kernel after installing ClimateLearn so that your Colab environment knows to use the correct package versions.

In [None]:
!pip install climate-learn

In [None]:
from google.colab import drive
drive.mount("/content/drive")

## Forecasting

### Data preparation
The second cell of this section can be skipped if the data is already downloaded to your Drive. See the "Data Processing" notebook for more details.

In [1]:
root = "/content/drive/MyDrive/ClimateLearn"
source = "weatherbench"
dataset = "era5"
resolution = "5.625"
variable = "2m_temperature"
years = range(1979, 2018)

In [None]:
from climate_learn.data import download
download(root=root, source=source, dataset=dataset, resolution=resolution, variable=variable)

In [2]:
from climate_learn.data import DataModule
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.dataset.args import MapDatasetArgs
from climate_learn.data.task.args import ForecastingArgs

data_args = ERA5Args(
    root_dir=f"{root}/data/{source}/{dataset}/{resolution}/",
    variables=[variable],
    years=years
)

forecasting_args = ForecastingArgs(
    in_vars=[variable],
    out_vars=[variable],
    pred_range=3*24,
    subsample=6
)

map_dataset_args = MapDatasetArgs(
    climate_dataset_args=data_args,
    task_args=forecasting_args
)

modified_args_for_train_dataset = {
    "climate_dataset_args": {
        "years": range(1979, 2015), "split": "train"
    }
}
train_dataset_args = map_dataset_args.create_copy(modified_args_for_train_dataset)

modified_args_for_val_dataset = {
    "climate_dataset_args": {
        "years": range(2015, 2017), "split": "val"
    }
}
val_dataset_args = map_dataset_args.create_copy(modified_args_for_val_dataset)

modified_args_for_test_dataset = {
    "climate_dataset_args": {
        "years": range(2017, 2019), "split": "test"
    }
}
test_dataset_args = map_dataset_args.create_copy(
    modified_args_for_test_dataset
)

data_module = DataModule(
    train_dataset_args,
    val_dataset_args,
    test_dataset_args,
    batch_size=128,
    num_workers=1
)

### Model initialization
Let's load some presets to get points of comparison.

In [8]:
import climate_learn as cl

climatology = cl.load_forecasting_module(preset="climatology")
persistence = cl.load_forecasting_module(preset="persistence")
linreg = cl.load_forecasting_module(preset="linear-regression")

The linear regression model needs training. Climatology and persistence do not require training.

In [None]:
trainer = cl.Trainer()
trainer.fit(linreg, data_module)

Now let's see how these do on the test data.

In [None]:
trainer.test(climatology, data_module)

In [None]:
trainer.test(persistence, data_module)

In [None]:
trainer.test(linreg, data_module)

Finally, let's load a more complex model, like the architecture used by [Rasp and Theurey (2020)](https://arxiv.org/abs/2008.08626) for the [WeatherBench](https://github.com/pangeo-data/WeatherBench) SoTA.

In [None]:
rasp_theurey = cl.load_forecasting_module(preset="rasp-theurey-2020")
trainer.fit(rasp_theurey, data_module)

In [None]:
trainer.test(rasp_theurey, data_module)

Ideally, the model's predictions have a strong correlation with the ground truth, which would be indicated by a high [anomaly correlation coefficient](https://climatelearn.readthedocs.io/en/latest/user-guide/metrics.html#anomaly-correlation-coefficient) value. We also want our model to achieve a smaller [latitude-weighted root mean square error](https://climatelearn.readthedocs.io/en/latest/user-guide/metrics.html#anomaly-correlation-coefficient) than the climatological forecast.

Also, ClimateLearn supports more advanced functionality for loading forecasting models. See our docs to learn more.