In [63]:
# testing the subtract model
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras import layers
import matplotlib.pyplot as plt


In [64]:
# generate dummy data


# generate a data that is pep one hot encoding with N=14, d=21. mhc_emb with N= 312 and d =1152. a contact map that is a soft weights over 312 mhc positions

# Create dummy data
num_samples = 1000

# Peptide One-Hot Encoding (N=14, d=21)
pep_len = 14
pep_alphabet_size = 21
pep_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, pep_len))
pep_OHE = tf.one_hot(pep_indices, depth=pep_alphabet_size, dtype=tf.float32)

# MHC Embedding (N=312, d=1152)
mhc_len = 312
mhc_embedding_dim = 1152
mhc_emb = tf.random.normal(shape=(num_samples, mhc_len, mhc_embedding_dim))

# Contact Map (soft weights over 312 MHC positions for each of the 14 peptide positions)
contact_map_logits = tf.random.normal(shape=(num_samples, pep_len, mhc_len))
contact_map = tf.nn.softmax(contact_map_logits, axis=-1)

# Define pad and mask tokens
pad_token = -2.0
mask_token = -1.0

# Generate peptide mask
# Let's assume a variable number of padded/masked tokens per sample
pep_mask = np.ones((num_samples, pep_len), dtype=np.float32)
for i in range(num_samples):
    # Add 0 to 3 padded tokens at the end
    num_padded = np.random.randint(0, 4)
    if num_padded > 0:
        pep_mask[i, -num_padded:] = pad_token
    # Add 0 to 2 masked tokens at random positions that are not padded
    num_masked = np.random.randint(0, 3)
    valid_indices = np.where(pep_mask[i] == 1.0)[0]
    if num_masked > 0 and len(valid_indices) >= num_masked:
        mask_indices = np.random.choice(valid_indices, num_masked, replace=False)
        pep_mask[i, mask_indices] = mask_token

pep_mask = tf.constant(pep_mask, dtype=tf.float32)

# Generate MHC mask (similar logic)
mhc_mask = np.ones((num_samples, mhc_len), dtype=np.float32)
for i in range(num_samples):
    # Add 0 to 10 padded tokens
    num_padded = np.random.randint(0, 11)
    if num_padded > 0:
        mhc_mask[i, -num_padded:] = pad_token
mhc_mask = tf.constant(mhc_mask, dtype=tf.float32)


# Zero out masked positions in the data
# Create boolean masks where True means the position is valid (not padded or masked)
pep_bool_mask = tf.cast(pep_mask > 0, dtype=tf.float32)
mhc_bool_mask = tf.cast(mhc_mask > 0, dtype=tf.float32)

# Apply masks to zero out data. Unsqueeze to allow broadcasting.
pep_OHE = pep_OHE * pep_bool_mask[:, :, tf.newaxis]
mhc_emb = mhc_emb * mhc_bool_mask[:, :, tf.newaxis]



In [65]:

class MaskedEmbedding(keras.layers.Layer):
    def __init__(self, mask_token=-1., pad_token=-2., name='masked_embedding'):
        super().__init__(name=name)
        self.mask_token = mask_token
        self.pad_token = pad_token

    def call(self, x, mask):
        """
        Args:
            x: Input tensor of shape (B, N, D)
            mask: Tensor of shape (B, N)
        Returns:
            Tensor with masked positions set to zero.
        """
        mask = tf.cast(mask, tf.float32)
        mask = tf.where((mask == self.pad_token) | (mask == self.mask_token), 0., 1.)
        return x * mask[:, :, tf.newaxis]  # Apply mask to zero out positions


class PositionalEncoding(keras.layers.Layer):
    """
    Sinusoidal Positional Encoding layer that applies encodings
    only to non-masked tokens.

    Args:
        embed_dim (int): Dimension of embeddings (must match input last dim).
        max_len (int): Maximum sequence length expected (used to precompute encodings).
    """

    def __init__(self, embed_dim, pos_range=100, mask_token=-1., pad_token=-2., name='positional_encoding'):
        super().__init__(name=name)
        self.embed_dim = embed_dim
        self.pos_range = pos_range
        self.mask_token = mask_token
        self.pad_token = pad_token

    def build(self, x):
        # Create (1, pos_range, embed_dim) encoding matrix
        pos = tf.range(self.pos_range, dtype=tf.float32)[:, tf.newaxis]  # (pos_range, 1)
        i = tf.range(self.embed_dim, dtype=tf.float32)[tf.newaxis, :]  # (1, embed_dim)
        #angle_rates = 1 / tf.pow(300.0, (2 * (i // 2)) / tf.cast(self.embed_dim, tf.float32))
        angle_rates = tf.pow(300.0, -(2 * tf.floor(i / 2)) / tf.cast(self.embed_dim, tf.float32))
        angle_rads = pos * angle_rates  # (pos_range, embed_dim)

        # Apply sin to even indices, cos to odd indices
        sines = tf.sin(angle_rads[:, 0::2])
        cosines = tf.cos(angle_rads[:, 1::2])

        pos_encoding = tf.concat([sines, cosines], axis=-1)  # (max_len, embed_dim)
        pos_encoding = pos_encoding[tf.newaxis, ...]  # (1, max_len, embed_dim)
        self.pos_encoding = tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x, mask):
        """
        Args:
            x: Input tensor of shape (B, N, D)
            mask: Tensor of shape (B,N)
        Returns:
            Tensor with positional encodings added for masked and non padded tokens.
        """
        seq_len = tf.shape(x)[1]
        pe = self.pos_encoding[:, :seq_len, :]  # (1, N, D)
        mask = tf.cast(mask[:, :, tf.newaxis], tf.float32)  # (B, N, 1)
        mask = tf.where(mask == self.pad_token, 0., 1.)
        pe = pe * mask  # zero out positions where mask is 0

        return x + pe

