In [1]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import shutil
import warnings
import glob

from os                        import path
from __future__                import annotations
from dgl.data.utils            import split_dataset
from mp_api.client             import MPRester
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.models              import M3GNet
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

In [2]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 30

dpi = 50

# Load simulation data

In [3]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

if path.exists(data_train_path):
    # Load data for model training
    m3gnet_dataset = pd.read_excel(data_train_path, index_col=0, header=[0,1,2])
else:
    # Path to dataset, structured as:
    # path_to_dataset
    #     material_i
    #         defect_i
    #             simulation_i (containing vasprun.xml)
    path_to_dataset = '../../../Desktop/defects/ssc'

    # Extract the data
    source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset)
    #source_m3gnet_dataset.to_excel(data_train_path)

source_m3gnet_dataset


BiSBr
	vac_1_Bi_0
	vac_1_Bi_2
	vac_2_S_-2
	vac_2_S_1
	vac_3_Br_-1
	vac_3_Br_0

BiSeBr

BiSeI

BiSI
	as_1_Bi_on_S_-2
	as_1_Bi_on_S_0
	as_1_I_on_Bi_0
	as_1_S_on_Bi_0
	as_2_Bi_on_I_0
	as_2_I_on_S_0
	as_2_S_on_I_-1
	as_2_S_on_I_0
	as_2_S_on_I_4
	supercell
	vac_1_Bi_-1
	vac_1_Bi_-2
	vac_1_Bi_-3
	vac_1_Bi_0
	vac_1_Bi_1
	vac_1_Bi_2
	vac_1_Bi_3
	vac_2_S_-1
	vac_2_S_-2
	vac_2_S_0
	vac_2_S_1
	vac_2_S_2
	vac_3_I_-1
	vac_3_I_0
	vac_3_I_1


