In [1]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import glob

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

In [2]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Whether to include charge or not
charged = True

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Stress weight for training
stress_weight = 0

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 10

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 1

# Load simulation data

In [3]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

if os.path.exists(data_train_path):
    # Load data for model training
    m3gnet_dataset = pd.read_excel(data_train_path, index_col=0, header=[0,1,2])
else:
    # Path to dataset, structured as:
    # path_to_dataset
    #     material_i
    #         defect_i
    #             simulation_i (containing vasprun.xml)
    path_to_dataset = '../../../Desktop/CeO2-data'

    # Extract the data
    source_m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
    #source_m3gnet_dataset.to_excel(data_train_path)

source_m3gnet_dataset


data-265
36
36

data-266
36
36

data-82
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41
41

data-112
52
52
52
52
52
52
52
52

data-416
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95

data-197
48
48

data-306
20
20

data-704

data-646
96
96
96
96
96

data-474
95
95

data-514
96
96
96
96
96
96
96

data-109
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49
49

data-614
96
96
96
9

Unnamed: 0_level_0,data-265,data-266,data-82,data-82,data-82,data-82,data-82,data-82,data-82,data-82,...,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-677
Unnamed: 0_level_1,0,0,0,1,2,3,4,5,6,7,...,27,28,29,30,31,32,33,34,35,0
structure,"[[ 8.4260887 -15.00883609 25.99606715] Ce, ...","[[ 8.66869447 -14.69118861 25.44588559] Ce, ...","[[31.07260901 4.606796 1.52830052] Ce, [-1...","[[31.07383634 4.60613169 1.53167539] Ce, [-1...","[[31.07441165 4.6059324 1.53375224] Ce, [-1...","[[31.07590745 4.60546738 1.53712712] Ce, [-1...","[[31.07832375 4.6048695 1.54569412] Ce, [-1...","[[31.07529379 4.6057331 1.53738673] Ce, [-1...","[[31.07851552 4.60493593 1.54647294] Ce, [-1...","[[31.07751832 4.60520165 1.54413648] Ce, [-1...",...,"[[ 6.55608587 -19.29339764 33.41714539] Ce, ...","[[ 6.55654985 -19.29331556 33.41700323] Ce, ...","[[ 6.55729129 -19.29286414 33.41622134] Ce, ...","[[ 6.557526 -19.29298725 33.41643458] Ce, ...","[[ 6.55783477 -19.29286414 33.41622134] Ce, ...","[[ 6.55760966 -19.29306933 33.41657674] Ce, ...","[[ 6.55807816 -19.29294621 33.4163635 ] Ce, ...","[[ 6.55775057 -19.29323348 33.41686107] Ce, ...","[[ 6.55704278 -19.29417738 33.41849594] Ce, ...","[[129.42089212 100.66518652 100.37830195] Gd, ..."
energy,-310.40562,-309.712666,-338.222808,-338.22325,-338.223417,-338.223867,-338.224763,-338.22374,-338.224933,-338.224502,...,-304.002402,-304.002651,-304.002944,-304.003208,-304.003321,-304.00343,-304.0035,-304.003673,-304.003994,-804.483765
force,"[[-0.000923, -0.0, 0.0], [-0.006576, -0.0, 0.0...","[[-0.000695, -0.0, 0.0], [-0.006015, -0.0, 0.0...","[[0.002624, 0.002335, -0.001922], [0.001388, 0...","[[0.002842, 0.001887, -0.001008], [0.000174, 0...","[[0.001847, 0.002187, -0.000964], [0.000932, 0...","[[0.002877, 0.001515, -0.000186], [0.000588, 0...","[[0.002127, 0.002558, -0.000899], [0.00135, 0....","[[0.002158, 0.003015, -0.001437], [0.001113, 0...","[[0.001672, 0.002496, -0.001131], [0.001383, 0...","[[0.00168, 0.003124, -0.001362], [0.001648, 0....",...,"[[-0.000692, -0.0, -0.000421], [0.000465, 0.0,...","[[0.002161, 0.0, -0.00149], [-0.002283, -0.0, ...","[[-0.001264, -0.0, 0.00057], [0.000767, 0.0, 0...","[[0.002653, 0.0, -0.000657], [-0.00324, -0.0, ...","[[-0.001915, -0.0, 0.001426], [0.000841, 0.0, ...","[[0.0017, 0.0, -0.000906], [-0.002345, -0.0, -...","[[-0.000504, -0.0, 0.000679], [-0.000993, -0.0...","[[0.001228, 0.0, 0.000344], [-0.001735, -0.0, ...","[[0.00175, 0.0, 2e-06], [-0.00158, -0.0, -0.00...","[[-0.00302, 0.000127, 0.000921], [0.001227, 0...."


In [4]:
len(source_m3gnet_dataset)

3

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [18]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

In [20]:
m3gnet_dataset

