# Train the autoencoder

Use this notebook to train an autoencoder to model conformational flexibility of the protein of interest. 

## Import necessary packages

In [None]:
# load packages
import os
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from tensorflow import keras
from keras import layers
from keras import saving
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

from openmm.app import *
from openmm import *
from openmm.unit import *
import openmm

## Initialize system properties

Initialize name of the protein and the model to train

In [None]:
system_name = "trpcage"
model_name = "model_1"

Set parameters of these functions:
- Standardize to take your input data (cartesian coordinates of the protein) and scale the data to be between 0.0 and 1.0
- Unstandardize in inverse to Standardise - this function will be usefull later for the MCMC simulation

In [None]:
def Standardize(x):
    #x = np.array(x)
    if system_name == "trpcage":
        result = (x + 1.0)/6
    elif system_name == "villin":
        result = (x + 2.0)/9
    elif system_name == "pdz":
        result = (x - 1.5)/9
    return result

def Unstandardize(x):
    #x = np.array(x)
    if system_name == "trpcage":
        result = (x * 6) + 1.0
    elif system_name == "villin":
        result = (x * 9) - 2.0
    elif system_name == "pdz":
        result = (x * 9) + 1.5
    return result

Check if GPU is available for acceleration

In [None]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

## Load and process the training dataset

In [None]:
coords = np.load("../data/"+system_name+"_ds.npy")

Standardize the loaded data
- `data_density` should be 1, use bigger integer for quick testing purposes

In [None]:
data_density = 1
data = Standardize(coords[::data_density])
del(coords)

Split the dataset to training and testing sets

In [None]:
x_train, x_test, y_train, y_test = train_test_split(data, data, test_size=0.2, random_state=24)
x_train.shape
ref_str = data[0,:]

Create dataset objects for TensorFlow

In [None]:
xy_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
xy_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

Specify the batch size for training

In [None]:
BATCH_SIZE = 100
SHUFFLE_BUFFER_SIZE = 100

xy_train = xy_train.shuffle(SHUFFLE_BUFFER_SIZE, seed=24).batch(BATCH_SIZE)
xy_test = xy_test.batch(BATCH_SIZE)

## Model architecture

PlotLearning is a Keras Callback that can be used to visualize the training progress and to save the best models to a file "model_cb.keras" that can be used as a backup in a case of some technical failure

In [None]:
class PlotLearning(tf.keras.callbacks.Callback):
    """
    Callback to plot the learning curves of the model during training.
    """
    def on_train_begin(self, logs=None):
        #keys = list(logs.keys())
        #print("Starting training; got log keys: {}".format(keys))
        self.metrics = {}
        self.rmsd = np.loadtxt("../data/"+system_name+"_rmsd_ds.xvg")
        for metric in logs:
            self.metrics[metric] = []

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        # Storing metrics
        for metric in keys:
            if metric in self.metrics:
                self.metrics[metric].append(logs.get(metric))
            else:
                self.metrics[metric] = [logs.get(metric)]

        latents = encoder.predict(data[::10,:])
        
        if len(self.metrics['val_loss']) > 2:
            if self.metrics['val_loss'][-1] == np.min(self.metrics['val_loss']):
                model.save("models/model_cb.keras")
                encoder.save("models/encoder_cb.keras")
                decoder.save("models/decoder_cb.keras")
                
        # Plotting
        metrics = [x for x in keys if 'val' not in x]
        
        f, axs = plt.subplots(1, len(metrics)+1, figsize=(15,5))
        clear_output(wait=True)

        for i, metric in enumerate(metrics):
            axs[i].plot(range(1, epoch + 2)[1:], self.metrics[metric][1:], label=metric)
            if logs['val_' + metric]:
                axs[i].plot(range(1, epoch + 2)[1:], self.metrics['val_' + metric][1:], label='val_' + metric, color="orange")
                
            axs[i].legend()
            axs[i].grid()
            axs[i].ticklabel_format(useOffset=False, style='plain')
        
        axs[-1].scatter(latents[:,0],latents[:,1], c=self.rmsd[::10*data_density,1], cmap="jet", s=1)
        
        keras.backend.clear_session() 
        
        plt.tight_layout()
        #plt.savefig("loss.png")
        plt.show()

