In [1]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import glob
import torch

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
device

device(type='cuda')

In [3]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Whether to include charge or not
charged = True

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Define batch size
batch_size = 128

# Stress weight for training
stress_weight = 0.7

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 10

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 1

# Load simulation data

In [4]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

if os.path.exists(data_train_path):
    # Load data for model training
    m3gnet_dataset = pd.read_excel(data_train_path, index_col=0, header=[0,1,2])
else:
    # Path to dataset, structured as:
    # path_to_dataset
    #     material_i
    #         defect_i
    #             simulation_i (containing vasprun.xml)
    path_to_dataset = '../../../../Desktop/defects/gamma'
    #path_to_dataset = '../../../Desktop/CeO2-data'

    # Extract the data
    source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
    #source_m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
    #source_m3gnet_dataset.to_excel(data_train_path)

source_m3gnet_dataset


BiSI
	as_1_Bi_on_S_-1_Bond_Distortion_-20.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Bond_Distortion_-40.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Bond_Distortion_-60.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Bond_Distortion_0.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Bond_Distortion_20.0%
	as_1_Bi_on_S_-1_Bond_Distortion_40.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Bond_Distortion_60.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-1_Unperturbed
	as_1_Bi_on_S_-2_Bond_Distortion_-20.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-2_Bond_Distortion_-40.0%
Error: vasprun not correctly loaded.
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-2_Bond_Distortion_-60.0%
Error: vasprun not correctly loaded.
	as_1_Bi_on_S_-2_Bond_Distortion_0.0%
	as_1_Bi_on_S_-2_Bond_Distortion_20.0%
	as_1_Bi_on_S_-2_Bond_Distortion_40.0%
	as_1_Bi_on_S_-2_Bond_Distortion_60.0%
Error: vasprun not correctly loaded.
	as_1_Bi

