In [1]:
'''
This script hooks a layer from a loaded trained SchNet model for a forward pass (to output
values of intermediatte trained layers)
'''

import torch
import schnetpack as spk
import math

import numpy as np
from numpy import savetxt
import pandas as pd
from xlwt import Workbook

# Define important functions

def interactions_hook(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global interactions_output
    #Update the interactions_output variable to be equal to our output tensor
    interactions_output=out_tensor 

def print_atom_coordinates(idx,props):
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()    
    # print molecule number in set, idx, and atom coordinates for molecule identification
    print(idx)
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')
    
def get_number_of_atoms(idx,props):
    z = props['_positions'][ :,2]
    number_of_atoms=len(z)
    return number_of_atoms
    

In [2]:
# Initiate 
# a) QM9 dataset and  
# b) SchNet model architechure 
# as it is exactly saved checkpoint file that will be loaded later

# a) Load QM9 dataset
from schnetpack.datasets import QM9
qm9data = QM9('./qm9.db', remove_uncharacterized=True)

# split the data set according to saved split file (after training)
train, val, test = spk.data.train_test_split(qm9data, split_file='../data/trained_models/qm9_i6_30f/split.npz')
train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)

# load the test, train and validation dataset loaders
test_loader = spk.AtomsLoader(test, batch_size=100)

# load some atom reference data that were used in inititial training conditions
atomrefs = qm9data.get_atomref(QM9.U0)

# get statistics on trained dataset
means, stddevs = train_loader.get_statistics(
    QM9.U0, divide_by_atoms=True, single_atom_ref=atomrefs)

# b) load SchNet representation model as it is exactly saved in checkpoint file
schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
    cutoff=4., cutoff_network=spk.nn.cutoff.CosineCutoff
)
output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0], property=QM9.U0,
                                   mean=means[QM9.U0], contributions=True, stddev=stddevs[QM9.U0])


# load atomistic model
model = spk.AtomisticModel(representation=schnet, output_modules=output_U0)


  properties[pname] = torch.FloatTensor(prop)


In [3]:
# Load saved checkpoint file of model
checkpoint_path = '../data/trained_models/qm9_i6_30f/qm9_i6_30f_20g-1000-500-4_300.pth'
load_checkpoint = torch.load(checkpoint_path)


# load state dictionary from loaded checkpoint
model.load_state_dict(load_checkpoint)


#Hook our function to the interaction layer during the forward pass
model.representation.interactions[0].register_forward_hook(interactions_hook)


<torch.utils.hooks.RemovableHandle at 0x7f86af343e10>

In [4]:

#set up device that will be used in forward pass and atoms converter for input
device='cpu'

#load the test dataset
test_loader = spk.AtomsLoader(test, batch_size=100)

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

# load spk calculator
calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)


In [6]:
# how many inputs do you want to evaluate?
print('how many molecules do you want to evaluate?')
number_of_inputs = input()
number_of_inputs = int(number_of_inputs)

#initiate data to hold specified number of inputs X 30 for dimension of trained layer
data = np.zeros((1,30))
index=0



# for a number of inputs, run a forward pass on interesting trained layers of SchNet model
for idx in range(number_of_inputs):
    
    #load atomistic data and properties from qm9data
    at, props = qm9data.get_properties(idx)
    
    #print molecule number in dataset and atom coordinates
    print_atom_coordinates(idx,props)

    # set calculator to atomistic data
    at.set_calculator(calculator)
    
    # convert atomistic data to inputs for the model
    inputs = converter(at)

    #We choose None to instatiate the variable originally
    interactions_output=None

    #Forward pass of the tensor inputs
    model(inputs)
    
    number_of_atoms = get_number_of_atoms(idx,props)
    
    #reshape interaction tensor to a 2D array
    interactions_output = interactions_output.reshape(number_of_atoms,30)
    
    #store data for each oxygen found
    for i in range(number_of_atoms):
        if props['_atomic_numbers'][i] == 8:
            place=i
            for j in range(30):
                data[index,j] = interactions_output[place,j]
            index = index+1
            data = np.vstack((data,np.zeros(30)))
            break