Here, there are the definitions of 
- custom loss functions
- the Sampling function which adds random noise to the encoder outputs
- metric for plotting the potential energy estimation for the decoded structures, to visualize the learning process

In [None]:
pdb = PDBFile("../data/"+system_name+"_reference.pdb")
forcefield = ForceField('amber14-all.xml', 'implicit/gbn2.xml')
E_eval_system = forcefield.createSystem(pdb.topology)
E_eval_integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
E_eval_simulation = Simulation(pdb.topology, E_eval_system, E_eval_integrator)

mse = keras.losses.MeanSquaredError()

@saving.register_keras_serializable()
def potential_energy(pos):
    pos = np.reshape(pos, ((int(data.shape[1]/3), 3)))
    #pos = pos.numpy().tolist()
    pos = Unstandardize(pos)
    E_eval_simulation.context.setPositions(pos)
    E_eval_state = E_eval_simulation.context.getState(getEnergy=True)
    return E_eval_state.getPotentialEnergy().value_in_unit(kilojoule/mole)
    
@saving.register_keras_serializable()
def pe_editlog_metric(pos_true, pos_pred):
    pos = pos_pred[0,:]
    pos = keras.ops.reshape(pos, ((int(pos.shape[0]/3), 3)))
    try:
        pos = Unstandardize(pos)
        E_eval_simulation.context.setPositions(list(pos))
        E_eval_state = E_eval_simulation.context.getState(getEnergy=True)
        E = E_eval_state.getPotentialEnergy().value_in_unit(kilojoule/mole)
        if E > 0.0:
            return np.log10(E+1.0)
        elif E == 0.0:
            return 0.0
        else:
            return -np.log10(-E+1.0)
    except: 
        return 20.0
        
@saving.register_keras_serializable()
def msle_and_dist(y_true, y_pred):
    # Expand dimensions to enable broadcasting
    #y_true_reshaped = keras.ops.reshape(y_true, (y_true.shape[0],int(y_true.shape[1]/3), 3))
    y_true_reshaped = keras.ops.reshape(y_true, (-1, keras.ops.cast(keras.ops.shape(y_true)[1]/3,"int32"), 3))
    y_true_expanded = keras.ops.expand_dims(y_true_reshaped, 1)
    y_true_expanded_transposed = keras.ops.transpose(y_true_expanded, axes=[0, 2, 1, 3])
    # Calculate pairwise differences
    true_diff = y_true_expanded - y_true_expanded_transposed
    # Calculate squared distances
    true_squared_distances = keras.ops.sum(keras.ops.square(true_diff), axis=-1)
    # Take square root to get distances
    dist_true = keras.ops.sqrt(true_squared_distances)    

    # Expand dimensions to enable broadcasting
    y_pred_reshaped = keras.ops.reshape(y_pred, (-1, keras.ops.cast(keras.ops.shape(y_pred)[1]/3,"int32"), 3))
    y_pred_expanded = keras.ops.expand_dims(y_pred_reshaped, 1)
    y_pred_expanded_transposed = keras.ops.transpose(y_pred_expanded, axes=[0, 2, 1, 3])
    # Calculate pairwise differences
    pred_diff = y_pred_expanded - y_pred_expanded_transposed
    # Calculate squared distances
    pred_squared_distances = keras.ops.sum(keras.ops.square(pred_diff), axis=-1)
    # Take square root to get distances
    dist_pred = keras.ops.sqrt(pred_squared_distances)

    dist_loss = mse(dist_pred, dist_true)
    
    msle = keras.ops.square(keras.ops.log(y_true + 1.) - keras.ops.log(y_pred + 1.))
    return dist_loss + msle * 1e-2
    
