This code allows you to extract embeddings from a trained schnet model, for all layers (interactions residues and built embedding signals)

Requirements:
obabel 3.1.1
schnetpack 0.3


NOTE that the model our trained on schnetpack 0.3, therefore it is required to load the model using that version of schnetpack, otherwise it might error out. In schnetpack 0.3, all tensors have to be detached to numpy FIRST, using detach().numpy() whereas in newer versions this is not necessary.

If you have your own model that was not trained on 0.3, it can be used here, only you have to delete the detach().numpy() steps as that is no longer required for tensors in later versions of schnet. 


1) First, load trained model 

In [None]:
import torch
import schnetpack as spk
from schnetpack.datasets import QM9
from schnetpack import AtomsData
import numpy as np

#Set the same hyperparameters that the schnet model was trained on 
#Our model's parameters are set below
n_atom_basis = 128
n_filters = 128
n_gaussians = 50
n_interactions = 6 
cutoff = 50. 

#Load qm9 data
qm9_filepath = 'data/datasets/QM9/qm9.db'
qm9_data = QM9(qm9_filepath,download=False,remove_uncharacterized=True)


# Load atom ref data 
atomrefs = qm9_data.get_atomref(QM9.U0)

# Define SchNet representation model
schnet = spk.representation.SchNet(
n_atom_basis=n_atom_basis, n_filters=n_filters, n_gaussians=n_gaussians, n_interactions=n_interactions,
cutoff=cutoff , cutoff_network=spk.nn.cutoff.CosineCutoff
)

# Define SchNet output model and property to be predicted
output_U0 = spk.atomistic.Atomwise(n_in=n_filters,atomref=atomrefs[QM9.U0])

# Define atomistic model
model = spk.AtomisticModel(representation=schnet,output_modules=output_U0)


# Load saved checkpoint file
checkpoint_path = 'data/trainedmodels/model1/trainingcheckpoints/trained-1000.pth'
load_checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))


#qm9_i6_30f_20g-1000-500-4_300.pth
# load model's state dictionary from saved checkpoint
model.load_state_dict(load_checkpoint)


#set up device for forward pass
device='cpu'

# load atoms converter 
converter = spk.data.AtomsConverter(device=device)

#This will show you all the available layers
print(model.state_dict().keys())

Define hook function for model layer extraction

In [None]:
#This hook function allows you to grab the output of any layer in a schnet model
#Using the "register_forward_hook" function on the layers of the loaded model 
def hook(self, inp_tensor, out_tensor):
    # Self is included and refers to the model class
    # Global allows us to utilize embedding_output outside the current function scope
    global layer
    #Update the embedding_output variable to be equal to our output tensor
    layer=out_tensor 


Extract embeddings from each layer in schnet 

In [None]:
import pandas as pd
from label.manuallabel2 import labeller
from label.manuallabel2.utils import utils

#Define number of molecule you want to extract (or range)
molecule_range = [0,100]

# Define column headings, 128 will be for embedding features extracted, other columns will be for other important information of the atom
#Which molecule is atom from (molecule_index), the index of the atom (atom_index), the xyz coordinats of atom, and
#the functional group label of the atom, with a gnuplot marker and a decimal color
column_headings = ['emb%s' %(each_feature) for each_feature in range(n_atom_basis)]
column_headings.append('molecule_index')
column_headings.append('element')
column_headings.append('x_coord')
column_headings.append('y_coord')
column_headings.append('z_coord')
column_headings.append('fg_label')
column_headings.append('fg_gnuplotmarker')
column_headings.append('fg_decimalcolor')
column_headings.append('layer')


In [None]:
#initialize an empty list that will hold all the atom's embedding & other relevant info about the atom
#This list will be stacked using pd.concat
atom_emb_dataframes = []
atom_int_dataframes = []

