# Analyse Autoencoder's latent space

## Initialize the notebook 
import packages

In [None]:
from openmm.app import *
from openmm import *
from openmm.unit import *
import openmm
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras import saving
import scipy.constants as sc
from matplotlib import pyplot as plt
import os
import pickle
from IPython.display import clear_output
from scipy.interpolate import RegularGridInterpolator
from sklearn.model_selection import train_test_split
import MDAnalysis as mda
from MDAnalysis.analysis import rms
import jupyter_beeper
beep = jupyter_beeper.Beeper()
import copy

Define system name and system-specific functions

In [None]:
system_name = "trpcage"
model_name = "n_5e-2_4"

In [None]:
def Standardize(x):
    x = np.array(x)
    if system_name == "trpcage":
        result = (x + 1.0)/6
    elif system_name == "villin":
        result = (x + 2.0)/9
    elif system_name == "pdz":
        result = (x - 1.5)/9
    return result

def Unstandardize(x):
    x = np.array(x)
    if system_name == "trpcage":
        result = (x * 6) + 1.0
    elif system_name == "villin":
        result = (x * 9) - 2.0
    elif system_name == "pdz":
        result = (x * 9) + 1.5
    return result

@saving.register_keras_serializable()
class Sampling(keras.layers.Layer):
    def __init__(self, noise=0.05):
        super().__init__()
        self.noise = noise
    
    def call(self, inputs):
        z_mean = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim))
        return z_mean +  epsilon * self.noise

    def get_config(self):
        return {"noise": self.noise}

## Load the trained models and training datasets

In [None]:
encoder = keras.models.load_model("../ae/models/e_"+model_name+".keras")

In [None]:
decoder = keras.models.load_model("../ae/models/d_"+model_name+".keras")

In [None]:
coord = np.load("../data/trpcage_ds.npy")

In [None]:
rmsd = np.loadtxt("../data/trpcage_rmsd_ds.xvg")

In [None]:
data_density = 1
data = Standardize(coord[::data_density])
#del(coord)
x_train, x_test, y_train, y_test = train_test_split(data, data, test_size=0.2)
x_train.shape

## Plot the data in the latent space

In [None]:
latent_space=2
each = 10
latents = encoder.predict(data[::each,:])

In [None]:
plt.scatter(latents[:,latent_values[0]],latents[:,latent_values[1]], c=rmsd[::each*data_density,1], cmap="jet", s=2)
cbar = plt.colorbar()
cbar.set_label("RMSD (nm)")
plt.xlabel("latent value 1")
plt.ylabel("latent value 2")
plt.title("Latent space")
axes = plt.gca()
axes.set(xlim=([-1.2,1.2]))
axes.set(ylim=([-1.2,1.2]))
plt.savefig("ls.png", dpi=600)

Define points in the latent space to analyse

In [None]:
latent_spaces_4a = np.array([[ -1.0, 1.0], 
                             [  0.0, 1.0], 
                             [  1.0, 1.0], 
                             
                             [ -1.0, 0.0], 
                             [  0.0, 0.0], 
                             [  1.0, 0.0],        
                             
                             [ -1.0, -1.0], 
                             [  0.0, -1.0], 
                             [  1.0, -1.0], 
                             ])

... and plot them

In [None]:
plt.figure()
plt.scatter(latents[::,latent_values[0]],latents[::,latent_values[1]], c=rmsd[::each*data_density,1], cmap="jet", s=1)
cbar = plt.colorbar(cmap="jet")
plt.scatter(latent_spaces_4a[:,0], latent_spaces_4a[:,1], color="k", s=10)
plt.title("Latent space")
cbar.set_label("RMSD (nm)")
plt.xlabel("latent value 1")
plt.ylabel("latent value 2")
axes = plt.gca()
axes.set(xlim=([-1.2,1.2]))
axes.set(ylim=([-1.2,1.2]))
plt.text(-1.0,  1.0-0.15,"A", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 0.0,  1.0-0.15,"B", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 1.0,  1.0-0.15,"C", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text(-1.0,  0.0-0.15,"D", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 0.0,  0.0-0.15,"E", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 1.0,  0.0-0.15,"F", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text(-1.0, -1.0-0.15,"G", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 0.0, -1.0-0.15,"H", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.text( 1.0, -1.0-0.15,"I", horizontalalignment='center',verticalalignment='center', c="k", weight="regular", size="x-large")
plt.savefig("ls_dots.png", dpi=600)

Decode protein structures from the latent space and save them into pdb file

In [None]:
output_trj = 'output_ls_'+model_name+'.pdb'

pdb = PDBFile('../data/'+system_name+'_reference.pdb')

forcefield = ForceField('amber14-all.xml', 'implicit/gbn2.xml')

system = forcefield.createSystem(pdb.topology)
integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)#0.002*picoseconds)
simulation = Simulation(pdb.topology, system, integrator)
simulation.reporters.append(PDBReporter(output_trj, 500))

# initial structure with topology from reference
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True, getPositions=True)
print(state.getPotentialEnergy())
simulation.reporters[0].report(simulation, state)

# encoded

e_system = forcefield.createSystem(pdb.topology)
e_integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)#0.002*picoseconds)
e_simulation = Simulation(pdb.topology, e_system, e_integrator)
e_simulation.reporters.append(PDBReporter("e_"+output_trj, 500))

