In [1]:
from pathlib import Path


from jmpfm.tasks.config import AdamWConfig
from jmpfm.tasks.finetune import (
    MatbenchConfig,
    MatbenchModel,
    PDBBindConfig,
    PDBBindModel,
    RMD17Config,
    RMD17Model,
)
from jmpfm.tasks.finetune.base import (
    PrimaryMetricConfig,
    FinetuneConfigBase,
    FinetuneModelBase,
)
from jmpfm.tasks.finetune import dataset_config as DC
from jmpfm.configs.finetune import jmp_l_ft_config_builder

CKPT_PATH = Path("/mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt")


def jmp_l_matbench_config(
    dataset: DC.MatbenchDataset,
    fold: DC.MatbenchFold,
    base_path: Path = Path("/mnt/shared/datasets/matbench/"),
):
    with jmp_l_ft_config_builder(MatbenchConfig, CKPT_PATH) as (builder, config):
        # Optimizer settings
        config.optimizer = AdamWConfig(
            lr=5.0e-6,
            amsgrad=False,
            betas=(0.9, 0.95),
            weight_decay=0.1,
        )

        # Set data config
        config.batch_size = 1

        # Set up dataset
        config.train_dataset = DC.matbench_config(dataset, base_path, "train", fold)
        config.val_dataset = DC.matbench_config(dataset, base_path, "val", fold)
        config.test_dataset = DC.matbench_config(dataset, base_path, "test", fold)

        # MatBench specific settings
        config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min")

        return builder(config), MatbenchModel


def jmp_l_rmd17_config(
    molecule: DC.RMD17Molecule,
    base_path: Path = Path("/mnt/shared/datasets/rmd17/"),
):
    with jmp_l_ft_config_builder(RMD17Config, CKPT_PATH) as (builder, config):
        # Optimizer settings
        config.optimizer = AdamWConfig(
            lr=5.0e-6,
            amsgrad=False,
            betas=(0.9, 0.95),
            weight_decay=0.1,
        )

        # Set data config
        config.batch_size = 4

        # Set up dataset
        config.train_dataset = DC.rmd17_config(molecule, base_path, "train")
        config.val_dataset = DC.rmd17_config(molecule, base_path, "val")
        config.test_dataset = DC.rmd17_config(molecule, base_path, "test")

        # RMD17 specific settings
        config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min")

        return builder(config), RMD17Model


def jmp_l_pdbbind_config():
    with jmp_l_ft_config_builder(PDBBindConfig, CKPT_PATH) as (builder, config):
        # Optimizer settings
        config.optimizer = AdamWConfig(
            lr=5.0e-6,
            amsgrad=False,
            betas=(0.9, 0.95),
            weight_decay=0.1,
        )

        # Set data config
        config.batch_size = 1

        # Set up dataset
        config.train_dataset = DC.pdbbind_config("train")
        config.val_dataset = DC.pdbbind_config("val")
        config.test_dataset = DC.pdbbind_config("test")

        # PDBBind specific settings
        config.pbdbind_task = "-logKd/Ki"
        config.metrics.report_rmse = True
        config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min")

        return builder(config), PDBBindModel


config, model_cls = jmp_l_pdbbind_config()
print(config)

configs: list[tuple[FinetuneConfigBase, type[FinetuneModelBase]]] = []
configs.append((config, model_cls))



id='gmfv7hg3' name=None project=None tags=[] notes=[] debug=False environment=EnvironmentConfig(cwd=None, python_executable=None, python_path=None, python_version=None, config=None, model=None, data=None, slurm=None, log_dir=None, seed=None, seed_workers=None, sweep_id=None, sweep_config=None) trainer=TrainerConfig(python_logging=PythonLogging(log_level=None, rich=True, rich_tracebacks=True, lovely_tensors=True, lovely_numpy=False), logging=LoggingConfig(enabled=True, log_lr=True, log_epoch=True, wandb=WandbLoggingConfig(enabled=True, log_model=False, watch=WandbWatchConfig(enabled=True, log=None, log_graph=True, log_freq=100)), csv=CSVLoggingConfig(enabled=True), tensorboard=TensorboardLoggingConfig(enabled=False)), optimizer=OptimizerConfig(grad_finite_checks=False, grad_none_checks=False, log_grad_norm=True, log_grad_norm_per_param=False, log_param_norm=False, log_param_norm_per_param=False, gradient_clipping=GradientClippingConfig(enabled=True, value=1.0, algorithm='value'), gradie



In [2]:
from ll import Runner, Trainer

from jmpfm.utils.finetune_state_dict import (
    filter_state_dict,
    retreive_state_dict_for_finetuning,
)


def run(config: FinetuneConfigBase, model_cls: type[FinetuneModelBase]) -> None:
    if (ckpt_path := config.meta.get("ckpt_path")) is None:
        raise ValueError("No checkpoint path provided")

    model = model_cls(config)

    # Load the checkpoint
    state_dict = retreive_state_dict_for_finetuning(
        ckpt_path, load_emas=config.meta.get("ema_backbone", False)
    )
    embedding = filter_state_dict(state_dict, "embedding.atom_embedding.")
    backbone = filter_state_dict(state_dict, "backbone.")
    model.load_backbone_state_dict(backbone=backbone, embedding=embedding, strict=True)

    trainer = Trainer(config)
    trainer.fit(model)


runner = Runner(run)
runner.fast_dev_run(configs)

Fast dev run:   0%|          | 0/1 [00:00<?, ?it/s]

Seed set to 0


Unrecognized arguments:  dict_keys(['learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout'])


Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Loading `train_dataloader` to estimate number of stepping batches.
/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.



  | Name                         | Type             | Params
------------------------------------------------------------------
0 | embedding                    | Embedding        | 30.7 K
1 | backbone                     | GemNetOCBackbone | 216 M 
2 | graph_outputs                | TypedModuleDict  | 263 K 
3 | graph_classification_outputs | TypedModuleDict  | 0     
4 | node_outputs                 | TypedModuleDict  | 0     
5 | train_metrics                | FinetuneMetrics  | 0     
6 | val_metrics                  | FinetuneMetrics  | 0     
7 | test_metrics                 | FinetuneMetrics  | 0     
------------------------------------------------------------------
216 M     Trainable params
0         Non-trainable params
216 M     Total params
866.719   Total estimated model params size (MB)
/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Con

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


Seed set to 0


[None]