In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
import cace
from cace.representations.cace_representation import Cace

In [None]:
from ase.io import read,write

In [None]:
cutoff = 5
batch_size = 10

In [None]:
train_xyz_dir = '../lode-datasets/train-id0.xyz'
test_xyz_dir = '../lode-datasets/test-id0.xyz'

In [None]:
train_ase_xyz = read(train_xyz_dir, ':')
test_ase_xyz = read(test_xyz_dir, ':')

In [None]:
element_list = cace.tools.get_unique_atomic_number(train_ase_xyz)

In [None]:
collection = cace.tasks.get_dataset_from_xyz(train_path=train_xyz_dir,
                                 valid_path=test_xyz_dir,
                                cutoff=cutoff,
                                 data_key={'energy': 'inter_energy', 
                                           'forces': 'forces',
                                          'distance': 'distance'},
                                            )

In [None]:
train_loader = cace.tasks.load_data_loader(collection=collection,
                              data_type='train',
                              batch_size=batch_size,
                                          )

In [None]:
valid_loader = cace.tasks.load_data_loader(collection=collection,
                              data_type='valid',
                              batch_size=batch_size)

In [None]:
device = cace.tools.init_device('cpu')

In [None]:
sampled_data = next(iter(valid_loader))

In [None]:
sampled_data = sampled_data.to(device)

In [None]:
sampled_data

In [None]:
from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff
from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered

In [None]:
radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=False)
cutoff_fn = PolynomialCutoff(cutoff=cutoff, p=5)

In [None]:
cace_representation = Cace(
    zs=element_list,
    n_atom_basis=3,
    cutoff=cutoff,
    cutoff_fn=cutoff_fn,
    radial_basis=radial_basis,
    n_radial_basis=8,
    max_l=2,
    max_nu=2,
    num_message_passing=1,
    device=device,
    timeit=False,
    forward_features=['atomic_charge']
           )

In [None]:
q = cace.modules.Atomwise(
    n_layers=3,
    n_hidden=[24,12],
    n_out=4,
    feature_key = ['node_feats'], 
    per_atom_output_key='q',
    output_key = 'tot_q',
    residual=False,
    add_linear_nn=False,
    bias=True)

In [None]:
from cace.modules import EwaldPotential

In [None]:
ep = EwaldPotential(dl=3.,
                    sigma=1.0,
                    feature_key='q',
                    output_key='ewald_potential',
                    aggregation_mode='sum')

In [None]:
forces_lr = cace.modules.Forces(energy_key='ewald_potential',
                                    forces_key='ewald_forces')

In [None]:
from cace.models.atomistic import NeuralNetworkPotential

In [None]:
nnp_lr = NeuralNetworkPotential(
    input_modules=None,
    representation=cace_representation,
    output_modules=[q, ep, forces_lr]
)

In [None]:
res = nnp_lr(sampled_data)

