In [6]:
import torch
import schnetpack as spk
import math
from schnetpack import AtomsData
from schnetpack.datasets import QM9
import os

import numpy as np
from numpy import savetxt

from DLChem import gendbdataset

## Example of how to generate a db file from xyz
db_file_path = '../../data/datasets/quintessentialH/quintH48'
available_properties = 'energy'

save_path = '../../data/schnet/' 
main_name = 'rep'
sub_name = 'quintH'
# convention dataset.energy.#training-#filters
model_name = 'qm9energy10000-30'
element = 'H'
number_inputs=48

#MUST define qm9data for splitting of checkpoint file
qm9_file = '../../data/datasets/QM9/qm9.db'
qm9data = QM9(qm9_file, download=False, remove_uncharacterized=True)
dataset = AtomsData(db_file_path+'.db', available_properties=['energy'])




In [7]:
name_data = main_name + model_name + element + str(number_inputs) + sub_name

gendbdataset.generate(db_file_path,available_properties,number_inputs)

# define schnet variables as they are saved in checkpoint
split_file='../../data/trainedmodels/%s/split.npz' %(model_name)
checkpoint_path = '../../data/trainedmodels/%s/trained.pth' %(model_name)
n_atom_basis=30
n_filters=30
n_gaussians=20
n_interactions=3
cutoff = 4.
index = number_inputs

Properties: [{'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtype=float32)}, {'energy': array([0.], dtyp

In [8]:
def hook(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer
    #Update the embedding_output variable to be equal to our output tensor
    layer=out_tensor 

In [10]:
vecs = np.zeros((1,31))
for idx in range(index):
    # Load split file 
    train, val, test = spk.data.train_test_split(qm9data,split_file=split_file)

    # Load atom ref data 
    atomrefs = qm9data.get_atomref(QM9.U0)

    # Define SchNet representation model

    schnet = spk.representation.SchNet(
    n_atom_basis=n_atom_basis, n_filters=n_filters, n_gaussians=n_gaussians, n_interactions=n_interactions,
    cutoff=cutoff , cutoff_network=spk.nn.cutoff.CosineCutoff
    )

    # Define SchNet output model and property to be predicted

    output_U0 = spk.atomistic.Atomwise(n_in=n_filters, atomref=atomrefs[QM9.U0])

    # Define atomistic model

    model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)

    # Load saved checkpoint file
    load_checkpoint = torch.load(checkpoint_path)


    #qm9_i6_30f_20g-1000-500-4_300.pth
    # load model's state dictionary from saved checkpoint
    model.load_state_dict(load_checkpoint)


    #set up device for forward pass
    device='cpu'

    # load spk calculator
    calculator = spk.interfaces.SpkCalculator(model=model, device=device, energy=QM9.U0)
    converter = spk.data.AtomsConverter(device=device)

    at, props = dataset.get_properties(idx)
    inputs = converter(at)
    number_atoms = len(props['_atomic_numbers'])

    print(at)

    layer = None

    model.representation.embedding.register_forward_hook(hook)
    model(inputs)

    emb = layer.clone()
    emb = layer.detach().numpy()

    layer = None

    model.representation.interactions[0].register_forward_hook(hook)

    model(inputs)
    int0 = layer.clone()
    int0 = int0.detach().numpy()    

    layer = None

    model.representation.interactions[1].register_forward_hook(hook)

    model(inputs)
    int1 = layer.clone()
    int1 = int1.detach().numpy()   

    layer = None 

    model.representation.interactions[2].register_forward_hook(hook)

    model(inputs)   

    int2 = layer.clone()
    int2 = int2.detach().numpy()   

    rep = emb+int0+int1+int2

    from schnetpack.atomistic.output_modules import yi

    yi=yi.detach().numpy()
    
    rows = np.zeros((number_atoms,31))
    for i in range(number_atoms):
        for j in range(30):
            rows[i][j] = rep[0][i][j]
        rows[i][30] = yi[0][i]


    for i in range(number_atoms):
        if props['_atomic_numbers'][i] == 1:
            vecs = np.vstack((vecs,rows[i]))
#        if props['_atomic_numbers'][i] == 7:
#            vecs = np.vstack((vecs,rows[i]))
#        if props['_atomic_numbers'][i] == 6 and 8 not in props['_atomic_numbers'] and 7 not in props['_atomic_numbers']: 
#            vecs = np.vstack((vecs,rows[i]))

vecs = np.delete(vecs, 0, axis=0)
savetxt(save_path+name_data+'.csv',vecs,delimiter=',') 

Atoms(symbols='H2', pbc=False)
Atoms(symbols='OH2', pbc=False)
Atoms(symbols='NOH3', pbc=False)
Atoms(symbols='O2H2', pbc=False)
Atoms(symbols='COH4', pbc=False)
Atoms(symbols='NH3', pbc=False)
Atoms(symbols='N2H4', pbc=False)
Atoms(symbols='CNH5', pbc=False)
Atoms(symbols='NOHOH2', pbc=False)
Atoms(symbols='N2H2OH2', pbc=False)
Atoms(symbols='CNOH5', pbc=False)
Atoms(symbols='N2H2NH3', pbc=False)
Atoms(symbols='CN2H6', pbc=False)
Atoms(symbols='NHC2H6', pbc=False)
Atoms(symbols='HNO', pbc=False)
Atoms(symbols='CNH2', pbc=False)
Atoms(symbols='CH4', pbc=False)
Atoms(symbols='C2H6', pbc=False)
Atoms(symbols='CO2H4', pbc=False)
Atoms(symbols='NCOH5', pbc=False)
Atoms(symbols='C2OH6', pbc=False)
Atoms(symbols='CN2H6', pbc=False)
Atoms(symbols='NC2H7', pbc=False)
Atoms(symbols='C3H8', pbc=False)
Atoms(symbols='COHOHOH2', pbc=False)
Atoms(symbols='CNH2OHOH2', pbc=False)
Atoms(symbols='COHOHCH4', pbc=False)
Atoms(symbols='CNH2NH2OH2', pbc=False)
Atoms(symbols='C2H3NH2OH2', pbc=False)
Atoms(s