In [1]:
import numpy             as np
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import torch
import json

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
device

device(type='cpu')

In [3]:
# Whether to include charge (which) or not
charged = 1

model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = f'finetuned_model-charge{charged}'

# Make folder if not ready
if not os.path.exists(model_save_path):
    os.mkdir(model_save_path)

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Define batch size
batch_size = 64

# Stress weight for training
stress_weight = 0.7  # 0.7

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 1000

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 0

# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

# Path to dataset, structured as:
# path_to_dataset
#     material_i
#         defect_i
#             simulation_i (containing vasprun.xml)

path_to_dataset = '/home/cibran/Desktop/defects/no-pressure/HSE06+D3+SOC'
#path_to_dataset = '../../../Desktop/CeO2-data'

# Create and save as a dictionary
model_parameters = {
    'model_load_path':  model_load_path,
    'model_save_path':  model_save_path,
    'charged':          charged,
    'depth':            depth,
    'batch_size':       batch_size,
    'stress_weight':    stress_weight,
    'test_ratio':       test_ratio,
    'validation_ratio': validation_ratio,
    'max_epochs':       max_epochs,
    'lr':               lr,
    'path_to_dataset':  path_to_dataset,
}

# Write the dictionary to the file in JSON format
with open(f'{model_save_path}/model_parameters.json', 'w') as json_file:
    json.dump(model_parameters, json_file)

# Load simulation data

In [4]:
# Extract the data
m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
#m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
m3gnet_dataset


BiSeI
	as_1_Bi_on_Se_-1
	as_1_Bi_on_Se_-2
	as_1_Bi_on_Se_0
	as_1_Bi_on_Se_1
	as_1_Bi_on_Se_2
	as_1_Bi_on_Se_3
	as_1_Bi_on_Se_4
	as_1_Bi_on_Se_5
	as_1_I_on_Bi_-1
	as_1_I_on_Bi_-2
	as_1_I_on_Bi_0
	as_1_I_on_Bi_1
	as_1_I_on_Bi_2
	as_1_I_on_Bi_3
	as_1_I_on_Bi_4
	as_1_I_on_Bi_5
	as_1_Se_on_Bi_-1
	as_1_Se_on_Bi_-2
	as_1_Se_on_Bi_0
	as_1_Se_on_Bi_1
	as_1_Se_on_Bi_2
	as_1_Se_on_Bi_3
	as_1_Se_on_Bi_4
	as_1_Se_on_Bi_5
	as_2_Bi_on_I_-1
	as_2_Bi_on_I_-2
	as_2_Bi_on_I_0
	as_2_Bi_on_I_1
	as_2_Bi_on_I_2
	as_2_Bi_on_I_3
	as_2_Bi_on_I_4
	as_2_Bi_on_I_5
	as_2_I_on_Se_-1
	as_2_I_on_Se_-2
	as_2_I_on_Se_0
	as_2_I_on_Se_1
	as_2_I_on_Se_2
	as_2_I_on_Se_3
	as_2_I_on_Se_4
	as_2_I_on_Se_5
	as_2_Se_on_I_-1
	as_2_Se_on_I_-2
	as_2_Se_on_I_0
	as_2_Se_on_I_1
	as_2_Se_on_I_2
	as_2_Se_on_I_3
	as_2_Se_on_I_4
	as_2_Se_on_I_5
	inter_10_Se_0
	inter_11_Se_0
	inter_12_Se_0
	inter_13_Se_0
	inter_14_Se_0
	inter_15_Se_0
	inter_16_Se_0
	inter_17_I_0
	inter_18_I_0
	inter_19_I_0
	inter_1_Bi_0
	inter_20_I_0
	inter_21_I_0
	inter_2

