In [17]:
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tensorflow.keras.models import load_model

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull
import pandas as pd
from tensorflow import keras
from sklearn.manifold import TSNE
import os
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

### Transformaer vae
Encoder path: (263,21) → (263,64) → (64) → (32)
Decoder path: (32) → (263,128) → (263,21)

In [22]:

# ------ Custom Layers ------
class Sampling(layers.Layer):
    """Reparameterization trick layer"""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def get_config(self):
        return super().get_config()

class PositionalEncoding(layers.Layer):
    """Positional encoding layer for transformer models."""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        _, self.seq_length, self.d_model = input_shape
        super().build(input_shape)
    
    def get_angles(self, pos, i, d_model):
        # Ensure the base is a float: 10000.0 instead of 10000.
        angle_rates = 1 / tf.pow(tf.constant(10000.0, dtype=tf.float32), (2 * (i // 2)) / tf.cast(d_model, tf.float32))
        return pos * angle_rates

    def call(self, x):
        # x shape: (batch, seq_length, d_model)
        # Create a tensor for positions: shape (seq_length, 1)
        positions = tf.cast(tf.range(self.seq_length)[:, tf.newaxis], tf.float32)
        # Create a tensor for the dimensions: shape (1, d_model)
        dims = tf.cast(tf.range(self.d_model)[tf.newaxis, :], tf.float32)
        angle_rads = self.get_angles(positions, dims, self.d_model)
        
        # Apply sin to even indices and cos to odd indices
        sines = tf.sin(angle_rads[:, 0::2])
        cosines = tf.cos(angle_rads[:, 1::2])
        
        # Concatenate along the last dimension. We need to interleave sin and cos.
        # One way is to create a tensor of the same shape as angle_rads and fill
        # even indices with sines and odd indices with cosines.
        pos_encoding = tf.concat([sines, cosines], axis=-1)
        pos_encoding = pos_encoding[:, :self.d_model]  # Make sure the shape matches.
        pos_encoding = pos_encoding[tf.newaxis, ...]    # (1, seq_length, d_model)
        
        return x + pos_encoding

    def get_config(self):
        config = super().get_config()
        # Optionally include parameters like seq_length and d_model if needed.
        config.update({
            "seq_length": self.seq_length,
            "d_model": self.d_model,
        })
        return config

class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads)
        self.ffn = tf.keras.Sequential([
            layers.Dense(dff, activation='relu'),
            layers.Dense(d_model)
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dropout2 = layers.Dropout(dropout_rate)
    
    def call(self, x, training=False):
        attn_output = self.mha(query=x, key=x, value=x)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'd_model': self.mha.key_dim * self.mha.num_heads,
            # You could add more parameters here if needed.
        })
        return config

# ------ Custom Transformer Encoder/Decoder ------
class CustomTransformerEncoder(layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, seq_length, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_layers = num_layers
        self.d_model = d_model
        self.seq_length = seq_length
        self.pos_encoding = PositionalEncoding()
        self.enc_layers = [TransformerBlock(d_model, num_heads, dff, dropout_rate) for _ in range(num_layers)]
        self.dropout = layers.Dropout(dropout_rate)
    
    def call(self, x, training=False):
        x = self.pos_encoding(x)
        x = self.dropout(x, training=training)
        for layer in self.enc_layers:
            x = layer(x, training=training)
        return x
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'num_layers': self.num_layers,
            'd_model': self.d_model,
            'seq_length': self.seq_length
        })
        return config

class CustomTransformerDecoder(layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, seq_length, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_layers = num_layers
        self.d_model = d_model
        self.seq_length = seq_length
        self.pos_encoding = PositionalEncoding()
        self.dec_layers = [TransformerBlock(d_model, num_heads, dff, dropout_rate) for _ in range(num_layers)]
        self.dropout = layers.Dropout(dropout_rate)
    
    def call(self, x, training=False):
        x = self.pos_encoding(x)
        x = self.dropout(x, training=training)
        for layer in self.dec_layers:
            x = layer(x, training=training)
        return x
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'num_layers': self.num_layers,
            'd_model': self.d_model,
            'seq_length': self.seq_length
        })
        return config

