This is an example of train and test pipeline for PaiNN model from schnetpack library.  
Same task could be performed with pre-defined config from repository root:
```bash
python run.py --config-name painn.yaml
```
For detailed description please refer to [README](../nablaDFT/README.md).


# Train/test cycles example using PaiNN

In [None]:
# Based on https://github.com/atomistic-machine-learning/schnetpack/blob/master/examples/tutorials/tutorial_02_qm9.ipynb
import os
from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

import schnetpack as spk
import schnetpack.representation as rep
import schnetpack.task as task
import schnetpack.transform as trn

import torch
import torchmetrics
from torch.optim.lr_scheduler import ReduceLROnPlateau

from nablaDFT.dataset import ASENablaDFT
from nablaDFT.dataset.split import TestSplit
from nablaDFT.ase_model import AtomisticTaskFixed
from nablaDFT.utils import seed_everything

In [None]:
dataset_name = 'dataset_train_2k'  # Name of the training dataset
datapath = 'database'              # Path to the selected dataset
logspath = 'logs'                  # Path to log files
nepochs = 200                      # Number of epochs to train for
seed = 1799                        # Random seed number for reproducibility
batch_size = 32                    # Size of each batch for training
train_ratio = 0.9                  # Part of dataset used for training
val_ratio = 0.1                    # Part of dataset used for validation
n_interactions = 6                 # Number of interactions to consider between atoms
n_atom_basis = 128                 # Number of basis functions for atoms in the representation
n_rbf = 20                         # Number of radial basis functions in the representation
cutoff = 5.0                       # Cutoff distance (in Bohr) for computing interactions
devices = 1                        # Number of GPU/TPU/CPU devices to use for training
transforms = [
    trn.ASENeighborList(cutoff=cutoff),
    trn.RemoveOffsets("energy", remove_mean=True, remove_atomrefs=False),
    trn.CastTo32(),
]                                  # data transforms used for training and validation

## Downloading dataset

In [None]:
seed_everything(seed)
workpath = logspath

if not os.path.exists(workpath):
    os.makedirs(workpath)

datamodule = ASENablaDFT(
    "train",
    dataset_name="dataset_train_2k",
    datapath="database",
    data_workdir=logspath,
    batch_size=batch_size,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    num_workers=4,
    train_transforms=transforms,
    val_transforms=transforms
)

## Initializing training procedure and starting training

In [None]:
pairwise_distance = spk.atomistic.PairwiseDistances()
radial_basis = spk.nn.radial.GaussianRBF(
    n_rbf=n_rbf,
    cutoff=cutoff
)
cutoff_fn = spk.nn.cutoff.CosineCutoff(cutoff)
representation = rep.PaiNN(
    n_interactions=n_interactions,
    n_atom_basis=n_atom_basis,
    radial_basis=radial_basis,
    cutoff_fn=cutoff_fn
)
pred_energy = spk.atomistic.Atomwise(
    n_in=representation.n_atom_basis,
    output_key="energy"
)
pred_forces = spk.atomistic.Forces()
postprocessors = [
    trn.AddOffsets("energy", add_mean=True)
]
nnpot = spk.model.NeuralNetworkPotential(
    representation=representation,
    input_modules=[pairwise_distance],
    output_modules=[pred_energy, pred_forces],
    postprocessors=postprocessors
)
output_energy = spk.task.ModelOutput(
    name="energy",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1,
    metrics={"MAE": torchmetrics.MeanAbsoluteError()}
)
output_forces = spk.task.ModelOutput(
    name="forces",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1,
    metrics={"MAE": torchmetrics.MeanAbsoluteError()}
)

scheduler_args = {
    "factor": 0.8,
    "patience": 10,
    "min_lr": 1e-06
}

task = AtomisticTaskFixed(
    model_name="PaiNN",
    model=nnpot,
    outputs=[output_energy, output_forces],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4},
    scheduler_cls=ReduceLROnPlateau,
    scheduler_args=scheduler_args,
    scheduler_monitor="val_loss"
)

# create trainer
logger = pl.loggers.TensorBoardLogger(save_dir=workpath)
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_loss",
    mode="min",
    dirpath=f"{workpath}/checkpoints",
    filename="Painn-{epoch:03d}_{val_loss:4f}"
)
callbacks = [
    lr_monitor,
    checkpoint_callback
]

trainer = pl.Trainer(
    accelerator='gpu',
    devices=devices,
    callbacks=callbacks,
    logger=logger,
    default_root_dir=workpath,
    max_epochs=nepochs,
)

trainer.fit(task, datamodule=datamodule)

In [None]:
ckpt_path = trainer.checkpoint_callback.best_model_path

## Initializing the testing procedure and computing the metric's result

In [None]:
batch_size = 100
cutoff = 5.0
gpu = 0

if gpu == -1:
    device = torch.device("cpu")
else:
    device = torch.device(f"cuda:{gpu}")

datamodule_test = ASENablaDFT(
    "test",
    dataset_name="dataset_test_conformations_2k",
    datapath="database_test",
    data_workdir=logspath,
    batch_size=batch_size,
    train_ratio=0.0,
    val_ratio=0.0,
    test_ratio=1.0,
    num_workers=4,
    test_transforms=[
        trn.ASENeighborList(cutoff=cutoff),
        trn.CastTo32()
    ],
    splitting=TestSplit()
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=devices,
    default_root_dir=workpath,
    inference_mode=False
)

trainer.test(model=task, datamodule=datamodule_test, ckpt_path=ckpt_path)