In [6]:
import numpy             as np
import pandas            as pd
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import glob

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

In [7]:
data_train_path = 'm3gnet_dataset.xlsx'
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Whether to include charge or not
charged = True

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Stress weight for training
stress_weight = 0

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 10

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 1

# Load simulation data

In [8]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

if os.path.exists(data_train_path):
    # Load data for model training
    m3gnet_dataset = pd.read_excel(data_train_path, index_col=0, header=[0,1,2])
else:
    # Path to dataset, structured as:
    # path_to_dataset
    #     material_i
    #         defect_i
    #             simulation_i (containing vasprun.xml)
    path_to_dataset = '../../../../Desktop/defects'

    # Extract the data
    source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
    #source_m3gnet_dataset.to_excel(data_train_path)

source_m3gnet_dataset


BiSeBr
	as_1_Bi_on_Se_-1
	as_1_Bi_on_Se_0
	as_1_Bi_on_Se_1
	as_1_Bi_on_Se_2
	as_1_Bi_on_Se_3
	as_1_Bi_on_Se_5
	as_1_Br_on_Bi_-1
	as_1_Br_on_Bi_-2
	as_1_Br_on_Bi_0
	as_1_Br_on_Bi_1
	as_1_Br_on_Bi_2
	as_1_Br_on_Bi_3
	as_1_Br_on_Bi_4
	as_1_Br_on_Bi_5
	as_1_Se_on_Bi_-1
	as_1_Se_on_Bi_-2
	as_1_Se_on_Bi_0
	as_1_Se_on_Bi_1
	as_1_Se_on_Bi_2
	as_1_Se_on_Bi_3
	as_1_Se_on_Bi_4
	as_1_Se_on_Bi_5
	as_2_Bi_on_Br_-1
	as_2_Bi_on_Br_-2
	as_2_Bi_on_Br_0
	as_2_Bi_on_Br_1
	as_2_Bi_on_Br_2
	as_2_Bi_on_Br_3
	as_2_Bi_on_Br_4
	as_2_Bi_on_Br_5
	as_2_Br_on_Se_-1
	as_2_Br_on_Se_-2
	as_2_Br_on_Se_0
	as_2_Br_on_Se_1
	as_2_Br_on_Se_2
	as_2_Br_on_Se_3
	as_2_Br_on_Se_4
	as_2_Br_on_Se_5
	as_2_Se_on_Br_-1
	as_2_Se_on_Br_0
	as_2_Se_on_Br_1
	as_2_Se_on_Br_2
	as_2_Se_on_Br_3
	supercell
	vac_1_Bi_-1
	vac_1_Bi_-2
	vac_1_Bi_-3
	vac_1_Bi_0
Error: vasprun not correctly loaded.
	vac_1_Bi_1
	vac_1_Bi_2
	vac_1_Bi_3
	vac_2_Se_-1
Error: vasprun not correctly loaded.
	vac_2_Se_-2
	vac_2_Se_0
	vac_2_Se_1
Error: vasprun not correctly 

