<a href="https://colab.research.google.com/github/Chrisvanhoorn/BioAI/blob/main/DCNN_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

try:
  drive.mount("/content/drive")
except Exception as error:
  if "EBUSY" in str(error):
    # Drive already mounted
    print("Drive already mounted")
  else:
    # Other error occurred
    raise error


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
from sklearn.utils.class_weight import compute_class_weight
from flax.training import train_state
from optax import adam
from jax.profiler import start_trace, stop_trace

In [None]:
data_path = '/content/drive/My Drive/ColabNotebooks/random_split'


# combine separate CSVs per folder
def read_data(split, data_folder=data_path):
    data = []
    for filename in os.listdir(os.path.join(data_folder, split)):
        with open(os.path.join(data_folder, split, filename)) as f:
            data.append(pd.read_csv(f, index_col=None))
    return pd.concat(data)


# three split folders
df_train = read_data('train')
df_dev = read_data('dev')
df_test = read_data('test')

def remove_duplicate_sequences(df_train, df_dev, df_test):
    # Removes duplicate sequences across and within all three dataframes.
    # Remove duplicates within each split
    df_train_unique = df_train.drop_duplicates(subset='sequence', keep='first')
    df_dev_unique = df_dev.drop_duplicates(subset='sequence', keep='first')
    df_test_unique = df_test.drop_duplicates(subset='sequence', keep='first')

    # Drop sequences from dev and test that are in train
    df_dev_unique = df_dev_unique[~df_dev_unique['sequence'].isin(df_train_unique['sequence'])]
    df_test_unique = df_test_unique[~df_test_unique['sequence'].isin(df_train_unique['sequence'])]

    # Drop sequences from test that are in dev
    df_test_unique = df_test_unique[~df_test_unique['sequence'].isin(df_dev_unique['sequence'])]

    return df_train_unique, df_dev_unique, df_test_unique

# Remove duplicates across and within all splits
df_train_unique, df_dev_unique, df_test_unique = remove_duplicate_sequences(df_train, df_dev, df_test)

#update dfs
df_train = df_train_unique
df_dev = df_dev_unique
df_test = df_test_unique

# data frame looks like:
df_train.head()

Unnamed: 0,family_id,sequence_name,family_accession,aligned_sequence,sequence
0,MORN_2,Q8EI47_SHEON/428-449,PF07661.13,LHGEFRNQTSSGQLLELI.NFNH,LHGEFRNQTSSGQLLELINFNH
1,Plexin_cytopl,H2TB23_TAKRU/1240-1793,PF08337.12,.MPFLDYKTYTDCNFFLPSKDGAND......AMITRKLQIPE.......,MPFLDYKTYTDCNFFLPSKDGANDAMITRKLQIPEARRAIVAQALN...
2,RT_RNaseH,H3H8E9_PHYRM/405-501,PF17917.1,DYSRRFHVFADAS.GH.QIGGVIVQ........................,DYSRRFHVFADASGHQIGGVIVQGRRILACFSRSMTDTQKKYSTME...
3,Transposase_20,Q981X5_RHILO/224-313,PF02371.16,VEAYQAMRGASFLVAVIFAAEI.GDV.RR.FDTPPQLMAFLGLVPG...,VEAYQAMRGASFLVAVIFAAEIGDVRRFDTPPQLMAFLGLVPGERS...
4,Mycobact_memb,MMPS4_MYCLE/16-154,PF05423.13,LSRIWIPLVILVVLVVGGFVVYRVHSYFASEKRESYADSNLGSSKP...,LSRIWIPLVILVVLVVGGFVVYRVHSYFASEKRESYADSNLGSSKP...


