In [1]:
import numpy             as np
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import torch
import json

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

In [2]:
device

device(type='cpu')

In [3]:
# Whether to include charge (which) or not
charged = 1

model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = f'finetuned_model-charge{charged}'

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Define batch size
batch_size = 64

# Stress weight for training
stress_weight = 0.7  # 0.7

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 180

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 0

# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

# Path to dataset, structured as:
# path_to_dataset
#     material_i
#         defect_i
#             simulation_i (containing vasprun.xml)

path_to_dataset = '/home/claudio/Desktop/BiSBr-example'
#path_to_dataset = '../../../Desktop/CeO2-data'

# Create and save as a dictionary
model_parameters = {
    'model_load_path':  model_load_path,
    'model_save_path':  model_save_path,
    'charged':          charged,
    'depth':            depth,
    'batch_size':       batch_size,
    'stress_weight':    stress_weight,
    'test_ratio':       test_ratio,
    'validation_ratio': validation_ratio,
    'max_epochs':       max_epochs,
    'lr':               lr,
    'path_to_dataset':  path_to_dataset,
}

# Write the dictionary to the file in JSON format
with open(f'{model_save_path}/model_parameters.json', 'w') as json_file:
    json.dump(model_parameters, json_file)

# Load simulation data

In [4]:
# Extract the data
source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
#source_m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
source_m3gnet_dataset


BiSBr
	as_1_Bi_on_S_0
	as_1_Bi_on_S_1
	as_1_S_on_Bi_-1
	as_1_S_on_Bi_3
	as_1_S_on_Bi_5
Error: vasprun not correctly loaded.
	inter_11_S_0
	inter_20_Br_0
	inter_21_Br_0
	vac_3_Br_-1
	vac_3_Br_0
	vac_3_Br_1


