In [13]:
import schnetpack as spk
import torch
import schnetpack.nn 
import schnetpack.data
import pandas as pd
import scipy.linalg as la
from schnetpack.datasets import QM9

# Load best model where it was stored after training

best_model = torch.load("./trained_models/qm911/best_model", map_location=torch.device('cpu'))

# Download QM9 dataset to use in evaluating model
qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)

# Load split file 
train, val, test = spk.data.train_test_split(qm9data, split_file='./trained_models/qm911/split.npz')

#set up device and atoms converter for input

device = 'cpu'

converter = spk.data.AtomsConverter(device=device)

test_loader = spk.AtomsLoader(test, batch_size=100)
converter = spk.data.AtomsConverter(device=device)

at, props = qm9data.get_properties(idx=0)

calculator = spk.interfaces.SpkCalculator(model=best_model, device=device, energy=QM9.U0)
at.set_calculator(calculator)


print('Prediction:', at.get_total_energy())


Prediction: -1103.222900390625


In [14]:
#Extract embedding from schnetpack
from schnetpack.representation import SchNet

model = SchNet()

embedding_output=None

def embedding_hook(self, inp_tensor, out_tensor):
   # Global allows us to utilize embedding_output outside the current function scope
   global embedding_output 
   embedding_output=out_tensor

model.embedding.register_forward_hook(embedding_hook)

inputs = converter(at)
model(inputs)

print(embedding_output)

tensor([[[-0.2131, -0.1541,  0.7155,  2.2288, -1.0402,  0.5182, -1.4842,
           0.6280,  2.3356, -1.6524,  0.2371,  0.3101, -0.2030, -0.2110,
          -0.5656,  1.4761, -0.2728, -0.7652, -0.0495, -1.1851, -0.3630,
           0.9301, -1.0113,  0.1016,  1.5456,  0.0256,  1.1504,  0.4426,
           0.8667, -0.1562, -0.4621,  0.3061,  0.9261,  0.2881, -1.5597,
           1.2288,  0.9443,  1.6847, -1.1722, -1.7441, -0.2525,  0.2136,
           1.6990, -1.3576, -0.4959, -0.0181,  0.9218,  0.2518,  0.7842,
          -0.1016, -1.0702,  0.1549, -0.3374, -0.3830,  1.2244, -0.0261,
          -0.6059,  0.3950, -1.9123, -0.4071,  0.6244,  1.0629, -0.5307,
           0.6747, -0.7225,  0.0946,  1.1752, -0.4311,  0.1213,  1.4363,
           0.2324, -1.2943, -0.2397,  0.9655,  0.7573, -0.2368,  1.8782,
          -0.8502, -1.3292, -0.1816, -0.9571,  0.0477, -0.3671, -0.1935,
           2.8333, -0.7572, -0.1475, -1.1664,  0.9843, -0.5542,  0.4030,
          -0.2721, -0.3766,  1.0638,  0.2676, -0.48

In [None]:


### ANOTHER WAY OF DOING IT, given in issues

activation = {}

inputs = converter(at)

def get_activation(name):
    def hook(module, input, output):
        activation[name] = output.detach()
    return hook

model = SchNet()
model.embedding.register_forward_hook(get_activation('Embedding'))
output = model(inputs)

print(activation['Embedding'])
    