In [22]:
import torch
import schnetpack as spk
import math


import numpy as np
from numpy import savetxt

# Define functions

In [23]:
def print_molecule(props):
    
    # load x,y,z coordinates tensors
    x = props['_positions'][ :,0]
    y = props['_positions'][ :,1]
    z = props['_positions'][ :,2]
    x = x.numpy()
    y = y.numpy()
    z = z.numpy()
    
    for i in range(len(z)):
        if props['_atomic_numbers'][i] == 1:
            print('H',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 6:
            print('C',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 7:
            print('N',x[i],y[i],z[i])     
        if props['_atomic_numbers'][i] == 8:
            print('O',x[i],y[i],z[i])
        if props['_atomic_numbers'][i] == 9:
            print('F',x[i],y[i],z[i]) 
    print('')  
    return z

In [24]:
def hook_layer(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer_output
    #Update the embedding_output variable to be equal to our output tensor
    layer_output=out_tensor 

In [25]:
def convert_2D(number_of_atoms,layer_output):
    layer = np.zeros((number_of_atoms,30))
    for i in range(number_of_atoms):
        for j in range(30):
            layer[i,j] = layer_output[0][i][j]
    return layer

# Load model and run forward passes using molecules of QM9 dataset

In [26]:
# Load QM9 dataset
from schnetpack.datasets import QM9

qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)
# Load split file 
train, val, test = spk.data.train_test_split(qm9data,split_file='./trained_models/qm9_i6_30f-moredata/split.npz')

# Load atom ref data 
atomrefs = qm9data.get_atomref(QM9.U0)
# Define SchNet representation model

schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=6,
    cutoff=4. , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted

output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0])

# Define atomistic model

model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

# Load saved checkpoint file
checkpoint_path = './trained_models/qm9_i6_30f-moredata/qm9_benchmark.pth'
load_checkpoint = torch.load(checkpoint_path)

#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint,strict=False)


number_of_inputs = 5000

#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

# load spk calculator
calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)

data = np.zeros((1,30))
datahae = np.zeros((1))
dataoae = np.zeros((1))
datah = np.zeros((1,30))
for idx in range(number_of_inputs):
        
    # load data for molecule
    at, props = qm9data.get_properties(idx)
        
    # print molecule for identification
    print(idx)
    z = print_molecule(props)
    number_of_atoms=len(z)
        
    # set calculator on molecule
    at.set_calculator(calculator)
        
    # convert qm9 data to machine-readable form
    inputs = converter(at)
        
    #Instatiate layer output
    layer_output=None
        
    # Forward hook the model's interaction layer 
    model.output_modules[0].out_net[1].out_net[0].register_forward_hook(hook_layer)
        
    # Forward pass molecules through the model
    model(inputs)
    
    layer_output = layer_output.detach().numpy()
    
    print(layer_output[0])
    
    from schnetpack.atomistic.output_modules import yi
#    print(ret)
    
    yi=yi.detach().numpy()
    
    #convert layer tensor to 2D array
#    rows = convert_2D(number_of_atoms,layer_output)

    #save the vector of every oxygen atom encountered
#    for i in range(number_of_atoms):
#        if props['_atomic_numbers'][i] == 8:
#            data = np.vstack((data,rows[i]))
            
#    for i in range(number_of_atoms):
#        if props['_atomic_numbers'][i] == 1:
#            datah = np.vstack((datah,rows[i]))    
#    for i in range(number_of_atoms):
#        if props['_atomic_numbers'][i] == 1:
#            datahae = np.vstack((datahae,yi[0][i]))
#    for i in range(number_of_atoms):
#        if props['_atomic_numbers'][i] == 8:
#            dataoae = np.vstack((dataoae,yi[0][i]))            
    
            
savetxt('../../../data/data.csv',data,delimiter=',')
savetxt('../../../data/datah.csv',datah,delimiter=',')
savetxt('../../../data/hae.csv',datahae,delimiter=',')
savetxt('../../../data/oae.csv',dataoae,delimiter=',')

