In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
import ase

In [None]:
import sys
sys.path.append('../')

In [None]:
import cace
from cace.representations.cace_representation import Cace

In [None]:
def polynomial_function(xsqr, rcut, derivative = False, n_pow=2, prefactor = 1.):
    if derivative:
        return prefactor * n_pow * (1 - xsqr/rcut**2.)**(n_pow - 1) * (-1./rcut**2.)
    else:
        return prefactor * (1 - xsqr/rcut**2.)**n_pow

In [None]:
from ase.io import read,write
equil_frames = read('../datasets/qm7b/qm7b.xyz', ':')
#equil_frames = read('diamond.xyz',':')
import copy
augmented_frames = []

step_size = 0.1
f_noise_level = 0.01
e_noise_level = 0.01
repulsive_rcut = 0.7


for ef in equil_frames:
    #del ef.info['energy']
    ef.info['ee'] = 0.0
    ef.set_array('forces', np.zeros(ef.positions.shape) )
    augmented_frames.append(ef)
    for step in [1, 2, 3, 4, 6, 8, 10, 12, 16, 32]:
        ef_1 = copy.deepcopy(ef)
        d_pos = step * step_size * ( np.random.rand(*ef.positions.shape) - 0.5 )

        positions = ef_1.get_positions() + d_pos

        f_repulsive = np.zeros(ef_1.positions.shape)
        i, j, S = ase.neighborlist.primitive_neighbor_list(
                quantities="ijS",
                pbc=ef_1.pbc,
                cell=ef_1.cell,
                positions=positions,
                cutoff=repulsive_rcut,
                self_interaction=False,  
                use_scaled_positions=False,  # positions are not scaled positions
            )

        D = positions[j]-positions[i] + S.dot(ef_1.cell)
        
        D_sqr = np.sum(D**2.,axis=1)
        exp_D_l = polynomial_function(D_sqr, repulsive_rcut, derivative = True)
        f_repulsive[i] += 2. * exp_D_l[:, None] * D 

        f_noise = f_noise_level * step**0.5 * ( np.random.rand(*ef_1.positions.shape) - 0.5 )
        
        f = -1. * d_pos + f_noise
        
        ef_1.positions += d_pos
        
        
        ef_1.info['ee'] = 0.5 * np.sum(f**2.)\
                          + np.sum(polynomial_function(D_sqr, repulsive_rcut)) / 2. \
                          + e_noise_level * step**0.5 * (np.random.rand(1) - 0.5) 

        f += f_repulsive
        ef_1.set_array('forces', f)
        
        augmented_frames.append(ef_1)
write('qm7b-augmented.xyz', augmented_frames)

In [None]:
collection = cace.tasks.get_dataset_from_xyz(
                                 train_path='qm7b-augmented.xyz',
                                 valid_fraction=0.1,
                                 test_path='../datasets/qm7b/qm7b.xyz',
                                 energy_key='ae_pbe0', #'ee'
                                 forces_key='forces',
                                            )

In [None]:
cutoff = 4.5
batch_size = 20

In [None]:
from cace.tools import torch_geometric

In [None]:
from cace.tools.torch_geometric import dataloader

In [None]:
train_loader = cace.tasks.load_data_loader(collection=collection,
                              data_type='train',
                              batch_size=batch_size,
                              cutoff=cutoff)

In [None]:
valid_loader = cace.tasks.load_data_loader(collection=collection,
                              data_type='valid',
                              batch_size=100,
                              cutoff=cutoff)

In [None]:
test_loader = cace.tasks.load_data_loader(collection=collection,
                              data_type='test',
                              batch_size=100,
                              cutoff=cutoff)

In [None]:
device = cace.tools.init_device('cpu')

In [None]:
sampled_data = next(iter(valid_loader))

In [None]:
sampled_data = sampled_data.to(device)

In [None]:
from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff
from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered

In [None]:
radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True)
cutoff_fn = PolynomialCutoff(cutoff=cutoff, p=2)

In [None]:
cace_representation = Cace(
    zs=[1, 6, 7, 8, 16, 17],
    n_atom_basis=3,
    cutoff=cutoff,
    cutoff_fn=cutoff_fn,
    radial_basis=radial_basis,
    n_radial_basis=8,
    max_l=3,
    max_nu=3,
    num_message_passing=0,
    type_message_passing=["Bchi"],
    args_message_passing={'Bchi': {'shared_channels': False, 'shared_l': False}},
    device=device,
    timeit=False
           )

In [None]:
cace_representation.to(device)

In [None]:
%%time
sampled_reps = cace_representation(sampled_data.to(device))

In [None]:
atomwise = cace.modules.atomwise.Atomwise(
    n_layers=3,
    n_hidden=[24,12],
    output_key='CACE_energy',
    descriptor_output_key='desc',
    residual=False,
    add_linear_nn=True)

In [None]:
forces = cace.modules.forces.Forces(energy_key='CACE_energy',
                                    forces_key='CACE_forces')

In [None]:
from cace.models.atomistic import NeuralNetworkPotential

In [None]:
cace_nnp = NeuralNetworkPotential(
    input_modules=None,
    representation=cace_representation,
    output_modules=[atomwise, forces]
)

In [None]:
cace_nnp.to(device)

In [None]:
from cace.tasks import GetLoss

In [None]:
force_loss = GetLoss(
    target_name='forces',
    predict_name='CACE_forces',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=10000
)

In [None]:
from cace.tools import Metrics

In [None]:
f_metric = Metrics(
    target_name='forces',
    predict_name='CACE_forces',
    name='f'
)

In [None]:
from cace.tasks.train import TrainingTask

In [None]:
# Example usage

optimizer_args = {'lr': 1e-2, 'amsgrad': True}
scheduler_args = {'mode': 'min', 'factor': 0.8, 'patience': 10}
    
task = TrainingTask(
    model=cace_nnp,
    losses=[force_loss],
    metrics=[f_metric],
    device=device,
    optimizer_args=optimizer_args, 
    scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, 
    scheduler_args=scheduler_args,
    max_grad_norm=10,
    ema=True,
    ema_start=10,
    warmup_steps=10,
)


In [None]:
task.fit(train_loader, valid_loader, epochs=100, screen_nan=False)