In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tensorflow.keras.models import load_model

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull
import pandas as pd
from tensorflow import keras
from sklearn.manifold import TSNE
import os
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
def read_fasta(input_f):
    sequences = []
    current_seq = ""
    with open(input_f, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if current_seq:
                    sequences.append(current_seq)
                    current_seq = ""
            else:
                current_seq += line
        if current_seq:
            sequences.append(current_seq)
    return sequences

def select_random_sequences(sequences, num_samples=10000):
    if len(sequences) < num_samples :
        raise ValueError("The number of requested sequences exceeds the initial list size.")

    return random.sample(sequences, num_samples)


def encode_sequences_one_hot_with_gap(sequences, max_length=None) :

    amino_acids = '-ACDEFGHIKLMNPQRSTVWY'
    aa_to_idx = {aa: idx for idx, aa in enumerate(amino_acids)}
    valid_sequences = [seq for seq in sequences if all(aa in aa_to_idx for aa in seq)]

    if not valid_sequences :  # Return empty array if no valid sequences remain
        return np.array([])

    # Set maximum length
    if max_length is None :
        max_length = max(len(seq) for seq in valid_sequences)

    # Initialize the output matrix with zeros
    M = len(valid_sequences)
    L = max_length
    encoded_matrix = np.zeros((M, L, len(amino_acids)), dtype=np.float32)

    # Encode each sequence
    for i, seq in enumerate(valid_sequences) :
        for j, aa in enumerate(seq[:max_length]) :  # Truncate sequences longer than max_length
            encoded_matrix[i, j, aa_to_idx[aa]] = 1.0  # One-hot encode valid amino acids and gaps
    return encoded_matrix

In [3]:
seq_list = read_fasta('PF00069_noinserts_gaps_noduplicates.fasta')

In [4]:
small_batch_seq_list = select_random_sequences(seq_list, 200000)

In [5]:
encoded_matrix = encode_sequences_one_hot_with_gap(small_batch_seq_list, max_length=None)
M, L, A = encoded_matrix.shape  # M: number of sequences, L: sequence length, A: alphabet size (21)
flattened_matrix = encoded_matrix.reshape(M, L*A)  # Shape: (M, L * 21)
print(flattened_matrix.shape)
# Save the flattened matrix in a compressed format

(199154, 5523)


In [6]:
from sklearn.model_selection import train_test_split
# First, split into training+validation and test (e.g., 85% for training+validation, 15% for test)
X_train_val, X_test = train_test_split(flattened_matrix, test_size=0.15, random_state=42)

# Now split training+validation into training and validation (e.g., 82.35% training, 17.65% validation)
# so that overall it is (70% train, 15% val, 15% test)
X_train, X_val = train_test_split(X_train_val, test_size=0.1765, random_state=42)

print("Train shape:", X_train.shape)
print("Validation shape:", X_val.shape)
print("Test shape:", X_test.shape)


Train shape: (139402, 5523)
Validation shape: (29878, 5523)
Test shape: (29874, 5523)


In [7]:
# ------ Sampling Layer ------
class Sampling(layers.Layer):
    """Reparameterization trick layer."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def get_config(self):
        return super().get_config()

# ------ Convolutional VAE Model ------
class ConvVAE(Model):
    def __init__(self, sequence_length=263, input_dim=21, latent_dim=32, **kwargs):
        """
        Parameters:
          sequence_length: Number of amino acids per sequence (e.g., 263).
          input_dim: One-hot encoded dimension (e.g., 21).
          latent_dim: Dimension of the latent space (e.g., 32).
        """
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # ------ Encoder ------
        # Input shape: (batch, 263, 21)
        self.conv1 = layers.Conv1D(filters=32, kernel_size=3, strides=1, padding='same', activation='relu')
        self.conv2 = layers.Conv1D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')
        self.conv3 = layers.Conv1D(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')
        self.flatten = layers.Flatten()
        
        # Compute the reduced sequence length using ceiling division.
        # After two conv layers with stride 2, reduced_seq_length = ceil(ceil(sequence_length/2)/2)
        self.reduced_seq_length = int(np.ceil(sequence_length / 2.0 / 2.0))  # For 263, expected to be 66.
        self.intermediate_dim = 128 * self.reduced_seq_length  # (128 * 66 = 8448)
        
        # Latent variable Dense layers.
        self.dense_z_mean = layers.Dense(latent_dim)
        self.dense_z_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()
        
        # ------ Decoder ------
        # Project latent vector back to flattened conv feature map.
        self.dense_decoder = layers.Dense(self.intermediate_dim, activation='relu')
        # Reshape to (reduced_seq_length, 128)
        self.reshape_decoder = layers.Reshape((self.reduced_seq_length, 128))
        
        # Upsampling and convolution to recover sequence length.
        self.upsample1 = layers.UpSampling1D(size=2)
        self.conv_dec1 = layers.Conv1D(filters=64, kernel_size=3, padding='same', activation='relu')
        self.upsample2 = layers.UpSampling1D(size=2)
        self.conv_dec2 = layers.Conv1D(filters=32, kernel_size=3, padding='same', activation='relu')
        # Final reconstruction layer: output probabilities over 21 classes.
        self.conv_dec3 = layers.Conv1D(filters=input_dim, kernel_size=3, padding='same', activation='softmax')
        # Crop one extra time step if the output sequence length is 264 instead of 263.
        self.crop = layers.Cropping1D(cropping=(0, 1))
    
    def encode(self, inputs, training=False):
        # Encoder pathway.
        x = self.conv1(inputs)                  # (batch, 263, 32)
        x = self.conv2(x)                       # (batch, ~132, 64)
        x = self.conv3(x)                       # (batch, ~66, 128) – expect about 66 timesteps.
        x = self.flatten(x)                     # (batch, intermediate_dim)
        z_mean = self.dense_z_mean(x)
        z_log_var = self.dense_z_log_var(x)
        z = self.sampling([z_mean, z_log_var])
        return z_mean, z_log_var, z
    
    def decode(self, z, training=False):
        x = self.dense_decoder(z)               # (batch, intermediate_dim)
        x = self.reshape_decoder(x)             # (batch, reduced_seq_length, 128)
        x = self.upsample1(x)                   # (batch, reduced_seq_length*2, 128)
        x = self.conv_dec1(x)                   # (batch, new_length, 64)
        x = self.upsample2(x)                   # (batch, reduced_seq_length*4, 64) -> likely 264 timesteps
        x = self.conv_dec2(x)                   # (batch, 264, 32)
        x = self.conv_dec3(x)                   # (batch, 264, 21)
        reconstruction = self.crop(x)           # Crop to (batch, 263, 21)
        return reconstruction
    
    def call(self, inputs, training=False):
        z_mean, z_log_var, z = self.encode(inputs, training=training)
        reconstruction = self.decode(z, training=training)
        # Compute KL divergence loss.
        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        self.add_loss(kl_loss)
        return reconstruction
    
    def get_config(self):
        config = super(ConvVAE, self).get_config()
        config.update({
            'sequence_length': self.sequence_length,
            'input_dim': self.input_dim,
            'latent_dim': self.latent_dim
        })
        return config


In [8]:
# --- Training Setup ---
if __name__ == "__main__":

    # Reshape flattened data to (batch, sequence_length, input_dim).
    sequence_length = 263
    input_dim = 21
    X_train_2D = X_train.reshape(-1, sequence_length, input_dim)
    X_val_2D   = X_val.reshape(-1, sequence_length, input_dim)
    
    # Instantiate and compile the ConvVAE.
    conv_vae = ConvVAE(sequence_length=sequence_length, input_dim=input_dim, latent_dim=32)
    conv_vae.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='categorical_crossentropy')
    
    # Build model with a dummy input to initialize layers.
    dummy_input = np.zeros((1, sequence_length, input_dim), dtype=np.float32)
    _ = conv_vae(dummy_input)
    conv_vae.summary()
    
    # Define callbacks.
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint('best_conv_vae.keras', monitor='val_loss', save_best_only=True)
    ]
    
    # Train the ConvVAE model.
    history = conv_vae.fit(
        X_train_2D, X_train_2D,
        validation_data=(X_val_2D, X_val_2D),
        epochs=100,
        batch_size=128,
        callbacks=callbacks,
        shuffle=True
    )
    
    # Save final model in native Keras format.
    conv_vae.save('final_conv_vae.keras')

Epoch 1/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 78ms/step - loss: 2.7755 - val_loss: 2.2832
Epoch 2/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 70ms/step - loss: 2.2349 - val_loss: 2.1561
Epoch 3/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 75ms/step - loss: 2.1432 - val_loss: 2.1167
Epoch 4/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 84ms/step - loss: 2.1098 - val_loss: 2.0973
Epoch 5/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 82ms/step - loss: 2.0929 - val_loss: 2.0864
Epoch 6/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 84ms/step - loss: 2.0833 - val_loss: 2.0795
Epoch 7/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 80ms/step - loss: 2.0762 - val_loss: 2.0746
Epoch 8/100
[1m1090/1090[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 84ms/step - loss: 2.0712 - val_loss: 2.0710