Unnamed: 0_level_0,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,BiSeI,...,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI,BiSI
Unnamed: 0_level_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,...,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_2_S_1,BiSI_vac_3_I_1
Unnamed: 0_level_2,BiSeI_as_1_Bi_on_Se_1_0,BiSeI_as_1_Bi_on_Se_1_1,BiSeI_as_1_Bi_on_Se_1_2,BiSeI_as_1_Bi_on_Se_1_3,BiSeI_as_1_Bi_on_Se_1_4,BiSeI_as_1_Bi_on_Se_1_5,BiSeI_as_1_Bi_on_Se_1_6,BiSeI_as_1_Bi_on_Se_1_7,BiSeI_as_1_Bi_on_Se_1_8,BiSeI_as_1_Bi_on_Se_1_9,...,BiSI_vac_1_Bi_1_3,BiSI_vac_1_Bi_1_4,BiSI_vac_1_Bi_1_5,BiSI_vac_1_Bi_1_6,BiSI_vac_1_Bi_1_7,BiSI_vac_1_Bi_1_8,BiSI_vac_1_Bi_1_9,BiSI_vac_1_Bi_1_10,BiSI_vac_2_S_1_0,BiSI_vac_3_I_1_0
structure,"[[3.63682019 2.6176977 6.48340485] Bi, [0.877...","[[3.63660081 2.61764926 6.4834071 ] Bi, [0.877...","[[3.63594253 2.61750377 6.48341364] Bi, [0.877...","[[3.63462609 2.61721295 6.48342672] Bi, [0.877...","[[3.63435953 2.61726374 6.48373336] Bi, [0.878...","[[3.63355969 2.61741629 6.48465326] Bi, [0.878...","[[3.63386383 2.61771939 6.48574571] Bi, [0.878...","[[3.63463379 2.61848664 6.48851095] Bi, [0.878...","[[3.63509487 2.61912087 6.48991399] Bi, [0.878...","[[3.62875171 2.61727965 6.47986428] Bi, [0.880...",...,"[[ 1.06528977 11.81172278 3.84842205] Bi, [5....","[[ 1.06527388 11.81202734 3.84834442] Bi, [5....","[[ 1.06526468 11.81218623 3.84830524] Bi, [5....","[[ 1.065245 11.81249473 3.84824849] Bi, [5....","[[ 1.06529179 11.81175778 3.84836642] Bi, [5....","[[ 1.06521095 11.81299816 3.84820139] Bi, [5....","[[ 1.06517186 11.81356455 3.84817353] Bi, [5....","[[ 1.06475368 11.8154952 3.85020289] Bi, [5....","[[0.8136422 3.18635783 3.70505612] Bi, [ 1.05...","[[1.10053179 3.23330857 3.84593827] Bi, [ 1.08..."
energy,-339.39218,-339.392222,-339.392324,-339.392402,-339.39247,-339.392565,-339.392643,-339.392727,-339.392791,-339.390985,...,-348.707182,-348.707324,-348.707394,-348.707514,-348.707195,-348.707701,-348.707865,-348.70831,-349.834732,-353.004082
force,"[[-0.01136962, -0.00251233, 0.00011303], [0.00...","[[-0.00801564, -7.571e-05, 0.00073892], [0.003...","[[-0.0061068, 6.037e-05, 0.00327689], [0.00313...","[[0.00212885, 0.00302581, 0.00705598], [0.0028...","[[0.00330602, 0.00263392, 0.00832432], [0.0018...","[[0.00821323, 0.00368859, 0.01154164], [-0.002...","[[0.00701172, 0.00366979, 0.01088983], [-0.002...","[[0.00401553, 0.00753984, 0.00929694], [-0.004...","[[0.00219873, 0.00648131, 0.00462846], [-0.001...","[[0.00762129, -0.01721987, 0.00596853], [0.002...",...,"[[-0.00017897, 0.00277865, -0.0010999], [-0.00...","[[-0.00019793, 0.00291401, -0.00079171], [-0.0...","[[-0.00021495, 0.0028123, -0.00035399], [-0.00...","[[-0.00018897, 0.00213378, 0.0004285], [-0.002...","[[-0.00028409, 0.00499442, -0.00145247], [-0.0...","[[-0.00026392, 0.0030714, 0.00108771], [-0.003...","[[-0.00031494, 0.00277768, 0.00226798], [-0.00...","[[-4.304e-05, 0.00230257, 0.00234782], [0.0017...","[[0.00333455, 0.00310133, 0.00020627], [0.0006...","[[-0.00027407, -0.00315128, -0.00408567], [0.0..."
stress,"[[1.5709493380000001, 0.009596524, 0.033147923...","[[1.571162329, 0.009614053000000001, 0.0332547...","[[1.571299285, 0.009760056000000001, 0.0335096...","[[1.572271836, 0.009972819, 0.034018356], [0.0...","[[1.5719516860000002, 0.009686527, 0.034541271...","[[1.5661166990000002, 0.008816085, 0.036110992...","[[1.566884794, 0.008378809, 0.0360777650000000...","[[1.565592826, 0.007277780000000001, 0.0361695...","[[1.5647386900000002, 0.007094567, 0.035732678...","[[1.58261107, 0.012235853, 0.029106107], [0.01...",...,"[[1.07809726, 0.0013952130000000002, 0.0025432...","[[1.073682773, 0.0013660270000000001, 0.002495...","[[1.072408514, 0.0013437920000000001, 0.002477...","[[1.071850834, 0.0013084300000000002, 0.002425...","[[1.0730329660000002, 0.0014174590000000003, 0...","[[1.071797598, 0.001234211, 0.002375217], [0.0...","[[1.071568483, 0.001139871, 0.0023203560000000...","[[1.090053006, 0.000189752, 0.0014052780000000...","[[0.7749622810000001, -1.4175e-05, 0.000193823...","[[1.0742299210000001, 0.0015558500000000001, 0..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [5]:
# Clone (copy) the DataFrame
#m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)
m3gnet_dataset

