In [1]:
import hydra
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
from omegaconf import OmegaConf

from safetensors.torch import save_model, load_file

from huggingface_hub import hf_hub_download

from bfm_model.bfm.dataloader_monthly import LargeClimateDataset
from bfm_model.bfm.model_helpers import get_mlflow_logger, get_trainer, setup_bfm_model
from bfm_model.bfm.dataloader_helpers import get_val_dataloader
from bfm_model.bfm.utils import plot_europe_timesteps_and_difference

hydra.initialize(config_path="../bfm_model/bfm/configs", version_base=None)
cfg = hydra.compose(config_name="train_config", overrides=["model.embed_dim=512", "model.depth=6",
"model.swin_backbone_size=medium", "training.devices=[0]"]
    )

# print(OmegaConf.to_yaml(cfg))

# LOAD THE MODEL 
checkpoint_repo = "BioDT/bfm-pretrained"
dcheckpoint_name = "bfm-pretrain-medium.safetensors"

checkpoint_path = hf_hub_download(repo_id=checkpoint_repo, filename=dcheckpoint_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_dataset = LargeClimateDataset(
    data_dir=cfg.data.test_data_path, # Adapt that to your folder that contains the batches
    scaling_settings=cfg.data.scaling,
    num_species=cfg.data.species_number,
    atmos_levels=cfg.data.atmos_levels,
    model_patch_size=cfg.model.patch_size,
)
# override batch_size
test_dataloader = get_val_dataloader(cfg, batch_size_override=cfg.evaluation.batch_size)


bfm_model = setup_bfm_model(cfg, mode="test")

# When you have a local checkpoint path.
# bfm_model.load_state_dict(path, strict=False)

# When you load from HF
state_dict = load_file(checkpoint_path)

bfm_model.eval()
bfm_model.to("cuda")

trainer = get_trainer(
    cfg,
    mlflow_logger=None,
    callbacks=[],
)

predictions = trainer.predict(model=bfm_model, dataloaders=test_dataloader)


We scale the dataset True with normalize
We scale the dataset True with normalize
Validation train: 13
Land-sea mask file not found at . Loss will be calculated over all pixels.
Num of patches in Encoder: 2800
Total latens 64400
BuiltinAttention q_dim 512 | context dim 641 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
BuiltinAttention q_dim 512 | context dim 512 | num q heads 16 | head dim 64 | kv_heads 8
Total query tokens for Decoder:  124
BuiltinAttention q_dim

/home/atrantas/production/bfm-model/venv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/atrantas/production/bfm-model/venv/lib/python3 ...
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/atrantas/production/bfm-model/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install

Predicting DataLoader 0: 100%|██████████| 13/13 [00:07<00:00,  1.75it/s]


In [3]:
last_prediction = predictions[-1][0]["pred"]
last_ground_truth = predictions[-1][0]["gt"]

last_ground_truth.batch_metadata.timestamp
# Thus the prediction is 1 month ahead -> 2020-07-01 00:00:00

[('2020-05-01 00:00:00',), ('2020-06-01 00:00:00',)]