In [3]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import glob
import torch

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cpu')

In [5]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Whether to include charge or not
charged = True

# 0: material, 1: charge state, 2: ionic step
depth = 0

# Define batch size
batch_size = 128

# Stress weight for training
stress_weight = 0

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 10

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 1

# Load simulation data

In [4]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

if os.path.exists(data_train_path):
    # Load data for model training
    m3gnet_dataset = pd.read_excel(data_train_path, index_col=0, header=[0,1,2])
else:
    # Path to dataset, structured as:
    # path_to_dataset
    #     material_i
    #         defect_i
    #             simulation_i (containing vasprun.xml)
    #path_to_dataset = '../../../Desktop/defects/gamma'
    path_to_dataset = '../../../Desktop/CeO2-data'

    # Extract the data
    #source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
    source_m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
    #source_m3gnet_dataset.to_excel(data_train_path)

source_m3gnet_dataset


data-244

data-620

data-412

data-415

data-627

data-243

data-618

data-288

data-275

data-423

data-611

data-281

data-286

data-616

data-424

data-272

data-629

data-689

data-219

data-484

data-470

data-642

data-226

data-448

data-221

data-645

data-477

data-483

data-228

data-673

data-441

data-217

data-687

data-479

data-680

data-210

data-446

data-674

data-425

data-617

data-273

data-287

data-628

data-280

data-274

data-610

data-422

data-626

data-414

data-242

data-289

data-619

data-245

data-413

data-621

data-478

data-211

data-675

data-447

data-681

data-229

data-686

data-440

data-672

data-216

data-449

data-482

data-220

data-476

data-644

data-218

data-688

data-643

data-471

data-227

data-485

data-193

data-167

data-355

data-731

data-503

data-31

data-158

data-504

data-36

data-352

data-160

data-194

data-709

data-399

data-364

data-156

data-532

data-700

data-390

data-169

data-397

data-707

data-535

data-151

d