#Run through each molecule defined in the range
for molecule_index in range(molecule_range[0],molecule_range[1]):

    #simple load bar, shows progress every 100 molecules passed
    if molecule_index % 100 == 0:
        print('molecule_index',molecule_index)


    #Load molecule's properties
    at, props = qm9_data.get_properties(molecule_index)
    xyz_positions = props['_positions'].detach().numpy()
    atomic_numbers = props['_atomic_numbers'].detach().numpy()
    number_atoms = len(atomic_numbers)

    #Convert to schnet-ready input
    inputs = converter(at)

    #write xyz and mol file in a temp file (important for labelling the functional groups)
    mol_filename = utils.xyz2mol(props)


    #Extract embedding and interaction data per molecule
    #run the model in inputs and obtain the 
    embs_all_layers = []
    ints_all_layers = []
    for layer_index in range(n_interactions):

        #set initial emb 
        if layer_index == 0:
            layer = None 
            model.representation.embedding.register_forward_hook(hook)
            model(inputs)
            emb = layer.clone()
            emb = layer.detach().numpy()
            
            embs_all_layers.append(emb)


        if layer_index > 0:
            layer = None
            model.representation.interactions[layer_index-1].register_forward_hook(hook)
            model(inputs)
            int = layer.clone()
            int = layer.detach().numpy()

            emb = emb + int

            embs_all_layers.append(emb)
            ints_all_layers.append(int)


    #Run the molecule through each layer, get the output of the layer,
    #FOR EACH ATOM, label the atom's functional group, the layer,... put all this information in the output file
    for atom_index in range(number_atoms):
        #run labeller code to output LDA label, gnuplot color label, gnuplot marker label... 
        fg_key,fg_label,decimal_color,gnu_marker,hex_color= labeller.label(mol_filename,number_atoms,atom_index,atomic_numbers[atom_index],molecule_index)

        for layer_index in range(n_interactions):
            #Append other relevant information about the atom on top of the extracted embedding
            #As a row which will be stacked to the embs dataframe and saved
            embs_atom = np.append(embs_all_layers[layer_index][0][atom_index][0:128],molecule_index)
            embs_atom = np.append(embs_atom,atomic_numbers[atom_index])
            embs_atom = np.append(embs_atom,xyz_positions[atom_index][0])
            embs_atom = np.append(embs_atom,xyz_positions[atom_index][1])
            embs_atom = np.append(embs_atom,xyz_positions[atom_index][2])
            embs_atom = np.append(embs_atom,fg_label)
            embs_atom = np.append(embs_atom,gnu_marker)
            embs_atom = np.append(embs_atom,decimal_color)
            embs_atom = np.append(embs_atom,layer_index) 

            #Create a new DataFrame from the NumPy array
            embs_atom_df = pd.DataFrame(embs_atom.reshape(1, -1),columns = column_headings)
            embs_atom_df.loc[0,'fg_key'] = fg_key
            embs_atom_df.loc[0,'fg_hexcolor'] = hex_color

            #Append the DataFrame to the list
            atom_emb_dataframes.append(embs_atom_df)
    
        #interactions only come after embedding layer 0, so there are only 5, or n_interactions - 1 of them
        if layer_index > 0:
            for atom_index in range(number_atoms):

                #Append other relevant information about the atom on top of the extracted embedding
                #As a row which will be stacked to the embs dataframe and saved
                ints_atom = np.append(ints_all_layers[layer_index-1][0][atom_index][0:128],molecule_index)
                ints_atom = np.append(ints_atom,atomic_numbers[atom_index])
                ints_atom = np.append(ints_atom,xyz_positions[atom_index][0])
                ints_atom = np.append(ints_atom,xyz_positions[atom_index][1])
                ints_atom = np.append(ints_atom,xyz_positions[atom_index][2])
                ints_atom = np.append(ints_atom,fg_label)
                ints_atom = np.append(ints_atom,gnu_marker)
                ints_atom = np.append(ints_atom,decimal_color)
                ints_atom = np.append(ints_atom,layer_index-1) 

                #Create a new DataFrame from the NumPy array
                ints_atom_df = pd.DataFrame(ints_atom.reshape(1, -1), columns = column_headings)
                ints_atom_df.loc[0,'fg_key'] = fg_key
                ints_atom_df.loc[0,'fg_hexcolor'] = hex_color


                #Append the DataFrame to the list
                atom_int_dataframes.append(ints_atom_df)    
        
embs_df = pd.concat(atom_emb_dataframes, ignore_index=True)
ints_df = pd.concat(atom_int_dataframes, ignore_index=True)




Save embeddings/interactions of every layer (and every atom, labelled) to a file 