class SubtractLayer(keras.layers.Layer):
    """
    Custom layer to subtract a tensor from another tensor.
    Tensor1: (B, P, D) -> (B, P*D) -> (B, M, P*D)
    Tensor2: (B, M, D) -> (B, M, P*D)
    Output: = Tensor2 - Tensor1
    """
    def __init__(self, mask_token=-1., pad_token=-2., **kwargs):
        """Initialize the layer."""
        super(SubtractLayer, self).__init__(**kwargs)
        self.mask_token = mask_token
        self.pad_token = pad_token

    def call(self, peptide, pep_mask, mhc, mhc_mask):
        B = tf.shape(peptide)[0]
        P = tf.shape(peptide)[1]
        D = tf.shape(peptide)[2]
        M = tf.shape(mhc)[1]
        P_D = P * D

        pep_mask = tf.cast(pep_mask, tf.float32)
        mhc_mask = tf.cast(mhc_mask, tf.float32)

        pep_mask = tf.where(pep_mask == self.pad_token, x=0., y=1.)  # (B,P)
        mhc_mask = tf.where(mhc_mask == self.pad_token, x=0., y=1.)

        # peptide  (B,P,D) -> (B,P*D) -> (B,M,P*D)
        peptide_flat = tf.reshape(peptide, (B, P_D))
        peptide_exp = tf.repeat(peptide_flat[:, tf.newaxis, :], repeats=M, axis=1)
        # mhc       (B,M,D) -> tile last axis P times -> (B,M,P*D)
        mhc_exp = tf.tile(mhc, [1, 1, P])
        result = mhc_exp - peptide_exp  # (B,M,P*D)
        # peptide mask  (B,P) -> (B,P,D) -> flatten -> (B,P*D) -> (B,M,P*D)
        pep_mask_PD = tf.tile(pep_mask[:, :, tf.newaxis], [1, 1, D])  # (B,P,D)
        pep_mask_PD = tf.reshape(pep_mask_PD, (B, P_D))  # (B,P*D)
        pep_mask_PD = tf.repeat(pep_mask_PD[:, tf.newaxis, :], repeats=M, axis=1)  # (B,M,P*D)
        # mhc mask      (B,M) -> (B,M,1) -> repeat P*D along last axis
        mhc_mask_PD = tf.repeat(mhc_mask[:, :, tf.newaxis], repeats=P_D, axis=2)  # (B,M,P*D)
        combined_mask = tf.logical_and(tf.cast(pep_mask_PD, tf.bool), tf.cast(mhc_mask_PD, tf.bool))
        masked_result = tf.where(combined_mask, result, tf.zeros_like(result))
        return masked_result

class AddGaussianNoise(layers.Layer):
    def __init__(self, std=0.1, **kw): super().__init__(**kw); self.std = std

    def call(self, x, training=None):
        if training: return x + tf.random.normal(tf.shape(x), stddev=self.std)
        return x