#save to csv
savetxt('../data/data.csv',data[ : (len(data)-1)],delimiter=',')

            



how many molecules do you want to evaluate?
300
0
C -2.8340169e-06 2.3049886e-06 -1.4378233e-07
H 0.014845718 -1.0918331 -0.0060250196
H 1.0244261 0.3779494 -0.007724565
H -0.52811974 0.36172476 -0.88464487
H -0.5111183 0.3521308 0.89839613

1
N -0.03160822 0.021159047 0.060332563
H 0.026075294 -0.9904034 -0.029608395
H 0.9246072 0.35579658 -0.030988995
H -0.5114599 0.34058344 -0.77777386

2
O -0.056248687 0.035458494 0.0007260183
H 0.042878296 -0.9215089 -0.00534094
H 0.84990215 0.35871136 -0.0061824406

3
C 0.5995395 0.0 0.0
C -0.5995395 0.0 0.0
H -1.6616386 0.0 0.0
H 1.6616386 0.0 0.0

4
C -0.0075630867 0.5570853 0.003070365
N 0.008071566 -0.59453905 -0.0032767907
H -0.022041854 1.6235689 0.008948269

5
C -0.008137181 0.5992856 0.002790173
O 0.0081545 -0.60058945 -0.0028012523
H 0.9208008 1.2085854 -0.0010057696
H -0.9532692 1.1830921 0.012220519

6
C -0.010403622 0.7647314 0.004216852
C 0.010404125 -0.7647325 -0.004217135
H 1.0031731 1.1788927 -0.0032747546
H -0.53377575 1.1627601 

43
C -0.5051477 1.4701715 0.23591569
C -0.5161165 -0.03155866 0.11930383
C 0.2842813 -0.74945444 -0.8804299
O 0.6834387 -0.71516526 0.49098954
H 0.42997557 1.8810065 -0.15465757
H -1.3402644 1.9064294 -0.32289335
H -0.6035613 1.777829 1.282179
H -1.4245102 -0.5216416 0.47502917
H -0.044245012 -1.7201506 -1.249905
H 0.91669863 -0.18413898 -1.5645137

44
C -0.5746446 1.3396131 -0.24876487
N -0.502484 -0.11268311 -0.19540729
C 0.8032212 -0.6990511 -0.47296318
C 0.25862837 -0.6751625 0.91367173
H 0.3459528 1.8371781 0.10421691
H -0.76034075 1.6563674 -1.2801846
H -1.4087237 1.684807 0.370748
H 0.7980399 -1.6164504 -1.0547763
H 1.6301389 -0.024380542 -0.6916084
H 0.72374654 0.0153784 1.6162366
H -0.15175563 -1.5747871 1.3635777

45
O 0.58268833 1.3023893 -0.39443073
C 0.3831935 -0.07980648 -0.3166366
C -0.023902148 -0.7664494 0.95567703
C -0.98426884 -0.69428307 -0.22474378
H -0.06951088 1.7391492 0.16258794
H 1.1167241 -0.5755198 -0.9437273
H 0.42650628 -1.7198842 1.2098128
H -0.24678102 -

66
C 0.47486082 1.4811767 -0.30994534
C 0.60074914 -0.014396568 -0.02700048
O 1.1117369 -0.3123183 1.2625037
C -0.7467008 -0.722982 -0.108747035
O -1.5888656 -0.47558346 -0.930893
H 1.4612727 1.9517622 -0.37279093
H -0.047447965 1.6337577 -1.2576145
H -0.094751164 1.9686095 0.48647782
H 1.2265673 -0.47986794 -0.8108536
H 1.9857252 0.08151487 1.339327
H -0.8775402 -1.513038 0.66285294

67
C -0.11099364 1.4434184 0.47809538
N -0.002798127 -0.002016144 0.43176952
C 0.040001623 -0.7200755 1.6851245
C 0.0540585 -0.6464266 -0.76813227
O 0.023135666 -0.11581948 -1.8586168
H 0.74327797 1.8817167 1.0089498
H -0.12971443 1.8105667 -0.5480162
H -1.0296397 1.7494612 0.99398965
H -0.86776423 -0.5423311 2.2763772
H 0.11974856 -1.7927177 1.4909891
H 0.90314484 -0.4131328 2.290089
H 0.13439374 -1.7437677 -0.63557583