Unnamed: 0_level_0,data-265,data-266,data-82,data-82,data-82,data-82,data-82,data-82,data-82,data-82,...,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-181,data-677
Unnamed: 0_level_1,0,0,0,1,2,3,4,5,6,7,...,27,28,29,30,31,32,33,34,35,0
structure,"[[ 8.4260887 -15.00883609 25.99606715] Ce, ...","[[ 8.66869447 -14.69118861 25.44588559] Ce, ...","[[31.07260901 4.606796 1.52830052] Ce, [-1...","[[31.07383634 4.60613169 1.53167539] Ce, [-1...","[[31.07441165 4.6059324 1.53375224] Ce, [-1...","[[31.07590745 4.60546738 1.53712712] Ce, [-1...","[[31.07832375 4.6048695 1.54569412] Ce, [-1...","[[31.07529379 4.6057331 1.53738673] Ce, [-1...","[[31.07851552 4.60493593 1.54647294] Ce, [-1...","[[31.07751832 4.60520165 1.54413648] Ce, [-1...",...,"[[ 6.55608587 -19.29339764 33.41714539] Ce, ...","[[ 6.55654985 -19.29331556 33.41700323] Ce, ...","[[ 6.55729129 -19.29286414 33.41622134] Ce, ...","[[ 6.557526 -19.29298725 33.41643458] Ce, ...","[[ 6.55783477 -19.29286414 33.41622134] Ce, ...","[[ 6.55760966 -19.29306933 33.41657674] Ce, ...","[[ 6.55807816 -19.29294621 33.4163635 ] Ce, ...","[[ 6.55775057 -19.29323348 33.41686107] Ce, ...","[[ 6.55704278 -19.29417738 33.41849594] Ce, ...","[[129.42089212 100.66518652 100.37830195] Gd, ..."
energy,-310.40562,-309.712666,-338.222808,-338.22325,-338.223417,-338.223867,-338.224763,-338.22374,-338.224933,-338.224502,...,-304.002402,-304.002651,-304.002944,-304.003208,-304.003321,-304.00343,-304.0035,-304.003673,-304.003994,-804.483765
force,"[[-0.000923, -0.0, 0.0], [-0.006576, -0.0, 0.0...","[[-0.000695, -0.0, 0.0], [-0.006015, -0.0, 0.0...","[[0.002624, 0.002335, -0.001922], [0.001388, 0...","[[0.002842, 0.001887, -0.001008], [0.000174, 0...","[[0.001847, 0.002187, -0.000964], [0.000932, 0...","[[0.002877, 0.001515, -0.000186], [0.000588, 0...","[[0.002127, 0.002558, -0.000899], [0.00135, 0....","[[0.002158, 0.003015, -0.001437], [0.001113, 0...","[[0.001672, 0.002496, -0.001131], [0.001383, 0...","[[0.00168, 0.003124, -0.001362], [0.001648, 0....",...,"[[-0.000692, -0.0, -0.000421], [0.000465, 0.0,...","[[0.002161, 0.0, -0.00149], [-0.002283, -0.0, ...","[[-0.001264, -0.0, 0.00057], [0.000767, 0.0, 0...","[[0.002653, 0.0, -0.000657], [-0.00324, -0.0, ...","[[-0.001915, -0.0, 0.001426], [0.000841, 0.0, ...","[[0.0017, 0.0, -0.000906], [-0.002345, -0.0, -...","[[-0.000504, -0.0, 0.000679], [-0.000993, -0.0...","[[0.001228, 0.0, 0.000344], [-0.001735, -0.0, ...","[[0.00175, 0.0, 2e-06], [-0.00158, -0.0, -0.00...","[[-0.00302, 0.000127, 0.000921], [0.001227, 0...."


### Splitting into train-validation-test sets

In [7]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 8812 samples to train, 2818 to evaluate, and 3174 to test


### Convert into graph database

In [8]:
for i in range(len(train_dataset.loc['force'].values)):
    train_dataset.loc['force'].values[i] = train_dataset.loc['force'].values[i].tolist()

In [9]:
for i in range(len(validation_dataset.loc['force'].values)):
    validation_dataset.loc['force'].values[i] = validation_dataset.loc['force'].values[i].tolist()

In [10]:
for i in range(len(test_dataset.loc['force'].values)):
    test_dataset.loc['force'].values[i] = test_dataset.loc['force'].values[i].tolist()

In [11]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    stresses = dataset.loc['stress'].values.tolist()
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    
    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

100%|██████████| 8812/8812 [01:41<00:00, 87.19it/s] 
100%|██████████| 2818/2818 [00:32<00:00, 86.39it/s] 
100%|██████████| 3174/3174 [00:36<00:00, 87.10it/s] 


In [12]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=2,
    num_workers=1,
)

# Retrain model

In [13]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [14]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
model_pretrained.save(model_save_path)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/M3GNet_finetuning

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


# Analyze metrics

In [15]:
# E_MAE = meV/atom, F_MAE eV/A, S_MAE GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE                nan
    test_Energy_RMSE                nan
     test_Force_MAE                 nan
     test_Force_RMSE                nan
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE                0.0
    test_Stress_RMSE                0.0
     test_Total_Loss                nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': nan,
  'test_Energy_MAE': nan,
  'test_Force_MAE': nan,
  'test_Stress_MAE': 0.0,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': nan,
  'test_Force_RMSE': nan,
  'test_Stress_RMSE': 0.0,
  'test_Site_Wise_RMSE': 0.0}]

In [None]:
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [45]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

(0.0135606033727526, 0.0874462649226188, 0.0)

In [46]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

(0.0, 0.0, 0.0)

# Cleanup the notebook

In [26]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')