Unnamed: 0_level_0,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,BiSeBr,...,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr
Unnamed: 0_level_1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,...,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1
Unnamed: 0_level_2,BiSeBr_as_1_Bi_on_Se_-1_0,BiSeBr_as_1_Bi_on_Se_-1_1,BiSeBr_as_1_Bi_on_Se_-1_2,BiSeBr_as_1_Bi_on_Se_-1_3,BiSeBr_as_1_Bi_on_Se_-1_4,BiSeBr_as_1_Bi_on_Se_-1_5,BiSeBr_as_1_Bi_on_Se_-1_6,BiSeBr_as_1_Bi_on_Se_-1_7,BiSeBr_as_1_Bi_on_Se_-1_8,BiSeBr_as_1_Bi_on_Se_-1_9,...,BiSBr_vac_3_Br_1_24,BiSBr_vac_3_Br_1_25,BiSBr_vac_3_Br_1_26,BiSBr_vac_3_Br_1_27,BiSBr_vac_3_Br_1_28,BiSBr_vac_3_Br_1_29,BiSBr_vac_3_Br_1_30,BiSBr_vac_3_Br_1_31,BiSBr_vac_3_Br_1_32,BiSBr_vac_3_Br_1_33
structure,"[[3.15313861 2.59447558 5.89513061] Bi3.01+, [...","[[3.15313534 2.58114513 5.89861221] Bi3.01+, [...","[[3.15313446 2.57774453 5.89950027] Bi3.01+, [...","[[3.15313253 2.57790839 5.88819456] Bi3.01+, [...","[[3.15312712 2.57836035 5.85701288] Bi3.01+, [...","[[3.15312367 2.57119505 5.84588567] Bi3.01+, [...","[[3.15312074 2.56532623 5.8367721 ] Bi3.01+, [...","[[3.15311544 2.55176861 5.83674393] Bi3.01+, [...","[[3.15311443 2.54933574 5.83673889] Bi3.01+, [...","[[3.1531085 2.54254466 5.83455571] Bi3.01+, [...",...,"[[1.09649871 3.33136373 3.76381899] Bi2.99+, [...","[[1.09648345 3.33145415 3.7638046 ] Bi2.99+, [...","[[1.09649467 3.33138827 3.76381519] Bi2.99+, [...","[[1.09648887 3.33142224 3.76380974] Bi2.99+, [...","[[1.09649177 3.33140525 3.76381241] Bi2.99+, [...","[[1.09649026 3.33141383 3.76381107] Bi2.99+, [...","[[1.09649101 3.33140988 3.76381169] Bi2.99+, [...","[[1.0958346 3.33199001 3.76284086] Bi2.99+, [...","[[1.09576499 3.33205143 3.76273804] Bi2.99+, [...","[[1.09508903 3.33194814 3.76179106] Bi2.99+, [..."
energy,-322.818221,-326.789146,-326.929496,-327.16552,-327.41749,-327.585568,-327.626605,-327.688237,-327.689669,-327.711614,...,-347.06941,-347.06941,-347.069411,-347.06941,-347.069411,-347.069411,-347.069411,-347.069592,-347.069593,-347.069728
force,"[[-6.872e-05, -0.27632283, 0.0721676], [0.4971...","[[-3.669e-05, -0.0440181, -0.1620237], [0.1171...","[[-3.582e-05, 0.01477313, -0.22599005], [0.034...","[[-1.688e-05, -0.0154473, -0.16731916], [0.000...","[[-1.921e-05, -0.09313938, 0.00081254], [-0.09...","[[-3.92e-05, -0.11631304, 0.01429217], [-0.067...","[[-5.123e-05, -0.13292314, 0.02322127], [-0.04...","[[-2.841e-05, -0.04953218, -0.01495643], [0.00...","[[-4.954e-05, -0.03669455, -0.02294653], [0.01...","[[-4.322e-05, -0.0178642, -0.0270642], [0.0179...",...,"[[-0.00476948, -0.00035319, -0.00736842], [-0....","[[-0.00497167, -0.0019363, -0.00785407], [-0.0...","[[-0.00514037, 0.00022598, -0.00803165], [-0.0...","[[-0.00477231, -0.00305617, -0.00765211], [-0....","[[-0.00526448, 0.00039504, -0.00818143], [-0.0...","[[-0.00495245, -0.00173867, -0.00788635], [-0....","[[-0.00517967, -4.506e-05, -0.00815423], [-0.0...","[[-0.00262996, -0.00362003, -0.00366045], [-0....","[[-0.00176465, -0.00481489, -0.00212935], [-0....","[[-2.587e-05, 0.00267735, 0.00092322], [0.0011..."
stress,"[[-1.2299856670000002, -4.5550000000000004e-06...","[[0.9128690100000001, 1.3227e-05, 8.3881e-05],...","[[1.3243452, 1.6388000000000003e-05, 7.469e-05...","[[1.315491035, 2.7492e-05, 6.8386e-05], [2.751...","[[1.141435605, 2.5294000000000002e-05, 4.77420...","[[0.6808166820000001, 2.6146e-05, 7.4944e-05],...","[[0.27998605, 2.8756000000000004e-05, 9.538e-0...","[[0.208272546, 3.0295000000000004e-05, 7.5592e...","[[0.195216261, 3.2144e-05, 7.3609e-05], [3.216...","[[0.24890751700000002, 3.253e-05, 6.7498000000...",...,"[[3.2947910190000003, 0.000152143, -0.00596234...","[[3.2939431730000006, 0.0002217, -0.005953362]...","[[3.294203065, 9.7631e-05, -0.00590036], [9.77...","[[3.293835067, 0.000125833, -0.005944812000000...","[[3.294353673, 0.00014495, -0.005944424], [0.0...","[[3.2939374960000003, 0.000133136, -0.00594735...","[[3.294100812, 0.000137383, -0.005949990000000...","[[3.2947419340000006, -0.000171238, -0.0040957...","[[3.294760212, -0.00020194100000000002, -0.003...","[[3.297168884, 0.0005292070000000001, -0.00134..."
charge_state,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [9]:
len(source_m3gnet_dataset)