Unnamed: 0_level_0,data-244,data-620,data-620,data-620,data-620,data-620,data-620,data-412,data-412,data-412,...,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469
Unnamed: 0_level_1,0,0,1,2,3,4,5,0,1,2,...,5,6,7,8,9,10,11,12,13,14
structure,[[3.33615302e+01 5.84467677e+00 2.59605999e-03...,"[[0. 0. 0.] Gd, [ 0. 0. 58.7528] Gd, ...","[[0. 0.05303286 0. ] Gd, [ 0. ...","[[0. 0.09083842 0. ] Gd, [ 0. ...","[[0. 0.10042071 0. ] Gd, [ 0. ...","[[0. 0.10310386 0. ] Gd, [ 0. ...","[[0. 0.10636262 0. ] Gd, [ 0. ...","[[0. 0.01605117 0.01605117] Gd, [58.81...","[[-0.14904472 0.18096735 -0.17713168] Gd, [58...","[[-0.04738322 0.07378836 -0.03270408] Gd, [59...",...,"[[0.07713968 0.46225566 0.46176667] Gd, [88.21...","[[0.08250283 0.46694071 0.46628983] Gd, [88.20...","[[0.08576356 0.47023434 0.47007156] Gd, [88.21...","[[0.08950012 0.47313377 0.47313371] Gd, [88.20...","[[0.09308381 0.47608524 0.47608518] Gd, [88.21...","[[0.09665897 0.47882985 0.47882978] Gd, [88.20...","[[0.09926665 0.48080404 0.48080397] Gd, [88.21...","[[0.1021922 0.48274136 0.48290401] Gd, [88.21...","[[0.10479837 0.48470722 0.48486988] Gd, [88.21...","[[0.10756158 0.48664665 0.48664659] Gd, [88.21..."
energy,-412.577999,-826.155132,-826.405749,-826.414187,-826.414585,-826.415257,-826.415375,-830.225833,-831.447328,-832.020734,...,-2832.759931,-2832.761295,-2832.762525,-2832.763084,-2832.763854,-2832.76423,-2832.764701,-2832.76491,-2832.765254,-2832.765394
force,"[[1.7e-05, -3.3e-05, 0.003824], [1.3e-05, -7.4...","[[0.0, 0.10137, -0.0], [0.0, 0.113333, -0.0], ...","[[-0.0, 0.075332, -0.0], [-0.0, 0.052073, -0.0...","[[0.0, 0.018956, -0.0], [0.0, 0.02052, -0.0], ...","[[0.0, 0.008732, 0.0], [0.0, 0.00968, 0.0], [0...","[[-0.0, 0.006542, 0.0], [-0.0, 0.006302, 0.0],...","[[0.0, 0.003934, 0.0], [0.0, 0.00258, 0.0], [0...","[[-0.285926, 0.316654, -0.370853], [0.222586, ...","[[0.33792, -0.357173, 0.479929], [-0.115185, -...","[[-0.048002, 0.058973, 0.051325], [0.096117, 0...",...,"[[0.007117, 0.006306, 0.006293], [0.001023, 0....","[[0.00509, 0.004901, 0.005553], [0.000624, 0.0...","[[0.00619, 0.005033, 0.005036], [-9.4e-05, 0.0...","[[0.004955, 0.003974, 0.004099], [0.000324, 0....","[[0.005073, 0.003932, 0.00398], [4.2e-05, 0.00...","[[0.004202, 0.003046, 0.003086], [0.000261, 0....","[[0.004312, 0.003148, 0.003168], [0.000142, 0....","[[0.003697, 0.002579, 0.002596], [0.000236, 0....","[[0.00362, 0.002505, 0.002514], [0.000233, 0.0...","[[0.003025, 0.001997, 0.002004], [0.000303, 0...."


In [5]:
len(source_m3gnet_dataset)

3

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [6]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

In [7]:
m3gnet_dataset

Unnamed: 0_level_0,data-244,data-620,data-620,data-620,data-620,data-620,data-620,data-412,data-412,data-412,...,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469,data-469
Unnamed: 0_level_1,0,0,1,2,3,4,5,0,1,2,...,5,6,7,8,9,10,11,12,13,14
structure,[[3.33615302e+01 5.84467677e+00 2.59605999e-03...,"[[0. 0. 0.] Gd, [ 0. 0. 58.7528] Gd, ...","[[0. 0.05303286 0. ] Gd, [ 0. ...","[[0. 0.09083842 0. ] Gd, [ 0. ...","[[0. 0.10042071 0. ] Gd, [ 0. ...","[[0. 0.10310386 0. ] Gd, [ 0. ...","[[0. 0.10636262 0. ] Gd, [ 0. ...","[[0. 0.01605117 0.01605117] Gd, [58.81...","[[-0.14904472 0.18096735 -0.17713168] Gd, [58...","[[-0.04738322 0.07378836 -0.03270408] Gd, [59...",...,"[[0.07713968 0.46225566 0.46176667] Gd, [88.21...","[[0.08250283 0.46694071 0.46628983] Gd, [88.20...","[[0.08576356 0.47023434 0.47007156] Gd, [88.21...","[[0.08950012 0.47313377 0.47313371] Gd, [88.20...","[[0.09308381 0.47608524 0.47608518] Gd, [88.21...","[[0.09665897 0.47882985 0.47882978] Gd, [88.20...","[[0.09926665 0.48080404 0.48080397] Gd, [88.21...","[[0.1021922 0.48274136 0.48290401] Gd, [88.21...","[[0.10479837 0.48470722 0.48486988] Gd, [88.21...","[[0.10756158 0.48664665 0.48664659] Gd, [88.21..."
energy,-412.577999,-826.155132,-826.405749,-826.414187,-826.414585,-826.415257,-826.415375,-830.225833,-831.447328,-832.020734,...,-2832.759931,-2832.761295,-2832.762525,-2832.763084,-2832.763854,-2832.76423,-2832.764701,-2832.76491,-2832.765254,-2832.765394
force,"[[1.7e-05, -3.3e-05, 0.003824], [1.3e-05, -7.4...","[[0.0, 0.10137, -0.0], [0.0, 0.113333, -0.0], ...","[[-0.0, 0.075332, -0.0], [-0.0, 0.052073, -0.0...","[[0.0, 0.018956, -0.0], [0.0, 0.02052, -0.0], ...","[[0.0, 0.008732, 0.0], [0.0, 0.00968, 0.0], [0...","[[-0.0, 0.006542, 0.0], [-0.0, 0.006302, 0.0],...","[[0.0, 0.003934, 0.0], [0.0, 0.00258, 0.0], [0...","[[-0.285926, 0.316654, -0.370853], [0.222586, ...","[[0.33792, -0.357173, 0.479929], [-0.115185, -...","[[-0.048002, 0.058973, 0.051325], [0.096117, 0...",...,"[[0.007117, 0.006306, 0.006293], [0.001023, 0....","[[0.00509, 0.004901, 0.005553], [0.000624, 0.0...","[[0.00619, 0.005033, 0.005036], [-9.4e-05, 0.0...","[[0.004955, 0.003974, 0.004099], [0.000324, 0....","[[0.005073, 0.003932, 0.00398], [4.2e-05, 0.00...","[[0.004202, 0.003046, 0.003086], [0.000261, 0....","[[0.004312, 0.003148, 0.003168], [0.000142, 0....","[[0.003697, 0.002579, 0.002596], [0.000236, 0....","[[0.00362, 0.002505, 0.002514], [0.000233, 0.0...","[[0.003025, 0.001997, 0.002004], [0.000303, 0...."


### Splitting into train-validation-test sets

In [23]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 9682 samples to train, 2402 to evaluate, and 2720 to test


### Convert into graph database

In [6]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    for j in range(len(dataset.loc['force'].values)):
        dataset.loc['force'].values[j] = dataset.loc['force'].values.tolist()[j].tolist()
    
    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc['stress'].values.tolist()

    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

NameError: name 'train_dataset' is not defined

In [None]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
)

# Retrain model

In [26]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [2]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
model_pretrained.save(model_save_path)

NameError: name 'CSVLogger' is not defined

# Analyze metrics

In [1]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

NameError: name 'trainer' is not defined

In [None]:
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [45]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

(0.0135606033727526, 0.0874462649226188, 0.0)

In [46]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

(0.0, 0.0, 0.0)

# Cleanup the notebook

In [14]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', 'labels*.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')