In [None]:
trainable_params = sum(p.numel() for p in nnp_lr.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

In [None]:
cace_representation(sampled_data)['node_feats'].shape

In [None]:
atomwise = cace.modules.Atomwise(
    n_layers=3,
    n_hidden=[24,12],
    n_out=1,
    output_key='CACE_energy_intra',
    residual=False,
    add_linear_nn=True,
    bias=True)

In [None]:
forces = cace.modules.Forces(energy_key='CACE_energy_intra',
                                    forces_key='CACE_forces_intra')

In [None]:
from cace.models.atomistic import NeuralNetworkPotential

In [None]:
cace_nnp_intra = NeuralNetworkPotential(
    input_modules=None,
    representation=cace_representation,
    output_modules=[atomwise, forces]
)

In [None]:
res = cace_nnp_intra(sampled_data)

In [None]:
sampled_data

In [None]:
from cace.models import CombinePotential

In [None]:
pot1 = {'CACE_energy': 'ewald_potential', 
        'CACE_forces': 'ewald_forces',
        'weight': 1.
       }

pot2 = {'CACE_energy': 'CACE_energy_intra', 
        'CACE_forces': 'CACE_forces_intra',
       }

In [None]:
combo_p = CombinePotential([nnp_lr, cace_nnp_intra], [pot1,pot2])

In [None]:
from cace.tasks import GetLoss

In [None]:
energy_loss = GetLoss(
    target_name='energy',
    predict_name='CACE_energy',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1
)

In [None]:
energy_loss_2 = GetLoss(
    target_name='energy',
    predict_name='CACE_energy',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1000
)

In [None]:
energy_loss_3 = GetLoss(
    target_name='energy',
    predict_name='CACE_energy',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=10000
)

In [None]:
force_loss = GetLoss(
    target_name='forces',
    predict_name= 'CACE_forces',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1000
)

In [None]:
force_loss_2 = GetLoss(
    target_name='forces',
    predict_name= 'CACE_forces',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=100
)

In [None]:
from cace.tools import Metrics

In [None]:
e_metric = Metrics(
    target_name='energy',
    predict_name='CACE_energy',
    name='e',
    per_atom=False
)

In [None]:
f_metric = Metrics(
    target_name='forces',
    predict_name='CACE_forces',
    name='f'
)

In [None]:
sampled_dict = sampled_data.to_dict()

In [None]:
sampled_data_result = combo_p(sampled_dict, training=True)

In [None]:
energy_loss(sampled_data_result, sampled_data)

In [None]:
force_loss(sampled_data_result, sampled_data)

In [None]:
from cace.tasks.train import TrainingTask

In [None]:
# Example usage

optimizer_args = {'lr': 1e-2, 'amsgrad': True}  # step 1
#optimizer_args = {'lr': 1e-3, 'amsgrad': True} # step 2
scheduler_args = {'step_size': 10, 'gamma': 0.9}  # 
#scheduler_args = {'mode': 'min', 'factor': 0.8, 'patience': 10}

for i in range(12):
    
    task = TrainingTask(
        model=combo_p,
        losses=[energy_loss, force_loss],
        metrics=[e_metric, f_metric],
        device=device,
        #optimizer_cls=torch.optim.SGD,
        optimizer_args=optimizer_args, 
        scheduler_cls=torch.optim.lr_scheduler.StepLR, 
        #scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, 
        scheduler_args=scheduler_args,
        max_grad_norm=10,
        ema=True,
        ema_start=10,
        warmup_steps=10,
    )
    
    task.fit(train_loader, valid_loader, epochs=300, screen_nan=False, val_stride=10)

In [None]:
# Example usage

optimizer_args = {'lr': 1e-3, 'amsgrad': True}  # step 1
#optimizer_args = {'lr': 1e-3, 'amsgrad': True} # step 2
scheduler_args = {'step_size': 20, 'gamma': 0.9}  # 
#scheduler_args = {'mode': 'min', 'factor': 0.8, 'patience': 10}

for i in range(8):
    
    task = TrainingTask(
        model=combo_p,
        losses=[energy_loss_2, force_loss],
        metrics=[e_metric, f_metric],
        device=device,
        #optimizer_cls=torch.optim.SGD,
        optimizer_args=optimizer_args, 
        scheduler_cls=torch.optim.lr_scheduler.StepLR, 
        #scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, 
        scheduler_args=scheduler_args,
        max_grad_norm=10,
        ema=True,
        ema_start=10,
        warmup_steps=10,
    )
    
    task.fit(train_loader, valid_loader, epochs=400, screen_nan=False, val_stride=10)

In [None]:
task.save_model('model.pth')

In [None]:
evaluator = cace.tasks.EvaluateTask(model_path='model.pth', device='cpu',
                                    energy_key='CACE_energy',
                                    forces_key='CACE_forces',
                                    )

In [None]:
pred_train = evaluator(train_ase_xyz)
pred_test = evaluator(test_ase_xyz)

In [None]:
train_f_true  = np.array([ xyz.get_array('forces') for xyz in train_ase_xyz]).reshape(-1,3)
test_f_true  = np.array([ xyz.get_array('forces') for xyz in test_ase_xyz]).reshape(-1,3)

In [None]:
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots(1, figsize=(3, 3))

ax1.plot(train_f_true[:,0], pred_train['forces'][:,0], '.', color='blue', label='Train')

ax1.plot(test_f_true[:,0], pred_test['forces'][:,0], '.', color='red', label='Test')

ax1.set_xlabel('Forces [$eV/\mathrm{\AA}$]')

ax1.set_ylabel('MLP-LR Forces [$eV/\mathrm{\AA}$]')
ax1.legend()

plt.tight_layout()
plt.show()

In [None]:
def get_property(atoms, info_name):
    return np.array([a.info[info_name] for a in atoms])

In [None]:
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots(1, figsize=(3, 2))

# Upper panel
ax1.plot(get_property(train_ase_xyz, 'distance'), get_property(train_ase_xyz, 'inter_energy'), 
         'o',color='blue', markerfacecolor='white')
ax1.plot(get_property(test_ase_xyz, 'distance'), get_property(test_ase_xyz, 'inter_energy'), 
         'o',color='red', markerfacecolor='white', label='True')

ax1.plot(get_property(train_ase_xyz, 'distance'), pred_train['energy'], 'x', color='b')
ax1.plot(get_property(test_ase_xyz, 'distance'), pred_test['energy'], 'x', color='r', label='MLP-LR')
#ax1.set_title('Energy')
ax1.set_xlabel('Distance [$\mathrm{\AA}$]')

ax1.set_ylabel('Energy [eV]')
ax1.legend()

plt.tight_layout()
plt.show()