5

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [10]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

In [11]:
m3gnet_dataset

Unnamed: 0_level_0,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,BiSeBr_as_1_Bi_on_Se_-1,...,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1
Unnamed: 0_level_1,BiSeBr_as_1_Bi_on_Se_-1_0,BiSeBr_as_1_Bi_on_Se_-1_1,BiSeBr_as_1_Bi_on_Se_-1_2,BiSeBr_as_1_Bi_on_Se_-1_3,BiSeBr_as_1_Bi_on_Se_-1_4,BiSeBr_as_1_Bi_on_Se_-1_5,BiSeBr_as_1_Bi_on_Se_-1_6,BiSeBr_as_1_Bi_on_Se_-1_7,BiSeBr_as_1_Bi_on_Se_-1_8,BiSeBr_as_1_Bi_on_Se_-1_9,...,BiSBr_vac_3_Br_1_24,BiSBr_vac_3_Br_1_25,BiSBr_vac_3_Br_1_26,BiSBr_vac_3_Br_1_27,BiSBr_vac_3_Br_1_28,BiSBr_vac_3_Br_1_29,BiSBr_vac_3_Br_1_30,BiSBr_vac_3_Br_1_31,BiSBr_vac_3_Br_1_32,BiSBr_vac_3_Br_1_33
structure,"[[3.15313861 2.59447558 5.89513061] Bi3.01+, [...","[[3.15313534 2.58114513 5.89861221] Bi3.01+, [...","[[3.15313446 2.57774453 5.89950027] Bi3.01+, [...","[[3.15313253 2.57790839 5.88819456] Bi3.01+, [...","[[3.15312712 2.57836035 5.85701288] Bi3.01+, [...","[[3.15312367 2.57119505 5.84588567] Bi3.01+, [...","[[3.15312074 2.56532623 5.8367721 ] Bi3.01+, [...","[[3.15311544 2.55176861 5.83674393] Bi3.01+, [...","[[3.15311443 2.54933574 5.83673889] Bi3.01+, [...","[[3.1531085 2.54254466 5.83455571] Bi3.01+, [...",...,"[[1.09649871 3.33136373 3.76381899] Bi2.99+, [...","[[1.09648345 3.33145415 3.7638046 ] Bi2.99+, [...","[[1.09649467 3.33138827 3.76381519] Bi2.99+, [...","[[1.09648887 3.33142224 3.76380974] Bi2.99+, [...","[[1.09649177 3.33140525 3.76381241] Bi2.99+, [...","[[1.09649026 3.33141383 3.76381107] Bi2.99+, [...","[[1.09649101 3.33140988 3.76381169] Bi2.99+, [...","[[1.0958346 3.33199001 3.76284086] Bi2.99+, [...","[[1.09576499 3.33205143 3.76273804] Bi2.99+, [...","[[1.09508903 3.33194814 3.76179106] Bi2.99+, [..."
energy,-322.818221,-326.789146,-326.929496,-327.16552,-327.41749,-327.585568,-327.626605,-327.688237,-327.689669,-327.711614,...,-347.06941,-347.06941,-347.069411,-347.06941,-347.069411,-347.069411,-347.069411,-347.069592,-347.069593,-347.069728
force,"[[-6.872e-05, -0.27632283, 0.0721676], [0.4971...","[[-3.669e-05, -0.0440181, -0.1620237], [0.1171...","[[-3.582e-05, 0.01477313, -0.22599005], [0.034...","[[-1.688e-05, -0.0154473, -0.16731916], [0.000...","[[-1.921e-05, -0.09313938, 0.00081254], [-0.09...","[[-3.92e-05, -0.11631304, 0.01429217], [-0.067...","[[-5.123e-05, -0.13292314, 0.02322127], [-0.04...","[[-2.841e-05, -0.04953218, -0.01495643], [0.00...","[[-4.954e-05, -0.03669455, -0.02294653], [0.01...","[[-4.322e-05, -0.0178642, -0.0270642], [0.0179...",...,"[[-0.00476948, -0.00035319, -0.00736842], [-0....","[[-0.00497167, -0.0019363, -0.00785407], [-0.0...","[[-0.00514037, 0.00022598, -0.00803165], [-0.0...","[[-0.00477231, -0.00305617, -0.00765211], [-0....","[[-0.00526448, 0.00039504, -0.00818143], [-0.0...","[[-0.00495245, -0.00173867, -0.00788635], [-0....","[[-0.00517967, -4.506e-05, -0.00815423], [-0.0...","[[-0.00262996, -0.00362003, -0.00366045], [-0....","[[-0.00176465, -0.00481489, -0.00212935], [-0....","[[-2.587e-05, 0.00267735, 0.00092322], [0.0011..."
stress,"[[-1.2299856670000002, -4.5550000000000004e-06...","[[0.9128690100000001, 1.3227e-05, 8.3881e-05],...","[[1.3243452, 1.6388000000000003e-05, 7.469e-05...","[[1.315491035, 2.7492e-05, 6.8386e-05], [2.751...","[[1.141435605, 2.5294000000000002e-05, 4.77420...","[[0.6808166820000001, 2.6146e-05, 7.4944e-05],...","[[0.27998605, 2.8756000000000004e-05, 9.538e-0...","[[0.208272546, 3.0295000000000004e-05, 7.5592e...","[[0.195216261, 3.2144e-05, 7.3609e-05], [3.216...","[[0.24890751700000002, 3.253e-05, 6.7498000000...",...,"[[3.2947910190000003, 0.000152143, -0.00596234...","[[3.2939431730000006, 0.0002217, -0.005953362]...","[[3.294203065, 9.7631e-05, -0.00590036], [9.77...","[[3.293835067, 0.000125833, -0.005944812000000...","[[3.294353673, 0.00014495, -0.005944424], [0.0...","[[3.2939374960000003, 0.000133136, -0.00594735...","[[3.294100812, 0.000137383, -0.005949990000000...","[[3.2947419340000006, -0.000171238, -0.0040957...","[[3.294760212, -0.00020194100000000002, -0.003...","[[3.297168884, 0.0005292070000000001, -0.00134..."
charge_state,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


