In [None]:
import hydra
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
from omegaconf import OmegaConf

from safetensors.torch import save_model, load_file

from huggingface_hub import hf_hub_download

from bfm_model.bfm.dataloader_monthly import LargeClimateDataset
from bfm_model.bfm.model_helpers import get_trainer, setup_bfm_model
from bfm_model.bfm.dataloader_helpers import get_val_dataloader

hydra.initialize(config_path="../bfm_model/bfm/configs", version_base=None)
cfg = hydra.compose(config_name="train_config", overrides=["model.embed_dim=256", "model.depth=3",
"model.swin_backbone_size=medium", "model.num_heads=16", "training.devices=[0]"]
    )

# print(OmegaConf.to_yaml(cfg))

# LOAD THE MODEL 
checkpoint_repo = "BioDT/bfm-pretrained"
dcheckpoint_name = "bfm-pretrained-small.safetensors"
checkpoint_path = hf_hub_download(repo_id=checkpoint_repo, filename=dcheckpoint_name)

In [None]:
test_dataset = LargeClimateDataset(
    data_dir=cfg.data.test_data_path, # Adapt that to your folder that contains the batches
    scaling_settings=cfg.data.scaling,
    num_species=cfg.data.species_number,
    atmos_levels=cfg.data.atmos_levels,
    model_patch_size=cfg.model.patch_size,
)
# override batch_size
test_dataloader = get_val_dataloader(cfg, batch_size_override=cfg.evaluation.batch_size)


bfm_model = setup_bfm_model(cfg, mode="test")

# When you load from HF
state_dict_path = load_file(checkpoint_path)

# When you have a local checkpoint path comment the HF path and add your local state_dict_path = "path"
bfm_model.load_state_dict(state_dict_path, strict=False)


bfm_model.eval()
bfm_model.to("cuda")

trainer = get_trainer(
    cfg,
    mlflow_logger=None,
    callbacks=[],
)

predictions = trainer.predict(model=bfm_model, dataloaders=test_dataloader)

In [None]:
last_prediction = predictions[-1][0]["pred"]
last_ground_truth = predictions[-1][0]["gt"]

last_ground_truth.batch_metadata.timestamp
# Thus the prediction is 1 month ahead -> 2020-07-01 00:00:00