class AttentionLayer(keras.layers.Layer):
    """
    Custom multi-head attention layer supporting self- and cross-attention.

    Args:
        query_dim (int): Input feature dimension for query.
        context_dim (int): Input feature dimension for context (key and value).
        output_dim (int): Output feature dimension.
        type (str): 'self' or 'cross'.
        heads (int): Number of attention heads.
        resnet (bool): Whether to use residual connection.
        return_att_weights (bool): Whether to return attention weights.
        name (str): Layer name.
        epsilon (float): Epsilon for layer normalization.
        gate (bool): Whether to use gating mechanism.
        mask_token (float): Value for masked tokens.
        pad_token (float): Value for padded tokens.
    """
    def __init__(self, query_dim, context_dim, output_dim, type, heads=4,
                 resnet=True, return_att_weights=False, name='attention',
                 epsilon=1e-6, gate=True, mask_token=-1., pad_token=-2.):
        super().__init__(name=name)
        assert isinstance(query_dim, int) and isinstance(context_dim, int) and isinstance(output_dim, int)
        assert type in ['self', 'cross']
        if resnet:
            assert query_dim == output_dim
        self.query_dim = query_dim
        self.context_dim = context_dim
        self.output_dim = output_dim
        self.type = type
        self.heads = heads
        self.resnet = resnet
        self.return_att_weights = return_att_weights
        self.epsilon = epsilon
        self.gate = gate
        self.mask_token = mask_token
        self.pad_token = pad_token
        self.att_dim = output_dim // heads  # Attention dimension per head

    def build(self, input_shape):
        # Projection weights
        self.q_proj = self.add_weight(shape=(self.heads, self.query_dim, self.att_dim),
                                      initializer='random_normal', trainable=True, name=f'q_proj_{self.name}')
        self.k_proj = self.add_weight(shape=(self.heads, self.context_dim, self.att_dim),
                                      initializer='random_normal', trainable=True, name=f'k_proj_{self.name}')
        self.v_proj = self.add_weight(shape=(self.heads, self.context_dim, self.att_dim),
                                      initializer='random_normal', trainable=True, name=f'v_proj_{self.name}')
        if self.gate:
            self.g = self.add_weight(shape=(self.heads, self.query_dim, self.att_dim),
                                     initializer='random_uniform', trainable=True, name=f'gate_{self.name}')
        self.norm = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_{self.name}')
        if self.type == 'cross':
            self.norm_context = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_context_{self.name}')
        self.norm_out = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_out_{self.name}')
        if self.resnet:
            self.norm_resnet = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_resnet_{self.name}')
        self.out_w = self.add_weight(shape=(self.heads * self.att_dim, self.output_dim),
                                     initializer='random_normal', trainable=True, name=f'outw_{self.name}')
        self.out_b = self.add_weight(shape=(self.output_dim,), initializer='zeros',
                                     trainable=True, name=f'outb_{self.name}')
        self.scale = 1.0 / tf.math.sqrt(tf.cast(self.att_dim, tf.float32))

    def call(self, x, mask, context=None, context_mask=None):
        """
        Args:
            x: Tensor of shape (B, N, query_dim) for query.
            mask: Tensor of shape (B, N).
            context: Tensor of shape (B, M, context_dim) for key/value in cross-attention.
            context_mask: Tensor of shape (B, M) for context.
        """
        mask = tf.cast(mask, tf.float32)
        if self.type == 'self':
            context = x
            context_mask = mask
            q_input = k_input = v_input = self.norm(x)
            mask_q = mask_k = tf.where(mask == self.pad_token, 0., 1.)
        else:
            assert context is not None and context_mask is not None
            q_input = self.norm(x)
            k_input = v_input = self.norm_context(context)
            mask_q = tf.where(mask == self.pad_token, 0., 1.)
            mask_k = tf.where(context_mask == self.pad_token, 0., 1.)

        # Project query, key, value
        q = tf.einsum('bnd,hde->bhne', q_input, self.q_proj)
        k = tf.einsum('bmd,hde->bhme', k_input, self.k_proj)
        v = tf.einsum('bmd,hde->bhme', v_input, self.v_proj)

        # Compute attention scores
        att = tf.einsum('bhne,bhme->bhnm', q, k) * self.scale
        mask_q_exp = tf.expand_dims(mask_q, axis=1)
        mask_k_exp = tf.expand_dims(mask_k, axis=1)
        attention_mask = tf.einsum('bqn,bkm->bqnm', mask_q_exp, mask_k_exp)
        attention_mask = tf.broadcast_to(attention_mask, tf.shape(att))
        att += (1.0 - attention_mask) * -1e9
        att = tf.nn.softmax(att, axis=-1) * attention_mask

        # Compute output
        out = tf.einsum('bhnm,bhme->bhne', att, v)
        if self.gate:
            g = tf.einsum('bnd,hde->bhne', q_input, self.g)
            g = tf.nn.sigmoid(g)
            out *= g

        out = tf.transpose(out, [0, 2, 1, 3])
        out = tf.reshape(out, [tf.shape(x)[0], tf.shape(x)[1], self.heads * self.att_dim])
        out = tf.matmul(out, self.out_w) + self.out_b

        if self.resnet:
            out += x
            out = self.norm_resnet(out)
        out = self.norm_out(out)
        mask_exp = tf.expand_dims(mask_q, axis=-1)
        out *= mask_exp

        return (out, att) if self.return_att_weights else out