68
C 0.018523093 1.929746 0.379988
C -0.019834245 0.42132896 0.32397303
O -0.13818796 -0.2750277 1.3099643
C 0.09629224 -0.24921714 -1.0405195
O 0.044388887 -1.6388092 -0.9297619
H 0.9562

H -1.6964426 0.6352329 -0.8339263
H -1.8139421 0.20165703 0.8795269
H -0.03536181 2.3933592 -0.25162515
H -0.0054291748 1.831497 1.4278737
H 1.8307295 0.2503917 0.8308742
H 1.6559936 0.68006015 -0.87867796

87
O 0.056591343 -1.7825387 -0.019702515
C 0.0003466606 -0.5835165 -0.006592764
C -1.1206632 0.4892062 0.0019281352
C 0.034038052 1.5315369 0.0173767
N 0.96318835 0.39859927 0.0073430813
H -1.7455553 0.48987544 -0.89248526
H -1.7515781 0.46973193 0.8918785
H 0.09645564 2.1707134 -0.86892277
H 0.09097603 2.1503446 0.91837865
H 1.9709184 0.34743536 0.010358569

88
O 0.061836094 -1.6717062 -0.5053475
C -0.030948734 -0.5368763 -0.1705357
C -1.1245205 0.52507997 -0.021252844
C 0.04092143 1.4125316 0.44217977
O 0.97423416 0.30086258 0.251401
H -1.6081415 0.8037834 -0.9586475
H -1.8824961 0.28449526 0.72565013
H 0.30127683 2.250689 -0.20692933
H 0.025418557 1.7283881 1.4869908

89
O 0.021096226 -1.7956383 -0.2329538
C 0.006837415 -0.6101543 -0.10412394
C -1.0861295 0.46382973 0.036955833
N

129
C -0.6239806 2.2652826 1.0627819
C -0.58159614 0.7355562 1.0160365
C 0.12677556 0.1890561 -0.22814913
C 0.17512423 -1.3416008 -0.28471738
C 0.88449264 -1.8769001 -1.5314317
H 0.38654172 2.6891649 1.0629447
H -1.1511468 2.673115 0.19294737
H -1.1357703 2.6259146 1.9606858
H -0.07832432 0.35471094 1.9147799
H -1.6049542 0.3387432 1.0526093
H -0.37609452 0.5697205 -1.1285543
H 1.1511855 0.5861045 -0.2662534
H 0.67813075 -1.7206153 0.61494565
H -0.84873635 -1.7369944 -0.24679668
H 0.3827926 -1.542168 -2.4463573
H 1.9214188 -1.525545 -1.5780299
H 0.90355206 -2.9712896 -1.5442541

130
C 1.1611094 2.325234 0.0014388984
C 1.1576172 0.79406464 -0.006777013
C -0.2546311 0.20148793 0.003528478
C -0.2636177 -1.3208859 -0.004598919
O -1.6161516 -1.7545445 0.005965832
H 2.1807752 2.7228246 -0.006167193
H 0.63985544 2.726048 -0.8748564
H 0.6567447 2.7165306 0.8918063
H 1.6993815 0.43245208 -0.89097977
H 1.716145 0.42300364 0.86296266
H -0.80204517 0.547377 0.8892867
H -0.81884295 0.556854 -0.8678

167
O -0.88570654 0.6970377 0.012255961
C 0.38539004 1.0800102 0.0021207342
C 1.1785717 -0.02061156 -0.011417538
N 0.33333373 -1.1075952 -0.009211886
N -0.86664766 -0.7420511 0.0043596397
H 0.5836886 2.139367 0.005927715
H 2.2494488 -0.12383694 -0.022250913

168
O -0.9004568 0.65601474 0.012166034
C 0.3831212 1.015315 0.0017471646
N 1.2172903 0.023664247 -0.011512083
C 0.35464114 -1.0484617 -0.008951138
N -0.9136649 -0.7442856 0.0046770107
H 0.59869295 2.0735366 0.0053549297
H 0.68331486 -2.077219 -0.017635273

