In [None]:
import os
from tqdm import tqdm
from functools import partial
from pymatgen.io.ase import AseAtomsAdaptor
from matgl.graph.data import MGLDataset, MGLDataLoader
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import collate_fn_graph
from matgl.utils.training import ModelLightningModule
from matgl.models import M3GNet
import torch

In [None]:
!mkdir data

In [None]:
from litraj.data import download_dataset, load_data

dataset_name = 'nebDFT2k'
download_dataset(dataset_name, 'data')

In [None]:
index = load_data(dataset_name, 'data')
atoms_list_train = index[index._split == 'train'].centroid
atoms_list_val = index[index._split == 'val'].centroid
atoms_list_test = index[index._split == 'test'].centroid

edge_ids_train = [st.info['edge_id'] for st in atoms_list_train]
edge_ids_val = [st.info['edge_id'] for st in atoms_list_val]
edge_ids_test = [st.info['edge_id'] for st in atoms_list_test]

# for nebBVSE122k, use
# atoms_list_train, atoms_list_val, atoms_list_test, index = load_dataset('nebBVSE122k', 'data')


In [None]:
sts_pmg_train = []
targets_train = []
for st in tqdm(atoms_list_train):
    st = st.copy()
    # M3GNet cannot process element 'X'. We replace it with 'H' because it does not exist in the data
    st.numbers[-1] = 1
    sts_pmg_train.append(AseAtomsAdaptor.get_structure(st))
    targets_train.append(st.info['em'])

sts_pmg_val = []
targets_val = []
for st in tqdm(atoms_list_val):
    st = st.copy()
    st.numbers[-1] = 1
    sts_pmg_val.append(AseAtomsAdaptor.get_structure(st))
    targets_val.append(st.info['em'])
    

sts_pmg_test = []
targets_test = []
for st in tqdm(atoms_list_test):
    st = st.copy()
    st.numbers[-1] = 1
    sts_pmg_test.append(AseAtomsAdaptor.get_structure(st))
    targets_test.append(st.info['em'])

In [None]:
labels_train = {
    "energies": targets_train,
}
labels_val = {
    "energies": targets_val,
}
labels_test = {
    "energies": targets_test,
}

In [None]:
elem_list = get_element_list(sts_pmg_train)
converter = Structure2Graph(element_types=elem_list, cutoff=5.0)

In [None]:
folder = './m3gnet_centroids/nebDFT2k'

_from = 'train'
os.makedirs(f'{folder}/{_from}', exist_ok = True)
train_data = MGLDataset(
                        threebody_cutoff=4.0,
                        structures = sts_pmg_train,
                        converter = converter,
                        labels=labels_train,
                        include_line_graph=True,
                        filename=f'dgl_graph.bin',
                        filename_lattice=f'lattice.pt',
                        filename_line_graph=f'dgl_line_graph.bin',
                        filename_state_attr=f'state_attr.pt',
                        filename_labels=f'labels.json',
                        name=f'MGLDataset_{dataset_name}_{_from}',
                        save_dir = f'{folder}/{_from}'
                        )

_from = 'val'
os.makedirs(f'{folder}/{_from}', exist_ok = True)
val_data = MGLDataset(
                        threebody_cutoff=4.0,
                        structures = sts_pmg_val,
                        converter = converter,
                        labels=labels_val,
                        include_line_graph=True,
                        filename=f'dgl_graph.bin',
                        filename_lattice=f'lattice.pt',
                        filename_line_graph=f'dgl_line_graph.bin',
                        filename_state_attr=f'state_attr.pt',
                        filename_labels=f'labels.json',
                        name=f'MGLDataset_{dataset_name}_{_from}',
                        save_dir = f'{folder}/{_from}'
                        )

_from = 'test'
os.makedirs(f'{folder}/{_from}', exist_ok = True)
test_data = MGLDataset(
                        threebody_cutoff=4.0,
                        structures = sts_pmg_test,
                        converter = converter,
                        labels=labels_test,
                        include_line_graph=True,
                        filename=f'dgl_graph.bin',
                        filename_lattice=f'lattice.pt',
                        filename_line_graph=f'dgl_line_graph.bin',
                        filename_state_attr=f'state_attr.pt',
                        filename_labels=f'labels.json',
                        name=f'MGLDataset_{dataset_name}_{_from}',
                        save_dir = f'{folder}/{_from}'
                        )


In [None]:
l_g_collate_fn = partial(collate_fn_graph, include_line_graph=True)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=l_g_collate_fn,
    batch_size=64,
    #num_workers=0,
)

In [None]:
model = M3GNet(element_types=elem_list,
               is_intensive=False,
               readout_type="set2set"
              )

In [None]:
lit_module = ModelLightningModule(model=model, include_line_graph=True,
                                  loss="mse_loss",
                                  lr=1e-3,
                                  decay_steps=250,
                                  decay_alpha=0.01)

In [None]:
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_RMSE:.4f}",
    monitor="val_RMSE",
    mode="min",
    save_top_k=3,
    save_last=True,
)


logger = CSVLogger("logs", name="M3GNet_centroids_nebDFT2k")
trainer = Trainer(max_epochs=30,
                  accelerator="gpu",
                  logger=logger,
                  inference_mode=False,
                  callbacks=[checkpoint_callback]
                 )
trainer.fit(model=lit_module, train_dataloaders=train_loader,
            val_dataloaders=val_loader,
           )

In [None]:
os.listdir('checkpoints')

In [None]:
device = torch.device("cuda")
inference_model = M3GNet(element_types=elem_list,
               is_intensive=False,
               readout_type="set2set"
              )


model_path = "checkpoints/best-epoch=28-val_RMSE=0.5779.ckpt"
inference_lit_module = ModelLightningModule.load_from_checkpoint(model_path, 
                                                                 model=inference_model,
                                                                 map_location=device
                                                                )
inference_lit_module.to(device)

In [None]:
import numpy as np

inference_lit_module.eval()
energy_pred = []
energy_true = []
sizes = []
for g, lat, l_g, state_attr, e in tqdm(test_loader):
    with torch.no_grad():
        e_pred = inference_lit_module(g=g.to(device), lat=lat.to(device), l_g=l_g.to(device), state_attr=state_attr.to(device))
        sizes.append(g.num_nodes())
        energy_true.extend(e.detach().cpu().numpy())
        energy_pred.extend(e_pred.detach().cpu().numpy())


energy_true = np.array(energy_true)
energy_pred = np.array(energy_pred)

In [None]:
from litraj.metrics import get_metrics
get_metrics(energy_true, energy_pred)