In [89]:
class Expert(layers.Layer):
    """A binary prediction expert with added complexity."""

    def __init__(self, input_dim, hidden_dim, output_dim=1, dropout_rate=0.2):
        super().__init__()
        self.fc1 = layers.Dense(hidden_dim, activation='relu', input_shape=(input_dim,))
        self.dropout1 = layers.Dropout(dropout_rate)
        self.fc2 = layers.Dense(hidden_dim // 2, activation='relu')
        self.dropout2 = layers.Dropout(dropout_rate)
        self.fc3 = layers.Dense(output_dim)

    def call(self, x, training=False):
        x = self.fc1(x)
        x = self.dropout1(x, training=training)
        x = self.fc2(x)
        x = self.dropout2(x, training=training)
        x = self.fc3(x)
        return x

class EnhancedMixtureOfExperts(layers.Layer):
    """
    Enhanced Mixture of Experts layer that uses cluster assignments.

    This implementation eliminates the need for a SparseDispatcher by:
    - During training: Using hard clustering to train specific experts
    - During inference: Using soft clustering to mix the experts' weights
    """

    def __init__(self, input_dim, hidden_dim, num_experts, output_dim=1, dropout_rate=0.2):
        super().__init__()
        self.num_experts = num_experts
        self.output_dim = output_dim

        # Create n experts
        self.experts = [
            Expert(input_dim, hidden_dim, output_dim, dropout_rate)
            for _ in range(num_experts)
        ]

    def convert_to_hard_clustering(self, soft_clusters):
        """Convert soft clustering values to hard clustering (one-hot encoding)"""
        # Get the index of the maximum value for each sample
        hard_indices = tf.argmax(soft_clusters, axis=1)
        # Convert to one-hot encoding
        return tf.one_hot(hard_indices, depth=self.num_experts)

    def call(self, inputs, training=False):
        # Unpack inputs
        if isinstance(inputs, tuple) and len(inputs) == 2:
            x, soft_cluster_probs = inputs
        else:
            raise ValueError("Inputs must include both features and clustering values")

        batch_size = tf.shape(x)[0]

        # Convert to hard clustering during training if requested
        if training:
            clustering = self.convert_to_hard_clustering(soft_cluster_probs)
        else:
            clustering = soft_cluster_probs

        # Initialize output tensor
        combined_output = tf.zeros([batch_size, self.output_dim])

        # Process each expert
        for i, expert in enumerate(self.experts):
            # Get the weight for this expert for each sample in the batch
            expert_weights = clustering[:, i:i + 1]  # Shape: [batch_size, 1]

            # Only compute outputs for samples with non-zero weights
            # to save computation during training with hard clustering
            if training:
                # Find samples assigned to this expert
                assigned_indices = tf.where(expert_weights[:, 0] > 0)[:, 0]

                def expert_computation():
                    """The computation to run if there are assigned samples."""
                    assigned_x = tf.gather(x, assigned_indices)
                    expert_output = expert(assigned_x, training=training)
                    indices = tf.expand_dims(assigned_indices, axis=1)
                    return tf.scatter_nd(indices, expert_output, [batch_size, self.output_dim])

                def no_computation():
                    """Return zeros if no samples are assigned."""
                    return tf.zeros([batch_size, self.output_dim])

                # Use tf.cond to handle control flow in graph mode
                update = tf.cond(
                    tf.size(assigned_indices) > 0,
                    expert_computation,
                    no_computation
                )
                combined_output += update
            else:
                # During inference or when using soft clustering:
                # Compute expert output for all samples
                expert_output = expert(x, training=training)

                # Weight the output by the clustering values
                weighted_output = expert_output * expert_weights

                # Add to combined output
                combined_output += weighted_output

        return combined_output

In [104]:
# def masked_categorical_crossentropy(y_true_y_pred, mask, pad_token=-2.0):
#     """
#     Compute masked categorical cross-entropy loss.
#
#     Args:
#         y_true: True labels (tensor).
#         y_pred: Predicted probabilities (tensor).
#         mask: Mask tensor indicating positions to include in the loss.
#
#     Returns:
#         Mean masked loss (tensor).
#     """
#     y_true, y_pred = tf.split(y_true_y_pred, num_or_size_splits=2, axis=-1)
#     # loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
#     loss = -tf.reduce_sum(y_true * tf.math.log(tf.clip_by_value(y_pred, 1e-7, 1.0)), axis=-1) #(B,N)
#     mask = tf.cast(mask, tf.float32)  # Ensure mask is float
#     mask = tf.where(mask == pad_token, 0.0, 1.0)  # Convert pad token to 0.0 and others to 1.0
#     if tf.rank(mask) > tf.rank(loss):
#         mask = tf.squeeze(mask, axis=-1)
#     loss = loss * mask
#     loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
#     return loss

def masked_categorical_crossentropy(y_true_and_pred, mask, pad_token=-2.0):
    """
    Compute masked categorical cross-entropy loss.

    Args:
        y_true_and_pred: Concatenated tensor of true labels and predictions.
        mask: Mask tensor indicating positions to include in the loss.
        pad_token: Value of the padding token in the mask.

    Returns:
        Mean masked loss (tensor).
    """
    y_true, y_pred = tf.split(y_true_and_pred, num_or_size_splits=2, axis=-1)
    loss = -tf.reduce_sum(y_true * tf.math.log(tf.clip_by_value(y_pred, 1e-9, 1.0)), axis=-1)

    mask = tf.cast(mask, tf.float32)
    mask = tf.where(mask == pad_token, 0.0, 1.0)

    # Squeeze the mask only if its rank is greater than the loss's rank
    # and its last dimension is 1, which is a common scenario for masks.
    if tf.rank(mask) > tf.rank(loss):
        if mask.shape[-1] == 1:
            mask = tf.squeeze(mask, axis=-1)

    loss = loss * mask

    # Avoid division by zero if mask is all zeros
    total_loss = tf.reduce_sum(loss)
    total_mask_elements = tf.reduce_sum(mask)
    mean_loss = tf.math.divide_no_nan(total_loss, total_mask_elements)

    return mean_loss


In [105]:
# model
from tensorflow.keras import layers
MASK_TOKEN = -1.0
PAD_TOKEN = -2.0


def pmbind_subtract_moe_auto(max_pep_len: int,
                               max_mhc_len: int,
                               emb_dim: int = 96,
                               heads: int = 4,
                               noise_std: float = 0.1,
                               num_experts: int = 30,
                               mask_token: float = MASK_TOKEN,
                               pad_token: float = PAD_TOKEN):
    """
    Builds a pMHC autoencoder model with a Mixture-of-Experts (MoE) classifier head.

    This model performs two tasks:
    1. Autoencoding: Reconstructs peptide and MHC sequences from a latent representation.
    2. Classification: Predicts a binary label using an MoE head, where experts are
       selected based on an internally generated clustering of the latent space.
    """
    # -------------------------------------------------------------------
    # INPUTS
    # -------------------------------------------------------------------
    pep_OHE_in = keras.Input((max_pep_len, 21), name="pep_onehot")
    pep_mask_in = keras.Input((max_pep_len,), name="pep_mask")
    mhc_emb_in = keras.Input((max_mhc_len, 1152), name="mhc_emb")
    mhc_mask_in = keras.Input((max_mhc_len,), name="mhc_mask")
    mhc_OHE_in = keras.Input((max_mhc_len, 21), name="mhc_onehot")
    # -------------------------------------------------------------------
    # MASKED  EMBEDDING  +  PE
    # -------------------------------------------------------------------
    pep = MaskedEmbedding(mask_token, pad_token, name="pep_mask2")(pep_OHE_in, pep_mask_in)
    pep = PositionalEncoding(21, int(max_pep_len * 3), name="pep_pos1")(pep, pep_mask_in)
    pep = layers.Dense(emb_dim, name="pep_Dense1")(pep)
    pep = layers.Dropout(0.1, name="pep_Dropout1")(pep)

    mhc = MaskedEmbedding(mask_token, pad_token, name="mhc_mask2")(mhc_emb_in, mhc_mask_in)
    mhc = PositionalEncoding(1152, int(max_mhc_len * 3), name="mhc_pos1")(mhc, mhc_mask_in)
    mhc = layers.Dense(emb_dim, name="mhc_dense1")(mhc)
    mhc = layers.Dropout(0.1, name="mhc_Dropout1")(mhc)
    # -------------------------------------------------------------------
    # Subtract Layer
    # -------------------------------------------------------------------
    mhc_subtracted_p = SubtractLayer(name="pmhc_subtract")(pep, pep_mask_in, mhc, mhc_mask_in) # (B, M, P*D) = mhc_expanded – peptide_expanded
    #tf.print("mhc_subtracted_p shape:", mhc_subtracted_p.shape)
    # -------------------------------------------------------------------
    # Add Gaussian Noise
    # -------------------------------------------------------------------
    mhc_subtracted_p = AddGaussianNoise(noise_std, name="pmhc_gaussian_noise")(mhc_subtracted_p)
    query_dim = int(emb_dim*max_pep_len)
    # # -------------------------------------------------------------------
    # Normal Self-Attention Layer
    # # -------------------------------------------------------------------
    mhc_subtracted_p_attn, mhc_subtracted_p_attn_scores = AttentionLayer(
        query_dim=query_dim, context_dim=query_dim, output_dim=query_dim,
        type="self", heads=heads, resnet=True,
        return_att_weights=True, name='mhc_subtracted_p_attn',
        mask_token=mask_token,
        pad_token=pad_token
    )(mhc_subtracted_p, mhc_mask_in)
    peptide_cross_att, peptide_cross_attn_scores = AttentionLayer(
        query_dim=int(emb_dim), context_dim=query_dim, output_dim=int(emb_dim),
        type="cross", heads=heads, resnet=False,
        return_att_weights=True, name='peptide_cross_att',
        mask_token=mask_token,
        pad_token=pad_token
    )(pep, pep_mask_in, mhc_subtracted_p_attn, mhc_mask_in)

    # --- Encoder ---
    latent_sequence = layers.Dense(emb_dim*max_pep_len * 2, activation='relu', name='latent_mhc_dense1')(mhc_subtracted_p_attn)
    latent_sequence = layers.Dropout(0.2, name='latent_mhc_dropout1')(latent_sequence)
    latent_sequence = layers.Dense(emb_dim, activation='relu', name='cross_latent')(latent_sequence) # Shape: (B, M, D)

    # --- Latent Vector for Clustering (pooled) ---
    latent_vector = layers.GlobalAveragePooling1D(name="gap_latent")(latent_sequence) # Shape: (B, D)
    latent_vector = layers.Dense(emb_dim * 2, activation='relu', name='latent_dense2')(latent_vector)
    latent_vector = layers.Dropout(0.2, name='latent_vector_dropout')(latent_vector)
    latent_vector = layers.Dense(emb_dim, activation='relu', name='latent_vector_output')(latent_vector) # Shape: (B, D)

    # --- Reconstruction Heads ---
    mhc_recon_head = layers.Dropout(0.2, name='latent_mhc_dropout2')(latent_sequence)
    mhc_recon = layers.Dense(21, activation='softmax', name='mhc_reconstruction_pred')(mhc_recon_head)
    pep_recon = layers.Dense(emb_dim, activation='relu', name='pep_latent')(peptide_cross_att)
    pep_recon = layers.Dense(21, activation='softmax', name='pep_reconstruction_pred')(pep_recon)

    pep_out = layers.Concatenate(name='pep_ytrue_ypred', axis=-1)([pep_OHE_in, pep_recon]) #(B,P,42)
    mhc_out = layers.Concatenate(name='mhc_ytrue_ypred', axis=-1)([mhc_OHE_in, mhc_recon]) #(B,M,42)

    # -------------------------------------------------------------------
    # CLASSIFIER HEAD (MIXTURE OF EXPERTS)
    # -------------------------------------------------------------------
    # 1. Gating network: Generate soft cluster assignments from the latent vector
    bigger_probs = layers.Dense(num_experts * 2, activation='relu', name='gating_network_dense1')(latent_vector)
    bigger_probs = layers.Dropout(0.2, name='gating_network_dropout1')(bigger_probs)
    soft_cluster_probs = layers.Dense(num_experts, activation='softmax', name='gating_network_softmax')(bigger_probs)

    # 2. MoE layer: Get weighted prediction from experts
    moe_layer = EnhancedMixtureOfExperts(
        input_dim=emb_dim,
        hidden_dim=emb_dim // 2,
        num_experts=num_experts,
        output_dim=1,
        dropout_rate=0.2
    )
    y_pred = moe_layer((latent_vector, soft_cluster_probs))
    y_pred = layers.Activation('sigmoid', name='cls_ypred')(y_pred)

    # -------------------------------------------------------------------
    # MODEL DEFINITION
    # -------------------------------------------------------------------
    model = keras.Model(
        inputs=[pep_OHE_in, pep_mask_in, mhc_emb_in, mhc_mask_in, mhc_OHE_in],
        outputs={
            "pep_ytrue_ypred": pep_out,
            "mhc_ytrue_ypred": mhc_out,
            "cls_ypred": y_pred,
        },
        name="pmbind_subtract_moe_autoencoder"
    )

    return model

In [106]:
# ==============================================================================
# 3. DUMMY DATA GENERATION (for a runnable example)
# ==============================================================================
num_samples = 1000
pep_len = 14
mhc_len = 312
pep_alphabet_size = 21
mhc_embedding_dim = 1152

# Peptide One-Hot Encoding
pep_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, pep_len))
pep_OHE = tf.one_hot(pep_indices, depth=pep_alphabet_size, dtype=tf.float32)

