This is an example of train and test pipeline for GemNet-OC model.  
Same task could be performed with pre-defined config from repository root:
```bash
python run.py --config-name gemnet-oc.yaml
python run.py --config-name gemnet-oc_test.yaml
```
For detailed description please refer to [README](../nablaDFT/README.md).


In [None]:
from functools import partial

import torch
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from omegaconf import OmegaConf
from torch_ema import ExponentialMovingAverage

import nablaDFT
from nablaDFT.dataset import PyGNablaDFTDataModule
from nablaDFT.gemnet_oc import GemNetOCLightning, GemNetOC
from nablaDFT.utils import seed_everything

In [None]:
dataset_name = 'dataset_train_2k'  # Name of the training dataset
datapath = 'database'              # Path to the selected dataset
nepochs = 200                      # Number of epochs to train for
seed = 1799                        # Random seed number for reproducibility
batch_size = 8                     # Size of each batch for training
num_workers = 2                    # Dataloader's num_workers
train_ratio = 0.9                  # Part of dataset used for training
val_ratio = 0.1                    # Part of dataset used for validation
devices = 1                        # Number of GPU/TPU/CPU devices to use for training

## Downloading dataset

In [None]:
seed_everything(seed)

datamodule = PyGNablaDFTDataModule(
    datapath,
    dataset_name,
    train_size=train_ratio,
    val_size=val_ratio,
    seed=seed,
    batch_size=batch_size,
    num_workers=num_workers,
)
datamodule.setup(stage="fit")

## Initialize model

In [None]:
# GemNet-OC params
model_cfg = {
    'num_targets': 1, 
    'num_spherical': 7, 
    'num_radial': 128, 
    'num_blocks': 4, 
    'emb_size_atom': 256, 
    'emb_size_edge': 512, 
    'emb_size_trip_in': 64, 
    'emb_size_trip_out': 64, 
    'emb_size_quad_in': 32, 
    'emb_size_quad_out': 32, 
    'emb_size_aint_in': 64, 
    'emb_size_aint_out': 64, 
    'emb_size_rbf': 16, 
    'emb_size_cbf': 16, 
    'emb_size_sbf': 32, 
    'num_before_skip': 2, 
    'num_after_skip': 2, 
    'num_concat': 1, 
    'num_atom': 3, 
    'num_output_afteratom': 3, 
    'num_atom_emb_layers': 0, 
    'num_global_out_layers': 2, 
    'regress_forces': True, 
    'direct_forces': True, 
    'use_pbc': False, 
    'scale_backprop_forces': False, 
    'cutoff': 12.0, 
    'cutoff_qint': 12.0, 
    'cutoff_aeaint': 12.0, 
    'cutoff_aint': 12.0, 
    'max_neighbors': 30, 
    'max_neighbors_qint': 8, 
    'max_neighbors_aeaint': 20, 
    'max_neighbors_aint': 1000, 
    'enforce_max_neighbors_strictly': True, 
    'rbf': {'name': 'gaussian'}, 
    'rbf_spherical': None, 
    'envelope': {'name': 'polynomial', 'exponent': 5}, 
    'cbf': {'name': 'spherical_harmonics'}, 
    'sbf': {'name': 'legendre_outer'},
    'extensive': True, 
    'forces_coupled': True, 
    'output_init': 'HeOrthogonal', 
    'activation': 'silu', 
    'scale_file': None, 
    'quad_interaction': True, 
    'atom_edge_interaction': True, 
    'edge_atom_interaction': True, 
    'atom_interaction': True, 
    'scale_basis': True
}

In [None]:
opt_args = {
    "amsgrad": True,
    "betas": [0.9, 0.95],
    "lr": 1e-3,
    "weight_decay": 0
}
lr_args = {
    "factor": 0.8,
    "patience": 10
}

In [None]:
net = GemNetOC(**model_cfg)
# optimizer, scheduler and EMA instantiated during Lightning module creation
optimizer = partial(torch.optim.AdamW, **opt_args)
lr_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, **lr_args)
ema = partial(ExponentialMovingAverage, decay=0.999)
losses = {
    "energy": torch.nn.L1Loss(),
    "forces": nablaDFT.gemnet_oc.loss.L2Loss()
}
losses_coefs = {
    "energy": 1,
    "forces": 100
}
metric = torchmetrics.MultitaskWrapper(
    task_metrics={
        "energy": torchmetrics.MeanAbsoluteError(),
        "forces": torchmetrics.MeanAbsoluteError()
    }
)
model = GemNetOCLightning(
    "GemNet-OC",
    net,
    optimizer,
    lr_scheduler,
    losses,
    ema,
    metric,
    losses_coefs,
)

In [None]:
model_checkpoint = pl.callbacks.ModelCheckpoint(
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    dirpath="./checkpoints",
    filename="GemNet-OC_{epoch:03d}_{val_loss:4f}"
)
early_stopping = pl.callbacks.EarlyStopping(
    monitor="val/loss",
    min_delta=1e-4,
    patience=50,
    mode="min",
    check_on_train_epoch_end=False
)
callbacks = [model_checkpoint, early_stopping]
logger = pl.loggers.TensorBoardLogger(save_dir="./tensorboard_logs")

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    accelerator="gpu",
    max_epochs=nepochs,
    gradient_clip_algorithm="norm",
    gradient_clip_val=5.0
)

trainer.fit(model=model, datamodule=datamodule)

In [None]:
ckpt_path = trainer.checkpoint_callback.best_model_path

## Initializing testing procedure and computing the metric's result

In [None]:
datamodule_test = PyGNablaDFTDataModule(
    datapath,
    "dataset_test_conformations_2k",
    batch_size=batch_size,
    num_workers=num_workers,
)
trainer.test(model=model, datamodule=datamodule_test, ckpt_path=ckpt_path)