Unnamed: 0_level_0,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,...,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI
Unnamed: 0_level_1,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,BiSBr_vac_1_Bi_0,...,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_2,BiSI_vac_1_Bi_3,BiSI_vac_2_S_-1,BiSI_vac_2_S_1,BiSI_vac_3_I_-1,BiSI_vac_3_I_-1,BiSI_vac_3_I_-1,BiSI_vac_3_I_-1,BiSI_vac_3_I_0
Unnamed: 0_level_2,BiSBr_vac_1_Bi_0_0,BiSBr_vac_1_Bi_0_1,BiSBr_vac_1_Bi_0_2,BiSBr_vac_1_Bi_0_3,BiSBr_vac_1_Bi_0_4,BiSBr_vac_1_Bi_0_5,BiSBr_vac_1_Bi_0_6,BiSBr_vac_1_Bi_0_7,BiSBr_vac_1_Bi_0_8,BiSBr_vac_1_Bi_0_9,...,BiSI_vac_1_Bi_1_9,BiSI_vac_1_Bi_2_0,BiSI_vac_1_Bi_3_0,BiSI_vac_2_S_-1_0,BiSI_vac_2_S_1_0,BiSI_vac_3_I_-1_0,BiSI_vac_3_I_-1_1,BiSI_vac_3_I_-1_2,BiSI_vac_3_I_-1_3,BiSI_vac_3_I_0_0
structure,"[[ 1.02035922 11.92158884 3.76766581] Bi, [5....","[[ 1.02034623 11.92166811 3.76780143] Bi, [5....","[[ 1.02032769 11.92178153 3.76799504] Bi, [5....","[[ 1.02033337 11.92174653 3.7679352 ] Bi, [5....","[[ 1.02031243 11.92179732 3.76800429] Bi, [5....","[[ 1.02026312 11.92191691 3.76816706] Bi, [5....","[[ 1.02016438 11.92215592 3.76849259] Bi, [5....","[[ 1.02023412 11.92198709 3.76826278] Bi, [5....","[[ 1.02015479 11.92193252 3.76797232] Bi, [5....","[[ 1.0200249 11.92184279 3.76749595] Bi, [5....",...,"[[ 1.06517186 11.81356455 3.84817353] Bi, [5....","[[ 1.05028994 11.82619682 3.84784363] Bi, [5....","[[ 1.03702209 11.7900414 3.8066726 ] Bi, [5....","[[1.08868358 3.20838168 3.97978954] Bi, [ 1.06...","[[0.8136422 3.18635783 3.70505612] Bi, [ 1.05...","[[1.08842518 3.20807541 3.91183347] Bi, [ 1.08...","[[1.08842001 3.20799151 3.91179091] Bi, [ 1.08...","[[1.08840349 3.20795101 3.9117762 ] Bi, [ 1.08...","[[1.08829289 3.20766018 3.9116413 ] Bi, [ 1.08...","[[1.07846683 3.19807917 3.85369671] Bi, [ 1.10..."
energy,-353.154052,-353.154319,-353.15437,-353.154402,-353.154494,-353.154635,-353.154529,-353.154652,-353.154752,-353.154845,...,-348.707865,-352.757038,-356.425515,-339.554443,-349.834732,-342.527911,-342.527942,-342.527974,-342.528156,-347.593826
force,"[[-0.00067363, 0.00411333, 0.007025], [0.00901...","[[-0.0014098, 0.0028294, 0.00484769], [0.00768...","[[-0.00109682, 0.0004379, 0.00034179], [-0.000...","[[-0.00074512, 0.00120851, 0.00130262], [0.003...","[[-0.00085168, -0.00039352, -0.00323306], [-4....","[[-0.00114135, -0.00171104, -0.00523857], [-0....","[[-0.00119504, -0.00523211, -0.0165078], [-0.0...","[[-0.00137029, -0.00278464, -0.00913819], [-0....","[[-0.00091695, 0.00018231, -0.0051009], [-0.00...","[[-0.00146048, 0.00247828, -0.00577224], [0.00...",...,"[[-0.00031494, 0.00277768, 0.00226798], [-0.00...","[[-0.00153442, 0.00152999, 0.00179698], [0.005...","[[-0.00138497, 0.00501111, 0.00277365], [0.009...","[[-0.01036026, -0.00882522, -0.01443069], [-0....","[[0.00333455, 0.00310133, 0.00020627], [0.0006...","[[-0.00026855, -0.00434635, -0.00220466], [-0....","[[-0.00088233, -0.00217825, -0.00078951], [-0....","[[-0.00076344, -0.00201738, -0.00095335], [-0....","[[-0.00061493, -4.22e-06, -0.00026272], [-0.00...","[[-0.00664046, 0.00150838, 0.00379567], [0.002..."
stress,"[[2.219643409, -0.014113724000000001, 0.029521...","[[2.217231167, -0.016863247, 0.032264316], [-0...","[[2.221431893, -0.020900573000000002, 0.036586...","[[2.219480204, -0.019606924, 0.035106075], [-0...","[[2.2173241860000004, -0.019437214, 0.03511274...","[[2.2114738920000003, -0.019020978, 0.03521111...","[[2.202701121, -0.018337836, 0.035449988], [-0...","[[2.209423786, -0.019026292, 0.035465137], [-0...","[[2.2065778970000003, -0.018549379, 0.03504115...","[[2.2065881600000004, -0.017884964, 0.03453286...",...,"[[1.071568483, 0.001139871, 0.0023203560000000...","[[1.7619330670000002, 0.052789061000000005, 0....","[[2.5636920080000003, 0.046901413, -0.09893155...","[[-0.5415527489999999, -0.004684149, -0.005580...","[[0.7749622810000001, -1.4175e-05, 0.000193823...","[[-0.26752265000000003, 0.0012726500000000002,...","[[-0.26929120100000004, 0.0012495120000000001,...","[[-0.270723587, 0.001221951, 0.001847412], [0....","[[-0.27955009000000003, 0.00105737, 0.00181845...","[[0.25633715100000004, -0.027647141, 0.0029843..."


# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [None]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

### Splitting into train-validation-test sets

In [None]:
# Define unique labels, wrt the outer column
unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

# Shuffle the list of unique labels
np.random.shuffle(unique_labels)

# Define the sizes of every sets
# Corresponds to the size wrt the number of unique materials in the dataset
test_size       = int(test_ratio       * len(unique_labels))
validation_size = int(validation_ratio * len(unique_labels))

test_labels       = unique_labels[:test_size]
validation_labels = unique_labels[test_size:test_size+validation_size]
train_labels      = unique_labels[test_size+validation_size:]

# Use the computed indexes to generate train and test sets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

In [None]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    # Extract data from dataset
    labels = {
        "energies": dataset.loc['energy'].values.tolist(),
        "forces":   dataset.loc['force'].values.tolist(),
        "stresses": dataset.loc['stress'].values.tolist(),
    }
    
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    print(len(structures))
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

In [None]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=2,
    num_workers=1,
)
model = M3GNet(
    element_types=element_types,
    is_intensive=False,
)
lit_module = PotentialLightningModule(model=model)

# Retrain model

In [None]:
# download a pre-trained M3GNet
m3gnet_nnp          = matgl.load_model(model_load_path)
model_pretrained    = m3gnet_nnp.model
lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4)

In [None]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

In [None]:
# Save trained model
model_pretrained.save(model_save_path)

In [None]:
# Access training metrics
training_metrics = trainer.callback_metrics

# Access test metrics
test_metrics = trainer.logged_metrics

In [None]:
# Version of trainng you specifically want to analyze
current_version = 0

# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df

In [None]:
# NaN to cero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()

df

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [None]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

# Cleanup the notebook

In [None]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')