# MHC Embedding
mhc_emb = tf.random.normal(shape=(num_samples, mhc_len, mhc_embedding_dim))
mhc_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, mhc_len))
mhc_OHE = tf.one_hot(mhc_indices, depth=pep_alphabet_size, dtype=tf.float32)


# Peptide mask
# Define pad and mask tokens
pad_token = -2.0
mask_token = -1.0

# Generate peptide mask
# Let's assume a variable number of padded/masked tokens per sample
pep_mask = np.ones((num_samples, pep_len), dtype=np.float32)
for i in range(num_samples):
    # Add 0 to 3 padded tokens at the end
    num_padded = np.random.randint(0, 4)
    if num_padded > 0:
        pep_mask[i, -num_padded:] = pad_token
    # Add 0 to 2 masked tokens at random positions that are not padded
    num_masked = np.random.randint(0, 3)
    valid_indices = np.where(pep_mask[i] == 1.0)[0]
    if num_masked > 0 and len(valid_indices) >= num_masked:
        mask_indices = np.random.choice(valid_indices, num_masked, replace=False)
        pep_mask[i, mask_indices] = mask_token

pep_mask = tf.constant(pep_mask, dtype=tf.float32)

# Generate MHC mask (similar logic)
mhc_mask = np.ones((num_samples, mhc_len), dtype=np.float32)
for i in range(num_samples):
    # Add 0 to 10 padded tokens
    num_padded = np.random.randint(0, 11)
    if num_padded > 0:
        mhc_mask[i, -num_padded:] = pad_token
