In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Constants
PAD_TOKEN = -2.0
NORM_TOKEN = 1.0
PAD_VALUE = 0.0


def likelihood_with_attention_vectorized(nn_model, peptide, mhc, label,
                                         pep_mask, mhc_mask, max_pep_len,
                                         temperature=1.0, variable_len=8,
                                         sample_weight=None, training=True):
    """
    Computes the marginal binding probability by marginalizing over latent cores.

    Probabilistic Model:
        - z_i ∈ C_i: latent variable for true binding core
        - q_ic = P(z_i = c | S_i, MHC) = softmax(g(c, S_i, MHC) / τ)
        - P(bind | z_i = c, MHC) = σ(f(c, MHC))
        - P_ij = Σ_c q_ic · σ(f(c, MHC))  [marginalization]

    Model outputs expected:
        - 'core_logits': g(c, S_i, MHC) - core selection scores
        - 'binding_logits': f(c, MHC) - binding prediction scores

    Args:
        nn_model: Keras model with two output heads
        peptide: (B, P, D) peptide embeddings
        mhc: (B, M, D) MHC embeddings
        label: (B, 1) binary binding labels
        pep_mask: (B, P) peptide mask
        mhc_mask: (B, M) MHC mask
        max_pep_len: int, maximum peptide length
        temperature: float, softmax temperature for core selection
        variable_len: int, minimum core length (cores from variable_len to max_pep_len)
        sample_weight: (B, 1) or (B,) optional sample weights
        training: bool, training mode flag

    Returns:
        loss: scalar, mean negative log-likelihood
        attention_weights: (B, N) core selection probabilities q_ic
        binding_probs_per_core: (B, N) per-core binding probabilities σ(f_c)
        cores_stack: (B, N, P, D) padded core embeddings
    """
    B = tf.shape(peptide)[0]
    P = tf.shape(peptide)[1]
    D = tf.shape(peptide)[2]
    M = tf.shape(mhc)[1]

    # Define core lengths: [variable_len, variable_len+1, ..., max_pep_len]
    if variable_len is None:
        ks = [9]
    else:
        ks = range(variable_len, max_pep_len + 1)

    cores_list = []
    masks_list = []
    valid_flags_list = []

    # Extract all possible cores
    for k in ks:
        num_windows = max_pep_len - k + 1
        if num_windows <= 0:
            continue

        for i in range(num_windows):
            # Extract core: (B, k, D)
            core = peptide[:, i:i + k, :]

            # Check validity: core is valid if last position is not padding
            is_valid = tf.not_equal(pep_mask[:, i + k - 1], PAD_TOKEN)
            valid_flags_list.append(is_valid)

            # Pad core to max_pep_len: (B, max_pep_len, D)
            paddings = [[0, 0], [0, max_pep_len - k], [0, 0]]
            padded_core = tf.pad(core, paddings, "CONSTANT", constant_values=PAD_VALUE)
            cores_list.append(padded_core)

            # Create mask for this core length
            mask_vec = tf.concat([
                tf.fill([k], NORM_TOKEN),
                tf.fill([max_pep_len - k], PAD_TOKEN)
            ], axis=0)
            mask_batch = tf.tile(tf.expand_dims(mask_vec, 0), [B, 1])
            masks_list.append(mask_batch)

    # Handle edge case: no valid cores
    if not cores_list:
        return (tf.constant(0.0),
                tf.zeros((B, 0)),
                tf.zeros((B, 0)),
                tf.zeros((B, 0, max_pep_len, D)))

    # Stack all cores: (B, N, ...)
    cores_stack = tf.stack(cores_list, axis=1)  # (B, N, P, D)
    masks_stack = tf.stack(masks_list, axis=1)  # (B, N, P)
    valid_flags_stack = tf.stack(valid_flags_list, axis=1)  # (B, N)
    valid_flags_float = tf.cast(valid_flags_stack, tf.float32)

    N = tf.shape(cores_stack)[1]

    # Flatten for batch processing through model
    cores_flat = tf.reshape(cores_stack, [B * N, max_pep_len, D])
    masks_flat = tf.reshape(masks_stack, [B * N, max_pep_len])

    # Tile MHC for each core
    mhc_tiled = tf.tile(tf.expand_dims(mhc, 1), [1, N, 1, 1])
    mhc_flat = tf.reshape(mhc_tiled, [B * N, M, D])

    mhc_mask_tiled = tf.tile(tf.expand_dims(mhc_mask, 1), [1, N, 1])
    mhc_mask_flat = tf.reshape(mhc_mask_tiled, [B * N, M])

    # Dummy target (if model requires it)
    dummy_target = tf.zeros((B * N, max_pep_len, 21), dtype=tf.float32)

    # Forward pass
    inputs = [cores_flat, masks_flat, mhc_flat, mhc_mask_flat, dummy_target]
    outputs = nn_model(inputs, training=training)

    # Extract model outputs
    core_logits_flat = outputs['core_logits']  # g(c, S_i, MHC): (B*N, 1)
    binding_logits_flat = outputs['binding_logits']  # f(c, MHC): (B*N, 1)

    # Reshape to (B, N)
    core_logits = tf.reshape(core_logits_flat, [B, N])
    binding_logits = tf.reshape(binding_logits_flat, [B, N])

    # =========================================================
    # Core selection distribution: q_ic = softmax(g_c / τ)
    # =========================================================
    # Mask invalid cores with -inf before softmax
    core_logits_masked = tf.where(valid_flags_stack, core_logits, -1e9)
    attention_weights = tf.nn.softmax(core_logits_masked / temperature, axis=1)  # (B, N)

    # =========================================================
    # Per-core binding probability: σ(f(c, MHC))
    # =========================================================
    binding_probs_per_core = tf.nn.sigmoid(binding_logits)  # (B, N)
    # Mask invalid cores to 0 for clean output
    binding_probs_per_core = binding_probs_per_core * valid_flags_float

    # =========================================================
    # Marginal binding probability: P_ij = Σ_c q_ic · σ(f_c)
    # =========================================================
    pred_prob = tf.reduce_sum(attention_weights * binding_probs_per_core, axis=1)  # (B,)

    # =========================================================
    # Negative log-likelihood loss
    # =========================================================
    label_flat = tf.cast(tf.reshape(label, [B]), dtype=pred_prob.dtype)

    # Clip for numerical stability
    pred_prob_clipped = tf.clip_by_value(pred_prob, 1e-7, 1.0 - 1e-7)

    # Binary cross-entropy: -[y·log(p) + (1-y)·log(1-p)]
    loss = -label_flat * tf.math.log(pred_prob_clipped) \
           - (1.0 - label_flat) * tf.math.log(1.0 - pred_prob_clipped)

    # Apply sample weights if provided
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, loss.dtype)
        if len(sample_weight.shape) == 2:
            sample_weight = tf.squeeze(sample_weight, axis=1)
        loss = loss * sample_weight

    return tf.reduce_mean(loss), attention_weights, binding_probs_per_core, cores_stack

