# Introduction

This notebook demonstrates how to fit a M3GNet potential using PyTorch Lightning with MatGL.

In [5]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import matgl
import os
import shutil
import warnings

from __future__                import annotations
from dgl.data.utils            import split_dataset
from mp_api.client             import MPRester
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.models              import M3GNet
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

In [6]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Load data for model training
m3gnet_dataset = pd.read_excel(data_train_path)

For the purposes of demonstration, we will download all Si-O compounds in the Materials Project via the MPRester. The forces and stresses are set to zero, though in a real context, these would be non-zero and obtained from DFT calculations.

In [None]:
# Obtain your API key here: https://next-gen.materialsproject.org/api
mpr = MPRester(api_key="YOUR_API_KEY")

entries = mpr.get_entries_in_chemsys(["Si", "O"])
structures = [e.structure for e in entries]

energies = [e.energy for e in entries]
forces = [np.zeros((len(s), 3)).tolist() for s in structures]
stresses = [np.zeros((3, 3)).tolist() for s in structures]


labels = {
    "energies": energies,
    "forces": forces,
    "stresses": stresses,
}

print(f"{len(structures)} downloaded from MP.")

Retrieving ThermoDoc documents: 100%|██████████| 407/407 [00:00<00:00, 4962446.88it/s]


407 downloaded from MP.


We will first setup the M3GNet model and the LightningModule.

In [None]:
element_types = get_element_list(structures)
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = M3GNetDataset(
    threebody_cutoff=4.0,
    structures=structures,
    converter=converter,
    labels=labels,
)
train_data, val_data, test_data = split_dataset(
    dataset,
    frac_list=[0.8, 0.1, 0.1],
    shuffle=True,
    random_state=42,
)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=2,
    num_workers=1,
)
model = M3GNet(
    element_types=element_types,
    is_intensive=False,
)
lit_module = PotentialLightningModule(model=model)

## Finetuning a pre-trained M3GNet
In the previous cells, we demonstrated the process of training an M3GNet from scratch. Next, let's see how to perform additional training on an M3GNet that has already been trained using Materials Project data.

In [None]:
# download a pre-trained M3GNet
m3gnet_nnp          = matgl.load_model(model_load_path)
model_pretrained    = m3gnet_nnp.model
lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4)

In [None]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.
logger  = CSVLogger("logs", name="M3GNet_finetuning")
trainer = pl.Trainer(max_epochs=1, accelerator="cpu", logger=logger, inference_mode=False)

trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 163/163 [00:37<00:00,  4.34it/s, v_num=1, val_Total_Loss=5.420, val_Energy_MAE=1.070, val_Force_MAE=0.407, val_Stress_MAE=0.000, val_Site_Wise_MAE=0.000, val_Energy_RMSE=1.330, val_Force_RMSE=0.615, val_Stress_RMSE=0.000, val_Site_Wise_RMSE=0.000, train_Total_Loss=21.80, train_Energy_MAE=3.270, train_Force_MAE=0.572, train_Stress_MAE=0.000, train_Site_Wise_MAE=0.000, train_Energy_RMSE=3.430, train_Force_RMSE=0.870, train_Stress_RMSE=0.000, train_Site_Wise_RMSE=0.000]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 163/163 [00:39<00:00,  4.17it/s, v_num=1, val_Total_Loss=5.420, val_Energy_MAE=1.070, val_Force_MAE=0.407, val_Stress_MAE=0.000, val_Site_Wise_MAE=0.000, val_Energy_RMSE=1.330, val_Force_RMSE=0.615, val_Stress_RMSE=0.000, val_Site_Wise_RMSE=0.000, train_Total_Loss=21.80, train_Energy_MAE=3.270, train_Force_MAE=0.572, train_Stress_MAE=0.000, train_Site_Wise_MAE=0.000, train_Energy_RMSE=3.430, train_Force_RMSE=0.870, train_Stress_RMSE=0.000, train_Site_Wise_RMSE=0.000]


In [None]:
# Save trained model
model_pretrained.save(model_save_path)

# Load trained model
trained_model = matgl.load_model(path=model_save_path)

In [None]:
# This code just performs cleanup for this notebook.

for fn in ("dgl_graph.bin", "dgl_line_graph.bin", "state_attr.pt", "labels.json"):
    try:
        os.remove(fn)
    except FileNotFoundError:
        pass

shutil.rmtree("logs")
shutil.rmtree("trained_model")
shutil.rmtree("finetuned_model")