mhc_mask = tf.constant(mhc_mask, dtype=tf.float32)


# Zero out masked positions in the data
# Create boolean masks where True means the position is valid (not padded or masked)
pep_bool_mask = tf.cast(pep_mask > 0, dtype=tf.float32)
mhc_bool_mask = tf.cast(mhc_mask > 0, dtype=tf.float32)

# Apply masks to zero out data. Unsqueeze to allow broadcasting.
pep_OHE = pep_OHE * pep_bool_mask[:, :, tf.newaxis]
mhc_emb = mhc_emb * mhc_bool_mask[:, :, tf.newaxis]

# Classification labels
y_true = tf.cast(tf.random.uniform((num_samples, 1)) > 0.5, tf.float32)

# Group inputs and outputs into dictionaries for tf.data.Dataset
inputs = {
    "pep_onehot": pep_OHE,
    "pep_mask": pep_mask,
    "mhc_emb": mhc_emb,
    "mhc_mask": mhc_mask,
    "mhc_onehot": mhc_OHE
}
targets = {
    "pep_ytrue_ypred": pep_OHE,
    "mhc_ytrue_ypred": mhc_OHE,
    "cls_ypred": y_true
}

In [109]:

# ==============================================================================
# 4. CUSTOM TRAINING LOOP SETUP
# ==============================================================================
# --- Hyperparameters and Setup ---
epochs = 10
batch_size = 32
learning_rate = 1e-3
emb_dim = 96
heads = 4
noise_std = 0.1
num_experts = 8