Unnamed: 0_level_0,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,BiSeI_as_1_Bi_on_Se_1,...,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_1_Bi_1,BiSI_vac_2_S_1,BiSI_vac_3_I_1
Unnamed: 0_level_1,BiSeI_as_1_Bi_on_Se_1_0,BiSeI_as_1_Bi_on_Se_1_1,BiSeI_as_1_Bi_on_Se_1_2,BiSeI_as_1_Bi_on_Se_1_3,BiSeI_as_1_Bi_on_Se_1_4,BiSeI_as_1_Bi_on_Se_1_5,BiSeI_as_1_Bi_on_Se_1_6,BiSeI_as_1_Bi_on_Se_1_7,BiSeI_as_1_Bi_on_Se_1_8,BiSeI_as_1_Bi_on_Se_1_9,...,BiSI_vac_1_Bi_1_3,BiSI_vac_1_Bi_1_4,BiSI_vac_1_Bi_1_5,BiSI_vac_1_Bi_1_6,BiSI_vac_1_Bi_1_7,BiSI_vac_1_Bi_1_8,BiSI_vac_1_Bi_1_9,BiSI_vac_1_Bi_1_10,BiSI_vac_2_S_1_0,BiSI_vac_3_I_1_0
structure,"[[3.63682019 2.6176977 6.48340485] Bi, [0.877...","[[3.63660081 2.61764926 6.4834071 ] Bi, [0.877...","[[3.63594253 2.61750377 6.48341364] Bi, [0.877...","[[3.63462609 2.61721295 6.48342672] Bi, [0.877...","[[3.63435953 2.61726374 6.48373336] Bi, [0.878...","[[3.63355969 2.61741629 6.48465326] Bi, [0.878...","[[3.63386383 2.61771939 6.48574571] Bi, [0.878...","[[3.63463379 2.61848664 6.48851095] Bi, [0.878...","[[3.63509487 2.61912087 6.48991399] Bi, [0.878...","[[3.62875171 2.61727965 6.47986428] Bi, [0.880...",...,"[[ 1.06528977 11.81172278 3.84842205] Bi, [5....","[[ 1.06527388 11.81202734 3.84834442] Bi, [5....","[[ 1.06526468 11.81218623 3.84830524] Bi, [5....","[[ 1.065245 11.81249473 3.84824849] Bi, [5....","[[ 1.06529179 11.81175778 3.84836642] Bi, [5....","[[ 1.06521095 11.81299816 3.84820139] Bi, [5....","[[ 1.06517186 11.81356455 3.84817353] Bi, [5....","[[ 1.06475368 11.8154952 3.85020289] Bi, [5....","[[0.8136422 3.18635783 3.70505612] Bi, [ 1.05...","[[1.10053179 3.23330857 3.84593827] Bi, [ 1.08..."
energy,-339.39218,-339.392222,-339.392324,-339.392402,-339.39247,-339.392565,-339.392643,-339.392727,-339.392791,-339.390985,...,-348.707182,-348.707324,-348.707394,-348.707514,-348.707195,-348.707701,-348.707865,-348.70831,-349.834732,-353.004082
force,"[[-0.01136962, -0.00251233, 0.00011303], [0.00...","[[-0.00801564, -7.571e-05, 0.00073892], [0.003...","[[-0.0061068, 6.037e-05, 0.00327689], [0.00313...","[[0.00212885, 0.00302581, 0.00705598], [0.0028...","[[0.00330602, 0.00263392, 0.00832432], [0.0018...","[[0.00821323, 0.00368859, 0.01154164], [-0.002...","[[0.00701172, 0.00366979, 0.01088983], [-0.002...","[[0.00401553, 0.00753984, 0.00929694], [-0.004...","[[0.00219873, 0.00648131, 0.00462846], [-0.001...","[[0.00762129, -0.01721987, 0.00596853], [0.002...",...,"[[-0.00017897, 0.00277865, -0.0010999], [-0.00...","[[-0.00019793, 0.00291401, -0.00079171], [-0.0...","[[-0.00021495, 0.0028123, -0.00035399], [-0.00...","[[-0.00018897, 0.00213378, 0.0004285], [-0.002...","[[-0.00028409, 0.00499442, -0.00145247], [-0.0...","[[-0.00026392, 0.0030714, 0.00108771], [-0.003...","[[-0.00031494, 0.00277768, 0.00226798], [-0.00...","[[-4.304e-05, 0.00230257, 0.00234782], [0.0017...","[[0.00333455, 0.00310133, 0.00020627], [0.0006...","[[-0.00027407, -0.00315128, -0.00408567], [0.0..."
stress,"[[1.5709493380000001, 0.009596524, 0.033147923...","[[1.571162329, 0.009614053000000001, 0.0332547...","[[1.571299285, 0.009760056000000001, 0.0335096...","[[1.572271836, 0.009972819, 0.034018356], [0.0...","[[1.5719516860000002, 0.009686527, 0.034541271...","[[1.5661166990000002, 0.008816085, 0.036110992...","[[1.566884794, 0.008378809, 0.0360777650000000...","[[1.565592826, 0.007277780000000001, 0.0361695...","[[1.5647386900000002, 0.007094567, 0.035732678...","[[1.58261107, 0.012235853, 0.029106107], [0.01...",...,"[[1.07809726, 0.0013952130000000002, 0.0025432...","[[1.073682773, 0.0013660270000000001, 0.002495...","[[1.072408514, 0.0013437920000000001, 0.002477...","[[1.071850834, 0.0013084300000000002, 0.002425...","[[1.0730329660000002, 0.0014174590000000003, 0...","[[1.071797598, 0.001234211, 0.002375217], [0.0...","[[1.071568483, 0.001139871, 0.0023203560000000...","[[1.090053006, 0.000189752, 0.0014052780000000...","[[0.7749622810000001, -1.4175e-05, 0.000193823...","[[1.0742299210000001, 0.0015558500000000001, 0..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


### Splitting into train-validation-test sets

In [6]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]
del m3gnet_dataset

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 315 samples to train, 170 to evaluate, and 108 to test


### Convert into graph database

In [7]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]

    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc['stress'].values.tolist()

    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data
del all_data, test_dataset, validation_dataset, train_dataset

100%|████████████████████████████████████████| 315/315 [00:01<00:00, 216.32it/s]
100%|████████████████████████████████████████| 170/170 [00:00<00:00, 217.46it/s]
100%|████████████████████████████████████████| 108/108 [00:00<00:00, 222.74it/s]


In [8]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=16,
    pin_memory=False  # True for more rapid data transfer to GPU, pinning memory to RAM
)
del train_data, val_data, test_data

# Retrain model

In [9]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [None]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
lit_module_finetune.model.save(model_save_path)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

# Analyze metrics

In [None]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

In [None]:
import pandas as pd

In [None]:
current_version = 0
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [None]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

In [None]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

# Cleanup the notebook

In [19]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', '*labels.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')

NameError: name 'glob' is not defined