Steps for preparing & training SchNet Model

In [None]:
# ====================================== SchNet Model start here ======================================
# SChNet arch: https://schnetpack.readthedocs.io/en/stable/modules/representation.html#module-schnetpack.representation.schnet
# Prep data: https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_01_preparing_data.html
# Build model: 
from schnetpack.data.atoms import AtomsData
from ase.io import read
import numpy as np
import schnetpack as spk
import os
# load atoms from xyz file. Here, we only parse the first 10 molecules
# atoms = read(os.path.join("SchNetInput","S2_indole_CAM-B3LYP_6-31G_d.xyz"), index=':')
# print(len(atoms))

In [None]:
# comment line is weirdly stored in the info dictionary as key by ASE. here it corresponds to the energy
# print('Energy:', atoms[0].info)
# print()
properties = ['energy', 'oscil_strength', 'dipole_moment']
# properties = ['energy', 'dipole_moment']
# %rm './SchNet_S2.db'
new_dataset = AtomsData(os.path.join('.','SchNetModel2','SchNet_energy_oscil_dipole_S1.db'), available_properties=properties)
# # parse properties as list of dictionaries
# energy_list = []
# oscil_list = []
# for at in atoms:
#     # All properties need to be stored as numpy arrays.
#     # Note: The shape for scalars should be (1,), not ()
#     # Note: GPUs work best with float32 data
#     energy = np.array([float(list(at.info.keys())[0].replace(',', ''))], dtype=np.float32)
#     oscil_strength = np.array([float(list(at.info.keys())[1])], dtype=np.float32)

#     new_dataset.add_system(at,energy=energy, oscil_strength=oscil_strength)
#     # energy_list.append({properties[0]: energy})
#     # oscil_list.append({properties[1]: oscil_strength})

Visualize molecule

In [None]:
properties = ['energy', 'oscil_strength', 'dipole_moment']
SchNet_data = AtomsData(os.path.join('.','SchNetModel2','SchNet_energy_oscil_dipole_S1.db'), available_properties=properties)
atoms, properties = new_dataset.get_properties(0)
print('Loaded properties:\n', *['{:s}\n'.format(i) for i in properties.keys()])

from ase.visualize import view
view(atoms, viewer='x3d')
#-----------------------------------------------------------------------

Split data & check statistical metrics 

In [None]:
import schnetpack as spk
import os

SchNetModel2 = './SchNetModel2'
if not os.path.exists(SchNetModel2):
    os.makedirs(SchNetModel2)

In [None]:
# --------------------------------- split data ---------------------------------
properties = ['energy', 'oscil_strength', 'dipole_moment']
SchNet_data = AtomsData(os.path.join('.','SchNetModel2','SchNet_energy_oscil_dipole_S1.db'), available_properties=properties)
train, val, test = spk.train_test_split(
    data=SchNet_data,
    num_train=1500,
    num_val=300,
    split_file=os.path.join(SchNetModel2, "split.npz"),
)

train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=100)

In [None]:
# --------------------------------- check statistical metrics ---------------------------------
train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=100)

means, stddevs = train_loader.get_statistics(
    properties
)
print('Mean S2 energy / atom:      {:12.4f} [eV]'.format(means['energy'][0]))
print('Std. dev. S2 energy / atom: {:12.4f} [ev]'.format(stddevs['energy'][0]))
print('Mean oscil_strength / atom:      {:12.4f} []'.format(means['oscil_strength'][0]))
print('Std. dev.  oscil_strength / atom: {:12.4f} []'.format(stddevs['oscil_strength'][0]))
print('Mean dipole / atom:      {:12.4f} []'.format(means['dipole_moment'][0]))
print('Std. dev. dipole / atom: {:12.4f} []'.format(stddevs['dipole_moment'][0]))

**Model parameters**

In [None]:
# ------------------------------------ model ----------------------------------------
import torch
n_features = 128 # 512 / 7 / 6.
schnet = spk.representation.SchNet(
    n_atom_basis=n_features, 
    n_filters=n_features, 
    n_gaussians=25, 
    n_interactions=10,
    cutoff=5., 
    cutoff_network=spk.nn.cutoff.CosineCutoff
)