# Test suite
def run_tests():
    """Run comprehensive tests on the likelihood function."""
    print("="*70)
    print("TESTING LIKELIHOOD FUNCTION")
    print("="*70)

    # Define a test model with correct output heads
    def create_test_model(max_pep_len, max_mhc_len, embed_dim):
        peptide_input = keras.Input(shape=(max_pep_len, embed_dim), name='peptide')
        pep_mask_input = keras.Input(shape=(max_pep_len,), name='pep_mask')
        mhc_input = keras.Input(shape=(max_mhc_len, embed_dim), name='mhc')
        mhc_mask_input = keras.Input(shape=(max_mhc_len,), name='mhc_mask')
        dummy_target = keras.Input(shape=(max_pep_len, 21), name='dummy_target')

        pep_avg = layers.GlobalAveragePooling1D()(peptide_input)
        mhc_avg = layers.GlobalAveragePooling1D()(mhc_input)
        combined = layers.Concatenate()([pep_avg, mhc_avg])

        # Use all mask/dummy inputs (multiply by 0 so they don't affect output)
        dummy_effect = layers.Lambda(lambda x: tf.expand_dims(
            tf.reduce_mean(x[0]) * 0.0 + tf.reduce_mean(x[1]) * 0.0 + tf.reduce_mean(x[2]) * 0.0,
            -1
        ))([dummy_target, pep_mask_input, mhc_mask_input])

        # Two output heads as required
        core_logits = layers.Dense(1, name='core_logits')(combined)
        core_logits = layers.Add()([core_logits, dummy_effect])

        binding_logits = layers.Dense(1, name='binding_logits')(combined)

        return keras.Model(
            inputs=[peptide_input, pep_mask_input, mhc_input, mhc_mask_input, dummy_target],
            outputs={'core_logits': core_logits, 'binding_logits': binding_logits}
        )

    batch_size = 4
    max_pep_len = 15
    max_mhc_len = 50
    embed_dim = 128
    model = create_test_model(max_pep_len, max_mhc_len, embed_dim)

    # Test 1: Basic functionality with 9-mers
    print("\n[Test 1] Basic functionality (9-mers)")
    peptide = tf.random.normal((batch_size, max_pep_len, embed_dim))
    mhc = tf.random.normal((batch_size, max_mhc_len, embed_dim))
    label = tf.constant([[1.0], [0.0], [1.0], [0.0]])
    pep_mask = tf.concat([tf.fill([batch_size, 9], NORM_TOKEN),
                          tf.fill([batch_size, max_pep_len - 9], PAD_TOKEN)], axis=1)
    mhc_mask = tf.fill([batch_size, max_mhc_len], NORM_TOKEN)

    loss, attn, bind_probs, cores = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask, mhc_mask, max_pep_len, training=False)

    print(f"  Loss: {loss.numpy():.4f}")
    print(f"  Shapes - attn: {attn.shape}, bind_probs: {bind_probs.shape}, cores: {cores.shape}")
    print(f"  ✓ Loss is finite: {tf.math.is_finite(loss).numpy()}")
    print(f"  ✓ Loss >= 0: {loss.numpy() >= 0}")

    # Test 2: Attention weights sum to 1
    print("\n[Test 2] Attention weights properties")
    attn_sums = tf.reduce_sum(attn, axis=1).numpy()
    print(f"  Attention sums: {attn_sums}")
    print(f"  ✓ Sum to 1: {tf.reduce_all(tf.abs(attn_sums - 1.0) < 1e-5).numpy()}")
    print(f"  ✓ All in [0,1]: {tf.reduce_all((attn >= 0) & (attn <= 1)).numpy()}")

    # Test 3: Binding probabilities in [0, 1]
    print("\n[Test 3] Binding probabilities")
    print(f"  Min: {tf.reduce_min(bind_probs).numpy():.4f}, Max: {tf.reduce_max(bind_probs).numpy():.4f}")
    print(f"  ✓ All in [0,1]: {tf.reduce_all((bind_probs >= 0) & (bind_probs <= 1)).numpy()}")

    # Test 4: Variable length cores (8-11 mers)
    print("\n[Test 4] Variable length cores (8-11 mers)")
    pep_mask_11 = tf.concat([tf.fill([batch_size, 11], NORM_TOKEN),
                             tf.fill([batch_size, max_pep_len - 11], PAD_TOKEN)], axis=1)
    loss_var, attn_var, _, _ = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask_11, mhc_mask, max_pep_len,
        variable_len=8, training=False)

    expected_cores = sum(15 - k + 1 for k in range(8, 16))
    print(f"  Expected cores: {expected_cores}, Actual: {attn_var.shape[1]}")
    print(f"  ✓ Correct number of cores: {attn_var.shape[1] == expected_cores}")
    print(f"  ✓ Attention sums to 1: {tf.reduce_all(tf.abs(tf.reduce_sum(attn_var, axis=1) - 1.0) < 1e-5).numpy()}")

    # Test 5: Sample weights
    print("\n[Test 5] Sample weights")
    loss_no_wt, _, _, _ = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask, mhc_mask, max_pep_len,
        sample_weight=None, training=False)
    loss_zero_wt, _, _, _ = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask, mhc_mask, max_pep_len,
        sample_weight=tf.zeros((batch_size, 1)), training=False)

    print(f"  Loss (no weight): {loss_no_wt.numpy():.4f}")
    print(f"  Loss (zero weight): {loss_zero_wt.numpy():.4f}")
    print(f"  ✓ Zero weights give zero loss: {abs(loss_zero_wt.numpy()) < 1e-6}")

    # Test 6: Temperature effect
    print("\n[Test 6] Temperature effect on attention")
    _, attn_low, _, _ = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask, mhc_mask, max_pep_len,
        temperature=0.1, training=False)
    _, attn_high, _, _ = likelihood_with_attention_vectorized(
        model, peptide, mhc, label, pep_mask, mhc_mask, max_pep_len,
        temperature=10.0, training=False)

    max_low = tf.reduce_max(attn_low, axis=1).numpy()
    max_high = tf.reduce_max(attn_high, axis=1).numpy()
    print(f"  Max attention (T=0.1): {max_low}")
    print(f"  Max attention (T=10.0): {max_high}")
    print(f"  ✓ Low T more peaked: {tf.reduce_all(max_low > max_high).numpy()}")

    # Test 7: Gradient computation
    print("\n[Test 7] Gradient computation")
    peptide_var = tf.Variable(peptide)
    with tf.GradientTape() as tape:
        loss_grad, _, _, _ = likelihood_with_attention_vectorized(
            model, peptide_var, mhc, label, pep_mask, mhc_mask, max_pep_len, training=True)
    grads = tape.gradient(loss_grad, peptide_var)
    print(f"  Gradient shape: {grads.shape}")
    print(f"  ✓ Gradients exist: {grads is not None}")
    print(f"  ✓ Gradients finite: {tf.reduce_all(tf.math.is_finite(grads)).numpy()}")

    print("\n" + "="*70)
    print("ALL TESTS PASSED ✓")
    print("="*70)