Unnamed: 0_level_0,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr
Unnamed: 0_level_1,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,...,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1
Unnamed: 0_level_2,BiSBr_as_1_Bi_on_S_0_0,BiSBr_as_1_Bi_on_S_0_1,BiSBr_as_1_Bi_on_S_0_2,BiSBr_as_1_Bi_on_S_0_3,BiSBr_as_1_Bi_on_S_0_4,BiSBr_as_1_Bi_on_S_0_5,BiSBr_as_1_Bi_on_S_0_6,BiSBr_as_1_Bi_on_S_0_7,BiSBr_as_1_Bi_on_S_0_8,BiSBr_as_1_Bi_on_S_0_9,...,BiSBr_vac_3_Br_0_21,BiSBr_vac_3_Br_0_22,BiSBr_vac_3_Br_1_0,BiSBr_vac_3_Br_1_1,BiSBr_vac_3_Br_1_2,BiSBr_vac_3_Br_1_3,BiSBr_vac_3_Br_1_4,BiSBr_vac_3_Br_1_5,BiSBr_vac_3_Br_1_6,BiSBr_vac_3_Br_1_7
structure,"[[2.97958563 2.32358318 5.42154828] Bi3+, [0.8...","[[2.98097652 2.32418305 5.42250786] Bi3+, [0.8...","[[2.9851492 2.32598249 5.4253864 ] Bi3+, [0.8...","[[2.99349468 2.32958154 5.43114348] Bi3+, [0.8...","[[2.98802546 2.32722299 5.42737065] Bi3+, [0.8...","[[2.98996737 2.33020912 5.42707484] Bi3+, [0.8...","[[2.9957932 2.33916767 5.4261875 ] Bi3+, [0.8...","[[2.99709059 2.34116287 5.42598992] Bi3+, [0.8...","[[2.99910495 2.34513122 5.42645495] Bi3+, [0.9...","[[3.00127234 2.34940087 5.4269553 ] Bi3+, [0.9...",...,"[[1.07415003 2.87169336 3.46336824] Bi3+, [ 1....","[[1.07431926 2.87205477 3.4636948 ] Bi3+, [ 1....","[[1.04687337 2.90783883 3.45503348] Bi2.99+, [...","[[1.04684978 2.9082431 3.45532443] Bi2.99+, [...","[[1.04677901 2.90945572 3.45619711] Bi2.99+, [...","[[1.04674732 2.9099992 3.45658829] Bi2.99+, [...","[[1.04670643 2.91015694 3.45653947] Bi2.99+, [...","[[1.04658389 2.91063 3.45639327] Bi2.99+, [...","[[1.04659514 2.9105865 3.45640669] Bi2.99+, [...","[[1.04654542 2.91023122 3.45631628] Bi2.99+, [..."
energy,-360.760545,-360.763442,-360.768975,-360.766246,-360.770153,-360.772518,-360.776215,-360.776379,-360.778366,-360.779611,...,-357.810327,-357.810449,-363.449331,-363.450317,-363.452037,-363.452212,-363.452635,-363.453106,-363.453114,-363.453248
force,"[[0.07207828, 0.0310842, 0.04972352], [-0.0055...","[[0.06486033, 0.03654152, 0.0371568], [0.00147...","[[0.03532649, 0.05397151, -0.00280999], [0.019...","[[-0.01159648, 0.08372168, -0.07333601], [0.05...","[[0.02034077, 0.06407088, -0.02805642], [0.032...","[[0.01561428, 0.05459537, -0.01997754], [0.030...","[[0.0070201, 0.02717547, 0.00706469], [0.02906...","[[0.00456369, 0.01994466, 0.01076315], [0.0294...","[[0.01311377, 0.0108779, 0.02011294], [0.02765...","[[0.02714164, -0.01028083, 0.03362521], [0.025...",...,"[[0.00035296, -0.00034773, 0.00183927], [0.000...","[[0.00258558, 0.00496723, 0.00346246], [-0.001...","[[-0.0012218, 0.02094739, 0.01507523], [-0.000...","[[-0.00113164, 0.01738979, 0.01086564], [-0.00...","[[-0.00093064, 0.00168066, 0.00237061], [-0.00...","[[-0.00079548, -0.00142413, -0.00552381], [-0....","[[-0.00087554, -0.00078263, -0.00148098], [-0....","[[-0.0004788, -0.00936268, -0.0014493], [-0.00...","[[-0.00066526, -0.00725528, -0.0013522], [-0.0...","[[-0.00032142, -0.00213019, -0.00111052], [-0...."
stress,"[[-2.1208515980000002, 0.010806431, -0.0143243...","[[-2.1110432200000004, 0.010463909, -0.0126552...","[[-2.0844480300000003, 0.009275477, -0.0078117...","[[-2.035588583, 0.007574436, 0.001602463000000...","[[-2.06438167, 0.008902324000000001, -0.004132...","[[-2.040863056, 0.008602743000000001, -0.00327...","[[-1.9787528900000002, 0.008207387, -0.0001403...","[[-1.965631174, 0.008343637, 0.000813288000000...","[[-1.9737114, 0.009291921, -0.0013492830000000...","[[-1.975341412, 0.010532467, -0.00380604800000...",...,"[[-1.501046807, 0.061856696, 0.017054780000000...","[[-1.497520691, 0.061590981, 0.01633381], [0.0...","[[-0.664278816, 0.0024805480000000004, 0.00107...","[[-0.6629477220000001, 0.0024513, 0.00109285],...","[[-0.6497951940000001, 0.002381179, 0.00115705...","[[-0.649715746, 0.002347, 0.001180398], [0.002...","[[-0.643747186, 0.002356341, 0.001179922], [0....","[[-0.578247272, 0.0023764560000000003, 0.00116...","[[-0.5848293400000001, 0.002371498, 0.00116914...","[[-0.586727318, 0.002406426, 0.001171773], [0...."
charge_state,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [5]:
len(source_m3gnet_dataset)

5

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [6]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)
m3gnet_dataset

