In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tensorflow.keras.models import load_model

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull
import pandas as pd
from tensorflow import keras
from sklearn.manifold import TSNE
import os
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
data = np.load('kinase_data_splits.npz')
X_train = data['X_train']
X_val = data['X_val']
X_test = data['X_test']
print("Train shape:", X_train.shape)
print("Validation shape:", X_val.shape)
print("Test shape:", X_test.shape)

Train shape: (10460, 5523)
Validation shape: (2243, 5523)
Test shape: (2242, 5523)


In [3]:
# --- Sampling Layer ---
class Sampling(layers.Layer):
    """Reparameterization trick layer."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def get_config(self):
        return super().get_config()

# --- Recurrent VAE Model ---
class RecurrentVAE(Model):
    def __init__(self, sequence_length=263, input_dim=21, latent_dim=32, lstm_units=128, **kwargs):
        """
        Parameters:
          sequence_length: Number of amino acids per sequence (e.g., 263)
          input_dim: Number of features per amino acid (e.g., 21 for one-hot encoding)
          latent_dim: Dimensionality of the latent space (e.g., 32)
          lstm_units: Number of units in the LSTM layers (e.g., 128)
        """
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.lstm_units = lstm_units

        # ---- Encoder ----
        # LSTM encoder: returns the final hidden state only.
        self.encoder_lstm = layers.LSTM(lstm_units, return_state=True, return_sequences=False)
        self.z_mean_dense = layers.Dense(latent_dim)
        self.z_log_var_dense = layers.Dense(latent_dim)
        self.sampling = Sampling()

        # ---- Decoder ----
        # Repeat latent vector over the time dimension:
        self.repeat_vector = layers.RepeatVector(sequence_length)
        # LSTM decoder that returns sequences.
        self.decoder_lstm = layers.LSTM(lstm_units, return_sequences=True)
        # Dense output layer to get back probability distribution over 21 channels.
        self.output_dense = layers.TimeDistributed(layers.Dense(input_dim, activation='softmax'))

    def encode(self, inputs, training=False):
        # inputs shape: (batch, sequence_length, input_dim)
        # Run the LSTM; we ignore the cell state.
        _, state_h, _ = self.encoder_lstm(inputs)
        z_mean = self.z_mean_dense(state_h)
        z_log_var = self.z_log_var_dense(state_h)
        z = self.sampling([z_mean, z_log_var])
        return z_mean, z_log_var, z

    def decode(self, z, training=False):
        # Repeat z over time steps to form a sequence.
        x = self.repeat_vector(z)  # shape: (batch, sequence_length, latent_dim)
        # Decode using LSTM.
        x = self.decoder_lstm(x)     # shape: (batch, sequence_length, lstm_units)
        reconstruction = self.output_dense(x)  # shape: (batch, sequence_length, input_dim)
        return reconstruction

    def call(self, inputs, training=False):
        z_mean, z_log_var, z = self.encode(inputs, training=training)
        reconstruction = self.decode(z, training=training)
        # Compute KL divergence loss.
        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        self.add_loss(kl_loss)
        return reconstruction

    def get_config(self):
        config = super(RecurrentVAE, self).get_config()
        config.update({
            'sequence_length': self.sequence_length,
            'input_dim': self.input_dim,
            'latent_dim': self.latent_dim,
            'lstm_units': self.lstm_units
        })
        return config

In [4]:
# --- Usage Example ---
if __name__ == "__main__":


    sequence_length = 263
    input_dim = 21
    # Reshape flattened data to (batch, sequence_length, input_dim)
    X_train_seq = X_train.reshape(-1, sequence_length, input_dim)
    X_val_seq = X_val.reshape(-1, sequence_length, input_dim)
    
    # Instantiate and compile the Recurrent VAE model.
    rec_vae = RecurrentVAE(sequence_length=sequence_length, input_dim=input_dim, latent_dim=32, lstm_units=128)
    rec_vae.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='categorical_crossentropy')
    
    # Build the model with a dummy input to initialize layers.
    dummy_input = np.zeros((1, sequence_length, input_dim), dtype=np.float32)
    _ = rec_vae(dummy_input)
    rec_vae.summary()
    
    # Define callbacks.
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint('best_recurrent_vae.keras', monitor='val_loss', save_best_only=True)
    ]
    
    # Train the model.
    history = rec_vae.fit(
        X_train_seq, X_train_seq,
        validation_data=(X_val_seq, X_val_seq),
        epochs=100,
        batch_size=128,
        callbacks=callbacks,
        shuffle=True
    )
    
    # Save final model in native Keras format.
    rec_vae.save('final_recurrent_vae.keras')


Epoch 1/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 1s/step - loss: 3.1134 - val_loss: 3.0404
Epoch 2/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m109s[0m 1s/step - loss: 3.0255 - val_loss: 2.9751
Epoch 3/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 1s/step - loss: 2.9687 - val_loss: 2.9596
Epoch 4/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 1s/step - loss: 2.9582 - val_loss: 2.9553
Epoch 5/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 1s/step - loss: 2.9550 - val_loss: 2.9533
Epoch 6/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m115s[0m 1s/step - loss: 2.9535 - val_loss: 2.9517
Epoch 7/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 1s/step - loss: 2.9517 - val_loss: 2.9505
Epoch 8/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 1s/step - loss: 2.9503 - val_loss: 2.9496
Epoch 9/100
[1m82/82[0m [32m━━━━━━━━━

In [9]:
print(tf.config.list_physical_devices())

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
