# Quickstart

In [None]:
import climate_learn as cl
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.task.args import ForecastingArgs
from climate_learn.data.dataset.args import MapDatasetArgs

## Load the data

In [None]:
root = "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg"
variables = ["2m_temperature"]
in_vars = out_vars = [f"era5:{v}" for v in variables]
train_years = range(2010, 2016)
val_years = range(2016, 2017)
test_years = range(2018, 2019)

forecasting_args = ForecastingArgs(
    in_vars,
    out_vars,
    pred_range=3*24,
    subsample=6
)

train_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, train_years),
    forecasting_args
)

val_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, val_years),
    forecasting_args
)

test_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, test_years),
    forecasting_args
)

dm = cl.data.DataModule(
    train_dataset_args,
    val_dataset_args,
    test_dataset_args,
    batch_size=32,
    num_workers=8
)

## Load the models

In [None]:
# climatology is the average value over the training period
climatology = cl.load_forecasting_module(data_module=dm, preset="climatology")

# persistence returns its input as its prediction
persistence = cl.load_forecasting_module(data_module=dm, preset="persistence")

# Rasp-Theurey 2020 is the SoTA on WeatherBench
resnet = cl.load_forecasting_module(data_module=dm, preset="rasp-theurey-2020")

## Train the models

Climatology and persistence don't require training.

In [None]:
# change this to whatever gpu device you want to use
gpu_num = 0

trainer = cl.Trainer(
    # stop when latitude-weighted RMSE, a validation metric, stops improving
    early_stopping="lat_rmse:aggregate",
    # wait for 1 epoch of no improvement
    patience=1,
    # uncomment to use gpu acceleration
    # accelerator="gpu",
    # devices=[gpu_num],
    # max epochs
    max_epochs=2
)

In [None]:
trainer.fit(resnet, dm)

## Test the models

In [None]:
trainer.test(climatology, dm)

In [None]:
trainer.test(persistence, dm)

In [None]:
trainer.test(resnet, dm)