In [None]:
class ProteinHelper:
    def __init__(self, df_train, df_dev, df_test, batch_size=32, shuffle=True, pad=True, max_seq_len=None):
        # Initialize datasets views of just sequence and family for memory efficiency
        self.df_train = df_train.loc[:, ['sequence', 'family_accession']].copy(deep=False)
        self.df_dev = df_dev.loc[:, ['sequence', 'family_accession']].copy(deep=False)
        self.df_test = df_test.loc[:, ['sequence', 'family_accession']].copy(deep=False)

        self.original_index_train = self.df_train.index  # Store the original train index
        self.original_index_dev = self.df_dev.index
        self.original_index_test = self.df_test.index

        # Get maximum sequence length
        self.max_seq_len = max(len(str(seq)) for df in [df_train, df_dev, df_test] for seq in df['sequence'])

        # Vocabulary of amino acids (using list)
        self.vocab = sorted(list("".join(df_train['sequence']))) + ['_PAD_']

        # Create numerical indexes for AAs
        self.char2idx = {char: idx for idx, char in enumerate(self.vocab)}  # Puts '_PAD_' in 0
        self.idx2char = {idx: char for idx, char in enumerate(self.vocab)}

        # Data Generator Attributes
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_classes = len(set(self.df_train['family_accession']))

        # Prepare all datasets (training, validation, testing) - no pre-padding needed
        self.prepare_data()

    def _pad_sequence(self, seq, pad=True, max_seq_len=None):
        """Pads or truncates a sequence to a fixed length."""
        if max_seq_len is None:  # Check if max_seq_len is explicitly set to None
            max_seq_len = self.max_seq_len  # If not, use the instance variable

        if pad:
            seq_list = list(seq)
            padding_length = max(max_seq_len - len(seq_list), 2 * 5) #to correct for odd sequences, add an extra
            left_padding = padding_length // 2
            right_padding = padding_length - left_padding + (len(seq_list) % 2)
            seq_list = ['_PAD_'] * left_padding + seq_list + ['_PAD_'] * right_padding
            seq_list = seq_list[:max_seq_len]  # Truncate if longer than max_seq_len
        else:
            seq_list = list(seq)[:max_seq_len]  # Truncate without padding

        return seq_list  # Return as a list for batch_x conversion

    def data_generator(self, df, batch_size, shuffle, rng_key):
        num_samples = len(df)
        indices = jnp.arange(num_samples)
        if shuffle:
            indices = jax.random.permutation(rng_key, indices, independent=True)

        for start_idx in range(0, num_samples, batch_size):
            end_idx = min(start_idx + batch_size, num_samples)  # Ensure we don't go out of bounds
            batch_indices = indices[start_idx: end_idx]  # Get the shuffled indices for this batch

            # Get sequences and pad them
            batch_sequences = df['sequence'].values[batch_indices.tolist()].tolist()
            batch_x = jnp.array([[self.char2idx.get(char, 0) for char in self._pad_sequence(seq)] for seq in batch_sequences], dtype=jnp.int8)

            # Get family accessions (labels) using .values and integer indexing, but with the label2idx mapping
            batch_y = jnp.array([self.label2idx[x] for x in df['family_accession'].values[batch_indices.tolist()]], dtype=jnp.int16)

            # Dynamically calculate class weights for the batch
            batch_y_np = np.array(batch_y)
            if len(np.unique(batch_y_np)) > 1:
                # Check if more than one class is present
                batch_class_weights = compute_class_weight('balanced', classes=np.unique(batch_y_np), y=batch_y_np)
                batch_class_weights = {cls: weight for cls, weight in zip(np.unique(batch_y_np), batch_class_weights)}

                batch_sample_weights = np.array([batch_class_weights[label.item()] for label in batch_y])
            else:
                # Handle the case where only one class is present in the batch
                batch_sample_weights = np.ones(len(batch_y))

            yield batch_x, batch_y, batch_sample_weights


    def prepare_data(self, pad=True):
        # Use list for unique labels
        self.unique_labels = sorted(list(set(self.df_train['family_accession'])))
        self.label2idx = {label: i for i, label in enumerate(self.unique_labels)} # Label encoding
        self.y_train = np.array([self.label2idx[label] for label in self.df_train['family_accession']], dtype=np.int16)
        self.y_dev = np.array([self.label2idx[label] for label in self.df_dev['family_accession']], dtype=np.int16)
        self.y_test = np.array([self.label2idx[label] for label in self.df_test['family_accession']], dtype=np.int16)

        ## Convert labels to numerical representations
        #self.label2idx = {label: idx for idx, label in enumerate(set(self.df_train['family_accession']))}
        #self.y_train = np.array([self.label2idx[label] for label in self.df_train['family_accession']], dtype=np.int16)
        #self.y_dev = np.array([self.label2idx[label] for label in self.df_dev['family_accession']], dtype=np.int16)
        #self.y_test = np.array([self.label2idx[label] for label in self.df_test['family_accession']], dtype=np.int16)