### Splitting into train-validation-test sets

In [12]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 904 samples to train, 285 to evaluate, and 227 to test


### Convert into graph database

In [13]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': dataset.loc['stress'].values.tolist(),
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

100%|██████████| 904/904 [00:02<00:00, 402.58it/s]
100%|██████████| 285/285 [00:00<00:00, 406.23it/s]
100%|██████████| 227/227 [00:00<00:00, 410.03it/s]


In [14]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=2,
    num_workers=1,
)

# Retrain model

In [15]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [17]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
model_pretrained.save(model_save_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

DGLError: [17:35:42] /opt/dgl/src/runtime/c_runtime_api.cc:82: Check failed: allow_missing: Device API cuda is not enabled. Please install the cuda version of dgl.
Stack trace:
  [bt] (0) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x75) [0x7f8a7a93e8f5]
  [bt] (1) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPIManager::GetAPI(std::string, bool)+0x202) [0x7f8a7acada92]
  [bt] (2) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPI::Get(DGLContext, bool)+0x1e1) [0x7f8a7acaa071]
  [bt] (3) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DGLDataType, DGLContext)+0x13b) [0x7f8a7acc554b]
  [bt] (4) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::CopyTo(DGLContext const&) const+0xc3) [0x7f8a7acffd53]
  [bt] (5) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::UnitGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0x3ff) [0x7f8a7ae0d24f]
  [bt] (6) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(dgl::HeteroGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0xf6) [0x7f8a7ad0c5d6]
  [bt] (7) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(+0x51b396) [0x7f8a7ad1b396]
  [bt] (8) /home/claudio/.local/lib/python3.10/site-packages/dgl/libdgl.so(DGLFuncCall+0x48) [0x7f8a7aca92a8]



# Analyze metrics

In [13]:
# E_MAE = meV/atom, F_MAE eV/A, S_MAE GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE        0.1421414613723755
    test_Energy_RMSE        0.1428816020488739
     test_Force_MAE         0.04972851276397705
     test_Force_RMSE        0.08818497508764267
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE                0.0
    test_Stress_RMSE                0.0
     test_Total_Loss         0.313906729221344
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 0.313906729221344,
  'test_Energy_MAE': 0.1421414613723755,
  'test_Force_MAE': 0.04972851276397705,
  'test_Stress_MAE': 0.0,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': 0.1428816020488739,
  'test_Force_RMSE': 0.08818497508764267,
  'test_Stress_RMSE': 0.0,
  'test_Site_Wise_RMSE': 0.0}]

In [None]:
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [45]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

(0.0135606033727526, 0.0874462649226188, 0.0)

In [46]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

(0.0, 0.0, 0.0)

# Cleanup the notebook

In [26]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')