In [None]:
import torch
import schnetpack as spk
import math


import numpy as np
from numpy import savetxt

In [None]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return z

def hook_v0(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v0
    #Update the embedding_output variable to be equal to our output tensor
    v0=out_tensor 

def hook_v1(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v1
    #Update the embedding_output variable to be equal to our output tensor
    v1=out_tensor 

def hook_v2(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v2
    #Update the embedding_output variable to be equal to our output tensor
    v2=out_tensor  
    
def hook_emb(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global emb
    #Update the embedding_output variable to be equal to our output tensor
    emb=out_tensor 
#def convert_2D(number_of_atoms,rep):
#    layer = np.zeros((number_of_atoms,30))
#    for i in range(number_of_atoms):
#        for j in range(30):
#            layer[i][j] = rep[0][i][j]
#    return layer

In [None]:
# Load QM9 dataset
from schnetpack.datasets import QM9

qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)
# Load split file 
train, val, test = spk.data.train_test_split(qm9data,split_file='../../../data/trained_models/qm9_i3_30f_10000_5000/split.npz')

# Load atom ref data 
atomrefs = qm9data.get_atomref(QM9.U0)
# Define SchNet representation model

schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=3,
    cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted

output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0])

# Define atomistic model

model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

# Load saved checkpoint file
checkpoint_path = '../../../data/trained_models/qm9_i3_30f_10000_5000/trained.pth'
load_checkpoint = torch.load(checkpoint_path)

#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)

device = 'cpu'

In [None]:
#Following code only works if above has ran
# NOW trying distorted molecules    
from ase.io import read

# load spk calculator
calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)
converter = spk.data.AtomsConverter(device=device)

dis = np.zeros((1,30))

atoms = read('../../../data/molecule-non-molecule/meoh-dis1.xyz')
print(atoms)

inputs = converter(atoms)

#Instatiate layer output
v0=None
v1=None
v2=None
emb=None

# Forward hook the model's interaction layer 
model.representation.interactions[0].register_forward_hook(hook_v0)

# Forward hook the model's interaction layer 
model.representation.interactions[1].register_forward_hook(hook_v1)


# Forward hook the model's interaction layer 
model.representation.interactions[2].register_forward_hook(hook_v2)

# Forward hook the model's interaction layer 
model.representation.embedding.register_forward_hook(hook_emb)

# Forward pass molecules through the model
model(inputs)


rep = emb + v0 + v1 + v2
rows = np.zeros((6,30))
for i in range(6):
    for j in range(30):
        rows[i][j] = rep[0][i][j]

from schnetpack.atomistic.output_modules import yi

yi=yi.detach().numpy()

for i in range(6):
    dis = np.vstack((dis,rows[i]))


print(rep)

from schnetpack.representation.schnet import x
from schnetpack.representation.schnet import v
from schnetpack.representation.schnet import x0

print(x)

print('DONE')            
savetxt('../../../data/data-dis.csv',dis,delimiter=',')