169
O -0.9254665 0.6734253 0.012426243
C 0.38508758 1.0192392 0.0017217763
N 1.1820408 0.006856728 -0.011373053
N 0.35711917 -1.1268427 -0.009369129
C -0.8513255 -0.6799761 0.004618929
H 0.63063556 2.068328 0.0048841974
H -1.77398 -1.2363613 0.010562299

170
O -0.8930561 0.6964472 0.01239512
C 0.38589993 1.0195645 0.0017412509
N 1.1878803 0.0005344148 -0.011306899
N 0.32652298 -1.0809288 -0.008971691
N -0.87256485 -0.7378602 0.004280037
H 0.6574695 2.063337 0.004829375

171
O -0

212
C -1.1270889 0.63044655 0.014579304
N -0.018633228 1.3722496 0.0076402826
C 1.109553 0.660903 -0.0073827957
N 1.1977264 -0.66994673 -0.015491148
C 0.017533202 -1.2913334 -0.0071896757
N -1.1790891 -0.7023209 0.007844814
H -2.076441 1.1614954 0.02685377
H 2.0441365 1.2176301 -0.013603471
H 0.032277603 -2.3790653 -0.013247637

213
C -0.18469812 1.6604778 0.29684693
C -0.166823 0.11517793 0.29344603
C 0.5847124 -0.3995723 1.5419301
C -1.6157839 -0.42274737 0.29909807
C 0.51985365 -0.3581485 -0.9141149
C 1.0824183 -0.74648625 -1.903469
H 0.8322073 2.0626292 0.2894881
H -0.70966786 2.0472996 -0.58101314
H -0.6947716 2.0276537 1.193858
H 0.080908276 -0.054983836 2.451431
H 0.6144539 -1.4926202 1.5573354
H 1.6143392 -0.031382766 1.5577376
H -2.1661525 -0.072723635 -0.57870704
H -1.6271929 -1.5162842 0.2932877
H -2.1409266 -0.077116154 1.196023
H 1.5791768 -1.0896446 -2.7768614

214
C -0.18502106 1.664808 0.29691824
C -0.17505741 0.12041812 0.30761725
C 0.5867131 -0.40136307 1.5457551
C -1

228
C 0.08036236 2.9976254 0.12178206
C 0.079849005 1.5430411 0.18126272
C 0.09050402 0.34002477 0.2447114
C 0.0884185 -1.1263555 0.26616716
C -1.3318263 -1.6922961 0.12158684
O 0.96843016 -1.6716083 -0.71692324
H 0.91577727 3.368351 -0.4823125
H -0.8456874 3.382235 -0.3194757
H 0.17790012 3.4338155 1.1220255
H 0.50518066 -1.4697646 1.2210897
H -1.9741023 -1.3463966 0.9368345
H -1.7780621 -1.3673269 -0.8239179
H -1.2920516 -2.7847133 0.13512275
H 0.74869937 -1.2549838 -1.5575789

229
C 0.98832166 1.3326411 0.43708977
C -0.4141936 0.76813895 0.45810932
O -1.2823974 1.1874572 1.18213
C -0.7363391 -0.36790487 -0.53935206
C 0.34427404 -1.3060999 -0.7997951
C 1.2414796 -2.0802298 -1.0006558
H 1.6983427 0.56856734 0.7694572
H 1.2798138 1.6022574 -0.5839898
H 1.0437311 2.204231 1.0895425
H -1.6291453 -0.878017 -0.1654198
H -1.0306201 0.12622719 -1.4774117
H 2.0296488 -2.7685916 -1.1815431

230
C 0.23931547 2.3464477 0.010852831
C 0.40333015 0.84032685 0.0008868977
O 1.471841 0.28993654 -0.012

H -0.026073365 2.1432781 -1.0701194
H -0.07194969 2.5527952 0.6552363
H 1.0960821 -0.016249089 -0.56616944
H 1.397896 -0.7932973 1.7751057
H 2.3300712 0.69324404 1.5130659
H 0.7988986 0.7569708 2.3952205
H -1.413277 0.24651828 0.95983624
H -2.19865 -1.4003334 -0.34419122

