In [1]:
import numpy             as np
import pytorch_lightning as pl
import ML_library        as MLL
import matplotlib.pyplot as plt
import matgl
import os
import warnings
import torch

from __future__                import annotations
from pytorch_lightning.loggers import CSVLogger
from matgl.ext.pymatgen        import Structure2Graph, get_element_list
from matgl.graph.data          import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training      import PotentialLightningModule

# To suppress warnings for clearer output
warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
device

device(type='cpu')

In [3]:
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

# Whether to include charge or not
charged = False

# 0: material, 1: charge state, 2: ionic step
depth = 1

# Define batch size
batch_size = 128

# Stress weight for training
stress_weight = 0.7

# Ratios for diving training data
test_ratio       = 0.2
validation_ratio = 0.2

# Number of epoch for re-training
max_epochs = 1000

# Learning-rate for re-training
lr = 1e-4

dpi = 100

# Version of training you specifically want to analyze
current_version = 1

# Load simulation data

In [4]:
# Each folder names a new column, and structure, energy, forces and stresses
# of each ionic step are loaded

# Path to dataset, structured as:
# path_to_dataset
#     material_i
#         defect_i
#             simulation_i (containing vasprun.xml)

path_to_dataset = 'data/database/SOC'
#path_to_dataset = '../../../Desktop/CeO2-data'

# Extract the data
source_m3gnet_dataset = MLL.extract_vaspruns_dataset(path_to_dataset, charged=charged)
#source_m3gnet_dataset = MLL.extract_OUTCAR_dataset(path_to_dataset)
source_m3gnet_dataset


BiSBr
	vac_1_Bi_-2
	vac_1_Bi_-3
	vac_1_Bi_0
	vac_1_Bi_2
	vac_2_S_-2
	vac_2_S_1
	vac_2_S_2
	vac_3_Br_-1
	vac_3_Br_0