Unnamed: 0_level_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,BiSBr_as_1_Bi_on_S_0,...,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1,BiSBr_vac_3_Br_1
Unnamed: 0_level_1,BiSBr_as_1_Bi_on_S_0_0,BiSBr_as_1_Bi_on_S_0_1,BiSBr_as_1_Bi_on_S_0_2,BiSBr_as_1_Bi_on_S_0_3,BiSBr_as_1_Bi_on_S_0_4,BiSBr_as_1_Bi_on_S_0_5,BiSBr_as_1_Bi_on_S_0_6,BiSBr_as_1_Bi_on_S_0_7,BiSBr_as_1_Bi_on_S_0_8,BiSBr_as_1_Bi_on_S_0_9,...,BiSBr_vac_3_Br_0_21,BiSBr_vac_3_Br_0_22,BiSBr_vac_3_Br_1_0,BiSBr_vac_3_Br_1_1,BiSBr_vac_3_Br_1_2,BiSBr_vac_3_Br_1_3,BiSBr_vac_3_Br_1_4,BiSBr_vac_3_Br_1_5,BiSBr_vac_3_Br_1_6,BiSBr_vac_3_Br_1_7
structure,"[[2.97958563 2.32358318 5.42154828] Bi3+, [0.8...","[[2.98097652 2.32418305 5.42250786] Bi3+, [0.8...","[[2.9851492 2.32598249 5.4253864 ] Bi3+, [0.8...","[[2.99349468 2.32958154 5.43114348] Bi3+, [0.8...","[[2.98802546 2.32722299 5.42737065] Bi3+, [0.8...","[[2.98996737 2.33020912 5.42707484] Bi3+, [0.8...","[[2.9957932 2.33916767 5.4261875 ] Bi3+, [0.8...","[[2.99709059 2.34116287 5.42598992] Bi3+, [0.8...","[[2.99910495 2.34513122 5.42645495] Bi3+, [0.9...","[[3.00127234 2.34940087 5.4269553 ] Bi3+, [0.9...",...,"[[1.07415003 2.87169336 3.46336824] Bi3+, [ 1....","[[1.07431926 2.87205477 3.4636948 ] Bi3+, [ 1....","[[1.04687337 2.90783883 3.45503348] Bi2.99+, [...","[[1.04684978 2.9082431 3.45532443] Bi2.99+, [...","[[1.04677901 2.90945572 3.45619711] Bi2.99+, [...","[[1.04674732 2.9099992 3.45658829] Bi2.99+, [...","[[1.04670643 2.91015694 3.45653947] Bi2.99+, [...","[[1.04658389 2.91063 3.45639327] Bi2.99+, [...","[[1.04659514 2.9105865 3.45640669] Bi2.99+, [...","[[1.04654542 2.91023122 3.45631628] Bi2.99+, [..."
energy,-360.760545,-360.763442,-360.768975,-360.766246,-360.770153,-360.772518,-360.776215,-360.776379,-360.778366,-360.779611,...,-357.810327,-357.810449,-363.449331,-363.450317,-363.452037,-363.452212,-363.452635,-363.453106,-363.453114,-363.453248
force,"[[0.07207828, 0.0310842, 0.04972352], [-0.0055...","[[0.06486033, 0.03654152, 0.0371568], [0.00147...","[[0.03532649, 0.05397151, -0.00280999], [0.019...","[[-0.01159648, 0.08372168, -0.07333601], [0.05...","[[0.02034077, 0.06407088, -0.02805642], [0.032...","[[0.01561428, 0.05459537, -0.01997754], [0.030...","[[0.0070201, 0.02717547, 0.00706469], [0.02906...","[[0.00456369, 0.01994466, 0.01076315], [0.0294...","[[0.01311377, 0.0108779, 0.02011294], [0.02765...","[[0.02714164, -0.01028083, 0.03362521], [0.025...",...,"[[0.00035296, -0.00034773, 0.00183927], [0.000...","[[0.00258558, 0.00496723, 0.00346246], [-0.001...","[[-0.0012218, 0.02094739, 0.01507523], [-0.000...","[[-0.00113164, 0.01738979, 0.01086564], [-0.00...","[[-0.00093064, 0.00168066, 0.00237061], [-0.00...","[[-0.00079548, -0.00142413, -0.00552381], [-0....","[[-0.00087554, -0.00078263, -0.00148098], [-0....","[[-0.0004788, -0.00936268, -0.0014493], [-0.00...","[[-0.00066526, -0.00725528, -0.0013522], [-0.0...","[[-0.00032142, -0.00213019, -0.00111052], [-0...."
stress,"[[-2.1208515980000002, 0.010806431, -0.0143243...","[[-2.1110432200000004, 0.010463909, -0.0126552...","[[-2.0844480300000003, 0.009275477, -0.0078117...","[[-2.035588583, 0.007574436, 0.001602463000000...","[[-2.06438167, 0.008902324000000001, -0.004132...","[[-2.040863056, 0.008602743000000001, -0.00327...","[[-1.9787528900000002, 0.008207387, -0.0001403...","[[-1.965631174, 0.008343637, 0.000813288000000...","[[-1.9737114, 0.009291921, -0.0013492830000000...","[[-1.975341412, 0.010532467, -0.00380604800000...",...,"[[-1.501046807, 0.061856696, 0.017054780000000...","[[-1.497520691, 0.061590981, 0.01633381], [0.0...","[[-0.664278816, 0.0024805480000000004, 0.00107...","[[-0.6629477220000001, 0.0024513, 0.00109285],...","[[-0.6497951940000001, 0.002381179, 0.00115705...","[[-0.649715746, 0.002347, 0.001180398], [0.002...","[[-0.643747186, 0.002356341, 0.001179922], [0....","[[-0.578247272, 0.0023764560000000003, 0.00116...","[[-0.5848293400000001, 0.002371498, 0.00116914...","[[-0.586727318, 0.002406426, 0.001171773], [0...."
charge_state,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


