In [1]:
import torch
import schnetpack as spk
import math
from schnetpack.datasets import QM9

import numpy as np
from numpy import savetxt

In [2]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return z

def hook_v0(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v0
    #Update the embedding_output variable to be equal to our output tensor
    v0=out_tensor 

def hook_v1(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v1
    #Update the embedding_output variable to be equal to our output tensor
    v1=out_tensor 

def hook_v2(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global v2
    #Update the embedding_output variable to be equal to our output tensor
    v2=out_tensor  
    
def hook_emb(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global emb
    #Update the embedding_output variable to be equal to our output tensor
    emb=out_tensor 
#def convert_2D(number_of_atoms,rep):
#    layer = np.zeros((number_of_atoms,30))
#    for i in range(number_of_atoms):
#        for j in range(30):
#            layer[i][j] = rep[0][i][j]
#    return layer
def load_checkpoint(qm9data):
    #Load split file 
    train, val, test = spk.data.train_test_split(qm9data,split_file=split_file)

    # Load atom ref data 
    atomrefs = qm9data.get_atomref(QM9.mu)
    # Define SchNet representation model

    schnet = spk.representation.SchNet(
        n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=3,
        cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
    )


    train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
    val_loader = spk.AtomsLoader(val, batch_size=100)
    # Define SchNet output model and property to be predicted
    means, stddevs = train_loader.get_statistics(QM9.mu, divide_by_atoms=True, single_atom_ref=atomrefs)

    output_dip = spk.atomistic.DipoleMoment(n_in=30, property=QM9.mu,
                                   mean=means[QM9.mu], contributions=None, stddev=stddevs[QM9.mu])
    # Define atomistic model

    model = spk.AtomisticModel(representation=schnet,output_modules=output_dip)

    # Load saved checkpoint file
    load_checkpoint = torch.load(checkpoint_path)

    #qm9_i6_30f_20g-1000-500-4_300.pth
    # load model's state dictionary from saved checkpoint
    model.load_state_dict(load_checkpoint)
    
    return model


In [3]:
checkpoint_path = '../../../../data/trained_models/qm9_dipole_trained/trained.pth'
split_file ='../../../../data/trained_models/qm9_dipole_trained/split.npz'
model_file = "../../../../data/trained_models/qm9_dipole_trained/best_model"
number_of_inputs = 5000

# Load QM9 dataset
qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)

model = load_checkpoint(qm9data)

#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

datao = np.zeros((1,30))
datahae = np.zeros((1))
dataoae = np.zeros((1))
datah = np.zeros((1,30))
for idx in range(0,1):
        
    # load data for molecule
    at, props = qm9data.get_properties(idx)
        
    # print molecule for identification
    print(idx)
    z = print_molecule(props)
    number_of_atoms=len(z)
        
    # convert qm9 data to machine-readable form
    inputs = converter(at)
        
    #Instatiate layer output
    v0=None
    v1=None
    v2=None
    x=None       
    
    # Forward hook the model's interaction layer 
    model.representation.interactions[0].register_forward_hook(hook_v0)
        
    # Forward hook the model's interaction layer 
    model.representation.interactions[1].register_forward_hook(hook_v1)
        
    # Forward hook the model's interaction layer 
    model.representation.interactions[2].register_forward_hook(hook_v2)
    
    # Forward hook the model's interaction layer 
    model.representation.embedding.register_forward_hook(hook_emb)
 
#    model = torch.load(model_file, map_location=torch.device('cpu'))
    # Forward pass molecules through the model
    pred = model(inputs)
    
    print('Prediction:', pred[QM9.mu].detach().cpu().numpy()[0,0])
    print('Keys:', list(inputs.keys()))
    print('Truth:', props[QM9.mu].cpu().numpy()[0])

    rep = emb + v0 + v1 + v2
    rows = np.zeros((number_of_atoms,30))
    for i in range(number_of_atoms):
        for j in range(30):
            rows[i][j] = rep[0][i][j]
            
    from schnetpack.atomistic.output_modules import y
    from schnetpack.atomistic.output_modules import yi
    from schnetpack.atomistic.output_modules import result
    from schnetpack.atomistic.output_modules import charges
    
    charges = charges.detach().numpy()
    yi=yi.detach().numpy()
#    
    print(charges[0][0])
    print(yi)
    print(y)
    print(result)
    
    dip = np.zeros((number_of_atoms))
    for i in range(number_of_atoms):
        dip[i] = math.sqrt((yi[0][i][0])**2 + (yi[0][i][1])**2 + (yi[0][i][2])**2)
    
    print(dip)
            
        
    #save the vector of every oxygen atom encountered
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            datao = np.vstack((datao,rows[i]))
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datah = np.vstack((datah,rows[i]))    
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 1:
            datahae = np.vstack((datahae,dip[i]))
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            dataoae = np.vstack((dataoae,dip[i]))            



  properties[pname] = torch.FloatTensor(prop)


0
C -2.8340169e-06 2.3049886e-06 -1.4378233e-07
H 0.014845718 -1.0918331 -0.0060250196
H 1.0244261 0.3779494 -0.007724565
H -0.52811974 0.36172476 -0.88464487
H -0.5111183 0.3521308 0.89839613

Prediction: -0.0033174977
Keys: ['_atomic_numbers', '_positions', '_cell', '_neighbors', '_cell_offset', '_atom_mask', '_neighbor_mask', 'representation']
Truth: 0.0
[-0.12051509]
[[[ 1.5303171e-03 -1.3085578e-01 -9.6424075e-04]
  [ 2.0532962e-04 -5.7589239e-04  1.8868723e-04]
  [ 9.6604437e-02  1.3976531e-01  2.6408561e-05]
  [-5.1640589e-02  1.3821939e-01 -8.3707727e-02]
  [-5.0016992e-02  1.3730277e-01  8.6548470e-02]]]
tensor([[-0.0033,  0.2839,  0.0021]], grad_fn=<SumBackward1>)
{'dipole_moment': tensor([[-0.0033,  0.2839,  0.0021]], grad_fn=<SumBackward1>)}
[0.13086828 0.00063986 0.1699022  0.16964178 0.16983636]


In [None]:
print('DONE')            
savetxt('../../../../data/dataO.csv',datao,delimiter=',')
savetxt('../../../../data/dataH.csv',datah,delimiter=',')
savetxt('../../../../data/hae.csv',datahae,delimiter=',')
savetxt('../../../../data/oae.csv',dataoae,delimiter=',')