if __name__ == "__main__":
    run_tests()


TESTING LIKELIHOOD FUNCTION

[Test 1] Basic functionality (9-mers)
  Loss: 0.6567
  Shapes - attn: (4, 36), bind_probs: (4, 36), cores: (4, 36, 15, 128)
  ✓ Loss is finite: True
  ✓ Loss >= 0: True

[Test 2] Attention weights properties
  Attention sums: [1. 1. 1. 1.]
  ✓ Sum to 1: True
  ✓ All in [0,1]: True

[Test 3] Binding probabilities
  Min: 0.0000, Max: 0.6471
  ✓ All in [0,1]: True

[Test 4] Variable length cores (8-11 mers)
  Expected cores: 36, Actual: 36
  ✓ Correct number of cores: True
  ✓ Attention sums to 1: True

[Test 5] Sample weights
  Loss (no weight): 0.6567
  Loss (zero weight): 0.0000
  ✓ Zero weights give zero loss: True

[Test 6] Temperature effect on attention
  Max attention (T=0.1): [0.7064812  0.40670326 0.5280339  0.62597156]
  Max attention (T=10.0): [0.3368348  0.33414948 0.3353532  0.336384  ]
  ✓ Low T more peaked: True

[Test 7] Gradient computation
  Gradient shape: (4, 15, 128)
  ✓ Gradients exist: True
  ✓ Gradients finite: True

ALL TESTS PASSED ✓