In [14]:
import torch
import schnetpack as spk
from schnetpack.datasets import QM9
from schnetpack import AtomsData
import numpy as np

def hook(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer
    #Update the embedding_output variable to be equal to our output tensor
    layer=out_tensor 

#define the hyperparameters you used for training
n_atom_basis = 128
n_filters = 128
n_gaussians = 50
n_interactions = 6 
cutoff = 50. 

#Load qm9 data 
qm9_filepath = 'data/datasets/QM9/qm9.db'
data = QM9(qm9_filepath,download=False,remove_uncharacterized=True)

#OR load your own database, db file 
#data_filepath = 'data/datasets/..db'
#available_properties = []'N/A']
#data = AtomsData(data_filepath,available_properties=available_properties)


# Load atom ref data 
atomrefs = data.get_atomref(QM9.U0)

# Define SchNet representation model parameters (make sure you load same parameters you trained the algorithm with)
schnet = spk.representation.SchNet(
n_atom_basis=n_atom_basis, n_filters=n_filters, n_gaussians=n_gaussians, n_interactions=n_interactions,
cutoff=cutoff , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted
output_U0 = spk.atomistic.Atomwise(n_in=n_filters,atomref=atomrefs[QM9.U0])

# Define atomistic model
model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)


# Load saved checkpoint file, NOTE YOU MUST SAVE A TRAINED MODEL OF SCHNET FIRST... TRAIN ONE AND SAVE IT TO A .PTH file
checkpoint_path = 'data/trainedmodels/model1/trainingcheckpoints/trained-950.pth'
load_checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))


#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)

#This allows you to access all the keys to the dictionary of layers! Very important to access all the layers in schnet, you have to copy
#the key
print(model.state_dict().keys())

#set up device for forward pass
device='cpu'
# load atoms converter 
converter = spk.data.AtomsConverter(device=device)
#load a molecule properties from qm9, 
at, props = data.get_properties(0)
#convert the atoms to input appropriate for schnet
inputs = converter(at)

layer = None 
model.representation.interactions[5].register_forward_hook(hook)
model(inputs)

int_layer_5 = layer.clone()
int_layer_5 = layer.detach().numpy()    #you may not need to do, detach to numpy, with most recent versions of schnet! so just delete
print(int_layer_5)


odict_keys(['representation.embedding.weight', 'representation.distance_expansion.width', 'representation.distance_expansion.offsets', 'representation.interactions.0.filter_network.0.weight', 'representation.interactions.0.filter_network.0.bias', 'representation.interactions.0.filter_network.1.weight', 'representation.interactions.0.filter_network.1.bias', 'representation.interactions.0.cutoff_network.cutoff', 'representation.interactions.0.cfconv.in2f.weight', 'representation.interactions.0.cfconv.f2out.weight', 'representation.interactions.0.cfconv.f2out.bias', 'representation.interactions.0.cfconv.filter_network.0.weight', 'representation.interactions.0.cfconv.filter_network.0.bias', 'representation.interactions.0.cfconv.filter_network.1.weight', 'representation.interactions.0.cfconv.filter_network.1.bias', 'representation.interactions.0.cfconv.cutoff_network.cutoff', 'representation.interactions.0.dense.weight', 'representation.interactions.0.dense.bias', 'representation.interactio