# --- Instantiate Model, Optimizer, and Loss ---
model = pmbind_subtract_moe_auto(
    max_pep_len=pep_len,
    max_mhc_len=mhc_len,
    emb_dim=emb_dim,
    heads=heads,
    noise_std=noise_std,
    num_experts=num_experts,
    mask_token=MASK_TOKEN,
    pad_token=PAD_TOKEN
)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
binary_loss_fn = tf.keras.losses.BinaryCrossentropy()

# compile the model
# model.compile(
#     optimizer=optimizer,
#     loss={
#         "pep_ytrue_ypred": masked_categorical_crossentropy,
#         "mhc_ytrue_ypred": masked_categorical_crossentropy,
#         "cls_ypred": binary_loss_fn
#     },
# )

# --- Metrics for Tracking ---
metrics_names = ['loss', 'pep_recon_loss', 'mhc_recon_loss', 'class_loss', 'auc']
train_metrics = {name: tf.keras.metrics.Mean(name=f"train_{name}") for name in metrics_names}
val_metrics = {name: tf.keras.metrics.Mean(name=f"val_{name}") for name in metrics_names}
train_metrics['auc'] = tf.keras.metrics.AUC(name='train_auc')
val_metrics['auc'] = tf.keras.metrics.AUC(name='val_auc')

# --- Prepare tf.data.Dataset for efficient training ---
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)).shuffle(num_samples)
val_size = int(0.2 * num_samples)
train_dataset = dataset.skip(val_size).batch(batch_size, drop_remainder=True)
val_dataset = dataset.take(val_size).batch(batch_size, drop_remainder=True)

