Detailed description of run configuration could be found [here](../nablaDFT/README.md).

## Test example

In [None]:
# model test example config
!cat ../config/gemnet-oc_test.yaml

In [None]:
!python ../run.py --config-name gemnet-oc_test.yaml

## Inference on another dataset

For CLI-run please use example from the root of the repository:
```bash
python run.py --config-name gemnet-oc_predict.yaml
```

Detailed description could be found in [README](../nablaDFT/README.md)

In [None]:
import os

import hydra
from omegaconf import OmegaConf
import torch
import pytorch_lightning as pl

from nablaDFT.gemnet_oc import GemNetOCLightning
from nablaDFT.dataset import PyGNablaDFTDataModule

### Paths and args

In [None]:
model_cfg_path = "../config/model/gemnet-oc.yaml"
ckpt_path = "../checkpoints/GemNet-OC/GemNet-OC_100k.ckpt"
tb_logs = "./logs"
predictions_dir = "./predictions"
devices = 1

data_args = {
    "root": "../datasets/nablaDFT/test",
    "dataset_name": "test_4k_mff_traj_part",
    "batch_size": 4,
    "num_workers": 2,
}

### Instantiate dataset and load model

In [None]:
cfg = OmegaConf.load(model_cfg_path)
model = hydra.utils.instantiate(cfg)
model.load_state_dict(torch.load(ckpt_path)['state_dict'])
datamodule = PyGNablaDFTDataModule(**data_args)

In [None]:
os.makedirs(predictions_dir, exist_ok=True)
pred_path = os.path.join(predictions_dir, "gemnet_preds.pt")

In [None]:
trainer = pl.Trainer(
    accelerator='gpu',
    devices=devices
)
trainer.logger = False

In [None]:
predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

In [None]:
# access to predictions
batch_preds = predictions[0]
energy, forces = batch_preds[0], batch_preds[1] # [bs], [natoms in batch, 3]