In [None]:
embs_df.to_csv('data/embs/model1/qm9_first1000/embs_all_layers&atoms.csv',index=False)
ints_df.to_csv('data/embs/model1/qm9_first1000/ints_all_layers&atoms.csv',index=False)

Isolate embeddings of an atom-type (element) and a layer

In [None]:
#choose the layer and element to analyze with dimension reduction and lda 
layer = 5   
element = 1

# Filter rows based on the value in the 'Category' column
filtered_df = embs_df[(embs_df['layer'] == layer) & (embs_df['element'] == element)]

# Save the filtered rows to a new CSV file
filtered_df.to_csv('data/embs/model1/qm9_first1000/layer5_elementO/embs.csv', index=False)

Dimensionality-reduction using PCA and t-SNE


Run PCA on embedding/interaction data

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import  PCA
import scipy.linalg as la

filtered_df = pd.read_csv('data/embs/model1/qm9_first1000/layer5_elementO/embs.csv')

#PCA parameters
n_features = 128
n_components = 128
scale_data = False

X = filtered_df.iloc[:,0:n_features]
print(X)

if scale_data == True:
    scaler = StandardScaler()
    scaler.fit(X)
    X = scaler.transform(X,random_state=100)

#perform PCA decomposition of the data
pca = PCA(random_state=100,n_components=n_components)
pca.fit(X)

x_pca = pca.transform(X.iloc[:,0:n_features])

#get the eigenvalues and eigenvectors of covariance matrix for analysis
cov = pca.get_covariance()
eig, ev = la.eig(cov)


Plot the PCA dimension reduction with hexadecimal colors

In [None]:
#plot PCA dimension reduction 
import matplotlib.pyplot as plt
import numpy as np

print(filtered_df.iloc[:,138])

# Create a scatter plot
plt.scatter(x_pca[:,0], x_pca[:,1],color=filtered_df.iloc[:,138].tolist(), marker='o')

# Optionally, you can add labels and a title
plt.xlabel('PC 1')
plt.ylabel('PC 2')

# Show a legend if needed
plt.legend()

# Show the plot
plt.show()



In [None]:
#t-SNE
from sklearn.manifold import TSNE

perp = 50
X = filtered_df.iloc[:,0:n_features]

X_tsne = TSNE(n_components=2,perplexity=perp).fit_transform(X)



In [None]:
#plot t-SNE dimension reduction 
import matplotlib.pyplot as plt
import numpy as np

print(filtered_df.iloc[:,138])

# Create a scatter plot
plt.scatter(X_tsne[:,0], X_tsne[:,1],color=filtered_df.iloc[:,138].tolist(), marker='o')

# Optionally, you can add labels and a title
plt.xlabel('PC 1')
plt.ylabel('PC 2')

# Show a legend if needed
plt.legend()

# Show the plot
plt.show()

Linear Discriminant Analysis

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

filtered_df = pd.read_csv('data/embs/model1/qm9_first1000/layer5_elementO/embs.csv')
n_features = 128

#Y is the fg label in column 133
X = filtered_df.iloc[:,0:n_features]
y = filtered_df.iloc[:,133].values

#First clean up the data, any class that has one data point cannot be split, has to be remove
#All labels must be shifted to go from 0--->N-labels
# Count the number of data points per clas
class_counts = np.bincount(y.astype(int))

# Identify classes with only one data point
classes_to_remove = np.where(class_counts == 1)[0]

# Remove data points from classes with only one data point
for class_to_remove in classes_to_remove:
    mask = y != class_to_remove
    X = X[mask]
    y = y[mask]

# Renumber the classes to make them consecutive
unique_classes = np.unique(y)
class_mapping = {old_class: new_class for new_class, old_class in enumerate(unique_classes)}
y = np.vectorize(class_mapping.get)(y)

#split data, while ensuring each split gets all classes (Stratify)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=30)    
        
LDA = LinearDiscriminantAnalysis(store_covariance=True)
X_fit = LDA.fit(X_train,y_train)

y_pred = X_fit.predict(X_test)


print('LDA accuracy: ', LDA.score(X_test,y_test))


In [None]:
print(y_pred,y_test)

In [None]:
from sklearn.metrics import confusion_matrix

# Compute the confusion matrix
confusion = confusion_matrix(y_test, y_pred)

print(confusion)
