In [1]:
import torch
import schnetpack as spk
import math


import numpy as np
from numpy import savetxt

In [2]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return z

In [3]:
def hook_layer(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer_output
    #Update the embedding_output variable to be equal to our output tensor
    layer_output=out_tensor 

In [4]:
def convert_2D(number_of_atoms,layer_output):
    layer = np.zeros((number_of_atoms,30))
    for i in range(number_of_atoms):
        for j in range(30):
            layer[i][j] = layer_output[0][i][j]
    return layer

In [None]:
# Load QM9 dataset
from schnetpack.datasets import QM9

qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)
# Load split file 
train, val, test = spk.data.train_test_split(qm9data,split_file='../../../data/trained_models/qm9_i3_30f_10000_5000/split.npz')

# Load atom ref data 
atomrefs = qm9data.get_atomref(QM9.U0)
# Define SchNet representation model

schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=3,
    cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted

output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0])

# Define atomistic model

model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

# Load saved checkpoint file
checkpoint_path = '../../../data/trained_models/qm9_i3_30f_10000_5000/trained.pth'
load_checkpoint = torch.load(checkpoint_path)

#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)


number_of_inputs = 5000

#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

# load spk calculator
calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)

data = np.zeros((1,30))
datahae = np.zeros((1))
dataoae = np.zeros((1))
datah = np.zeros((1,30))
for idx in range(number_of_inputs):
        
    # load data for molecule
    at, props = qm9data.get_properties(idx)
        
    # print molecule for identification
    print(idx)
    z = print_molecule(props)
    number_of_atoms=len(z)
        
    # set calculator on molecule
    at.set_calculator(calculator)
        
    # convert qm9 data to machine-readable form
    inputs = converter(at)
        
    #Instatiate layer output
    layer_output=None
        
    # Forward hook the model's interaction layer 
    model.representation.interactions[0].filter_network[0].register_forward_hook(hook_layer)
        
    # Forward pass molecules through the model
    model(inputs)
    
    layer_output = layer_output.detach().numpy()
    
    from schnetpack.atomistic.output_modules import yi
    
    yi=yi.detach().numpy()
    
    print(layer_output)

    #convert layer tensor to 2D array
    rows = convert_2D(number_of_atoms,layer_output)

    #save the vector of every oxygen atom encountered
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            data = np.vstack((data,rows[i]))
            
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datah = np.vstack((datah,rows[i]))    
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datahae = np.vstack((datahae,yi[0][i]))
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            dataoae = np.vstack((dataoae,yi[0][i]))            
    
print('DONE')            
savetxt('../../../data/dataO.csv',data,delimiter=',')
savetxt('../../../data/dataH.csv',datah,delimiter=',')
savetxt('../../../data/hae.csv',datahae,delimiter=',')
savetxt('../../../data/oae.csv',dataoae,delimiter=',')