In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!jupyter nbextension enable --py widgetsnbextension

In [None]:
from climate_learn.data import convert_nc2npz

In [None]:
convert_nc2npz(
    root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/5.625",
    save_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/5.625_npz",
    variables=["2m_temperature"],
    start_train_year=1979,
    start_val_year=2016,
    start_test_year=2017,
    end_year=2018,
    num_shards=16
)

In [None]:
convert_nc2npz(
    root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/2.8125",
    save_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/2.8125_npz",
    variables=["2m_temperature"],
    start_train_year=1979,
    start_val_year=2016,
    start_test_year=2017,
    end_year=2018,
    num_shards=16
)

In [None]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import IterDataModule

forecast_data_module = IterDataModule(
    task="forecasting",
    inp_root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/5.625_npz",
    out_root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/5.625_npz",
    in_vars=["2m_temperature"],
    out_vars=["2m_temperature"],
    pred_range=Days(3),
    subsample=Hours(6),
    batch_size=128,
    num_workers=1
)

In [None]:
from src.climate_learn.models import load_model

forecast_model_kwargs = {
    "in_channels": len(forecast_data_module.hparams.in_vars),
    "out_channels": len(forecast_data_module.hparams.out_vars),
    "n_blocks": 28
}

forecast_optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 0,
    "max_epochs": 1
}

forecast_model_module = load_model(
    name="resnet",
    task="forecasting",
    model_kwargs=forecast_model_kwargs,
    optim_kwargs=forecast_optim_kwargs
)

In [None]:
from src.climate_learn.models import set_climatology
set_climatology(forecast_model_module, forecast_data_module)

In [None]:
from src.climate_learn.training import Trainer, WandbLogger

forecast_trainer = Trainer(
    seed=0,
    accelerator="gpu",
    devices=[7],
    precision=16,
    max_epochs=1
)

In [None]:
forecast_trainer.fit(forecast_model_module, forecast_data_module)

In [None]:
forecast_trainer.test(forecast_model_module, forecast_data_module)

In [None]:
from src.climate_learn.utils.datetime import Year, Days, Hours
from src.climate_learn.data import IterDataModule

downscale_data_module = IterDataModule(
    task="downscaling",
    inp_root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/5.625_npz",
    out_root_dir="/home/data/datasets/ClimateLearn/data/weatherbench/era5/2.8125_npz",
    in_vars=["2m_temperature"],
    out_vars=["2m_temperature"],
    batch_size=128,
    num_workers=1
)

In [None]:
from src.climate_learn.models import load_model

downscale_model_kwargs = {
    "in_channels": len(downscale_data_module.hparams.in_vars),
    "out_channels": len(downscale_data_module.hparams.out_vars),
    "n_blocks": 4,
}

downscale_optim_kwargs = {
    "optimizer": "adamw",
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5
}

downscale_model_module = load_model(
    name="resnet",
    task="downscaling",
    model_kwargs=downscale_model_kwargs,
    optim_kwargs=downscale_optim_kwargs
)

In [None]:
from src.climate_learn.models import set_climatology
set_climatology(downscale_model_module, downscale_data_module)

In [None]:
from src.climate_learn.training import Trainer, WandbLogger

downscale_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    devices=[7],
    precision = 16,
    max_epochs = 1,
)

In [None]:
downscale_trainer.fit(downscale_model_module, data_module=downscale_data_module)