In [1]:
from pathlib import Path

from jmpfm.tasks.pretrain import PretrainConfig, PretrainModel
from jmpfm.tasks.pretrain.module import (
    NormalizationConfig,
    PretrainDatasetConfig,
    TaskConfig,
)
from jmpfm.configs.pretrain import jmp_l_pt_config_builder


# Let's make the config
def jmp_l_config():
    with jmp_l_pt_config_builder() as (builder, config):
        # Set data config
        config.batch_size = 4
        config.num_workers = 0

        # Set the tasks
        config.tasks = [
            TaskConfig(
                name="oc20",
                train_dataset=PretrainDatasetConfig(
                    src=Path("/datasets/s2ef/2M/train/"),
                    metadata_path=Path("/datasets/s2ef/2M/train_metadata.npz"),
                ),
                val_dataset=PretrainDatasetConfig(
                    src=Path("/datasets/s2ef/all/val_id/"),
                    metadata_path=Path("/datasets/s2ef/all/val_id_metadata.npz"),
                ),
                energy_loss_scale=1.0,
                force_loss_scale=73.0,
                normalization={
                    "y": NormalizationConfig(mean=0.0, std=24.901469505465872),
                    "force": NormalizationConfig(mean=0.0, std=0.5111534595489502),
                },
            ),
            TaskConfig(
                name="oc22",
                train_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/oc22/s2ef-total/train/"),
                ),
                val_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/oc22/s2ef-total/val_id/"),
                ),
                energy_loss_scale=1.0,
                force_loss_scale=80.0,
                normalization={
                    "y": NormalizationConfig(mean=0.0, std=25.229595396538468),
                    "force": NormalizationConfig(mean=0.0, std=0.25678861141204834),
                },
            ),
            TaskConfig(
                name="ani1x",
                train_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/ani1x/train/"),
                ),
                val_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/ani1x/val/"),
                ),
                energy_loss_scale=1.0,
                force_loss_scale=15.0,
                normalization={
                    "y": NormalizationConfig(mean=0.0, std=2.8700712783472118),
                    "force": NormalizationConfig(mean=0.0, std=2.131422996520996),
                },
            ),
            TaskConfig(
                name="transition1x",
                train_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/trans1x/train/"),
                ),
                val_dataset=PretrainDatasetConfig(
                    src=Path("/shared/pre-training-datasets/trans1x/val/"),
                ),
                energy_loss_scale=1.0,
                force_loss_scale=14.0,
                normalization={
                    "y": NormalizationConfig(mean=0.0, std=1.787466168382901),
                    "force": NormalizationConfig(mean=0.0, std=0.3591422140598297),
                },
            ),
        ]

        return builder(config)


config = jmp_l_config()
print(config)

configs: list[tuple[PretrainConfig, type[PretrainModel]]] = []
configs.append((config, PretrainModel))

id='nwbzqfcs' name=None project=None tags=[] notes=[] debug=False environment=EnvironmentConfig(cwd=None, python_executable=None, python_path=None, python_version=None, config=None, model=None, data=None, slurm=None, log_dir=None, seed=None, seed_workers=None, sweep_id=None, sweep_config=None) trainer=TrainerConfig(python_logging=PythonLogging(log_level=None, rich=True, rich_tracebacks=True, lovely_tensors=True, lovely_numpy=False), logging=LoggingConfig(enabled=True, log_lr=True, log_epoch=True, wandb=WandbLoggingConfig(enabled=True, log_model=False, watch=WandbWatchConfig(enabled=True, log=None, log_graph=True, log_freq=100)), csv=CSVLoggingConfig(enabled=True), tensorboard=TensorboardLoggingConfig(enabled=False)), optimizer=OptimizerConfig(grad_finite_checks=False, grad_none_checks=False, log_grad_norm=True, log_grad_norm_per_param=False, log_param_norm=False, log_param_norm_per_param=False, gradient_clipping=GradientClippingConfig(enabled=True, value=1.0, algorithm='norm'), gradien



In [2]:
from ll import Runner, Trainer


def run(config: PretrainConfig, model_cls: type[PretrainModel]) -> None:
    model = model_cls(config)
    trainer = Trainer(config)
    trainer.fit(model)

runner = Runner(run)
runner.fast_dev_run(configs, n_batches=16)

Fast dev run:   0%|          | 0/1 [00:00<?, ?it/s]

Seed set to 0


Unrecognized arguments:  dict_keys(['learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout'])


Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 16 batch(es). Logging and checkpointing is suppressed.


/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Loading `train_dataloader` to estimate number of stepping batches.
/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.



  | Name          | Type             | Params
---------------------------------------------------
0 | embedding     | Embedding        | 30.7 K
1 | backbone      | GemNetOCBackbone | 38.8 M
2 | output        | Output           | 5.3 M 
3 | train_metrics | FMMetrics        | 0     
4 | val_metrics   | FMMetrics        | 0     
5 | task_steps    | TypedModuleDict  | 0     
---------------------------------------------------
44.1 M    Trainable params
0         Non-trainable params
44.1 M    Total params
176.525   Total estimated model params size (MB)


/opt/conda/envs/fm/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=16` reached.


Seed set to 0


[None]