# ==============================================================================
# 5. TRAIN AND VALIDATION STEPS (using tf.function for performance)
# ==============================================================================
@tf.function
def train_step(x_batch, y_batch):
    x_batch_list = [x_batch['pep_onehot'], x_batch['pep_mask'], x_batch['mhc_emb'], x_batch['mhc_mask'], x_batch['mhc_onehot']]

    with tf.GradientTape() as tape:
        predictions = model(x_batch_list, training=True)
        # Calculate individual losses
        pep_loss = masked_categorical_crossentropy(predictions['pep_ytrue_ypred'], x_batch['pep_mask'], PAD_TOKEN)
        mhc_loss = masked_categorical_crossentropy(predictions['mhc_ytrue_ypred'], x_batch['mhc_mask'], PAD_TOKEN)
        class_loss = binary_loss_fn(y_batch['cls_ypred'], predictions['cls_ypred'])

        # Combine losses (you can apply weights here, e.g., total_loss = 0.5*pep_loss + ...)
        total_loss = pep_loss + mhc_loss + class_loss

    # Apply gradients
    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Update metrics
    train_metrics['loss'](total_loss)
    train_metrics['pep_recon_loss'](pep_loss)
    train_metrics['mhc_recon_loss'](mhc_loss)
    train_metrics['class_loss'](class_loss)
    train_metrics['auc'](y_batch['cls_ypred'], predictions['cls_ypred'])

@tf.function
def val_step(x_batch, y_batch):
    x_batch_list = [x_batch['pep_onehot'], x_batch['pep_mask'], x_batch['mhc_emb'], x_batch['mhc_mask'], x_batch['mhc_onehot']]

    predictions = model(x_batch_list, training=False)
    # Calculate losses
    pep_loss = masked_categorical_crossentropy(predictions['pep_ytrue_ypred'], x_batch['pep_mask'], PAD_TOKEN)
    mhc_loss = masked_categorical_crossentropy(predictions['mhc_ytrue_ypred'], x_batch['mhc_mask'], PAD_TOKEN)
    class_loss = binary_loss_fn(y_batch['cls_ypred'], predictions['cls_ypred'])
    total_loss = pep_loss + mhc_loss + class_loss

    # Update metrics
    val_metrics['loss'](total_loss)
    val_metrics['pep_recon_loss'](pep_loss)
    val_metrics['mhc_recon_loss'](mhc_loss)
    val_metrics['class_loss'](class_loss)
    val_metrics['auc'](y_batch['cls_ypred'], predictions['cls_ypred'])

# ==============================================================================
# 6. THE MAIN TRAINING LOOP
# ==============================================================================
history = {f"{key}": [] for key in train_metrics.keys()}
history.update({f"val_{key}": [] for key in val_metrics.keys()})

print("Starting training...")
for epoch in range(epochs):
    # Reset metrics at the start of each epoch
    for metric in train_metrics.values(): metric.reset_state()
    for metric in val_metrics.values(): metric.reset_state()

    # Training loop
    for x_batch, y_batch in train_dataset:
        train_step(x_batch, y_batch)

    # Validation loop
    for x_val_batch, y_val_batch in val_dataset:
        val_step(x_val_batch, y_val_batch)

    # Log results
    train_results = {key: value.result().numpy() for key, value in train_metrics.items()}
    val_results = {key: value.result().numpy() for key, value in val_metrics.items()}

    # Store history
    for key, value in train_results.items(): history[key].append(value)
    for key, value in val_results.items(): history[f"val_{key}"].append(value)

    print(f"Epoch {epoch+1}/{epochs} - "
          f"Loss: {train_results['loss']:.4f} - "
          f"AUC: {train_results['auc']:.4f} - "
          f"Val Loss: {val_results['loss']:.4f} - "
          f"Val AUC: {val_results['auc']:.4f}")

print("Training finished.")

# ==============================================================================
# 7. VISUALIZATION
# ==============================================================================
plt.style.use('seaborn-v0_8-whitegrid')
plt.figure(figsize=(20, 5))

# Plot total loss
plt.subplot(1, 4, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Total Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot peptide reconstruction loss
plt.subplot(1, 4, 2)
plt.plot(history['pep_recon_loss'], label='Peptide Recon Loss')
plt.plot(history['val_pep_recon_loss'], label='Val Peptide Recon Loss')
plt.title('Peptide Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot MHC reconstruction loss
plt.subplot(1, 4, 3)
plt.plot(history['mhc_recon_loss'], label='MHC Recon Loss')
plt.plot(history['val_mhc_recon_loss'], label='Val MHC Recon Loss')
plt.title('MHC Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot AUC
plt.subplot(1, 4, 4)
plt.plot(history['auc'], label='Training AUC')
plt.plot(history['val_auc'], label='Validation AUC')
plt.title('Classification AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()

plt.tight_layout()
plt.show()

Starting training...


2025-08-19 13:28:14.233361: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-08-19 13:28:19.190200: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1/10 - Loss: 6.5618 - AUC: 0.5175 - Val Loss: 6.2262 - Val AUC: 0.5000


2025-08-19 13:29:21.681634: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2/10 - Loss: 5.3287 - AUC: 0.5093 - Val Loss: 4.1891 - Val AUC: 0.5000


KeyboardInterrupt: 