In [26]:
import schnetpack as spk
import torch
import schnetpack.nn 
import schnetpack.data
from schnetpack.datasets import QM9

# Load best model where it was stored after training

model = torch.load("./trained_models/qm911/best_model", map_location=torch.device('cpu'))

# Download QM9 dataset to use in evaluating model
qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)

# Load split file 
train, val, test = spk.data.train_test_split(qm9data, split_file='./trained_models/qm911/split.npz')

#set up device and atoms converter for input

device = 'cpu'

converter = spk.data.AtomsConverter(device=device)

test_loader = spk.AtomsLoader(test, batch_size=100)
converter = spk.data.AtomsConverter(device=device)

at, props = qm9data.get_properties(idx=0)

calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)
at.set_calculator(calculator)


print('Prediction:', at.get_total_energy())


Prediction: -1103.222900390625


In [27]:
#Extract embedding from schnetpack
from schnetpack.representation import SchNet

activation = {}

inputs = converter(at)

def get_activation(name):
    def hook(module, input, output):
        activation[name] = output.detach()
    return hook

model = SchNet()
model.embedding.register_forward_hook(get_activation('Embedding'))
output = model(inputs)

print(activation['Embedding'])

tensor([[[ 1.1473e+00, -1.4091e+00,  2.7091e-02,  1.0048e-01,  5.3784e-01,
           2.8431e-01,  1.8038e+00,  3.0531e-01,  1.8624e+00,  1.5431e+00,
           6.0254e-01, -1.7790e+00,  3.3343e-01,  8.6157e-01, -2.6138e-01,
          -2.2289e-01, -5.0458e-01, -3.6986e-01,  1.7041e-01,  6.3513e-03,
           1.2486e+00,  1.1366e+00,  1.5765e+00, -3.1541e-01,  3.1032e-01,
           7.3961e-01,  5.9152e-01,  6.4857e-01,  9.1421e-01, -9.5122e-02,
           5.3959e-01,  8.8846e-02, -2.9820e-01, -4.0632e-01,  2.1832e-02,
          -3.9870e-01,  1.4926e+00, -7.2099e-01,  7.9599e-01, -2.6449e-01,
          -1.1147e+00, -8.0482e-01,  7.0179e-01, -1.5013e+00,  6.6699e-01,
          -7.5498e-02, -3.1712e-01,  1.9212e-01,  1.1185e+00, -5.9226e-02,
          -2.3270e+00,  1.5636e+00,  1.0057e+00,  9.5719e-02, -5.5641e-01,
          -1.7229e+00,  7.9184e-01,  4.0852e-01,  9.5650e-01, -9.6671e-01,
          -6.9639e-01,  7.3960e-01,  5.6517e-01, -7.3105e-01,  9.9112e-02,
           6.4526e-01,  1