In [1]:
import os

from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
import torchmetrics

import pytorch_lightning as pl

import schnetpack as spk
import schnetpack.representation as rep
import schnetpack.task as task
import schnetpack.transform as trn

import tqdm

from nablaDFT.dataset import NablaDFT
from nablaDFT.painn.train_painn import AtomisticTaskFixed, seed_everything

In [2]:
dataset_name = 'dataset_train_2k'  # Name of the training dataset
datapath = 'database'              # Path to the selected dataset
logspath = 'logs'                  # Path to log files
nepochs = 200                        # Number of epochs to train for
seed = 1799                        # Random seed number for reproducibility
batch_size = 100                   # Size of each batch for training
n_interactions = 6                 # Number of interactions to consider between atoms
n_atom_basis = 128                 # Number of basis functions for atoms in the representation
n_rbf = 20                         # Number of radial basis functions in the representation
cutoff = 5.0                       # Cutoff distance (in Angstroms) for computing interactions
devices = 1                        # Number of GPU/TPU/CPU devices to use for training

In [3]:
seed_everything(seed)
workpath = logspath

if not os.path.exists(workpath):
    os.makedirs(workpath)

data = NablaDFT("ASE", dataset_name,
                datapath=datapath,
                data_workdir=workpath,
                batch_size=batch_size,
                num_workers=4,
                transforms=[
                    trn.ASENeighborList(cutoff=cutoff),
                    trn.RemoveOffsets("energy", remove_mean=True,
                                      remove_atomrefs=False),
                    trn.CastTo32()
                ],
                split_file=os.path.join(workpath, "split.npz"))

In [4]:
pairwise_distance = spk.atomistic.PairwiseDistances()
radial_basis = spk.nn.radial.GaussianRBF(
    n_rbf=n_rbf,
    cutoff=cutoff
)
cutoff_fn = spk.nn.cutoff.CosineCutoff(cutoff)
representation = rep.PaiNN(
    n_interactions=n_interactions,
    n_atom_basis=n_atom_basis,
    radial_basis=radial_basis,
    cutoff_fn=cutoff_fn
)
pred_energy = spk.atomistic.Atomwise(
    n_in=representation.n_atom_basis,
    output_key="energy"
)
postprocessors = [
    trn.CastTo64(),
    trn.AddOffsets("energy", add_mean=True)
]
nnpot = spk.model.NeuralNetworkPotential(
    representation=representation,
    input_modules=[pairwise_distance],
    output_modules=[pred_energy],
    postprocessors=postprocessors
)
output_energy = spk.task.ModelOutput(
    name="energy",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1,
    metrics={"MAE": torchmetrics.MeanAbsoluteError()}
)

scheduler_args = {
    "factor": 0.8,
    "patience": 10,
    "min_lr": 1e-06
}

task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_energy],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4},
    scheduler_cls=ReduceLROnPlateau,
    scheduler_args=scheduler_args,
    scheduler_monitor="val_loss"
)

# create trainer
logger = pl.loggers.TensorBoardLogger(save_dir=workpath)
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(workpath, "best_inference_model"),
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    accelerator='gpu',
    devices=devices,
    callbacks=callbacks,
    logger=logger,
    default_root_dir=workpath,
    max_epochs=nepochs,
)

trainer.fit(task, datamodule=data.dataset)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
100%|███████████████████████████████████████████| 52/52 [00:01<00:00, 31.70it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 1.2 M 
1 | outputs | ModuleList             | 0     
---------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.628     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

  rank_zero_deprecation(


Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


In [6]:
dataset_name = 'dataset_test_conformations_2k'
datapath = 'database_test'
logspath = 'logs'
batch_size = 100
cutoff = 5.0
gpu = -1
if gpu == -1:
    device = torch.device("cpu")
else:
    device = torch.cuda.device(gpu)
device = torch.cuda.device(gpu)
data = NablaDFT("ASE",dataset_name,
                       datapath=datapath,
                       data_workdir=workpath,
                       distance_unit="Bohr",
                       batch_size=batch_size,
                       train_ratio = 0,
                       num_workers=1,
                       transforms=[
                            trn.ASENeighborList(cutoff=cutoff),
                            trn.CastTo32()
                       ],
                       split_file=os.path.join(workpath, "split_test.npz"))

data.dataset.prepare_data()
data.dataset.setup()

best_model = torch.load(os.path.join(workpath, 'best_inference_model'))
best_model = best_model.cuda()
best_model = best_model.eval()

metric = torchmetrics.MeanAbsoluteError()

with torch.no_grad():
    for x in tqdm.tqdm(data.dataset.val_dataloader()):

        for k in x:
            if x[k].dtype == torch.float64:
                x[k] = x[k].float()
            x[k] = x[k].to("cuda:0")

        target = x['energy'].cpu().clone()
        prediction = best_model(x)['energy'].cpu()
        mae = metric(prediction, target)

print(metric.compute())

100%|███████████████████████████████████████████| 18/18 [00:10<00:00,  1.72it/s]

tensor(0.5854)



