# Quickstart

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import climate_learn as cl
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.task.args import ForecastingArgs
from climate_learn.data.dataset.args import MapDatasetArgs

## Load the data

In [None]:
root = "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg/"
variables = ["geopotential_500"]#, "temperature_850", "2m_temperature"]
# variables = ['2m_temperature']
in_vars = out_vars = [f"era5:{v}" for v in variables]
train_years = range(1979, 2016)
val_years = range(2016, 2017)
test_years = range(2017, 2019)

forecasting_args = ForecastingArgs(
    in_vars,
    out_vars,
    pred_range=72,
    subsample=6
)

train_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, train_years),
    forecasting_args
)

val_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, val_years),
    forecasting_args
)

test_dataset_args = MapDatasetArgs(
    ERA5Args(root, variables, test_years),
    forecasting_args
)

dm = cl.data.DataModule(
    train_dataset_args,
    val_dataset_args,
    test_dataset_args,
    batch_size=32,
    num_workers=8
)

## Load the models

In [62]:
# climatology is the average value over the training period
# climatology = cl.load_forecasting_module(data_module=dm, preset="climatology")

# persistence returns its input as its prediction
# persistence = cl.load_forecasting_module(data_module=dm, preset="persistence")

# Rasp-Theurey 2020 is the SoTA on WeatherBench
# resnet = cl.load_forecasting_module(data_module=dm, preset="rasp-theurey-2020")

vit_pretrained = cl.load_forecasting_module(data_module=dm, 
    preset="vit", 
    use_pretrained_backbone=True, 
    use_pretrained_embeddings=True, 
    freeze_backbone=True,
    freeze_embeddings=True
)

Loading preset: vit
Using preset optimizer
Using preset learning rate scheduler
Loading training loss: lat_mse
Loading validation loss: lat_rmse
Loading validation loss: lat_acc
Loading validation loss: lat_mse
Loading test loss: lat_rmse
Loading test loss: lat_acc
Loading validation transform: denormalize
Loading validation transform: denormalize
No validation transform
Loading test transform: denormalize
Loading test transform: denormalize


## Looking at Rasp and Theurey training curves for reference

In [None]:
rasp_theurey = cl.load_forecasting_module(data_module=dm, preset="rasp-theurey-2020")

In [None]:
# change this to whatever gpu device you want to use
gpu_num = 0

import wandb
from pytorch_lightning.loggers import WandbLogger
from datetime import datetime

now = datetime.now()
dt_string = now.strftime("%H-%M-%S_%d-%m-%Y")

wandb.init(project="Climate", name=f'Rasp Theurey {dt_string}')
logger = WandbLogger()

trainer = cl.Trainer(
    # stop when latitude-weighted RMSE, a validation metric, stops improving
    early_stopping="lat_rmse:aggregate [val]",
    # wait for 1 epoch of no improvement
    patience=1,
    # uncomment to use gpu acceleration
    accelerator="gpu",
    devices=[gpu_num],
    # max epochs
    max_epochs=50,
    # log to wandb
    logger=logger,
    # Print model summary
    enable_model_summary=False,
)

In [44]:
trainer.fit(rasp_theurey, dm)

## Train the models

Climatology and persistence don't require training.

In [63]:
# change this to whatever gpu device you want to use
gpu_num = 0

import wandb
from pytorch_lightning.loggers import WandbLogger
from datetime import datetime

now = datetime.now()
dt_string = now.strftime("%H-%M-%S_%d-%m-%Y")

# wandb.init(project="Climate", name=f'VIT Pretrained {dt_string}')
# logger = WandbLogger()

trainer = cl.Trainer(
    # stop when latitude-weighted RMSE, a validation metric, stops improving
    early_stopping="lat_rmse:aggregate [val]",
    # wait for 1 epoch of no improvement
    patience=1,
    # uncomment to use gpu acceleration
    accelerator="gpu",
    devices=[gpu_num],
    # max epochs
    max_epochs=50,
    # log to wandb
    # logger=logger,
    # Print model summary
    enable_model_summary=False,
)

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅█████████████
lat_acc:aggregate [train]_epoch,▁▇█
lat_acc:aggregate [train]_step,▁▆▇▇▇▇▇▇▇▇█▇████████████████████████████
lat_acc:aggregate [val],▁██
lat_acc:era5:geopotential_500 [train]_epoch,▁▇█
lat_acc:era5:geopotential_500 [train]_step,▁▆▇▇▇▇▇▇▇▇█▇████████████████████████████
lat_acc:era5:geopotential_500 [val],▁██
lat_mse:aggregate [train]_epoch,█▂▁
lat_mse:aggregate [train]_step,█▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lat_mse:aggregate [val],█▁▁

0,1
epoch,2.0
lat_acc:aggregate [train]_epoch,0.98427
lat_acc:aggregate [train]_step,0.98631
lat_acc:aggregate [val],0.97901
lat_acc:era5:geopotential_500 [train]_epoch,0.98427
lat_acc:era5:geopotential_500 [train]_step,0.98631
lat_acc:era5:geopotential_500 [val],0.97901
lat_mse:aggregate [train]_epoch,3318860999436.9854
lat_mse:aggregate [train]_step,2984399360284.914
lat_mse:aggregate [val],4533218007517.774


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668450948782266, max=1.0…

  rank_zero_warn(
Global seed set to 0
  warn("In interactive environment: cannot use DDP spawn strategy")


In [64]:

trainer.fit(vit_pretrained, dm)

Output()

AttributeError: 'ViTModel' object has no attribute 'patch_embeddings'

## Test the models

In [None]:
trainer.test(climatology, dm)

In [None]:
trainer.test(persistence, dm)

In [None]:
trainer.test(vit_pretrained, dm)