energy_model = spk.atomistic.Atomwise(
    n_in=n_features,
    property='energy',
    mean=means['energy'][0],
    stddev=stddevs['energy'][0]
    # activation=torch.nn.functional.relu # <---------torch.nn.ReLU
)

oscil_model = spk.atomistic.Atomwise( # <--------------------------------
    n_in=n_features,
    property='oscil_strength',
    mean=means['oscil_strength'][0],
    stddev=stddevs['oscil_strength'][0]
    # activation=torch.nn.functional.relu
)

dipole_model = spk.atomistic.DipoleMoment(
    n_in=n_features, 
    n_layers=2, 
    n_neurons=None, 
    # activation=torch.nn.functional.relu, # <--------------torch.nn.functional.relu // & energy & dipole only // interaction blocks //& generate Model for S1
    property='dipole_moment', 
    predict_magnitude=True, 
    mean=means['dipole_moment'][0], 
    stddev=stddevs['dipole_moment'][0]
)

# energy_model = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0], property=QM9.U0,
#                                    mean=means[QM9.U0], stddev=stddevs[QM9.U0])
# model = spk.AtomisticModel(representation=schnet, output_modules=[energy_model,oscil_model,dipole_model])
model = spk.AtomisticModel(representation=schnet, output_modules=[energy_model,dipole_model])

 **Loss function & Optimizer**

In [None]:
# ------------------------------- Loss function -----------------------------------
import torch
from torch.optim import Adam 
# tradeoff
rho_tradeoff = 0.9 # for oscilator strength
properties = ['energy', 'oscil_strength', 'dipole_moment']
# loss function
def mse_loss(batch, result):
    # compute the mean squared error on the energies
    diff_energy = batch[properties[0]]-result[properties[0]]
    err_sq_energy = torch.mean(diff_energy ** 2)
    # compute the mean squared error on the oscil
    # diff_oscil = batch[properties[1]]-result[properties[1]]
    # err_sq_oscil = torch.mean(diff_oscil ** 2)
    # compute the mean squared error on the oscil
    diff_dipole = batch[properties[2]]-result[properties[2]]
    err_sq_dipole = torch.mean(diff_dipole ** 2)
    # build the combined loss function
    print("L_e: "+str(err_sq_energy)+ "and L_dipole: "+ str(err_sq_dipole))
    # err_sq = err_sq_energy + rho_tradeoff*err_sq_oscil + rho_tradeoff*err_sq_dipole
    err_sq = (1-rho_tradeoff)*err_sq_energy + rho_tradeoff*err_sq_dipole
    return err_sq

# build optimizer
optimizer = Adam(model.parameters(), lr=5e-4) # SGD recommended - by KietChu

Setup logging & trainer

In [None]:
# before setting up the trainer, remove previous training checkpoints and logs
%rm -rf ./SchNetModel2/checkpoints
%rm -rf ./SchNetModel2/log.csv
%rm -rf ./SchNetModel2/best_model
%rm -rf ./SchNetModel2/tensorboard

import schnetpack.train as trn
os.makedirs(os.path.dirname(os.path.join("SchNetModel2","tensorboard")), exist_ok=True) # create folder if not existed yet
# loss = trn.build_mse_loss([curr_porperty]) #<--------------------

metrics = [spk.metrics.MeanAbsoluteError(properties[0]),
           spk.metrics.MeanAbsoluteError(properties[2])] #<--------------------
hooks = [
    trn.TensorboardHook(log_path=os.path.join("SchNetModel2","tensorboard"), metrics=metrics),
    trn.CSVHook(log_path=SchNetModel2, metrics=metrics),
    trn.ReduceLROnPlateauHook(
        optimizer,
        patience=5, factor=0.8, min_lr=1e-4,
        stop_after_min=False # Stop after min learning rate is reached
    )
]

trainer = trn.Trainer(
    model_path=SchNetModel2,
    model=model,#<--------------------
    hooks=hooks,#<--------------------
    loss_fn=mse_loss,#<--------------------
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader
)