250
C 0.97945106 1.8619744 0.2776861
C 0.9254399 0.34188372 0.17995349
C 1.5941969 -0.36501822 1.3530612
O -0.46406925 -0.055093344 0.15236567
C -0.9299842 -0.7034941 -0.9270798
O -2.0638669 -1.0659494 -1.0298145
H 2.0184174 2.2039135 0.30476978
H 0.48487094 2.3250544 -0.5804722
H 0.47842288 2.2031734 1.1883858
H 1.3942621 0.017550142 -0.759683
H 1.1101925 -0.084252864 2.2931292
H 1.5268229 -1.450751 1.242958
H 2.6509824 -0.08759621 1.4109092
H -0.15311097 -0.862283 -1.702048

251
C -0.5427133 2.2461572 -0.42692298
C -0.52146536 0.71909356 -0.4536726
O -0.25602046 0.21887927 -1.7611507
C 0.4710687 0.13648511 0.5719813
C 0.42918 -1.3763388 0.6297505
O 0.40406886 -2.0158904 1.6508129
H 0.45397422 2.6516583 -0.64302766
H

266
C 1.0671483 2.636078 0.28079134
C 1.0982454 1.1079654 0.18121542
C -0.30262643 0.4848633 0.17112583
C -0.3402794 -1.0543104 0.16371079
C -1.7769314 -1.5576874 0.36016494
C 0.26792848 -1.6523954 -1.1132302
H 2.0771258 3.0581205 0.286715
H 0.5275047 3.07686 -0.5650061
H 0.5640135 2.9634848 1.1976414
H 1.6445446 0.8182063 -0.72425056
H 1.6693957 0.6982945 1.0253797
H -0.84943134 0.84243214 1.054674
H -0.8614499 0.8602632 -0.6990819
H 0.25767404 -1.4049951 1.0181042
H -2.4211926 -1.234197 -0.46645796
H -1.814308 -2.6515388 0.39914566
H -2.211751 -1.1731912 1.2889273
H 1.3188163 -1.3763007 -1.2389427
H 0.21471412 -2.7461674 -1.0969689
H -0.27633816 -1.3084215 -2.0015244

267
C 1.1138065 2.6260688 0.20076706
C 1.093152 1.0956389 0.20655648
C -0.32212654 0.5118986 0.2555211
C -0.35663202 -1.0204629 0.1879417
C -1.7742789 -1.5790719 0.31584442
O 0.28105596 -1.5076588 -0.99328864
H 2.1379519 3.011109 0.1697528
H 0.582271 3.0270567 -0.6698729
H 0.63186646 3.035312 1.0960472
H 1.6011951 0.706

C -0.02199721 0.23552921 0.83303225
O -1.3943117 -0.16036472 0.78649557
C 0.69916433 -0.2601673 -0.43976086
C 0.8485837 -1.7789257 -0.4645719
O 0.0022761826 0.21100542 -1.5949553
H 1.1561235 2.0277576 1.1953933
H -0.23148678 2.2672613 0.12393762
H -0.48216024 2.076447 1.8778484
H 0.3964437 -0.28624654 1.7021883
H -1.8304284 0.4469871 0.1782824
H 1.6884509 0.21038988 -0.49135464
H 1.2783315 -2.1054163 -1.4154559
H 1.498258 -2.1243324 0.3461822
H -0.1300747 -2.254111 -0.33391994
H -0.77410305 -0.35208535 -1.6906849

284
C -0.7535421 1.7136879 -0.112333685
C -0.35279915 0.24305102 -0.059888024
C -1.5667529 -0.6610507 -0.28865868
C 0.862387 -0.17471917 -0.9524816
C 1.4517901 -0.9636141 0.24691969
C 0.47989786 -0.20263116 1.1874754
H 0.11037136 2.3670497 0.050601505
H -1.1888041 1.9718366 -1.0853764
H -1.5000386 1.9491243 0.65592307
H -2.0214562 -0.46752238 -1.2672526
H -1.293422 -1.7211709 -0.25363174
H -2.333118 -0.49026567 0.47643724
H 0.63533634 -0.7357681 -1.8637669
H 1.4817307 0.68577