# For reference, I have looked at their NeuroChem sublibrary of Torchani in their documentaion here https://aiqm.github.io/torchani/api.html#module-torchani.neurochem

In [68]:
import torch
from torch import Tensor
import numpy
import torchani
from torchani import ANIModel
#from torchviz import make_dot
from torchani import neurochem
from torchani.utils import ChemicalSymbolsToInts as CSTI
import os

#os.environ["PATH"] += os.pathsep + 'C:\Program Files (x86)\Graphviz-2.38\bin'


In [2]:
local = os.path.abspath(os.getcwd())

## I have found the information they have used to transfer their NeuroChem Networks to pyTorch. This is all stored in the 'ani-model-zoo-master' directory accessible here https://github.com/aiqm/ani-model-zoo



#### The Model compromises of 3 parts: AEVcomputer, ANI-Network, and the EnergyShifter(to account for atom's self energy at the end)

In [3]:


#We want to use the ani-2x_8x model so ive set the path to there
model_path = os.path.join(local,r"ani-model-zoo-master\resources\ani-2x_8x")

#Extracting constants used for their AEV computer
const = neurochem.Constants(os.path.join(model_path, 'rHCNOSFCl-5.1R_16-3.5A_a8-4.params'))
aev_computer = torchani.AEVComputer(Rcr=const.Rcr, Rca=const.Rca, EtaR=const.EtaR, ShfR=const.ShfR, 
                                    EtaA=const.EtaA, Zeta=const.Zeta, ShfA=const.ShfA, ShfZ=const.ShfZ,
                                   num_species = const.num_species)



#Exracting the self-atomic energies for the energy shifter
EShifter = neurochem.load_sae(os.path.join(model_path, 'sae_linfit.dat'))

#### Now we can either load a single atomic network or a full model (or an ensemble of models too but I am not sure that will work with transfer learning)

In [77]:
#I do not know what number next to train does in each folder, I've assumed train7 would be the best


#Here we load an H network
H_Network = neurochem.load_atomic_network(os.path.join(model_path, r'train7\networks\ANN-H.nnf'))

#Loading a full model incorprating atoms F O C for example we would instead
FOC_Network = neurochem.load_model(['F', 'O','C'], os.path.join(model_path, r'train7\networks'))

#In our case ofcourse we would like to to load the ani2 model with al the atoms
ANI2_Network = neurochem.load_model(const.species, os.path.join(model_path,r'train7\networks'))

In [65]:
#Testing things on a simple example CO2
species_to_tensor = CSTI(['H','O',"F", 'C'])

coordinates = torch.tensor([[[-1.0,0.0,0.0],
                             [0.0,0.0,0.0],
                             [1.0,0.0,0.0]]])
species = species_to_tensor(['O', 'C', 'O']).unsqueeze(0)

species2 = const.species_to_tensor(['O','C','O']).unsqueeze(0)


aev = aev_computer((species,coordinates))

#This network does the sum
y = FOC_Network(aev)

y = EShifter(y)

species2

tensor([[3, 1, 3]])

### Running the AEV through the H network didn't seem to work and looked like a pain to fix. Instead I overrided the ANIModel (in this case FOC- Model) to return the individual energies

In [76]:
#Testing with H2 Molecule

coordinates1 = torch.tensor([[[0.75,0,0],
                            [0,0,0]]])
species1 = species_to_tensor(['H']).unsqueeze(0)

aev1 = aev_computer((species1, coordinates1))

#z = H_Network(aev1)
#z = EShifter(z)
aev1

#Overriding the FOC-Model
from typing import NamedTuple, Tuple

class SpeciesEnergies(NamedTuple):
    species: Tensor
    energies: Tensor



def overforward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
    species, aev = species_aev
    species_ = species.flatten()
    aev = aev.flatten(0, 1)

    output = aev.new_zeros(species_.shape)

    for i, (_, m) in enumerate(self.items()):
        mask = (species_ == i)
        midx = mask.nonzero().flatten()
        if midx.shape[0] > 0:
            input_ = aev.index_select(0, midx)
            output.masked_scatter_(mask, m(input_).flatten())
    output = output.view_as(species)
    return SpeciesEnergies(species, output)


funcType = type(FOC_Network.forward)

FOC_Network.forward = funcType(overforward, FOC_Network)


FOC_Network(aev)
EShifter(FOC_Network(aev))
    

SpeciesEnergies(species=tensor([[1, 3, 1]]), energies=tensor([[-151.2266, -151.3697, -151.2266]], dtype=torch.float64,
       grad_fn=<AddBackward0>))

# Now we have the individual energies :)

## See the single torchani_update.py script I made for the code to do this for ani2 :)