#=======================================================================
# before setting up the trainer, remove previous training checkpoints and logs
# %rm -rf ./SchNetModel2/checkpoints
# %rm -rf ./SchNetModel2/log.csv

# import schnetpack.train as trn

# # set up metrics
# metrics = [
#     spk.metrics.MeanAbsoluteError(MD17.energy),
#     spk.metrics.MeanAbsoluteError(MD17.forces)
# ]

# # construct hooks
# hooks = [
#     trn.CSVHook(log_path=SchNetModel2, metrics=metrics),
#     trn.ReduceLROnPlateauHook(
#         optimizer,
#         patience=5, factor=0.8, min_lr=1e-6,
#         stop_after_min=True
#     )
# ]

# trainer = trn.Trainer(
#     model_path=forcetut,
#     model=model,
#     hooks=hooks,
#     loss_fn=loss,
#     optimizer=optimizer,
#     train_loader=train_loader,
#     validation_loader=val_loader,
# )

In [None]:
# Load the TensorBoard notebook extension.
%reload_ext tensorboard

**Training & tensorboard logging**

In [None]:
# check if a GPU is available and use a CPU otherwise
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# determine number of epochs and train
n_epochs = 1000
trainer.train(device=device, n_epochs=n_epochs)

In [None]:
# !kill 2022
# %reload_ext tensorboard
%tensorboard --logdir SchNetModel2/tensorboard/

Zip & Download model

In [None]:
# ---------------------- Zip & Download model --------------------
!zip -r SchNetModel1_energy_oscil_dipole_128_10_5p.zip SchNetModel2/ 

In [None]:
!unzip SchNetModel2_energy_oscil_dipole_10_5p.zip

**Calculate perfomance metrics & visualize**

In [None]:
# ------------------------- Performance metrics MAPE --------------------
# !unzip ./SchNetModel2_energy_oscil_dipole_128_10_5p.zip

!unzip ./SchNetModel1_energy_oscil_dipole_128_10_5p.zip

import torch
import matplotlib.pyplot as plt
import numpy as np
import schnetpack as spk
import os
from schnetpack.data.atoms import AtomsData

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
properties = ['energy', 'oscil_strength', 'dipole_moment']
SchNetModel = './SchNetModel2'
SchNet_data = AtomsData(os.path.join('.','SchNetModel2','SchNet_energy_oscil_dipole_S1.db'), available_properties=properties)    
best_model = torch.load(os.path.join(SchNetModel, 'best_model'))

train, val, test = spk.train_test_split(
    data=SchNet_data,
    num_train=1500,
    num_val=300,
    split_file=os.path.join(SchNetModel, "split.npz"),
)
train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=100)
test_loader = spk.AtomsLoader(test, batch_size=100)

means, stddevs = test_loader.get_statistics(
    properties
)

energy_error = 0.0
oscil_error = 0.0
dipole_error = 0.0

