In [22]:
import schnetpack as spk
import torch
import schnetpack.nn 
import schnetpack.data
import pandas as pd
import scipy.linalg as la
from schnetpack.datasets import QM9

# Load best model where it was stored after training

model = torch.load("./trained_models/qm911/best_model", map_location=torch.device('cpu'))

# Download QM9 dataset to use in evaluating model
qm9data = QM9('./qm9.db', download=True, remove_uncharacterized=True)

# Load split file 
train, val, test = spk.data.train_test_split(qm9data, split_file='./trained_models/qm911/split.npz')

#set up device and atoms converter for input

device = 'cpu'

converter = spk.data.AtomsConverter(device=device)

test_loader = spk.AtomsLoader(test, batch_size=100)
converter = spk.data.AtomsConverter(device=device)

at, props = qm9data.get_properties(idx=0)

calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)
at.set_calculator(calculator)


print('Prediction:', at.get_total_energy())


Prediction: -1103.222900390625


In [23]:
#Extract embedding from schnetpack
from schnetpack.representation import SchNet

activation = {}

inputs = converter(at)

def get_activation(name):
    def hook(module, input, output):
        activation[name] = output.detach()
    return hook

model = SchNet()
model.embedding.register_forward_hook(get_activation('Embedding'))
output = model(inputs)

print(activation['Embedding'])

tensor([[[ 1.0020e+00, -4.5615e-01,  9.3046e-01, -1.0296e+00,  4.3175e-01,
          -1.0405e+00, -1.5797e-01,  1.8883e-01, -2.1637e+00, -6.1232e-01,
          -4.4904e-01,  5.8696e-01, -2.9584e-02,  1.8830e-01,  5.9463e-01,
          -3.4088e-01, -3.7521e-01, -1.4308e+00, -5.3438e-01, -1.8368e+00,
          -1.1038e+00, -8.7094e-01,  5.9205e-01,  1.2761e+00, -1.6461e+00,
          -1.5379e-02,  7.2449e-01,  6.5014e-01,  6.9771e-01, -1.4652e-01,
           6.9538e-01,  8.3891e-01,  2.2921e+00, -6.9208e-01, -4.9742e-01,
           4.6223e-01, -9.9598e-01, -1.4892e-03, -4.4677e-01,  5.6274e-01,
          -5.2279e-01, -1.0060e+00,  7.2956e-01,  1.1741e+00, -4.9470e-01,
           1.2095e-01, -1.2147e+00,  1.7405e+00, -4.4640e-01, -3.9764e-01,
          -1.6745e+00,  1.9847e+00,  1.1028e+00,  1.4822e+00, -6.2252e-01,
          -5.8495e-02, -6.4114e-01,  2.7588e-01, -5.9091e-01, -3.4031e-02,
           7.2127e-01, -1.8101e-02,  1.1446e+00,  9.7523e-01, -6.9461e-01,
           8.5732e-01, -8