Unnamed: 0_level_0,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI
Unnamed: 0_level_1,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,...,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed
Unnamed: 0_level_2,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_0,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_1,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_2,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_3,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_4,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_5,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_6,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_7,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_8,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_9,...,BiSI_vac_3_I_0_Unperturbed_290,BiSI_vac_3_I_0_Unperturbed_291,BiSI_vac_3_I_0_Unperturbed_292,BiSI_vac_3_I_0_Unperturbed_293,BiSI_vac_3_I_0_Unperturbed_294,BiSI_vac_3_I_0_Unperturbed_295,BiSI_vac_3_I_0_Unperturbed_296,BiSI_vac_3_I_0_Unperturbed_297,BiSI_vac_3_I_0_Unperturbed_298,BiSI_vac_3_I_0_Unperturbed_299
structure,"[[3.15281399 2.93861474 5.6361836 ] Bi2.99+, [...",[[ 9.21776379 -15.41093792 0.74145347] Bi2....,[[ 1.31199571 -0.38017066 -1.53905415] Bi2.99+...,[[ 1.08225314 0.05662526 -1.60532598] Bi2.99+...,[[ 1.05513081 0.10819123 -1.61314978] Bi2.99+...,[[ 1.02611907 0.16334963 -1.62151854] Bi2.99+...,[[ 1.05887366 0.34036609 -1.02392546] Bi2.99+...,[[ 1.03679083 0.22102375 -1.4268156 ] Bi2.99+...,[[ 1.07817797 0.29046901 -1.47346213] Bi2.99+...,[[ 1.08404869 0.30031975 -1.4800788 ] Bi2.99+...,...,"[[1.05106711 3.14353984 3.83122885] Bi3+, [ 1....","[[1.05106724 3.14354053 3.83122885] Bi3+, [ 1....","[[1.05106761 3.14354156 3.83122895] Bi3+, [ 1....","[[1.05106774 3.14354208 3.83122895] Bi3+, [ 1....","[[1.05106799 3.14354293 3.83122895] Bi3+, [ 1....","[[1.05106824 3.14354413 3.83122905] Bi3+, [ 1....","[[1.05106837 3.14354465 3.83122916] Bi3+, [ 1....","[[1.05106875 3.14354585 3.83122916] Bi3+, [ 1....","[[1.05106887 3.14354619 3.83122916] Bi3+, [ 1....","[[1.051069 3.14354671 3.83122916] Bi3+, [ 1...."
energy,149.211553,877.41682,-268.153812,-278.289435,-278.495802,-278.582394,-219.075788,-300.079055,-309.013146,-309.278513,...,-332.607569,-332.607569,-332.607569,-332.607569,-332.607569,-332.60757,-332.60757,-332.60757,-332.60757,-332.60757
force,"[[387.1292446, -736.02583193, 111.67178315], [...","[[-1153.95790902, 4473.63317809, 1255.636439],...","[[-43.69594355, 26.84991863, -9.04688045], [-2...","[[-4.62140365, 8.11960207, 12.80144524], [-4.3...","[[-2.48935461, 7.13968974, 13.18109357], [-4.5...","[[-0.42931824, 6.25335171, 13.23984493], [-4.7...","[[3.34154402, -3.31261524, -8.32649688], [-1.8...","[[0.99440821, 1.3237105, -2.91348007], [-3.675...","[[1.14493218, 0.4267721, -1.78019771], [-2.295...","[[1.11468214, 0.32782807, -1.60578552], [-2.12...",...,"[[0.00109285, 0.00304389, 0.00057261], [-7.35e...","[[0.00120373, 0.00310365, 0.00051866], [2.976e...","[[0.0008785, 0.0029228, 0.00033936], [-0.00027...","[[0.00075588, 0.0029523, 0.00022033], [-0.0003...","[[0.00059159, 0.00308264, 0.000194], [-0.00056...","[[0.00098029, 0.00211477, -0.00046483], [-0.00...","[[0.00125204, 0.00205222, -0.0003113], [0.0001...","[[0.00124688, 0.00276692, -0.00022363], [0.000...","[[0.00125325, 0.00280768, -5.719e-05], [0.0001...","[[0.00094215, 0.00259488, 6.043e-05], [-0.0002..."
stress,"[[-33.94113915, 0.042448057000000004, -23.0759...","[[-19.111646372000003, 29.831243862999997, 11....","[[-13.478182424, 4.5043607020000005, -3.211689...","[[-8.5870026, 2.044275746, 0.373448098], [2.04...","[[-8.347976159, 1.8382724880000003, 0.69312253...","[[-8.153794291, 1.6371626, 1.0160740430000001]...","[[-17.456272801, 13.613489110000002, 0.1421639...","[[-6.068613831, 1.415885776, -0.75968938000000...","[[-4.043314921, 0.6384717310000001, -0.0959284...","[[-3.9506473389999996, 0.590466944, -0.0769174...",...,"[[0.5176018880000001, -8.1931e-05, 0.003965123...","[[0.5163720020000001, 0.001028801, 0.004027142...","[[0.5203135879999999, -0.000296831, 0.00213702...","[[0.5208997110000001, -0.00039039600000000005,...","[[0.520801838, -0.000372522, 0.00350896], [-0....","[[0.5203314490000001, -0.00020664400000000002,...","[[0.518850041, -0.0014548190000000002, 0.00319...","[[0.515232405, 0.001054742, 0.005106482], [0.0...","[[0.515235369, 0.0010547430000000001, 0.005107...","[[0.513089892, 0.000870935, 0.005695736], [0.0..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [5]:
len(source_m3gnet_dataset)

5

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [6]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

In [7]:
m3gnet_dataset

Unnamed: 0_level_0,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%,...,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed,BiSI_vac_3_I_0_Unperturbed
Unnamed: 0_level_1,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_0,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_1,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_2,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_3,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_4,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_5,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_6,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_7,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_8,BiSI_as_1_Bi_on_S_-1_Bond_Distortion_-60.0%_9,...,BiSI_vac_3_I_0_Unperturbed_290,BiSI_vac_3_I_0_Unperturbed_291,BiSI_vac_3_I_0_Unperturbed_292,BiSI_vac_3_I_0_Unperturbed_293,BiSI_vac_3_I_0_Unperturbed_294,BiSI_vac_3_I_0_Unperturbed_295,BiSI_vac_3_I_0_Unperturbed_296,BiSI_vac_3_I_0_Unperturbed_297,BiSI_vac_3_I_0_Unperturbed_298,BiSI_vac_3_I_0_Unperturbed_299
structure,"[[3.15281399 2.93861474 5.6361836 ] Bi2.99+, [...",[[ 9.21776379 -15.41093792 0.74145347] Bi2....,[[ 1.31199571 -0.38017066 -1.53905415] Bi2.99+...,[[ 1.08225314 0.05662526 -1.60532598] Bi2.99+...,[[ 1.05513081 0.10819123 -1.61314978] Bi2.99+...,[[ 1.02611907 0.16334963 -1.62151854] Bi2.99+...,[[ 1.05887366 0.34036609 -1.02392546] Bi2.99+...,[[ 1.03679083 0.22102375 -1.4268156 ] Bi2.99+...,[[ 1.07817797 0.29046901 -1.47346213] Bi2.99+...,[[ 1.08404869 0.30031975 -1.4800788 ] Bi2.99+...,...,"[[1.05106711 3.14353984 3.83122885] Bi3+, [ 1....","[[1.05106724 3.14354053 3.83122885] Bi3+, [ 1....","[[1.05106761 3.14354156 3.83122895] Bi3+, [ 1....","[[1.05106774 3.14354208 3.83122895] Bi3+, [ 1....","[[1.05106799 3.14354293 3.83122895] Bi3+, [ 1....","[[1.05106824 3.14354413 3.83122905] Bi3+, [ 1....","[[1.05106837 3.14354465 3.83122916] Bi3+, [ 1....","[[1.05106875 3.14354585 3.83122916] Bi3+, [ 1....","[[1.05106887 3.14354619 3.83122916] Bi3+, [ 1....","[[1.051069 3.14354671 3.83122916] Bi3+, [ 1...."
energy,149.211553,877.41682,-268.153812,-278.289435,-278.495802,-278.582394,-219.075788,-300.079055,-309.013146,-309.278513,...,-332.607569,-332.607569,-332.607569,-332.607569,-332.607569,-332.60757,-332.60757,-332.60757,-332.60757,-332.60757
force,"[[387.1292446, -736.02583193, 111.67178315], [...","[[-1153.95790902, 4473.63317809, 1255.636439],...","[[-43.69594355, 26.84991863, -9.04688045], [-2...","[[-4.62140365, 8.11960207, 12.80144524], [-4.3...","[[-2.48935461, 7.13968974, 13.18109357], [-4.5...","[[-0.42931824, 6.25335171, 13.23984493], [-4.7...","[[3.34154402, -3.31261524, -8.32649688], [-1.8...","[[0.99440821, 1.3237105, -2.91348007], [-3.675...","[[1.14493218, 0.4267721, -1.78019771], [-2.295...","[[1.11468214, 0.32782807, -1.60578552], [-2.12...",...,"[[0.00109285, 0.00304389, 0.00057261], [-7.35e...","[[0.00120373, 0.00310365, 0.00051866], [2.976e...","[[0.0008785, 0.0029228, 0.00033936], [-0.00027...","[[0.00075588, 0.0029523, 0.00022033], [-0.0003...","[[0.00059159, 0.00308264, 0.000194], [-0.00056...","[[0.00098029, 0.00211477, -0.00046483], [-0.00...","[[0.00125204, 0.00205222, -0.0003113], [0.0001...","[[0.00124688, 0.00276692, -0.00022363], [0.000...","[[0.00125325, 0.00280768, -5.719e-05], [0.0001...","[[0.00094215, 0.00259488, 6.043e-05], [-0.0002..."
stress,"[[-33.94113915, 0.042448057000000004, -23.0759...","[[-19.111646372000003, 29.831243862999997, 11....","[[-13.478182424, 4.5043607020000005, -3.211689...","[[-8.5870026, 2.044275746, 0.373448098], [2.04...","[[-8.347976159, 1.8382724880000003, 0.69312253...","[[-8.153794291, 1.6371626, 1.0160740430000001]...","[[-17.456272801, 13.613489110000002, 0.1421639...","[[-6.068613831, 1.415885776, -0.75968938000000...","[[-4.043314921, 0.6384717310000001, -0.0959284...","[[-3.9506473389999996, 0.590466944, -0.0769174...",...,"[[0.5176018880000001, -8.1931e-05, 0.003965123...","[[0.5163720020000001, 0.001028801, 0.004027142...","[[0.5203135879999999, -0.000296831, 0.00213702...","[[0.5208997110000001, -0.00039039600000000005,...","[[0.520801838, -0.000372522, 0.00350896], [-0....","[[0.5203314490000001, -0.00020664400000000002,...","[[0.518850041, -0.0014548190000000002, 0.00319...","[[0.515232405, 0.001054742, 0.005106482], [0.0...","[[0.515235369, 0.0010547430000000001, 0.005107...","[[0.513089892, 0.000870935, 0.005695736], [0.0..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Splitting into train-validation-test sets

In [8]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 43222 samples to train, 12612 to evaluate, and 14769 to test


### Convert into graph database

In [9]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    #for j in range(len(dataset.loc['force'].values)):
    #    dataset.loc['force'].values[j] = dataset.loc['force'].values.tolist()[j].tolist()
    
    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc['stress'].values.tolist()

    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

100%|██████████| 43222/43222 [02:00<00:00, 360.10it/s]
100%|██████████| 12612/12612 [00:32<00:00, 383.18it/s]
100%|██████████| 14769/14769 [00:38<00:00, 382.53it/s]


In [10]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
)

# Retrain model

In [11]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [None]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     logger=logger,
                     accelerator='cpu',
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
model_pretrained.save(model_save_path)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:03<00:03,  0.28it/s]

# Analyze metrics

In [None]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

In [None]:
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [45]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

(0.0135606033727526, 0.0874462649226188, 0.0)

In [46]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

(0.0, 0.0, 0.0)

# Cleanup the notebook

In [14]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', 'labels*.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')