In [1]:
import torch
import schnetpack as spk
import math

import numpy as np
from numpy import savetxt

# Define functions

In [2]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return z

In [3]:
def hook_layer(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer_output
    #Update the embedding_output variable to be equal to our output tensor
    layer_output=out_tensor 

In [4]:
def convert_2D(number_of_atoms,layer_output):
    layer = np.zeros((number_of_atoms,30))
    for i in range(number_of_atoms):
        for j in range(30):
            layer[i,j] = layer_output[0][i][j]
    return layer

# Load model and run forward passes using molecules of QM9 dataset

In [7]:
# Load QM9 dataset
from schnetpack.datasets import QM9

qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)

# Load split file 
train, val, test = spk.data.train_test_split(qm9data,split_file='../../../data/trained_models/qm9_i6_30f/split.npz')

# Load atom ref data 
atomrefs = qm9data.get_atomref(QM9.U0)
# Define SchNet representation model

schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
    cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted

output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0])

# Define atomistic model

model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

# Load saved checkpoint file
checkpoint_path = '../../../data/trained_models/qm9_i6_30f/checkpoints/checkpoint-330.pth.tar'
load_checkpoint = torch.load(checkpoint_path)

#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)


number_of_inputs = 5000

#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

# load spk calculator
calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)

data = np.zeros((1,30))
datay = np.zeros((1))
for idx in range(number_of_inputs):
        
    # load data for molecule
    at, props = qm9data.get_properties(idx)
        
    # print molecule for identification
    print(idx)
    z = print_molecule(props)
    number_of_atoms=len(z)
        
    # set calculator on molecule
    at.set_calculator(calculator)
        
    # convert qm9 data to machine-readable form
    inputs = converter(at)
        
    #Instatiate layer output
    layer_output=None
        
    # Forward hook the model's interaction layer 
    model.representation.interactions[0].register_forward_hook(hook_layer)
        
    # Forward pass molecules through the model
    model(inputs)

#    from schnetpack.nn.blocks import ret
#    print(ret)
    
    #yi=yi.detach().numpy()
    #print(yi)
    #convert layer tensor to 2D array
    rows = convert_2D(number_of_atoms,layer_output)

    #save the vector of every oxygen atom encountered
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            data = np.vstack((data,rows[i]))
    
#    for i in range(number_of_atoms):
#        if props['_atomic_numbers'][i] == 1:
#            datay = np.vstack((datay,yi[0][i]))
            
    
            
savetxt('../../../data/test.csv',data,delimiter=',')
savetxt('../../../data/ae.csv',datay,delimiter=',')

None


RuntimeError: Error(s) in loading state_dict for AtomisticModel:
	Missing key(s) in state_dict: "representation.embedding.weight", "representation.distance_expansion.width", "representation.distance_expansion.offsets", "representation.interactions.0.filter_network.0.weight", "representation.interactions.0.filter_network.0.bias", "representation.interactions.0.filter_network.1.weight", "representation.interactions.0.filter_network.1.bias", "representation.interactions.0.cutoff_network.cutoff", "representation.interactions.0.cfconv.in2f.weight", "representation.interactions.0.cfconv.f2out.weight", "representation.interactions.0.cfconv.f2out.bias", "representation.interactions.0.cfconv.filter_network.0.weight", "representation.interactions.0.cfconv.filter_network.0.bias", "representation.interactions.0.cfconv.filter_network.1.weight", "representation.interactions.0.cfconv.filter_network.1.bias", "representation.interactions.0.cfconv.cutoff_network.cutoff", "representation.interactions.0.dense.weight", "representation.interactions.0.dense.bias", "representation.interactions.1.filter_network.0.weight", "representation.interactions.1.filter_network.0.bias", "representation.interactions.1.filter_network.1.weight", "representation.interactions.1.filter_network.1.bias", "representation.interactions.1.cutoff_network.cutoff", "representation.interactions.1.cfconv.in2f.weight", "representation.interactions.1.cfconv.f2out.weight", "representation.interactions.1.cfconv.f2out.bias", "representation.interactions.1.cfconv.filter_network.0.weight", "representation.interactions.1.cfconv.filter_network.0.bias", "representation.interactions.1.cfconv.filter_network.1.weight", "representation.interactions.1.cfconv.filter_network.1.bias", "representation.interactions.1.cfconv.cutoff_network.cutoff", "representation.interactions.1.dense.weight", "representation.interactions.1.dense.bias", "representation.interactions.2.filter_network.0.weight", "representation.interactions.2.filter_network.0.bias", "representation.interactions.2.filter_network.1.weight", "representation.interactions.2.filter_network.1.bias", "representation.interactions.2.cutoff_network.cutoff", "representation.interactions.2.cfconv.in2f.weight", "representation.interactions.2.cfconv.f2out.weight", "representation.interactions.2.cfconv.f2out.bias", "representation.interactions.2.cfconv.filter_network.0.weight", "representation.interactions.2.cfconv.filter_network.0.bias", "representation.interactions.2.cfconv.filter_network.1.weight", "representation.interactions.2.cfconv.filter_network.1.bias", "representation.interactions.2.cfconv.cutoff_network.cutoff", "representation.interactions.2.dense.weight", "representation.interactions.2.dense.bias", "representation.interactions.3.filter_network.0.weight", "representation.interactions.3.filter_network.0.bias", "representation.interactions.3.filter_network.1.weight", "representation.interactions.3.filter_network.1.bias", "representation.interactions.3.cutoff_network.cutoff", "representation.interactions.3.cfconv.in2f.weight", "representation.interactions.3.cfconv.f2out.weight", "representation.interactions.3.cfconv.f2out.bias", "representation.interactions.3.cfconv.filter_network.0.weight", "representation.interactions.3.cfconv.filter_network.0.bias", "representation.interactions.3.cfconv.filter_network.1.weight", "representation.interactions.3.cfconv.filter_network.1.bias", "representation.interactions.3.cfconv.cutoff_network.cutoff", "representation.interactions.3.dense.weight", "representation.interactions.3.dense.bias", "representation.interactions.4.filter_network.0.weight", "representation.interactions.4.filter_network.0.bias", "representation.interactions.4.filter_network.1.weight", "representation.interactions.4.filter_network.1.bias", "representation.interactions.4.cutoff_network.cutoff", "representation.interactions.4.cfconv.in2f.weight", "representation.interactions.4.cfconv.f2out.weight", "representation.interactions.4.cfconv.f2out.bias", "representation.interactions.4.cfconv.filter_network.0.weight", "representation.interactions.4.cfconv.filter_network.0.bias", "representation.interactions.4.cfconv.filter_network.1.weight", "representation.interactions.4.cfconv.filter_network.1.bias", "representation.interactions.4.cfconv.cutoff_network.cutoff", "representation.interactions.4.dense.weight", "representation.interactions.4.dense.bias", "output_modules.0.atomref.weight", "output_modules.0.out_net.1.out_net.0.weight", "output_modules.0.out_net.1.out_net.0.bias", "output_modules.0.out_net.1.out_net.1.weight", "output_modules.0.out_net.1.out_net.1.bias", "output_modules.0.standardize.mean", "output_modules.0.standardize.stddev". 
	Unexpected key(s) in state_dict: "epoch", "step", "best_loss", "optimizer", "hooks", "model". 

In [None]:
model_load_state_