In [4]:
import torch
import schnetpack as spk
import math
from schnetpack.datasets import QM9


import numpy as np
from numpy import savetxt

# Define Important Functions

In [2]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return x,y,z

def hook_v0(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v0
    #Update the embedding_output variable to be equal to our output tensor
    v0=out_tensor 

def hook_v1(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v1
    #Update the embedding_output variable to be equal to our output tensor
    v1=out_tensor 

def hook_v2(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v2
    #Update the embedding_output variable to be equal to our output tensor
    v2=out_tensor  
    
def hook_emb(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global emb
    #Update the embedding_output variable to be equal to our output tensor
    emb=out_tensor 
#def convert_2D(number_of_atoms,rep):
#    layer = np.zeros((number_of_atoms,30))
#    for i in range(number_of_atoms):
#        for j in range(30):
#            layer[i][j] = rep[0][i][j]
#    return layer
            






In [3]:
qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)
checkpoint_path = '../../../../data/trained_models/qm9_i3_30f_10000_5000/trained.pth'
split_file='../../../../data/trained_models/qm9_i3_30f_10000_5000/split.npz'
number_of_inputs = 5000

# Load split file 
train, val, test = spk.data.train_test_split(qm9data,split_file=split_file)

# Load atom ref data 
atomrefs = qm9data.get_atomref(QM9.U0)
print('U0 of hyrogen:', '{:.2f}'.format(atomrefs[QM9.U0][1][0]), 'eV')
print('U0 of carbon:', '{:.2f}'.format(atomrefs[QM9.U0][6][0]), 'eV')
print('U0 of oxygen:', '{:.2f}'.format(atomrefs[QM9.U0][8][0]), 'eV')
print('U0 of oxygen:', '{:.2f}'.format(atomrefs[QM9.U0][7][0]), 'eV')   

# Define SchNet representation model

schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=3,
    cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted

output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0])

# Define atomistic model

model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

# Load saved checkpoint file
load_checkpoint = torch.load(checkpoint_path)

#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)

#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

datao = np.zeros((1,30))
datahae = np.zeros((1))
dataoae = np.zeros((1))
datah = np.zeros((1,30))
data = np.zeros((1,4))

print(data)
for idx in range(number_of_inputs):

    # load data for molecule
    at, props = qm9data.get_properties(idx)

    
    # print molecule for identification
    print(idx)
    x, y, z = print_molecule(props)
    number_of_atoms=len(z)

    # convert qm9 data to machine-readable form
    inputs = converter(at)

    #Instatiate layer output
    v0=None
    v1=None
    v2=None
    x=None       

    # Forward hook the model's interaction layer 
    model.representation.interactions[0].register_forward_hook(hook_v0)

    # Forward hook the model's interaction layer 
    model.representation.interactions[1].register_forward_hook(hook_v1)

    # Forward hook the model's interaction layer 
    model.representation.interactions[2].register_forward_hook(hook_v2)

    # Forward hook the model's interaction layer 
    model.representation.embedding.register_forward_hook(hook_emb)

    # Forward pass molecules through the model
    model(inputs)


    rep = emb + v0 + v1 + v2
    rows = np.zeros((number_of_atoms,30))
    for i in range(number_of_atoms):
        for j in range(30):
            rows[i][j] = rep[0][i][j]

    from schnetpack.atomistic.output_modules import yi

    yi=yi.detach().numpy()

    #save the vector of every oxygen atom encountered
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            datao = np.vstack((datao,rows[i]))
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datah = np.vstack((datah,rows[i]))    
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datahae = np.vstack((datahae,yi[0][i]))
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            dataoae = np.vstack((dataoae,yi[0][i]))
            

    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2] 
    
    distance = np.zeros((number_of_atoms))
    dist_x = np.zeros((number_of_atoms))
    dist_y = np.zeros((number_of_atoms))    
    dist_z = np.zeros((number_of_atoms))

    #calculate E distance, to figure out neighbor
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            for j in range(number_of_atoms):
                dist_x[j] = x[i] - x[j]
                dist_y[j] = y[i] - y[j]
                dist_z[j] = z[i] - z[j]
                distance[j] = math.sqrt(dist_x[j]**2 + dist_y[j]**2 + dist_z[j]**2)
            
            neighbor_dist = distance[0]                    
            neighbor = 0
            for k in range(number_of_atoms):
                if distance[k] < neighbor_dist and distance[k] != 0:
                    neighbor_dist = distance[k]
                    neighbor = k
                
            
            neighbor2_dist = neighbor_dist                    
            neighbor2 = 0
            for k in range(number_of_atoms):
                if distance[k] < neighbor2_dist and distance[k] != 0 and props['_atomic_numbers'][k] != 1:
                    neighbor2_dist = distance[k]
                    neighbor2 = k
                    
            neighbor3_dist = neighbor2_dist                    
            neighbor3 = 0
            for k in range(number_of_atoms):
                if distance[k] < neighbor3_dist and distance[k] != 0 and props['_atomic_numbers'][k] != 1:
                    neighbor3_dist = distance[k]
                    neighbor3 = k
            
            neighbor4_dist = neighbor3_dist                    
            neighbor4 = 0
            for k in range(number_of_atoms):
                if distance[k] < neighbor4_dist and distance[k] != 0 and props['_atomic_numbers'][k] != 1:
                    neighbor4_dist = distance[k]
                    neighbor4 = k
            
        
            first = float(yi[0][neighbor])
            second = float(yi[0][neighbor2])
            third = float(yi[0][neighbor3])
            fourth = float(yi[0][neighbor4])
            row = (first,second,third,fourth)
            data = np.vstack((data, row))
print(data)

U0 of hyrogen: -13.61 eV
U0 of carbon: -1029.86 eV
U0 of oxygen: -2042.61 eV
U0 of oxygen: -1485.30 eV
[[0. 0. 0.]]
0
C -2.8340169e-06 2.3049886e-06 -1.4378233e-07
H 0.014845718 -1.0918331 -0.0060250196
H 1.0244261 0.3779494 -0.007724565
H -0.52811974 0.36172476 -0.88464487
H -0.5111183 0.3521308 0.89839613

[]
[2. 3. 4. 1. 3. 1. 2. 1.]
[[0. 0. 0.]]


  properties[pname] = torch.FloatTensor(prop)


In [None]:
print('DONE')
savetxt('../../../../data/data.csv',data,delimiter=',') 
savetxt('../../../data/dataO.csv',datao,delimiter=',') 
savetxt('../../../data/dataH.csv',datah,delimiter=',') 
savetxt('../../../data/hae.csv',datahae,delimiter=',') 
savetxt('../../../data/oae.csv',dataoae,delimiter=',')
        