Unnamed: 0_level_0,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr,BiSBr
Unnamed: 0_level_1,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,...,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0
Unnamed: 0_level_2,BiSBr_vac_1_Bi_-2_0,BiSBr_vac_1_Bi_-2_1,BiSBr_vac_1_Bi_-2_2,BiSBr_vac_1_Bi_-2_3,BiSBr_vac_1_Bi_-2_4,BiSBr_vac_1_Bi_-2_5,BiSBr_vac_1_Bi_-2_6,BiSBr_vac_1_Bi_-2_7,BiSBr_vac_1_Bi_-2_8,BiSBr_vac_1_Bi_-2_9,...,BiSBr_vac_3_Br_0_19,BiSBr_vac_3_Br_0_20,BiSBr_vac_3_Br_0_21,BiSBr_vac_3_Br_0_22,BiSBr_vac_3_Br_0_23,BiSBr_vac_3_Br_0_24,BiSBr_vac_3_Br_0_25,BiSBr_vac_3_Br_0_26,BiSBr_vac_3_Br_0_27,BiSBr_vac_3_Br_0_28
structure,[[ 1.12288994 11.84915635 3.76785106] Bi2.985...,[[ 1.12305124 11.84916887 3.76785702] Bi2.985...,[[ 1.12345278 11.84917762 3.76786401] Bi2.985...,[[ 1.12356514 11.84918963 3.76786329] Bi2.985...,[[ 1.12378798 11.8491977 3.76786864] Bi2.985...,[[ 1.12373388 11.8491965 3.76786854] Bi2.985...,[[ 1.12351003 11.84918929 3.76784839] Bi2.985...,[[ 1.12387285 11.84919478 3.7678562 ] Bi2.985...,[[ 1.12460783 11.84923424 3.76782196] Bi2.985...,[[ 1.12569264 11.84926118 3.7678011 ] Bi2.985...,...,"[[0.99954117 3.32938725 3.78335491] Bi3+, [ 1....","[[0.99915678 3.32910449 3.78321147] Bi3+, [ 1....","[[0.99868096 3.32875429 3.7830338 ] Bi3+, [ 1....","[[0.99817627 3.32839636 3.78334267] Bi3+, [ 1....","[[0.99760687 3.32799246 3.78369102] Bi3+, [ 1....","[[0.99801396 3.32788676 3.78359324] Bi3+, [ 1....","[[0.99828472 3.32781641 3.78352826] Bi3+, [ 1....","[[0.99855574 3.3283104 3.78301724] Bi3+, [ 1....","[[0.99871817 3.3286069 3.78271074] Bi3+, [ 1....","[[0.99836253 3.32816507 3.78243137] Bi3+, [ 1...."
energy,-346.683161,-346.683238,-346.683418,-346.68347,-346.683566,-346.683546,-346.683451,-346.683604,-346.683912,-346.68434,...,-357.012702,-357.01315,-357.013298,-357.013663,-357.013789,-357.014135,-357.014183,-357.014373,-357.014399,-357.014542
force,"[[0.00835869, 0.00064517, 0.00031], [-0.007161...","[[0.00825922, 0.00017896, 0.00014339], [-0.006...","[[0.00843988, 0.00190129, -0.00032288], [-0.00...","[[0.00845843, 0.0007909, 0.00012468], [-0.0089...","[[0.00914372, 0.0006242, -0.00029772], [-0.007...","[[0.00970161, 0.00084707, -9.275e-05], [-0.007...","[[0.00897098, 0.00166177, -1.475e-05], [-0.006...","[[0.0087766, 0.00174576, -0.00052072], [-0.006...","[[0.00971229, 0.00088492, 0.00024059], [-0.005...","[[0.01202756, 0.00091644, 0.00026965], [-0.005...",...,"[[-0.00848698, -0.00337636, 0.00832999], [0.00...","[[-0.005821, -0.00269221, 0.00724404], [0.0032...","[[-0.00205848, -0.0013421, 0.00560257], [-0.00...","[[0.00239626, -0.00056658, 0.00185966], [-0.00...","[[0.0085215, 0.00191985, -0.00363173], [-0.006...","[[0.00376318, 0.00238075, -0.00449801], [-0.00...","[[0.00028265, 0.00513489, -0.00418222], [0.004...","[[-0.00315621, -0.00314707, -0.00117365], [0.0...","[[-0.00486497, -0.00716668, 0.00134633], [0.00...","[[0.00145237, 0.00149947, 0.00234208], [0.0021..."
stress,"[[1.18883182, -0.045450646000000004, 0.0572963...","[[1.188422523, -0.04534139300000001, 0.0570914...","[[1.187356265, -0.045072649, 0.056845249], [-0...","[[1.1868099900000002, -0.045100825000000004, 0...","[[1.185687683, -0.044955313000000004, 0.056849...","[[1.1855293310000001, -0.045067341000000004, 0...","[[1.1874243260000001, -0.045151206, 0.05688580...","[[1.186097223, -0.044836004000000006, 0.056903...","[[1.183644189, -0.04446411, 0.0571164360000000...","[[1.182371493, -0.043553064, 0.057390124], [-0...",...,"[[2.188154276, 0.002678751, -0.005328923], [0....","[[2.184244166, 0.0014387890000000002, -0.00394...","[[2.180521461, -4.3988000000000006e-05, -0.002...","[[2.170404267, -0.000531909, -0.00198525], [-0...","[[2.1597340970000003, -0.0010707029999999999, ...","[[2.152978238, 0.000176565, -0.001665943], [0....","[[2.148861257, 0.001051386, -0.001601949], [0....","[[2.147951402, 5.776000000000001e-05, -0.00185...","[[2.145998127, -0.000543678, -0.00204980700000...","[[2.14501845, -0.0017708430000000002, -0.00241..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [5]:
len(source_m3gnet_dataset)

5

# Split data into train-validation-test sets

### Decide if we split in terms of mateiral, defect state or simulation directly

In [6]:
# Clone (copy) the DataFrame
m3gnet_dataset = source_m3gnet_dataset.copy()

# Remove the outer (top-level) column index up to depth-1 level
for i in range(depth):
    m3gnet_dataset.columns = m3gnet_dataset.columns.droplevel(0)

In [7]:
m3gnet_dataset

Unnamed: 0_level_0,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,BiSBr_vac_1_Bi_-2,...,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0,BiSBr_vac_3_Br_0
Unnamed: 0_level_1,BiSBr_vac_1_Bi_-2_0,BiSBr_vac_1_Bi_-2_1,BiSBr_vac_1_Bi_-2_2,BiSBr_vac_1_Bi_-2_3,BiSBr_vac_1_Bi_-2_4,BiSBr_vac_1_Bi_-2_5,BiSBr_vac_1_Bi_-2_6,BiSBr_vac_1_Bi_-2_7,BiSBr_vac_1_Bi_-2_8,BiSBr_vac_1_Bi_-2_9,...,BiSBr_vac_3_Br_0_19,BiSBr_vac_3_Br_0_20,BiSBr_vac_3_Br_0_21,BiSBr_vac_3_Br_0_22,BiSBr_vac_3_Br_0_23,BiSBr_vac_3_Br_0_24,BiSBr_vac_3_Br_0_25,BiSBr_vac_3_Br_0_26,BiSBr_vac_3_Br_0_27,BiSBr_vac_3_Br_0_28
structure,[[ 1.12288994 11.84915635 3.76785106] Bi2.985...,[[ 1.12305124 11.84916887 3.76785702] Bi2.985...,[[ 1.12345278 11.84917762 3.76786401] Bi2.985...,[[ 1.12356514 11.84918963 3.76786329] Bi2.985...,[[ 1.12378798 11.8491977 3.76786864] Bi2.985...,[[ 1.12373388 11.8491965 3.76786854] Bi2.985...,[[ 1.12351003 11.84918929 3.76784839] Bi2.985...,[[ 1.12387285 11.84919478 3.7678562 ] Bi2.985...,[[ 1.12460783 11.84923424 3.76782196] Bi2.985...,[[ 1.12569264 11.84926118 3.7678011 ] Bi2.985...,...,"[[0.99954117 3.32938725 3.78335491] Bi3+, [ 1....","[[0.99915678 3.32910449 3.78321147] Bi3+, [ 1....","[[0.99868096 3.32875429 3.7830338 ] Bi3+, [ 1....","[[0.99817627 3.32839636 3.78334267] Bi3+, [ 1....","[[0.99760687 3.32799246 3.78369102] Bi3+, [ 1....","[[0.99801396 3.32788676 3.78359324] Bi3+, [ 1....","[[0.99828472 3.32781641 3.78352826] Bi3+, [ 1....","[[0.99855574 3.3283104 3.78301724] Bi3+, [ 1....","[[0.99871817 3.3286069 3.78271074] Bi3+, [ 1....","[[0.99836253 3.32816507 3.78243137] Bi3+, [ 1...."
energy,-346.683161,-346.683238,-346.683418,-346.68347,-346.683566,-346.683546,-346.683451,-346.683604,-346.683912,-346.68434,...,-357.012702,-357.01315,-357.013298,-357.013663,-357.013789,-357.014135,-357.014183,-357.014373,-357.014399,-357.014542
force,"[[0.00835869, 0.00064517, 0.00031], [-0.007161...","[[0.00825922, 0.00017896, 0.00014339], [-0.006...","[[0.00843988, 0.00190129, -0.00032288], [-0.00...","[[0.00845843, 0.0007909, 0.00012468], [-0.0089...","[[0.00914372, 0.0006242, -0.00029772], [-0.007...","[[0.00970161, 0.00084707, -9.275e-05], [-0.007...","[[0.00897098, 0.00166177, -1.475e-05], [-0.006...","[[0.0087766, 0.00174576, -0.00052072], [-0.006...","[[0.00971229, 0.00088492, 0.00024059], [-0.005...","[[0.01202756, 0.00091644, 0.00026965], [-0.005...",...,"[[-0.00848698, -0.00337636, 0.00832999], [0.00...","[[-0.005821, -0.00269221, 0.00724404], [0.0032...","[[-0.00205848, -0.0013421, 0.00560257], [-0.00...","[[0.00239626, -0.00056658, 0.00185966], [-0.00...","[[0.0085215, 0.00191985, -0.00363173], [-0.006...","[[0.00376318, 0.00238075, -0.00449801], [-0.00...","[[0.00028265, 0.00513489, -0.00418222], [0.004...","[[-0.00315621, -0.00314707, -0.00117365], [0.0...","[[-0.00486497, -0.00716668, 0.00134633], [0.00...","[[0.00145237, 0.00149947, 0.00234208], [0.0021..."
stress,"[[1.18883182, -0.045450646000000004, 0.0572963...","[[1.188422523, -0.04534139300000001, 0.0570914...","[[1.187356265, -0.045072649, 0.056845249], [-0...","[[1.1868099900000002, -0.045100825000000004, 0...","[[1.185687683, -0.044955313000000004, 0.056849...","[[1.1855293310000001, -0.045067341000000004, 0...","[[1.1874243260000001, -0.045151206, 0.05688580...","[[1.186097223, -0.044836004000000006, 0.056903...","[[1.183644189, -0.04446411, 0.0571164360000000...","[[1.182371493, -0.043553064, 0.057390124], [-0...",...,"[[2.188154276, 0.002678751, -0.005328923], [0....","[[2.184244166, 0.0014387890000000002, -0.00394...","[[2.180521461, -4.3988000000000006e-05, -0.002...","[[2.170404267, -0.000531909, -0.00198525], [-0...","[[2.1597340970000003, -0.0010707029999999999, ...","[[2.152978238, 0.000176565, -0.001665943], [0....","[[2.148861257, 0.001051386, -0.001601949], [0....","[[2.147951402, 5.776000000000001e-05, -0.00185...","[[2.145998127, -0.000543678, -0.00204980700000...","[[2.14501845, -0.0017708430000000002, -0.00241..."
charge_state,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Splitting into train-validation-test sets

In [8]:
# Check if data has been already split, else do it randomly

path_to_test_labels       = 'test_labels.txt'
path_to_validation_labels = 'validation_labels.txt'
path_to_train_labels      = 'train_labels.txt'

if os.path.exists(path_to_test_labels) and os.path.exists(path_to_validation_labels) and os.path.exists(path_to_train_labels):
    # Read labels splitting (which are strings)
    test_labels       = np.genfromtxt(path_to_test_labels,       dtype='str').tolist()
    validation_labels = np.genfromtxt(path_to_validation_labels, dtype='str').tolist()
    train_labels      = np.genfromtxt(path_to_train_labels,      dtype='str').tolist()
else:
    # Define unique labels, wrt the outer column
    unique_labels = np.unique(m3gnet_dataset.columns.get_level_values(0))

    # Shuffle the list of unique labels
    np.random.shuffle(unique_labels)

    # Define the sizes of every set
    # Corresponds to the size wrt the number of unique materials in the dataset
    test_size       = int(test_ratio       * len(unique_labels))
    validation_size = int(validation_ratio * len(unique_labels))

    test_labels       = unique_labels[:test_size]
    validation_labels = unique_labels[test_size:test_size+validation_size]
    train_labels      = unique_labels[test_size+validation_size:]
    
    # Save this splitting for transfer-learning approaches
    np.savetxt(path_to_test_labels,       test_labels,       fmt='%s')
    np.savetxt(path_to_validation_labels, validation_labels, fmt='%s')
    np.savetxt(path_to_train_labels,      train_labels,      fmt='%s')

# Use the loaded/computed labels to generate split datasets
test_dataset       = m3gnet_dataset[test_labels]
validation_dataset = m3gnet_dataset[validation_labels]
train_dataset      = m3gnet_dataset[train_labels]

n_test       = np.shape(test_dataset)[1]
n_validation = np.shape(validation_dataset)[1]
n_train      = np.shape(train_dataset)[1]

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')

Using 205 samples to train, 26 to evaluate, and 42 to test


### Convert into graph database

In [ ]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_dataset, validation_dataset, test_dataset][i]
    
    # Extract data from dataset
    structures    = dataset.loc['structure'].values.tolist()
    element_types = get_element_list(structures)
    converter     = Structure2Graph(element_types=element_types, cutoff=5.0)
    
    # Define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc['stress'].values.tolist()

    labels = {
        'energies': dataset.loc['energy'].values.tolist(),
        'forces':   dataset.loc['force'].values.tolist(),
        'stresses': stresses,
    }
    
    # Generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

In [10]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
)

# Retrain model

In [11]:
# Download a pre-trained M3GNet
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

# Stress and site-wise are added to training loss
# Stresses are being computed (calc_stress=True)
lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss',
                                               lr=lr)

In [12]:
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator='cpu' kwarg.
# accelerator='auto' selects the appropriate Accelerator
max_epochs = 100
logger  = CSVLogger('logs',
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='auto',
                     logger=logger,
                     inference_mode=False)

trainer.fit(model=lit_module_finetune,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
           )

# Save trained model
model_pretrained.save(model_save_path)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | mae   | MeanAbsoluteError | 0     
1 | rmse  | MeanSquaredError  | 0     
2 | model | Potential         | 288 K 
--------------------------------------------
288 K     Trainable params
0         Non-trainable params
288 K     Total params
1.153     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


# Analyze metrics

In [13]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE        0.10218613594770432
    test_Energy_RMSE        0.10229505598545074
     test_Force_MAE         0.21951214969158173
     test_Force_RMSE        0.33297672867774963
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE                0.0
    test_Stress_RMSE                0.0
     test_Total_Loss        0.12133777886629105
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 0.12133777886629105,
  'test_Energy_MAE': 0.10218613594770432,
  'test_Force_MAE': 0.21951214969158173,
  'test_Stress_MAE': 0.0,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': 0.10229505598545074,
  'test_Force_RMSE': 0.33297672867774963,
  'test_Stress_RMSE': 0.0,
  'test_Site_Wise_RMSE': 0.0}]

In [22]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_Energy_MAE        0.1976138800382614
    test_Energy_RMSE         0.197726771235466
     test_Force_MAE         0.2671034038066864
     test_Force_RMSE        0.43352213501930237
   test_Site_Wise_MAE               0.0
   test_Site_Wise_RMSE              0.0
     test_Stress_MAE                0.0
    test_Stress_RMSE                0.0
     test_Total_Loss        0.22703731060028076
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 0.22703731060028076,
  'test_Energy_MAE': 0.1976138800382614,
  'test_Force_MAE': 0.2671034038066864,
  'test_Stress_MAE': 0.0,
  'test_Site_Wise_MAE': 0.0,
  'test_Energy_RMSE': 0.197726771235466,
  'test_Force_RMSE': 0.43352213501930237,
  'test_Stress_RMSE': 0.0,
  'test_Site_Wise_RMSE': 0.0}]

In [None]:
# Read the CSV file
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=dpi, bbox_inches='tight')
plt.show()

In [45]:
df['val_Energy_MAE'].iloc[-2], df['val_Force_MAE'].iloc[-2], df['val_Stress_MAE'].iloc[-2]

(0.0135606033727526, 0.0874462649226188, 0.0)

In [46]:
df['val_Energy_MAE'].iloc[-1], df['val_Force_MAE'].iloc[-1], df['val_Stress_MAE'].iloc[-1]

(0.0, 0.0, 0.0)

# Cleanup the notebook

In [34]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', '*labels.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

#shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')