### Splitting into train-validation-test sets

In [7]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 78 samples to train, 43 to evaluate, and 43 to test


### Convert into graph database

In [8]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]

    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc['stress'].values.tolist()

    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

100%|██████████████████████████████████████████| 78/78 [00:00<00:00, 393.94it/s]
100%|██████████████████████████████████████████| 43/43 [00:00<00:00, 402.99it/s]
100%|██████████████████████████████████████████| 43/43 [00:00<00:00, 401.08it/s]


In [9]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
)

# Retrain model

In [10]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [11]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='cpu',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
lit_module_finetune.model.save(model_save_path)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Epoch 0: 100%|██████████████████████████| 2/2 [00:04<00:00,  0.42it/s, v_num=35]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:02<00:00,  0.42it/s][A
Epoch 1: 100%|█| 2/2 [00:06<00:00,  0.32it/s, v_num=35, val_Total_Loss=92.60, va[A
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:02<00:00,  0.44it/s][A
Epoch 2: 100%|█| 2/2 [00:06<00:00,  0.30it/s, v_num=35, val_Total_Loss=84.20, va[A
Validation: |                                             | 0/? [00:00<?, ?it/s

`Trainer.fit` stopped: `max_epochs=180` reached.


Epoch 179: 100%|█| 2/2 [00:09<00:00,  0.21it/s, v_num=35, val_Total_Loss=0.175, 


# Analyze metrics

In [12]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

Testing DataLoader 0: 100%|███████████████████████| 1/1 [00:02<00:00,  0.46it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE        0.0325804203748703
    test_Energy_RMSE       0.037385132163763046
     test_Force_MAE         0.22744520008563995
     test_Force_RMSE        0.35987406969070435
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE        0.5429314970970154
    test_Stress_RMSE        0.9654031991958618
     test_Total_Loss        0.7833092212677002
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 0.7833092212677002,
  'test_Energy_MAE': 0.0325804203748703,
  'test_Force_MAE': 0.22744520008563995,
  'test_Stress_MAE': 0.5429314970970154,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': 0.037385132163763046,
  'test_Force_RMSE': 0.35987406969070435,
  'test_Stress_RMSE': 0.9654031991958618,
  'test_Site_Wise_RMSE': 0.0}]

In [12]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

Testing: |                                        | 1/? [00:02<00:00,  0.39it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE       0.018371904268860817
    test_Energy_RMSE        0.02523837611079216
     test_Force_MAE         0.2208978682756424
     test_Force_RMSE        0.35096800327301025
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE        0.4628913700580597
    test_Stress_RMSE        0.8336446285247803
     test_Total_Loss        0.6102897524833679
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 0.6102897524833679,
  'test_Energy_MAE': 0.018371904268860817,
  'test_Force_MAE': 0.2208978682756424,
  'test_Stress_MAE': 0.4628913700580597,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': 0.02523837611079216,
  'test_Force_RMSE': 0.35096800327301025,
  'test_Stress_RMSE': 0.8336446285247803,
  'test_Site_Wise_RMSE': 0.0}]

In [13]:
import pandas as pd

In [14]:
current_version = 0
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

FileNotFoundError: [Errno 2] No such file or directory: 'logs/M3GNet_finetuning/version_0/metrics.csv'

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [None]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

In [None]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

# Cleanup the notebook

In [None]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', '*labels.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')