In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from aurora import Aurora
import importlib

import torch
from torch.utils.data import DataLoader

from bfm_finetune.aurora_mod import AuroraFlex, AuroraRaw
from bfm_finetune import plots_v2
from bfm_finetune.dataloaders.geolifeclef_species.dataloader import GeoLifeCLEFSpeciesDataset
from bfm_finetune.dataloaders.dataloader_utils import custom_collate_fn
from bfm_finetune.utils import load_checkpoint, seed_everything, load_config

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed_everything(42)

In [None]:
base_model = Aurora(use_lora=True) # stabilise_level_agg=True
base_model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) # strict=False
atmos_levels = (50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000)
base_model.to(device)

num_species = 500

In [None]:
# PATH = "/home/thanasis.trantas/github_projects/bfm-finetune/outputs/2025-04-29/16-59-52"
# PATH = "/home/martino.mensio/projects/bfm/bfm-finetune/outputs/2025-04-30/09-39-45"
# PATH = "/home/thanasis.trantas/github_projects/bfm-finetune/outputs/2025-05-06/10-13-15-good-to-t1-only"
PATH = "/home/thanasis.trantas/github_projects/bfm-finetune/outputs/2025-05-06/15-33-56-MAE-best"
CHECKPOINT_PATH = Path(PATH) / "checkpoints"
cfg = load_config(PATH)

In [None]:
val_dataset = GeoLifeCLEFSpeciesDataset(num_species=num_species, mode="val", negative_lon_mode=cfg.dataset.negative_lon_mode)
val_dataloader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=1,
    )

lat_lon = val_dataset.get_lat_lon()

In [None]:
# model = AuroraFlex(base_model=base_model, in_channels=num_species, hidden_channels=160,
#                     out_channels=num_species, atmos_levels=atmos_levels, lat_lon=lat_lon,
#                     supersampling_cfg=cfg.model.supersampling)
model = AuroraRaw(base_model)
model.to(device)

params_to_optimize = model.parameters()
optimizer = torch.optim.AdamW(params_to_optimize, lr=1.0)

In [None]:
_, _ = load_checkpoint(model, optimizer, CHECKPOINT_PATH)

In [None]:
importlib.reload(plots_v2)
import bfm_finetune.metrics as UM

In [None]:
crps_sum, dev_sum, d2_sum, tss_sum, n = 0, 0, 0, 0, 0
for sample in val_dataloader:
    batch = sample["batch"]# .to(device)
    batch["species_distribution"] = batch["species_distribution"].to(device)
    target = sample["target"]
    with torch.inference_mode():
        prediction = model.forward(batch)
        unnormalized_preds = val_dataset.scale_species_distribution(prediction.clone(), unnormalize=True)
    # plot_channel_time_slices(batch["species_distribution"], channel_idx=0, cmap='plasma')
    # plot_channel_time_slices(prediction, channel_idx=0, cmap='plasma')
    un_preds = torch.nan_to_num(unnormalized_preds.float(), nan=0.0).clamp(min=0)

    plots_v2.plot_eval(batch, un_preds, Path("."), n_species_to_plot=1, save=False)
    print("Preds shape", unnormalized_preds.shape)
    # ---- metrics ------------------------------------------------------ #
    un_preds = un_preds.squeeze(0) # Remove the batch dim
    target_yr2 = batch["species_distribution"][:, 1]  # second year
    crps_sum += UM.crps(un_preds, target_yr2).item()
    dev = UM.poisson_deviance(un_preds, target_yr2)
    d2_sum += UM.explained_deviance(un_preds, target_yr2).item()
    dev_sum += dev.item()
    tss_sum += UM.tss(un_preds, target_yr2).item()
    n += 1
metrics = {
    "CRPS": crps_sum / n,
    "PoissonDev": dev_sum / n,
    "D2": d2_sum / n,
    "TSS": tss_sum / n,
}
metrics

In [None]:
from bfm_finetune.vis_tools import plot_change_map, plot_confusion_map, plot_taylor_single, plot_hexbin, plot_error_violin
lon = batch["metadata"]["lon"].cpu().numpy()
lat = batch["metadata"]["lat"].cpu().numpy()
species_i = 42

species_subset = [0, 1, 2] # select species
for s in species_subset:
    y0       = batch["species_distribution"][0, 0, s].cpu()
    y1_true  = batch["species_distribution"][0, 1, s].cpu()
    y1_pred  = un_preds[0, s].cpu()
    plot_change_map(lat, lon, y0, y1_true, y1_pred, s)
    # plot_taylor(lat, lon, y1_true.numpy(), y1_pred.numpy())
    # plot_taylor_single(y0.numpy().ravel(), y1_pred.numpy().ravel(), title="Year‑1 spatial skill")
    plot_confusion_map(lat, lon, y1_true.numpy(), y1_pred.numpy())

# global calibration
obs_all  = batch["species_distribution"][:, 1].reshape(-1).cpu()
pred_all = un_preds.reshape(-1).cpu()
plot_hexbin(pred_all, obs_all)

# violin of absolute errors  [species, cells]
abs_err = torch.abs(un_preds - batch["species_distribution"][:, 1]).cpu()
plot_error_violin(abs_err.reshape(500, -1).numpy())