final_batch = None
final_pred = None
final_batchD = None
final_predD = None
final_batchOs = None
final_predOs = None
for count, batch in enumerate(test_loader):
    # move batch to GPU, if necessary
    batch = {k: v.to(device) for k, v in batch.items()}

    # apply model
    pred = best_model(batch)
    temp0 = batch 
    temp_pred0 = pred
    temp = temp0['energy'].detach().cpu().numpy()
    temp_pred = temp_pred0['energy'].detach().cpu().numpy()
    #--------------------------- dipole --------------------------
    tempD = temp0['dipole_moment'].detach().cpu().numpy()
    temp_predD = temp_pred0['dipole_moment'].detach().cpu().numpy()

    # calculate absolute error of energies
    tmp_energy = torch.sum(torch.abs(pred['energy'] - batch['energy']))
    tmp_energy = tmp_energy.detach().cpu().numpy() # detach from graph & convert to numpy
    energy_error += tmp_energy

    # absolute error for dipole
    tmp_dipole = torch.sum(torch.abs(pred['dipole_moment'] - batch['dipole_moment']))
    tmp_dipole = tmp_dipole.detach().cpu().numpy()
    dipole_error += tmp_dipole

    # ---- log to CSV file -------
    if final_batch is None:
        final_batch = temp.copy()
    else:
        final_batch = np.concatenate((final_batch,temp),axis=None)

    if final_pred is None:
        final_pred = temp_pred.copy()
    else:
        final_pred = np.concatenate((final_pred,temp_pred),axis=None)
    
    #------------------------------------------------------------------
    if final_batchD is None:
        final_batchD = tempD.copy()
    else:
        final_batchD = np.concatenate((final_batchD,tempD),axis=None)

    if final_predD is None:
        final_predD = temp_predD.copy()
    else:
        final_predD = np.concatenate((final_predD,temp_predD),axis=None)

    #------------------------------ Oscil strength ---------------------------
    tmpO = temp0['oscil_strength'].detach().cpu().numpy()
    tmp_oscil = np.sum(np.abs(getOscil(temp_pred,temp_predD)-tmpO))
    oscil_error += tmp_oscil

    if final_batchOs is None:
        final_batchOs = tmpO.copy()
    else:
        final_batchOs = np.concatenate((final_batchOs,tmpO),axis=None)

final_predOs = getOscil(final_pred,final_predD)

plt.scatter(final_batchOs, final_predOs, c="r", alpha=0.5)
plt.xlabel("Ground truth Oscil")
plt.ylabel("Predicted Oscil")
# plt.legend(loc='upper left')

m, b = np.polyfit(final_batchOs, final_predOs, 1)
plt.plot(final_batchOs, m*final_batchOs + b)
plt.show()



plt.scatter(final_batch, final_pred, c="g", alpha=0.5)
plt.xlabel("Ground truth [eV]")
plt.ylabel("Predicted [eV]")
# plt.legend(loc='upper left')

m, b = np.polyfit(final_batch, final_pred, 1)
plt.plot(final_batch, m*final_batch + b)
plt.show()


plt.scatter(final_batchD, final_predD, c="b", alpha=0.5)
plt.xlabel("Ground truth dipole")
plt.ylabel("Predicted dipole")
plt.legend(loc='upper left')

mD, bD = np.polyfit(final_batchD, final_predD, 1)
plt.plot(final_batchD, mD*final_batchD + bD)
plt.show()


np.savetxt('E1_batch_1500_300.csv', final_batch[100:150])
np.savetxt('E1_pred_1500_300.csv', final_pred[100:150])
np.savetxt('oscil1_batch_1500_300.csv', final_batchOs[100:150])
np.savetxt('oscil1_pred_1500_300.csv', final_predOs[100:150])

energy_error /= len(test)
energy_error_percentage = energy_error *100 / means['energy'][0]

oscil_error /= len(test)
oscil_error_percentage = oscil_error *100 / means['oscil_strength'][0]

dipole_error /= len(test)
dipole_error_percentage = dipole_error *100 / means['dipole_moment'][0]

print('Mean S2 energy / atom     : {:12.4f} [eV]'.format(means['energy'][0]))
print('Std. dev. S2 energy / atom: {:12.4f} [ev]'.format(stddevs['energy'][0]))
print('Mean dipole / atom        : {:12.4f} []'.format(means['oscil_strength'][0]))
print('Std. dev.  dipole / atom  : {:12.4f} []'.format(stddevs['oscil_strength'][0]))
print('Mean dipole / atom        : {:12.4f} []'.format(means['dipole_moment'][0]))
print('Std. dev.  dipole / atom  : {:12.4f} []'.format(stddevs['dipole_moment'][0]))

print('\nTest MAE:')
print('    energy     : {:10.3f} eV'.format(energy_error))
print('    WMAPError  : {:10.3f} %'.format(energy_error_percentage))
print('    oscil      : {:10.3f}'.format(oscil_error))
print('    WMAPError  : {:10.3f} %'.format(oscil_error_percentage))
print('    dipole     : {:10.3f}'.format(dipole_error))
print('    WMAPError  : {:10.3f} %'.format(dipole_error_percentage))