In [1]:
import pandas as pd
import numpy as np
from multiprocessing import Pool as ThreadPool
from tqdm import tqdm
import dask.config
import xarray as xr
import xbatcher as xb
import numpy as np
import dask
import torch
import random
import os

from data.era5 import gen_bgen
from metrics.metrics import WeightedRMSE
from models.latent_umbrella_net import LatentUmbrellaNet
from models.autoencoder import Autoencoder

NUM_WORKERS = 6

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def eval_lun_unet(
    forecast_steps: int = 4, rounds: int = 1, save_to_csv: bool = True
) -> pd.DataFrame:
    lun = LatentUmbrellaNet(
        vae_ckpt_path="checkpoints/vae-kl-f8-rmse-disc-2-step=5000-z500=93.ckpt",
        vae_config_path="configs/autoencoder/kl-f8-disc.yaml",
        prediction_net_ckpt_path="checkpoints/prediction-model-val_loss=0.01241.ckpt",
        device="cuda",
        prediction_net_type="unet",
    )

    dfs: list[pd.DataFrame] = []

    for _ in range(rounds):
        # create a thread pool for parallel processing
        pool = ThreadPool(NUM_WORKERS)

        # create a batch generator for the era5 data
        bgen = gen_bgen(train=True)

        # seed
        s = random.randint(0, 1000)

        indexes = np.arange(0, 6 * (forecast_steps + 2), 6) + s
        indexes = indexes.tolist()

        # load the data parallelly from the gsc
        job = pool.map_async(bgen.__getitem__, indexes)
        batches: list[xr.Dataset] = job.get()
        pool.close()

        # convert the batches to torch tensors
        data = []
        for batch in batches:
            stacked = batch.to_stacked_array(
                new_dim="channel", sample_dims=["latitude", "longitude"]
            ).transpose("channel", "longitude", "latitude")

            item = torch.tensor(stacked.data)
            item = item.unsqueeze(0)
            item = item[:, :, :, :-1]

            data.append(item)

        data = torch.cat(data, dim=0)  # [x_0, x_6, x_12, x_18, x_24, x_30]

        forecastst = []

        for i in range(forecast_steps):
            forecast = lun.forward(data[0].unsqueeze(0), data[1].unsqueeze(0), i + 1)
            forecastst.append(forecast)

        data = data[2:]  # [x_12, x_18, x_24, x_30]

        forecastst = torch.cat(forecastst, dim=0)  # [y_12, y_18, y_24, y_30]

        wrmse = WeightedRMSE(num_latitudes=720)

        lun_unet = np.array(
            [
                wrmse(data[i].numpy(), forecastst[i].numpy())
                for i in range(forecast_steps)
            ]
        )

        lun_unet_dict = {
            "z500": lun_unet[:, 50],
            "t850": lun_unet[:, 14],
            "h700": lun_unet[:, 65],
            "t2m": lun_unet[:, 0],
            "u10": lun_unet[:, 1],
            "u850": lun_unet[:, 27],
        }

        df = pd.DataFrame(lun_unet_dict)
        dfs.append(df)

    res_df = sum(dfs) / len(dfs)

    if save_to_csv:
        if not os.path.exists("./evaluation"):
            os.makedirs("./evaluation")

        res_df.to_csv("evaluation/lun_unet.csv", index=False, header=True, mode="w")

    return res_df


eval_lun_unet(
    forecast_steps=2,
    rounds=1,
    save_to_csv=True,
)

c:\Users\hendr\Desktop\3d-vae\venv\Lib\site-packages\lightning\pytorch\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.5.1, which is newer than your current Lightning version: v2.5.0.post0


In [2]:
forecast_steps = 1

lun = LatentUmbrellaNet(
    vae_ckpt_path="checkpoints/vae-kl-f8-rmse-disc-2-step=5000-z500=93.ckpt",
    vae_config_path="configs/autoencoder/kl-f8-disc.yaml",
    prediction_net_ckpt_path="checkpoints/prediction-model-val_loss=0.01221.ckpt",
    device="cuda",
    prediction_net_type="unet",
)

# create a thread pool for parallel processing
pool = ThreadPool(NUM_WORKERS)

# create a batch generator for the era5 data
bgen = gen_bgen(train=True)

# seed
s = random.randint(0, 1000)

indexes = np.arange(0, 6 * (forecast_steps + 2), 6) + s
indexes = indexes.tolist()

# load the data parallelly from the gsc
job = pool.map_async(bgen.__getitem__, indexes)
batches: list[xr.Dataset] = job.get()
pool.close()

# convert the batches to torch tensors
data = []
for batch in batches:
    stacked = batch.to_stacked_array(
        new_dim="channel", sample_dims=["latitude", "longitude"]
    ).transpose("channel", "longitude", "latitude")

    item = torch.tensor(stacked.data)
    item = item.unsqueeze(0)
    item = item[:, :, :, :-1]

    data.append(item)

data = torch.cat(data, dim=0)  # [x_0, x_6, x_12, x_18, x_24, x_30]

c:\Users\hendr\Desktop\3d-vae\venv\Lib\site-packages\lightning\pytorch\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.5.1, which is newer than your current Lightning version: v2.5.0.post0


In [3]:
data.shape

torch.Size([3, 69, 1440, 720])

In [None]:
forecastst = []

for i in range(forecast_steps):
    forecast = lun.forward(data[0].unsqueeze(0), data[1].unsqueeze(0), i + 1)
    forecastst.append(forecast)

In [10]:
data = data[2:]  # [x_12, x_18, x_24, x_30]

forecastst = torch.cat(forecastst, dim=0)  # [y_12, y_18, y_24, y_30]

wrmse = WeightedRMSE(num_latitudes=720)

lun_unet = np.array(
    [wrmse(data[i].numpy(), forecastst[i].numpy()) for i in range(forecast_steps)]
)

lun_unet_dict = {
    "z500": lun_unet[:, 50],
    "t850": lun_unet[:, 14],
    "h700": lun_unet[:, 65],
    "t2m": lun_unet[:, 0],
    "u10": lun_unet[:, 1],
    "u850": lun_unet[:, 27],
}

df = pd.DataFrame(lun_unet_dict)