@saving.register_keras_serializable()
def msle_and_inv_dist(y_true, y_pred):
    # Expand dimensions to enable broadcasting
    y_true_reshaped = keras.ops.reshape(y_true, (-1, keras.ops.cast(keras.ops.shape(y_true)[1]/3,"int32"), 3))
    y_true_expanded = keras.ops.expand_dims(y_true_reshaped, 1)
    y_true_expanded_transposed = keras.ops.transpose(y_true_expanded, axes=[0, 2, 1, 3])
    # Calculate pairwise differences
    true_diff = y_true_expanded - y_true_expanded_transposed
    # Calculate squared distances
    true_squared_distances = keras.ops.sum(keras.ops.square(true_diff), axis=-1)
    # Take square root to get distances
    dist_true = keras.ops.sqrt(true_squared_distances)    

    # Expand dimensions to enable broadcasting
    y_pred_reshaped = keras.ops.reshape(y_pred, (-1, keras.ops.cast(keras.ops.shape(y_pred)[1]/3,"int32"), 3))
    y_pred_expanded = keras.ops.expand_dims(y_pred_reshaped, 1)
    y_pred_expanded_transposed = keras.ops.transpose(y_pred_expanded, axes=[0, 2, 1, 3])
    # Calculate pairwise differences
    pred_diff = y_pred_expanded - y_pred_expanded_transposed
    # Calculate squared distances
    pred_squared_distances = keras.ops.sum(keras.ops.square(pred_diff), axis=-1)
    # Take square root to get distances
    dist_pred = keras.ops.sqrt(pred_squared_distances)

    dist_true_inv = keras.ops.reciprocal(dist_true + 1e-3)
    dist_pred_inv = keras.ops.reciprocal(dist_pred + 1e-3)
    
    dist_inv_loss = mse(dist_pred_inv, dist_true_inv)
    
    msle = keras.ops.square(keras.ops.log(y_true + 1.) - keras.ops.log(y_pred + 1.))
    return dist_inv_loss + msle * 1e-2

@saving.register_keras_serializable()
class Sampling(keras.layers.Layer):
    def __init__(self, noise=0.05):
        super().__init__()
        self.noise = noise
    
    def call(self, inputs):
        z_mean = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim))
        return z_mean +  epsilon * self.noise

    def get_config(self):
        return {"noise": self.noise}

The architecture of the autoencoder is defined here

In [None]:
latent_space = 2
activation = "elu"

#encoder
encoder_inputs = keras.Input(shape=(x_train.shape[1],))
encoder_diff_ref = keras.ops.subtract(encoder_inputs, ref_str)

#ehl1 = layers.Dense(1024*3, activation=activation)(encoder_diff_ref)
#ebn1 = layers.BatchNormalization()(ehl1)
#
#ehl2 = layers.Dense(1024, activation=activation)(ebn1)
#ebn2 = layers.BatchNormalization()(ehl2)

ehl3 = layers.Dense(512, activation=activation)(encoder_diff_ref)
ebn3 = layers.BatchNormalization()(ehl3)

ehl4 = layers.Dense(128, activation=activation)(ebn3)
ebn4 = layers.BatchNormalization()(ehl4)

ehl5 = layers.Dense(32, activation=activation)(ebn4)

ehl6 = layers.Dense(8, activation=activation)(ehl5)

ls_mean = layers.Dense(latent_space, activation="tanh", name="z_mean")(ehl6)
ls = Sampling(noise=0.05)(ls_mean)

encoder = keras.Model(encoder_inputs, ls)

#decoder
decoder_inputs = keras.Input(shape=(latent_space,))

dhl1 = layers.Dense(8, activation=activation)(decoder_inputs)

dhl2 = layers.Dense(32, activation=activation)(dhl1)

dhl3 = layers.Dense(128, activation=activation)(dhl2)
dbn3 = layers.BatchNormalization()(dhl3)

dhl4 = layers.Dense(512, activation=activation)(dbn3)
dbn4 = layers.BatchNormalization()(dhl4)

#dhl5 = layers.Dense(1024, activation=activation)(dbn4)
#dbn5 = layers.BatchNormalization()(dhl5)
#
#dhl6 = layers.Dense(1024*3, activation=activation)(dbn5)
#dbn6 = layers.BatchNormalization()(dhl6)

decoder_outputs = layers.Dense(y_train.shape[1], activation="linear")(dbn4)
decoded_structure = keras.ops.add(decoder_outputs, ref_str)

decoder = keras.Model(decoder_inputs, decoded_structure)


model = keras.Model(encoder_inputs,decoder(encoder(encoder_inputs)))

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
metrics=["mse", pe_editlog_metric]