# ------ Transformer VAE Model ------
class TransformerVAE(Model):
    def __init__(self, seq_length=263, input_dim=21, latent_dim=32, num_heads=4, intermediate_dim=128, **kwargs):
        super().__init__(**kwargs)
        self.seq_length = seq_length
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder: Add positional encoding, then transformer encoder blocks.
        self.encoder_projection = layers.Dense(intermediate_dim)  # Project input_dim to intermediate_dim.
        self.transformer_encoder = CustomTransformerEncoder(num_layers=2, d_model=intermediate_dim, 
                                                             num_heads=num_heads, dff=intermediate_dim*2,
                                                             seq_length=seq_length, dropout_rate=0.1)
        self.pool = layers.GlobalAveragePooling1D()
        self.z_mean_dense = layers.Dense(latent_dim)
        self.z_log_var_dense = layers.Dense(latent_dim)
        self.sampling = Sampling()
        
        # Decoder: Project latent back and reshape, followed by transformer decoder blocks.
        self.decoder_projection = layers.Dense(seq_length * intermediate_dim, activation='relu')
        self.decoder_reshape = layers.Reshape((seq_length, intermediate_dim))
        self.transformer_decoder = CustomTransformerDecoder(num_layers=2, d_model=intermediate_dim, 
                                                             num_heads=num_heads, dff=intermediate_dim*2,
                                                             seq_length=seq_length, dropout_rate=0.1)
        self.decoder_output = layers.Dense(input_dim, activation='softmax')
    
    def encode(self, inputs, training=False):
        # inputs shape: (batch, seq_length, input_dim)
        x = self.encoder_projection(inputs)    # (batch, seq_length, intermediate_dim)
        x = self.transformer_encoder(x, training=training)  # (batch, seq_length, intermediate_dim)
        x = self.pool(x)                       # (batch, intermediate_dim)
        z_mean = self.z_mean_dense(x)            # (batch, latent_dim)
        z_log_var = self.z_log_var_dense(x)      # (batch, latent_dim)
        return z_mean, z_log_var
    
    def decode(self, z, training=False):
        x = self.decoder_projection(z)          # (batch, seq_length * intermediate_dim)
        x = self.decoder_reshape(x)             # (batch, seq_length, intermediate_dim)
        x = self.transformer_decoder(x, training=training)  # (batch, seq_length, intermediate_dim)
        reconstruction = self.decoder_output(x) # (batch, seq_length, input_dim)
        return reconstruction
    
    def call(self, inputs, training=False):
        z_mean, z_log_var = self.encode(inputs, training=training)
        z = self.sampling([z_mean, z_log_var])
        reconstruction = self.decode(z, training=training)
        # Add KL divergence loss.
        kl_loss = -0.5 * tf.reduce_mean(tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1))
        self.add_loss(kl_loss)
        return reconstruction
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "seq_length": self.seq_length,
            "input_dim": self.input_dim,
            "latent_dim": self.latent_dim,
        })
        return config


In [24]:
# ------ Training Setup ------
if __name__ == "__main__":

    # Load pre-split data (assuming they are saved in 'kinase_data_splits.npz')
    data = np.load('kinase_data_splits.npz')
    X_train = data['X_train'].astype(np.float32)
    X_val = data['X_val'].astype(np.float32)
    X_test = data['X_test'].astype(np.float32)
    
    # Reshape data for transformer input: (batch, seq_length, input_dim)
    sequence_length = 263
    input_dim = 21
    X_train_t = X_train.reshape(-1, sequence_length, input_dim)
    X_val_t = X_val.reshape(-1, sequence_length, input_dim)
    X_test_t = X_test.reshape(-1, sequence_length, input_dim)
    
    # Create and compile the Transformer VAE
    transformer_vae = TransformerVAE(seq_length=sequence_length, input_dim=input_dim, latent_dim=32,
                                      num_heads=4, intermediate_dim=128)
    transformer_vae.compile(optimizer='adam', loss='categorical_crossentropy')
    
    # Build model with a dummy input to initialize layers.
    dummy_input = np.zeros((1, sequence_length, input_dim), dtype=np.float32)
    _ = transformer_vae(dummy_input)
    transformer_vae.summary()
    
    # Define callbacks.
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
        ModelCheckpoint('best_transformer_vae.keras', monitor='val_loss', save_best_only=True)
        ]

    
    # Train the model.
    history = transformer_vae.fit(
        X_train_t, X_train_t,
        validation_data=(X_val_t, X_val_t),
        epochs=100, batch_size=128,
        callbacks=callbacks,
        shuffle=True
    )
    
    # Final model save.
    transformer_vae.save('final_transformer_vae.keras')

Epoch 1/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 2s/step - loss: 5.3254 - val_loss: 2.2247
Epoch 2/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m165s[0m 2s/step - loss: 2.2528 - val_loss: 2.1317
Epoch 3/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m162s[0m 2s/step - loss: 2.1635 - val_loss: 2.1103
Epoch 4/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m162s[0m 2s/step - loss: 2.1371 - val_loss: 2.0940
Epoch 5/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m162s[0m 2s/step - loss: 2.1205 - val_loss: 2.0851
Epoch 6/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1778s[0m 22s/step - loss: 2.1094 - val_loss: 2.0767
Epoch 7/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m182s[0m 2s/step - loss: 2.1001 - val_loss: 2.0737
Epoch 8/100
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m173s[0m 2s/step - loss: 2.0951 - val_loss: 2.0739
Epoch 9/100
[1m82/82[0m [32m━━━━━━━