In [None]:
class ProteinDilatedCNN(nn.Module):
    vocab_size: int
    embedding_dim: int
    num_classes: int

    @nn.compact
    def __call__(self, x, training: bool = False): # Add training argument

        # Function for dilated convolutions
        def dilated_conv(x, features, kernel_size, dilation_rate, name):
            padding = (kernel_size - 1) * dilation_rate // 2
            x = nn.Conv(features=features, kernel_size=(kernel_size,),
                        padding=((padding, padding),), name=name)(x)
            x = nn.BatchNorm(name=f"batchnorm_{name}", axis=-1, momentum=0.9, epsilon=1e-5, use_running_average=not training)(x) # Pass use_running_average
            return nn.relu(x)

        #Embedding layer
        x = nn.Embed(num_embeddings=self.vocab_size, features=self.embedding_dim)(x)

        # Convolutional Block 1
        x = dilated_conv(x, features=128, kernel_size=5, dilation_rate=2, name="conv1d_1")
        x = nn.max_pool(x, window_shape=(2,), strides=(2,)) # MaxPooling1D

        # Convolutional Block 2
        x = dilated_conv(x, features=256, kernel_size=5, dilation_rate=2, name="conv1d_2")
        x = nn.max_pool(x, window_shape=(2,), strides=(2,))

        # Convolutional Block 3
        x = dilated_conv(x, features=512, kernel_size=5, dilation_rate=4, name="conv1d_3")
        x = nn.max_pool(x, window_shape=(2,), strides=(2,))

        # Convolutional Block 4
        x = dilated_conv(x, features=1024, kernel_size=5, dilation_rate=8, name="conv1d_4")
        x = nn.max_pool(x, window_shape=(2,), strides=(2,))

        # Convolutional Block 5
        x = dilated_conv(x, features=2048, kernel_size=5, dilation_rate=16, name="conv1d_5")
        x = nn.max_pool(x, window_shape=(2,), strides=(2,))

        # Global Max Pooling
        #x = x.max(axis=1)

        # Global Average Pooling
        x = jnp.mean(x, axis=1)

        # Fully Connected Layer
        x = nn.Dense(features=1024, name="dense_1")(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not training)  # Dropout with training mode
        x = nn.Dense(features=self.num_classes, name="dense_2")(x)

        return x

# Create a train state
@jax.jit
def train_step(state, batch, dropout_rng):
    x, y, _ = batch

    # Print shapes of x and y
    #print("Shape of x:", x.shape)
    #print("Shape of y:", y.shape)
    rng = jax.random.PRNGKey(0)

    def loss_fn(params):
        # Extract learnable parameters, excluding batch stats
        #learnable_params = jax.tree.map(lambda x: x, params)
        #pdb.set_trace()
        logits = model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_rng}, mutable=['batch_stats'])  # Pass dropout_rng
        # One-hot encode labels
        one_hot_labels = jax.nn.one_hot(y, data_helper.num_classes)

        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels))
        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Create an evaluation state
@jax.jit
def eval_step(state, batch):
    x, y, _ = batch
    logits = model.apply({'params': state.params}, x, training=False)  # Pass training=False
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
    return loss, accuracy

