In [12]:
import os

import hydra
import pytorch_lightning as pl
import torch
from ase.db import connect
from nablaDFT import model_registry
from nablaDFT.dataset import PyGNablaDFTDataModule
from nablaDFT.pipelines import predict
from omegaconf import OmegaConf

In [13]:
data_args = {
    "root": "./datasets/nablaDFT/test",
    "dataset_name": "dataset_test_conformations_tiny",
    "batch_size": 4,
    "num_workers": 2,
}
if torch.cuda.is_available():
    accelerator = "gpu"
    devices = 1
else:
    accelerator = "cpu"
    devices = None

In [14]:
model = model_registry.get_pretrained_model("lightning", "GemNet-OC_train_large")
datamodule = PyGNablaDFTDataModule(**data_args)
trainer = pl.Trainer(accelerator=accelerator, devices=devices)
trainer.logger = False


Downloading GemNet-OC_train_large: 578MB [00:14, 40.5MB/s]                                                                                 
/home/kostanew/anaconda3/envs/p4env/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'metric' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['metric'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [15]:
predict(trainer, model, datamodule, ckpt_path=None, model_name="GemNet-OC-large", output_dir="./predictions")

You are using a CUDA device ('NVIDIA GeForce RTX 3080 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Downloading split: dataset_test_conformations_tiny: 100%|█████████████████████████████████████████████| 11.1M/11.1M [00:00<00:00, 30.8MB/s]
Processing...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2774/2774 [00:01<00:00, 2467.94it/s]
INFO:nablaDFT.dataset.pyg_datasets:Saved processed dataset: datasets/nablaDFT/test/processed/dataset_test_conformations_tiny_predict.pt
Done!
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                                 …

INFO:root:Write predictions to predictions/GemNet-OC-large_dataset_test_conformations_tiny.db


In [16]:
db = connect("./predictions/GemNet-OC-large_dataset_test_conformations_tiny.db")
row = db.get(1)

energy, forces = row.data["energy_pred"], row.data["forces_pred"]

print(f"Predicted energy: {energy}")
print("Predicted interatomic forces:")
print(forces)

Predicted energy: [-6.055543422698975]
Predicted interatomic forces:
[[ 3.40643935e-02  1.61777381e-02 -6.22326694e-02]
 [-2.85620289e-03 -3.46308611e-02 -1.75479930e-02]
 [-8.64476264e-02  2.33703889e-02  6.05947115e-02]
 [ 8.89855027e-02  1.12606687e-02 -2.05886275e-01]
 [ 1.15613164e-02 -2.96420865e-02  1.33667454e-01]
 [ 3.51779200e-02  2.05350779e-02  2.25494280e-02]
 [-2.69800629e-02 -3.53164971e-02 -1.77055355e-02]
 [ 4.11478356e-02 -8.74247402e-02  5.56267537e-02]
 [ 1.87740941e-02  7.37474672e-03 -1.03036799e-02]
 [ 1.24148568e-02 -4.48866524e-02 -1.14919133e-02]
 [ 1.70258526e-03 -7.12430850e-03 -5.78673673e-04]
 [-2.35788897e-02 -3.81501131e-02  5.07842051e-03]
 [-4.07069698e-02  4.83627357e-02  7.44050462e-03]
 [ 2.16332860e-02 -2.79590711e-02 -1.79409736e-03]
 [ 1.22112278e-02  7.30230063e-02 -2.21254100e-04]
 [-1.38332024e-01 -7.06819966e-02 -1.08784549e-01]
 [ 4.83015142e-02  1.12306774e-01 -7.58751063e-04]
 [ 9.95967910e-02 -5.26936986e-02  1.81389198e-01]
 [-5.93704395