None
0
C -2.8340169e-06 2.3049886e-06 -1.4378233e-07
H 0.014845718 -1.0918331 -0.0060250196
H 1.0244261 0.3779494 -0.007724565
H -0.52811974 0.36172476 -0.88464487
H -0.5111183 0.3521308 0.89839613

[[ 9.7616017e-02 -1.2053251e-03 -6.9095653e-01 -4.4941849e-01
  -2.3959714e-01 -6.9309735e-01  1.6046619e+01 -6.9309264e-01
  -6.9311059e-01  9.9842596e-01 -6.9041473e-01 -5.8983636e-01
  -6.9314468e-01  1.3854115e+00 -6.9047785e-01]
 [ 1.2093080e+00  1.9985193e-01 -4.6824938e-01  4.1875629e+00
   8.7090671e-01 -6.2404108e-01  4.8403955e+00 -6.1168873e-01
  -6.5266472e-01 -6.8837756e-01 -4.1238487e-01  1.1816982e+00
  -6.8977404e-01 -6.8532079e-01 -6.8048668e-01]
 [ 1.2092948e+00  1.9984829e-01 -4.6824929e-01  4.1875606e+00
   8.7088680e-01 -6.2404293e-01  4.8404016e+00 -6.1168849e-01
  -6.5266412e-01 -6.8837756e-01 -4.1238403e-01  1.1816961e+00
  -6.8977410e-01 -6.8532068e-01 -6.8048668e-01]
 [ 1.2091664e+00  1.9981706e-01 -4.6822920e-01  4.1875725e+00
   8.7079811e-01 -6.2405694e-01  4.84



22
C 0.6809802 0.0 0.0
C -0.6809802 0.0 0.0
C -1.887666 0.0 0.0
C 1.887666 0.0 0.0
H -2.9496 0.0 0.0
H 2.9496 0.0 0.0

[[-0.1624192   4.8704157  -0.5512613  -0.55628437 -0.64472365  2.5482266
  -0.6910157   6.5321536   3.1657214   0.23737246 -0.14212161 -0.31702772
   6.588573    0.3871416   7.8035474 ]
 [-0.16241914  4.8704157  -0.55126125 -0.55628425 -0.6447236   2.5482264
  -0.6910157   6.5321546   3.1657214   0.23737198 -0.14212167 -0.31702766
   6.5885725   0.38714135  7.8035464 ]
 [-0.08215797  2.6024208  -0.64507365  0.02578586 -0.6667526   1.1588303
  -0.6823726   3.3255386   2.2138715   0.6548282  -0.4970417  -0.4078414
   2.5495436  -0.28397828  2.7311156 ]
 [-0.08215773  2.6024194  -0.6450736   0.02578604 -0.6667526   1.1588303
  -0.6823726   3.325539    2.2138717   0.65482795 -0.49704182 -0.40784138
   2.5495436  -0.2839783   2.7311149 ]
 [-0.31271973  0.89525163 -0.5153811   2.759934   -0.5710903  -0.65037405
  -0.64883274 -0.4531309   1.0619805  -0.5696076  -0.43547952 -0

[[ 1.23435020e-01 -3.52860749e-01  1.61016166e-01 -1.41006708e-03
  -2.66960919e-01 -6.27913237e-01  4.49116850e+00 -6.83367908e-01
  -6.65890634e-01  9.68042016e-02  3.09682012e-01  6.68881536e-01
  -6.82227433e-01 -3.30701232e-01 -5.57415664e-01]
 [-2.54440337e-01  3.72209787e-01 -5.78793168e-01  6.06904030e-02
   3.59609127e-01  3.65906119e-01  1.59746265e+00 -6.39483333e-01
  -4.87089574e-01 -3.48549306e-01 -4.46982265e-01  6.26545429e-01
  -6.44998252e-01 -7.96321630e-02  2.51838863e-01]
 [-3.63050342e-01  2.53015637e-01 -3.30360860e-01  2.54391611e-01
   1.57325745e-01 -5.00222862e-01  2.19448161e+00 -6.52704000e-01
  -5.32466292e-01 -1.83020532e-01 -3.06430310e-01  9.37516570e-01
  -6.13198400e-01  1.46441996e-01  2.27694809e-01]
 [-3.63025993e-01  2.53055394e-01 -3.30371261e-01  2.54395366e-01
   1.57084525e-01 -5.00173986e-01  2.19425678e+00 -6.52706206e-01
  -5.32458603e-01 -1.82952166e-01 -3.06406528e-01  9.37567353e-01
  -6.13189280e-01  1.46455824e-01  2.27613151e-01]
 [-1

[[ 8.7133288e-02 -1.8740511e-01  8.5218251e-02  3.4794068e-01
  -3.6644161e-01 -3.2377607e-01  2.9802737e+00 -6.9082326e-01
  -6.5929347e-01  1.2333058e+00  7.9564011e-01  6.7145944e-01
  -6.5642011e-01 -4.7376537e-01 -2.9282135e-01]
 [ 1.1406268e+00  3.7719929e-01 -6.4177090e-01  2.8769964e-01
  -4.4121000e-01  8.7804437e-02  9.0806329e-01 -6.2295300e-01
  -6.0961783e-01  2.1263342e+00  9.8416555e-01 -3.7566423e-02
  -6.8158323e-01 -4.2151669e-01 -5.3578615e-04]
 [ 8.7149382e-02 -1.8738776e-01  8.5276544e-02  3.4786010e-01
  -3.6647338e-01 -3.2377863e-01  2.9804356e+00 -6.9082338e-01
  -6.5929401e-01  1.2332031e+00  7.9562044e-01  6.7149997e-01
  -6.5642089e-01 -4.7374368e-01 -2.9285887e-01]
 [ 1.9207078e-01 -2.7695578e-01  1.2583542e-01  2.1031499e-02
  -2.5925174e-01 -3.1465214e-01  3.0705473e+00 -6.9000423e-01
  -6.3733995e-01  5.8691049e-01  7.7084959e-01  8.3621526e-01
  -6.7347437e-01 -4.6995807e-01 -3.5787204e-01]
 [-6.3225543e-01 -4.7843683e-01  2.6091105e-01 -6.0107517e-01
  

H -0.9655583 2.0404017 -0.09316861
H -1.2196443 1.8142298 1.6702896
H -0.9142307 -2.1782663 -0.37406856
H 0.726341 -2.2863443 -1.0582728
H 0.47658333 -2.2356112 0.70416516
H 1.3551111 -0.14688806 -1.755677

[[-0.34157485 -0.61556995 -0.33311176 -0.1532011   0.89428127 -0.6433973
   8.872276   -0.6915285  -0.69135535  5.45422    -0.49350935 -0.2555234
  -0.69300765 -0.58413875 -0.6726928 ]
 [-0.4270698   0.7864562   0.8585322  -0.21652731  0.09079921  0.13212699
  -0.6037868  -0.43232715 -0.06923902 -0.18306571 -0.2106485  -0.4984771
   4.289598    0.29463857  2.5227768 ]
 [-0.46228647  0.1709277   0.5558295   1.4757376  -0.5052199   2.872164
  -0.28545317  2.6875343   1.3515322   0.75682724  1.6297767   1.131064
  -0.45173144 -0.6923444   1.9519722 ]
 [ 0.83072424 -0.14868224 -0.38358772  0.2138868  -0.10700727 -0.50829244
   4.1764894  -0.68901825 -0.68766713  0.5066372  -0.35141864 -0.13513833
  -0.68150145  0.52723885 -0.26778248]
 [-0.6407158  -0.6479744   3.22986     0.19586504 -0

H -1.0120881 1.9578694 0.60887945
H -1.5794762 -0.24373159 -1.6352811
H -1.297035 -1.6578658 -0.60985535
H -2.2659814 -0.3059959 -0.00600338
H 1.2002164 -1.4826398 -0.92556787
H 2.0582016 0.066421434 -0.5048103
H 1.0757284 -0.0227479 1.8269489
H 0.2177659 -1.5718249 1.4062023

[[ 0.25079805 -0.23190594  0.2682109  -0.35456875 -0.41271704 -0.5859635
   3.8887472  -0.68408847 -0.6658174   0.07654727  0.89942575  0.7401079
  -0.67845905 -0.17872733 -0.5293784 ]
 [ 0.27337295  0.41594148 -0.62367153 -0.03934336  0.33827245  1.7347503
   0.6227498  -0.5712788  -0.5551797  -0.45979056 -0.20714447  0.23816025
  -0.6060019   0.19974548  0.8763343 ]
 [ 0.25080395 -0.23190892  0.26821172 -0.35456413 -0.41271526 -0.58596313
   3.8887205  -0.68408865 -0.6658175   0.07655829  0.8994435   0.7401067
  -0.67845905 -0.17873865 -0.52937835]
 [-0.39251637  0.47433734 -0.36725894 -0.09403747  0.08027571 -0.46464592
   1.6613586  -0.65440756 -0.5432514  -0.08590072 -0.25483328  0.7275847
  -0.57254666  0.4

H 1.6579394 0.4257211 1.5425892
H 2.478353 -0.58090097 0.33958152

[[ 3.77616525e-01 -4.42890465e-01  3.68345261e-01  2.44561434e-02
  -2.32737780e-01 -6.59640968e-01  5.11086893e+00 -6.83109224e-01
  -6.53779924e-01  4.13277149e-01  2.02998340e-01  5.66385865e-01
  -6.79146409e-01 -2.54985899e-01 -5.84951997e-01]
 [-4.81252760e-01  2.00898767e-01 -4.91943747e-01  3.28157187e-01
  -4.68295842e-01  3.33052278e-01  1.71939087e+00 -6.58117533e-01
  -5.25066018e-01  2.10170126e+00 -2.35066116e-01  9.27878141e-01
  -6.16402924e-01 -4.28174227e-01  1.03623629e-01]
 [-5.43738246e-01 -7.26640224e-04  4.21673059e-03  3.60038042e-01
  -3.58832985e-01 -4.42621708e-01  2.92422986e+00 -6.72951400e-01
  -5.92509687e-01  3.60499573e+00  1.74238324e-01  1.13851428e+00
  -6.05694413e-01 -3.89005661e-01  2.96676755e-01]
 [-1.26956701e-01 -2.27385938e-01  1.66965413e+00 -5.02986610e-01
   7.35612273e-01 -5.26168704e-01  2.16701031e+00 -6.67778313e-01
  -2.56937593e-01  4.86301303e-01 -4.77962673e-01 -1.7

C -1.2217767 -0.4254112 0.009063386
O -1.5203862 0.73662335 0.018326586
H 2.4145567 1.1987293 -0.016082177
H 2.1115346 -1.0306176 -0.024377037
H 0.10537184 -1.9650236 -0.011645644
H -1.9669752 -1.2483422 0.011400136

[[-0.61659014 -0.6816253   4.0713363  -0.08879328  1.5682399  -0.69118994
   6.7627177  -0.66918784 -0.31597018 -0.39746314 -0.6675099  -0.6626728
  -0.36881068  2.2135282   0.0237487 ]
 [-0.4038068   0.9736097  -0.4841168   0.8362864  -0.359879    0.75379074
   0.7135564   2.541642    0.24314326  2.256995   -0.27395838  1.6414652
  -0.19149345 -0.6906159   0.11778402]
 [ 0.3977021  -0.0596202   2.9576871  -0.5184517   6.7390013  -0.6521154
   3.2918365  -0.6905151  -0.67467296  0.5911155  -0.5464582  -0.5994819
   1.4408345   5.9645042   1.6729236 ]
 [-0.11988533 -0.45449513 -0.52060586  1.0015697  -0.40092352  3.5611548
   1.9311423   3.783648   -0.23477572  2.2097883  -0.04596633  1.0696492
  -0.6619516  -0.6929505  -0.57778674]
 [-0.29106992  0.18288815  0.4926064   0.

125
C 0.7109723 2.0540879 -0.039460648
C 0.78988546 0.6404294 0.5396623
N -0.50956607 -0.01078898 0.6203147
C -0.85790944 -1.0802972 -0.144141
O -0.13905254 -1.635817 -0.947671
H 1.70647 2.5077415 -0.073256
H 0.30594346 2.0348496 -1.0551943
H 0.070603155 2.7007704 0.5707422
H 1.406924 -0.0024695427 -0.09142478
H 1.2453208 0.66388947 1.537388
H -1.2089765 0.37682232 1.2338046
H -1.8995581 -1.4024749 0.05690479

[[ 7.3117244e-01 -4.5730376e-01 -2.9501688e-01  5.9038246e-01
  -2.4785557e-01 -6.6118777e-01  5.7299757e+00 -6.8259811e-01
  -6.7825359e-01  7.1264732e-01 -1.7748940e-01 -1.4507908e-01
  -6.8618315e-01 -3.5346997e-01 -6.7333752e-01]
 [-4.6940190e-01 -5.6667048e-01 -4.2988983e-01  9.6139026e-01
  -1.2112528e-01 -4.4250873e-01  3.1994836e+00 -6.6173124e-01
  -6.0781479e-01  3.1976547e+00 -1.7372531e-01  7.7684760e-01
  -6.8665665e-01 -5.9962136e-01 -5.3678346e-01]
 [ 2.5685847e-02 -9.4895959e-02  2.2952080e+00 -1.2809914e-01
   3.1712470e+00 -5.8871835e-01  2.0894005e+00 -6.899488

142
O -1.1219449 -1.2776604 0.09214229
C -0.90592533 -0.3789788 -0.68857545
C 0.26388115 0.53164035 -0.5725005
C 0.48125643 1.154802 0.7790999
N 1.36842 0.108615465 0.30177027
H -1.5848378 -0.18405849 -1.5452439
H 0.54290754 1.0860189 -1.4636819
H 0.9102511 2.151471 0.8249108
H -0.24609812 0.91533905 1.5498927
H 1.0858709 -0.778325 0.7213586

[[-0.40555632  0.05630815  1.0014228   0.41742766 -0.49250787 -0.5631926
  -0.2762266   0.28198433  0.69058573  0.01633561  0.19659734 -0.61557674
   4.670937    0.79359686  3.9810987 ]
 [ 0.32842195 -0.19472766  0.4503944   0.15041697 -0.31735203  3.300183
   0.33420074  2.6085093   0.08001035  0.13582784  0.80667305  1.6332486
  -0.5072377  -0.6885098   0.0418123 ]
 [-0.05544186  0.27239144 -0.6184193  -0.26118395 -0.32522494  0.63785195
   1.9164364  -0.6679435  -0.6530812   3.4761634  -0.48965704  0.20411867
  -0.56042135  0.7611382   0.6866646 ]
 [-0.35930154 -0.17586035 -0.51675624 -0.2575292  -0.2745327  -0.6296115
   3.8058963  -0.6839872 

158
C -0.687976 1.1508685 -0.4133764
C 0.62554616 0.43969205 -0.5816147
C 1.0441158 -0.2147326 0.7409022
C -0.30141667 -1.0160704 0.7496904
C -0.69029945 -0.3437845 -0.5729453
H -0.95100754 1.6031911 0.5414027
H -1.0277855 1.7120854 -1.2781218
H 1.2861955 0.65737355 -1.4123385
H 1.9107535 -0.8721382 0.64121306
H 1.2220726 0.46874723 1.5782135
H -0.13717736 -2.0917826 0.65453327
H -0.9789847 -0.84216726 1.5926342
H -1.2045507 -0.825638 -1.3959323



KeyboardInterrupt: 

# Multi-Dim Linear Regression on Last Layer