model.compile(
    optimizer=optimizer, 
    loss=mse, 
    metrics=metrics
)

Summary of the created models:

In [None]:
encoder.summary()
decoder.summary()

In [None]:
model.summary()

In [None]:
callbacks_list = [PlotLearning()]

## Training of the model

Training with the MSE loss function

In [None]:
model.fit(
    xy_train,  
    epochs=100,
    #batch_size=200,
    validation_data=xy_test, 
    callbacks=callbacks_list
)

In [None]:
model.save("./models/m_"+model_name+"_unfinished_1.keras")

In [None]:
model = saving.load_model("./models/m_"+model_name+"_unfinished_1.keras")

Fine tuning of the decoder with the loss function minimizing the error in pairwise distances between atoms

In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3), 
    loss=msle_and_dist, 
    metrics=metrics,
    run_eagerly=True
)

In [None]:
print(model.layers[-2].layers[-1].noise,encoder.layers[-1].noise)
model.layers[-2].trainable = False

In [None]:
model.fit(xy_train, epochs=5, validation_data=xy_test, callbacks=callbacks_list)

In [None]:
model.save("./models/m_"+model_name+"_unfinished_2.keras")

Fine tuning of the decoder with the loss function minimizing the error in reciprocal pairwise distances between atoms

In [None]:
model =   saving.load_model("./models/m_"+model_name+"_unfinished_2.keras")

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5), 
    loss=msle_and_inv_dist, 
    metrics=metrics,
    run_eagerly=True
)

print(model.layers[-2].layers[-1].noise,encoder.layers[-1].noise)
model.layers[-2].trainable = False

In [None]:
model.fit(xy_train, epochs=20, validation_data=xy_test, callbacks=callbacks_list)

Saving of the trained model

In [None]:
decoder = saving.load_model("./models/d_"+model_name+".keras")
encoder = saving.load_model("./models/e_"+model_name+".keras")
model =   saving.load_model("./models/m_"+model_name+".keras")

## Latent space visualization

Calculate the latent space coordinates of selected encoded protein structures. 
The `colvar` variable should load the properties of the structures from the training dataset - these have to be calculated separately

In [None]:
latent_space=2
each = 5
latents = encoder.predict(data[::each,:])

colvar = np.loadtxt("../data/"+system_name+"_rmsd_ds.xvg")

plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,1], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("RMSD (nm)")
plt.xlabel("latent value 1")
plt.ylabel("latent value 2")
#plt.title("RMSD")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig(system_name+"ls_rmsd.png")

In [None]:
encoded = encoder(data[0,:].reshape((1,y_train.shape[1])))
np.save("models/encoded_reference_"+model_name+".npy", encoded.numpy())

Plotting the coordinates of the encoded structures with coloring based on selected proterties

In [None]:
plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,1], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("RMSD (nm)")
plt.xlabel("latent value 0")
plt.ylabel("latent value 1")
plt.title("RMSD")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig("ls_rmsd.png")

In [None]:
plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,2], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("radius of gyration (nm)")
plt.xlabel("latent value 0")
plt.ylabel("latent value 1")
plt.title("Radius of gyration")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig("ls_rg.png")

In [None]:
plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,3], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("Alpha RMSD (nm)")
plt.xlabel("latent value 0")
plt.ylabel("latent value 1")
plt.title("Alpha helix content")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig("ls_alpha.png")

In [None]:
plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,4], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("Anti-beta RMSD (nm)")
plt.xlabel("latent value 0")
plt.ylabel("latent value 1")
plt.title("Antiparallel beta sheet content")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig("ls_antibeta.png")

In [None]:
plt.scatter(latents[:,0],latents[:,1], c=colvar[::each*data_density,5], cmap="jet", s=1)
cbar = plt.colorbar()
cbar.set_label("Para-beta RMSD (nm)")
plt.xlabel("latent value 0")
plt.ylabel("latent value 1")
plt.title("Parallel beta sheet content")
axes = plt.gca()
#axes.set(xlim=([-1.5,1.5]))
#axes.set(ylim=([-1.5,1.5]))
plt.savefig("ls_parabeta.png")