In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tensorflow.keras.models import load_model

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull
import pandas as pd
from tensorflow import keras
from sklearn.manifold import TSNE
import os
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
data = np.load('kinase_data_splits.npz')
X_train = data['X_train']
X_val = data['X_val']
X_test = data['X_test']
print("Train shape:", X_train.shape)
print("Validation shape:", X_val.shape)
print("Test shape:", X_test.shape)


Train shape: (10460, 5523)
Validation shape: (2243, 5523)
Test shape: (2242, 5523)


In [11]:
# ------ Sampling Layer ------
class Sampling(layers.Layer):
    """Reparameterization trick layer."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def get_config(self):
        return super().get_config()

# ------ Convolutional VAE Model ------
class ConvVAE(Model):
    def __init__(self, sequence_length=263, input_dim=21, latent_dim=32, **kwargs):
        """
        Parameters:
          sequence_length: Number of amino acids per sequence (e.g., 263).
          input_dim: One-hot encoded dimension (e.g., 21).
          latent_dim: Dimension of the latent space (e.g., 32).
        """
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # ------ Encoder ------
        # Input shape: (batch, 263, 21)
        self.conv1 = layers.Conv1D(filters=32, kernel_size=3, strides=1, padding='same', activation='relu')
        self.conv2 = layers.Conv1D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')
        self.conv3 = layers.Conv1D(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')
        self.flatten = layers.Flatten()
        
        # Compute the reduced sequence length using ceiling division.
        # After two conv layers with stride 2, reduced_seq_length = ceil(ceil(sequence_length/2)/2)
        self.reduced_seq_length = int(np.ceil(sequence_length / 2.0 / 2.0))  # For 263, expected to be 66.
        self.intermediate_dim = 128 * self.reduced_seq_length  # (128 * 66 = 8448)
        
        # Latent variable Dense layers.
        self.dense_z_mean = layers.Dense(latent_dim)
        self.dense_z_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()
        
        # ------ Decoder ------
        # Project latent vector back to flattened conv feature map.
        self.dense_decoder = layers.Dense(self.intermediate_dim, activation='relu')
        # Reshape to (reduced_seq_length, 128)
        self.reshape_decoder = layers.Reshape((self.reduced_seq_length, 128))
        
        # Upsampling and convolution to recover sequence length.
        self.upsample1 = layers.UpSampling1D(size=2)
        self.conv_dec1 = layers.Conv1D(filters=64, kernel_size=3, padding='same', activation='relu')
        self.upsample2 = layers.UpSampling1D(size=2)
        self.conv_dec2 = layers.Conv1D(filters=32, kernel_size=3, padding='same', activation='relu')
        # Final reconstruction layer: output probabilities over 21 classes.
        self.conv_dec3 = layers.Conv1D(filters=input_dim, kernel_size=3, padding='same', activation='softmax')
        # Crop one extra time step if the output sequence length is 264 instead of 263.
        self.crop = layers.Cropping1D(cropping=(0, 1))
    
    def encode(self, inputs, training=False):
        # Encoder pathway.
        x = self.conv1(inputs)                  # (batch, 263, 32)
        x = self.conv2(x)                       # (batch, ~132, 64)
        x = self.conv3(x)                       # (batch, ~66, 128) – expect about 66 timesteps.
        x = self.flatten(x)                     # (batch, intermediate_dim)
        z_mean = self.dense_z_mean(x)
        z_log_var = self.dense_z_log_var(x)
        z = self.sampling([z_mean, z_log_var])
        return z_mean, z_log_var, z
    
    def decode(self, z, training=False):
        x = self.dense_decoder(z)               # (batch, intermediate_dim)
        x = self.reshape_decoder(x)             # (batch, reduced_seq_length, 128)
        x = self.upsample1(x)                   # (batch, reduced_seq_length*2, 128)
        x = self.conv_dec1(x)                   # (batch, new_length, 64)
        x = self.upsample2(x)                   # (batch, reduced_seq_length*4, 64) -> likely 264 timesteps
        x = self.conv_dec2(x)                   # (batch, 264, 32)
        x = self.conv_dec3(x)                   # (batch, 264, 21)
        reconstruction = self.crop(x)           # Crop to (batch, 263, 21)
        return reconstruction
    
    def call(self, inputs, training=False):
        z_mean, z_log_var, z = self.encode(inputs, training=training)
        reconstruction = self.decode(z, training=training)
        # Compute KL divergence loss.
        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        self.add_loss(kl_loss)
        return reconstruction
    
    def get_config(self):
        config = super(ConvVAE, self).get_config()
        config.update({
            'sequence_length': self.sequence_length,
            'input_dim': self.input_dim,
            'latent_dim': self.latent_dim
        })
        return config


In [12]:
# --- Training Setup ---
if __name__ == "__main__":

    # Reshape flattened data to (batch, sequence_length, input_dim).
    sequence_length = 263
    input_dim = 21
    X_train_2D = X_train.reshape(-1, sequence_length, input_dim)
    X_val_2D   = X_val.reshape(-1, sequence_length, input_dim)
    
    # Instantiate and compile the ConvVAE.
    conv_vae = ConvVAE(sequence_length=sequence_length, input_dim=input_dim, latent_dim=32)
    conv_vae.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='categorical_crossentropy')
    
    # Build model with a dummy input to initialize layers.
    dummy_input = np.zeros((1, sequence_length, input_dim), dtype=np.float32)
    _ = conv_vae(dummy_input)
    conv_vae.summary()
    
    # Define callbacks.
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint('best_conv_vae.keras', monitor='val_loss', save_best_only=True)
    ]
    
    # Train the ConvVAE model.
    history = conv_vae.fit(
        X_train_2D, X_train_2D,
        validation_data=(X_val_2D, X_val_2D),
        epochs=100,
        batch_size=128,
        callbacks=callbacks,
        shuffle=True
    )
    
    # Save final model in native Keras format.
    conv_vae.save('final_conv_vae.keras')

Epoch 1/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 106ms/step - loss: 3.0678 - val_loss: 2.9740
Epoch 2/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 84ms/step - loss: 2.9480 - val_loss: 2.8823
Epoch 3/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 82ms/step - loss: 2.8534 - val_loss: 2.7723
Epoch 4/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 88ms/step - loss: 2.7458 - val_loss: 2.6741
Epoch 5/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 86ms/step - loss: 2.6514 - val_loss: 2.5955
Epoch 6/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 82ms/step - loss: 2.5776 - val_loss: 2.5320
Epoch 7/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 82ms/step - loss: 2.5156 - val_loss: 2.4778
Epoch 8/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 84ms/step - loss: 2.4626 - val_loss: 2.4298
Epoch 9/100
[1m82/82[0m [32m━━━━━━━