e_simulation.context.setPositions(pdb.positions)
e_state = e_simulation.context.getState(getEnergy=True, getPositions=True)
print(e_state.getPotentialEnergy())
e_simulation.reporters[0].report(e_simulation, e_state)

# decoded

d_system = forcefield.createSystem(pdb.topology)
d_integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)#0.002*picoseconds)
d_simulation = Simulation(pdb.topology, d_system, d_integrator)
d_simulation.reporters.append(PDBReporter("d_"+output_trj, 500))

d_simulation.context.setPositions(pdb.positions)
d_state = d_simulation.context.getState(getEnergy=True, getPositions=True)
print(d_state.getPotentialEnergy())
d_simulation.reporters[0].report(d_simulation, d_state)

# decoded with minimization

d_min_system = forcefield.createSystem(pdb.topology)
d_min_integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)#0.002*picoseconds)
d_min_simulation = Simulation(pdb.topology, d_min_system, d_min_integrator, Platform.getPlatformByName('CPU'))
d_min_simulation.reporters.append(PDBReporter("d_min_"+output_trj, 500))

d_min_simulation.context.setPositions(pdb.positions)
d_min_simulation.minimizeEnergy(maxIterations=200, tolerance = 4.0e2)
d_min_state = d_min_simulation.context.getState(getEnergy=True, getPositions=True)
print(d_min_state.getPotentialEnergy())
d_min_simulation.reporters[0].report(d_min_simulation, d_min_state)

Find the structures from the training dataset which are closest to the selected points

In [None]:
closest_indexes = np.empty((latent_spaces_4a.shape[0]))
for i in range(latent_spaces_4a.shape[0]):
    
    # find the closest structure in training data
    distances_ls = (latents[:,0]-latent_spaces_4a[i,0])**2 + (latents[:,1]-latent_spaces_4a[i,1])**2
    closest_indexes[i] = int(np.where(np.sort(distances_ls)[0] == distances_ls)[0][0])
closest_indexes = closest_indexes.astype(int)
closest_indexes

Calculate potential energies of the closest structures from training dataset

In [None]:
for i in range(latent_spaces_4a.shape[0]):
    e_simulation.context.setPositions(Unstandardize(data[closest_indexes[i],:]).reshape((int(data.shape[1]/3), 3)))
    e_state = e_simulation.context.getState(getEnergy=True, getPositions=True)
    print(e_state.getPotentialEnergy())
    e_simulation.reporters[0].report(e_simulation, e_state)

In [None]:
letters = ["A","B","C","D","E","F","G","H","I"]

Compare potential energies of decoded structures before and after energy minimization

In [None]:
d_min_system = forcefield.createSystem(pdb.topology)
d_min_integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)#0.002*picoseconds)
d_min_simulation = Simulation(pdb.topology, d_min_system, d_min_integrator, Platform.getPlatformByName('CPU'))
d_min_simulation.reporters.append(PDBReporter("conformations_A-I_min.pdb", 500))

for i in range(latent_spaces_4a.shape[0]):
    pos = decoder((latent_spaces_4a[i,:].reshape((1,2)))).numpy().transpose()
    d_simulation.context.setPositions(Unstandardize(pos.reshape((int(data.shape[1]/3), 3))))
    d_state = d_simulation.context.getState(getEnergy=True, getPositions=True)
    print(f"Point: {letters[i]}: Potential energy after decoding: {d_state.getPotentialEnergy()}", end=", ")
    d_simulation.reporters[0].report(d_simulation, d_state)

    d_min_simulation.context.setPositions(Unstandardize(pos.reshape((int(data.shape[1]/3), 3))))
    d_min_simulation.minimizeEnergy(maxIterations=200, tolerance = 1.0e2)
    d_min_state = d_min_simulation.context.getState(getEnergy=True, getPositions=True)
    print(f"Potential energy after minimization: {d_min_state.getPotentialEnergy()}")
    d_min_simulation.reporters[0].report(d_min_simulation, d_min_state)