In [None]:
# Hyperparameters
batch_size = 32
embedding_dim = 100
learning_rate = 0.001
num_epochs = 10

In [None]:
# still overshoots memory. start trace
output_dir = '/content/drive/My Drive/ColabNotebooks/data_helper'
os.makedirs(output_dir, exist_ok=True)
start_trace(output_dir)

# Data Helper Initialization
data_helper = ProteinHelper(df_train, df_dev, df_test, batch_size=batch_size)
stop_trace()  # Stop profiling

In [None]:
#stop_trace()  # Stop profiling

In [None]:
# Initialize your model
# Model Initialization and Train State
rng = jax.random.PRNGKey(0)
model = ProteinDilatedCNN(vocab_size=24, embedding_dim=embedding_dim, num_classes=data_helper.num_classes)
#params = model.init(rng, jnp.ones((1, data_helper.max_seq_len), jnp.int32), training=True) # Pass training=True during initialization
initial_params = model.init(rng, jnp.ones((1, data_helper.max_seq_len), jnp.int32), training=True)
params = initial_params['params']

tx = adam(learning_rate=learning_rate)


# Remove the extra nesting of 'params'
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Print the initial structure of 'state'
#print("Initial state structure:")
# Use jax.tree.map and handle non-array elements
#jax.tree.map(lambda x: print(f"  {type(x).__name__}: {getattr(x, 'shape', None)}"), state)

In [None]:
# Main Training Loop
for epoch in range(num_epochs):
    print(f"\n\n--- Starting Epoch {epoch+1} ---\n")

    train_losses = []
    rng, dropout_rng = jax.random.split(rng)  # Split RNG for dropout
    train_generator = data_helper.data_generator(data_helper.df_train, batch_size, shuffle=True, rng_key=rng)

    for batch_idx, batch in enumerate(train_generator):
        # Print the structure of 'state.params' before calling 'train_step'
        #print("State params before train_step:")
        #jax.tree.map(lambda x: print(f"  {type(x).__name__}: {getattr(x, 'shape', None)}"), state.params)
        state, loss = train_step(state, batch, dropout_rng)
        train_losses.append(loss)

        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx + 1}, Current Loss: {loss}")

    avg_train_loss = jnp.mean(jnp.array(train_losses))
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Average Training Loss: {avg_train_loss}")

    # Evaluation on Validation Set
    val_losses = []
    val_accuracies = []
    val_generator = data_helper.data_generator(data_helper.df_dev, batch_size, shuffle=False, rng_key=rng)
    for batch_idx, batch in enumerate(val_generator):
        loss, accuracy = eval_step(state, batch)
        val_losses.append(loss)
        val_accuracies.append(accuracy)

    # Calculate average metrics
    avg_val_loss = jnp.mean(jnp.array(val_losses))
    avg_val_accuracy = jnp.mean(jnp.array(val_accuracies))

    # print validation loss and accuracy
    print(f"Epoch {epoch + 1}/{num_epochs}, Avg Eval Loss: {avg_val_loss}, Avg Eval Accuracy: {avg_val_accuracy}")

# Testing
print("\n--- Final Evaluation on Test Set ---\n")

# Initialize lists to store losses and accuracies
test_losses = []
test_accuracies = []
# Iterate over the test set in batches
test_generator = data_helper.data_generator(data_helper.df_test, batch_size, shuffle=False, rng_key=rng)
for batch_idx, batch in enumerate(test_generator):
    loss, accuracy = eval_step(state, batch)
    test_losses.append(loss)
    test_accuracies.append(accuracy)

# Calculate and print the average test loss and accuracy
avg_test_loss = jnp.mean(jnp.array(test_losses))
avg_test_accuracy = jnp.mean(jnp.array(test_accuracies))
print(f"Test Loss: {avg_test_loss}")
print(f"Test Accuracy: {avg_test_accuracy}")




--- Starting Epoch 1 ---



ValueError: invalid shape in fixed-type tuple.

In [None]:
#!pip install ipdb
#%pdb on
%pdb off