In [1]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [2]:
from src.climate_learn.utils.datetime import Year, Days, Hours
from src.climate_learn.data import IterDataModule

forecast_data_module = IterDataModule(
    task = "forecasting",
    inp_root_dir ="/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg_npz/",
    out_root_dir="/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg_npz/",
    in_vars = [
        "2m_temperature",
        "10m_u_component_of_wind",
        "10m_v_component_of_wind",
        "geopotential_500",
        "temperature_850"
    ],
    out_vars = ["2m_temperature"],
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1
)

In [3]:
from src.climate_learn.models import load_model

forecast_model_kwargs = {
    "in_channels": len(forecast_data_module.hparams.in_vars),
    "out_channels": len(forecast_data_module.hparams.out_vars),
    "n_blocks": 28
}

forecast_optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 0,
    "max_epochs": 1,
}

forecast_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = forecast_model_kwargs, optim_kwargs = forecast_optim_kwargs)

  if self.prob_type is not "categorical":
  if self.prob_type is not "categorical":


In [4]:
from src.climate_learn.models import set_climatology
set_climatology(forecast_model_module, forecast_data_module)

In [5]:
from src.climate_learn.training import Trainer, WandbLogger

forecast_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    devices=[7],
    precision = 16,
    max_epochs = 1,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

Global seed set to 0


In [8]:
forecast_trainer.fit(forecast_model_module, forecast_data_module)

In [9]:
forecast_trainer.test(forecast_model_module, forecast_data_module)

  rank_zero_warn(


Output()

In [10]:
# import os
# import random
# import numpy as np
# import matplotlib.pyplot as plt
# from datetime import datetime

# def visualize(model_module, data_module, split = "test", samples = 2, save_dir = None):
#     if save_dir is not None:
#         os.makedirs(save_dir, exist_ok = True)

#     # dataset.setup()
#     dataset = eval(f"data_module.{split}_dataset")

#     if(type(samples) == int):
#         idxs = random.sample(range(0, len(dataset)), samples)
#     elif(type(samples) == list):
#         idxs = [np.searchsorted(dataset.time, np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H"))) for dt in samples]
#     else:
#         raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

#     fig, axes = plt.subplots(len(idxs), 4, figsize=(30, 3 * len(idxs)), squeeze = False)

#     for index, idx in enumerate(idxs):
#         x, y, _, _ = dataset[idx] # 1, 1, 32, 64
#         pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

#         inv_normalize = model_module.denormalization
#         init_condition, gt = inv_normalize(x), inv_normalize(y)
#         pred = inv_normalize(pred)
#         bias = pred - gt

#         for i, tensor in enumerate([init_condition, gt, pred, bias]):
#             ax = axes[index][i]
#             im = ax.imshow(tensor.detach().squeeze().cpu().numpy())
#             im.set_cmap(cmap=plt.cm.RdBu)
#             fig.colorbar(im, ax=ax)

#         if(data_module.hparams.task == "forecasting"):
#             axes[index][0].set_title("Initial condition [Kelvin]")
#             axes[index][1].set_title("Ground truth [Kelvin]")
#             axes[index][2].set_title("Prediction [Kelvin]")
#             axes[index][3].set_title("Bias [Kelvin]")
#         elif(data_module.hparams.task == "downscaling"):
#             axes[index][0].set_title("Low resolution data [Kelvin]")
#             axes[index][1].set_title("High resolution data [Kelvin]")
#             axes[index][2].set_title("Downscaled [Kelvin]")
#             axes[index][3].set_title("Bias [Kelvin]")
#         else:
#             raise NotImplementedError

#     fig.tight_layout()
    
#     if save_dir is not None:
#         plt.savefig(os.path.join(save_dir, 'visualize.png'))
#     else:
#         plt.show()

# # if samples = 2, we randomly pick 2 initial conditions in the test set
# visualize(model_module, data_module, samples = ["2017-06-01:12", "2017-08-01:18"])

In [1]:
from src.climate_learn.utils.datetime import Year, Days, Hours
from src.climate_learn.data import IterDataModule

downscale_data_module = IterDataModule(
    task = "downscaling",
    inp_root_dir ="/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg_npz/",
    out_root_dir="/data0/datasets/weatherbench/data/weatherbench/era5/2.8125deg_npz/",
    in_vars = [
        "2m_temperature",
        "10m_u_component_of_wind",
        "10m_v_component_of_wind",
        "geopotential_500",
        "temperature_850"
    ],
    out_vars = ["2m_temperature"],
    batch_size = 128,
    num_workers = 1
)

In [2]:
from src.climate_learn.models import load_model

downscale_model_kwargs = {
    "in_channels": len(downscale_data_module.hparams.in_vars),
    "out_channels": len(downscale_data_module.hparams.out_vars),
    "n_blocks": 4,
}

downscale_optim_kwargs = {
    "optimizer": "adamw",
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

downscale_model_module = load_model(name = "resnet", task = "downscaling", model_kwargs = downscale_model_kwargs, optim_kwargs = downscale_optim_kwargs)

In [3]:
from src.climate_learn.models import set_climatology
set_climatology(downscale_model_module, downscale_data_module)

In [4]:
from src.climate_learn.training import Trainer, WandbLogger

downscale_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    devices=[7],
    precision = 16,
    max_epochs = 1,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

Global seed set to 0


In [5]:
downscale_trainer.fit(downscale_model_module, data_module=downscale_data_module)