# BindHack: Antibody-Antigen Binding Prediction

## The Problem

Predict **binding affinity** between antibodies and antigens from amino acid sequences.

---

**Input:**
- Antibody heavy + light chain sequences
- Antigen sequence(s)

**Output:**
- Binding score (higher = stronger binding)

**Approach:**
- Extract features from sequences
- Train baseline ML model
- Identify improvement opportunities

---

## 1. Setup and Data Loading

In [None]:
import polars as pl
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.ensemble import RandomForestRegressor

# Set random seed for reproducibility
np.random.seed(42)

In [3]:
# Load the training data
# The data has already been prepared from the AbiBench dataset
train_df = pl.read_csv("../data/train_data.csv")

print(f"Training data shape: {train_df.shape}")
print(f"\nColumns: {train_df.columns}")
print("\nFirst few rows:")
print(train_df.head())

Training data shape: (148065, 5)

Columns: ['id', 'heavy_chain_sequence', 'light_chain_sequence', 'antigen_sequences', 'binding_score']

First few rows:
shape: (5, 5)
┌────────────────────┬────────────────────┬────────────────────┬───────────────────┬───────────────┐
│ id                 ┆ heavy_chain_sequen ┆ light_chain_sequen ┆ antigen_sequences ┆ binding_score │
│ ---                ┆ ce                 ┆ ce                 ┆ ---               ┆ ---           │
│ str                ┆ ---                ┆ ---                ┆ str               ┆ f64           │
│                    ┆ str                ┆ str                ┆                   ┆               │
╞════════════════════╪════════════════════╪════════════════════╪═══════════════════╪═══════════════╡
│ 4FQI_HLAB_SH24A,DH ┆ QVQLVQSGAEVKKPGSSV ┆ SALTQPPAVSGTPGQRVT ┆ GLFGAIAGFIEGGWQGM ┆ 6.0           │
│ 46E,SH52I,SH…      ┆ KVSCKASGGTSN…      ┆ ISCSGSDSNIGR…      ┆ VDGWYGYHHSNEQ…    ┆               │
│ 4FQI_HLAB_SH24A,DH ┆ QV

---

## 2. The Data

**Dataset:** Pre-processed binding assay results

**Key Columns:**
- `heavy_chain_sequence` / `light_chain_sequence` - Antibody sequences
- `antigen_sequences` - Target protein(s)
- `binding_score` - Our prediction target (ΔG)

**What matters:** Higher binding score = stronger interaction = better therapeutic candidate

In [None]:
# Basic statistics about the data
print("=== Dataset Statistics ===\n")
print(f"Number of samples: {len(train_df):,}")
print("\nBinding score statistics:")
print(train_df["binding_score"].describe())

# Check for missing values
print("\nMissing values per column:")
print(train_df.null_count())

In [None]:
# Visualize the distribution of binding scores
plt.figure(figsize=(12, 6))
plt.hist(
    train_df["binding_score"].to_numpy(),
    bins=50,
    edgecolor="black",
    alpha=0.7,
    color="steelblue",
)
plt.axvline(
    train_df["binding_score"].mean(),
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Mean: {train_df['binding_score'].mean():.2f}",
)
plt.xlabel("Binding Score (ΔG)", fontsize=12)
plt.ylabel("Frequency", fontsize=12)
plt.title("Distribution of Binding Scores", fontsize=14, fontweight="bold")
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---

## 3. Feature Engineering

**Challenge:** Convert amino acid sequences into numbers

**Our Features:**
- **Composition:** Percentage of each amino acid (20 features × 3 sequences)
- **Physicochemical:** Hydrophobic, polar, charged, aromatic percentages

**Result:** 72 features per sample (60 composition + 12 physicochemical)

In [None]:
def get_amino_acid_composition(sequence):
    """
    Calculate the percentage of each amino acid in a sequence.
    Returns a dictionary with amino acid counts.
    """
    if sequence is None or sequence == "":
        return {aa: 0 for aa in "ACDEFGHIKLMNPQRSTVWY"}

    aa_count = {aa: 0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
    total = len(sequence)

    for aa in sequence:
        if aa in aa_count:
            aa_count[aa] += 1

    # Convert to percentages
    aa_composition = {
        aa: (count / total * 100) if total > 0 else 0 for aa, count in aa_count.items()
    }

    return aa_composition


def get_physicochemical_features(sequence):
    """
    Calculate basic physicochemical properties of a sequence.
    """
    if sequence is None or sequence == "":
        return {
            "hydrophobic_percent": 0,
            "polar_percent": 0,
            "charged_percent": 0,  # can we improve this?
            "aromatic_percent": 0,
        }

    hydrophobic = "AILMFWYV"
    polar = "STNQ"
    charged = "DEKR"
    aromatic = "FWY"

    total = len(sequence)

    return {
        "hydrophobic_percent": sum(1 for aa in sequence if aa in hydrophobic)
        / total
        * 100
        if total > 0
        else 0,
        "polar_percent": sum(1 for aa in sequence if aa in polar) / total * 100
        if total > 0
        else 0,
        "charged_percent": sum(1 for aa in sequence if aa in charged) / total * 100
        if total > 0
        else 0,
        "aromatic_percent": sum(1 for aa in sequence if aa in aromatic) / total * 100
        if total > 0
        else 0,
    }

### Alternative: ESM-8M Embeddings

Now let's try a more sophisticated approach using protein language models.

**ESM (Evolutionary Scale Modeling):**
- Pre-trained on millions of protein sequences
- Captures evolutionary and structural information
- State-of-the-art for many protein tasks

We'll use the smallest model (ESM2-8M) for speed.

In [40]:
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, EsmModel

# Load ESM2-8M model (smallest, fastest)
model_name = "facebook/esm2_t6_8M_UR50D"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model = esm_model.to(device)
esm_model.eval()

print(f"Model loaded on {device}")

Loading facebook/esm2_t6_8M_UR50D...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cpu


In [None]:
def get_esm_embeddings_batch(sequences, model, tokenizer, device, batch_size=32):
    """
    Extract ESM embeddings for a batch of sequences with progress tracking.
    Returns a list of mean-pooled embeddings.
    """
    all_embeddings = []

    num_batches = (len(sequences) + batch_size - 1) // batch_size

    for i in tqdm(
        range(0, len(sequences), batch_size),
        total=num_batches,
        desc="Processing batches",
    ):
        batch = sequences[i : i + batch_size]

        # Tokenize batch with padding
        inputs = tokenizer(
            batch, return_tensors="pt", padding=True, truncation=True, max_length=1024
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get embeddings
        with torch.no_grad():
            outputs = model(**inputs)

            # Mean pool over sequence length for each sequence in batch
            for j in range(len(batch)):
                # Get attention mask to exclude padding
                mask = inputs["attention_mask"][j]
                # Exclude [CLS] and [SEP] tokens (first and last)
                seq_len = mask.sum().item()
                emb = outputs.last_hidden_state[j, 1 : seq_len - 1].mean(dim=0)
                all_embeddings.append(emb.cpu().numpy())

    return all_embeddings


# Test on sample sequences
sample_seqs = train_df["heavy_chain_sequence"][:5].to_list()
sample_embs = get_esm_embeddings_batch(
    sample_seqs, esm_model, tokenizer, device, batch_size=2
)
print(f"\nBatch processing works! Embedding dimension: {sample_embs[0].shape[0]}")

Processing batches:   0%|          | 0/3 [00:00<?, ?it/s]


Batch processing works! Embedding dimension: 320


In [None]:
print("Extracting ESM embeddings for all sequences...")
print("Using batched processing for efficiency...\n")

batch_size = 32

# Extract sequences
heavy_seqs = train_df["heavy_chain_sequence"].to_list()
light_seqs = train_df["light_chain_sequence"].to_list()
antigen_seqs = train_df["antigen_sequences"].to_list()

# Process each type of sequence in batches
print("Processing heavy chains:")
heavy_embs = get_esm_embeddings_batch(
    heavy_seqs, esm_model, tokenizer, device, batch_size
)

print("\nProcessing light chains:")
light_embs = get_esm_embeddings_batch(
    light_seqs, esm_model, tokenizer, device, batch_size
)

print("\nProcessing antigens:")
antigen_embs = get_esm_embeddings_batch(
    antigen_seqs, esm_model, tokenizer, device, batch_size
)

# Concatenate embeddings for each sample
print("\nCombining embeddings...")
X_esm = np.array(
    [np.concatenate([h, l, a]) for h, l, a in zip(heavy_embs, light_embs, antigen_embs)]
)
y_esm = train_df["binding_score"].to_numpy()

print(f"\nESM feature matrix shape: {X_esm.shape}")
print(f"Features per sequence: {heavy_embs[0].shape[0]}")
print(f"Total features: {X_esm.shape[1]} (320 × 3 sequences)")

Extracting ESM embeddings for all sequences...
Using batched processing for efficiency...

Processing heavy chains:


Processing batches:   0%|          | 0/4628 [00:00<?, ?it/s]


Processing light chains:


Processing batches:   0%|          | 0/4628 [00:00<?, ?it/s]


Processing antigens:


Processing batches:   0%|          | 0/4628 [00:00<?, ?it/s]


Combining embeddings...

ESM feature matrix shape: (148065, 960)
Features per sequence: 320
Total features: 960 (320 × 3 sequences)


### Train Model with ESM Features

Let's see if ESM embeddings improve over our simple composition features.

In [None]:
# Split ESM features
X_esm_train, X_esm_val, y_esm_train, y_esm_val = train_test_split(
    X_esm, y_esm, test_size=0.2, random_state=42
)

print(f"Training set size: {len(X_esm_train)}")
print(f"Validation set size: {len(X_esm_val)}")

# Train Random Forest with ESM features
print("\nTraining Random Forest with ESM embeddings...")
model_esm = RandomForestRegressor(n_estimators=1000, random_state=42, n_jobs=-1)
model_esm.fit(X_esm_train, y_esm_train)

# Make predictions
y_esm_train_pred = model_esm.predict(X_esm_train)
y_esm_val_pred = model_esm.predict(X_esm_val)

print("Model trained!")

### Compare Results: Composition vs ESM

How much do protein language models help?

In [None]:
# Calculate ESM model metrics
esm_train_rmse = np.sqrt(mean_squared_error(y_esm_train, y_esm_train_pred))
esm_val_rmse = np.sqrt(mean_squared_error(y_esm_val, y_esm_val_pred))

esm_train_mae = mean_absolute_error(y_esm_train, y_esm_train_pred)
esm_val_mae = mean_absolute_error(y_esm_val, y_esm_val_pred)

esm_train_r2 = r2_score(y_esm_train, y_esm_train_pred)
esm_val_r2 = r2_score(y_esm_val, y_esm_val_pred)

esm_train_spearman = spearmanr(y_esm_train, y_esm_train_pred).correlation
esm_val_spearman = spearmanr(y_esm_val, y_esm_val_pred).correlation

# Comparison table
print("=" * 70)
print(f"{'Metric':<20} {'Composition':<22} {'ESM-8M':<22}")
print("=" * 70)
print(f"{'Validation RMSE':<20} {val_rmse:<22.4f} {esm_val_rmse:<22.4f}")
print(f"{'Validation MAE':<20} {val_mae:<22.4f} {esm_val_mae:<22.4f}")
print(f"{'Validation R²':<20} {val_r2:<22.4f} {esm_val_r2:<22.4f}")
print(f"{'Validation Spearman':<20} {val_spearman:<22.4f} {esm_val_spearman:<22.4f}")
print("=" * 70)

# Calculate improvement
r2_improvement = ((esm_val_r2 - val_r2) / val_r2) * 100
print(f"\nR² improvement: {r2_improvement:+.1f}%")
print(f"MAE improvement: {((val_mae - esm_val_mae) / val_mae) * 100:+.1f}%")

In [None]:
# Visualize ESM predictions
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Composition model
axes[0].scatter(
    y_val,
    y_val_pred,
    alpha=0.2,
    s=15,
    color="coral",
    edgecolors="none",
    label="Composition",
)
axes[0].plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--", lw=2.5)
axes[0].set_xlabel("Actual Binding Score", fontsize=12)
axes[0].set_ylabel("Predicted Binding Score", fontsize=12)
axes[0].set_title(
    f"Composition Features (R² = {val_r2:.3f})", fontsize=13, fontweight="bold"
)
axes[0].grid(True, alpha=0.3)

# ESM model
axes[1].scatter(
    y_esm_val,
    y_esm_val_pred,
    alpha=0.2,
    s=15,
    color="steelblue",
    edgecolors="none",
    label="ESM-8M",
)
axes[1].plot(
    [y_esm_val.min(), y_esm_val.max()],
    [y_esm_val.min(), y_esm_val.max()],
    "r--",
    lw=2.5,
)
axes[1].set_xlabel("Actual Binding Score", fontsize=12)
axes[1].set_ylabel("Predicted Binding Score", fontsize=12)
axes[1].set_title(
    f"ESM-8M Embeddings (R² = {esm_val_r2:.3f})", fontsize=13, fontweight="bold"
)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
print("Extracting features from sequences...")
print("This may take a few minutes for the full dataset...")

features_list = []

for row in train_df.iter_rows(named=True):
    heavy_seq = row["heavy_chain_sequence"]
    light_seq = row["light_chain_sequence"]
    antigen_seq = row["antigen_sequences"]

    feature_dict = {}

    # Amino acid composition (20 features × 3 sequences)
    for aa, pct in get_amino_acid_composition(heavy_seq).items():
        feature_dict[f"heavy_{aa}"] = pct
    for aa, pct in get_amino_acid_composition(light_seq).items():
        feature_dict[f"light_{aa}"] = pct
    for aa, pct in get_amino_acid_composition(antigen_seq).items():
        feature_dict[f"antigen_{aa}"] = pct

    # Physicochemical properties (4 features × 3 sequences)
    for prop, val in get_physicochemical_features(heavy_seq).items():
        feature_dict[f"heavy_{prop}"] = val
    for prop, val in get_physicochemical_features(light_seq).items():
        feature_dict[f"light_{prop}"] = val
    for prop, val in get_physicochemical_features(antigen_seq).items():
        feature_dict[f"antigen_{prop}"] = val

    features_list.append(feature_dict)

X = np.array([[v for v in feat.values()] for feat in features_list])
y = train_df["binding_score"].to_numpy()

print(f"\nFeature matrix shape: {X.shape}")
print(f"Target vector shape: {y.shape}")

---

## 4. Train-Validation Split

**80% training** | **20% validation**

Goal: Measure generalization to unseen antibody-antigen pairs

In [14]:
# Split the data (80% train, 20% validation)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Training set size: {len(X_train)}")
print(f"Validation set size: {len(X_val)}")

Training set size: 118452
Validation set size: 29613


---

## 5. Baseline Model: Random Forest

**Why Random Forest?**
- Non-linear, robust, interpretable
- 1000 trees for stable predictions

This is our baseline to beat.

In [24]:
print("Training Random Forest model...")

# Train the model
model = RandomForestRegressor(n_estimators=1000, random_state=42, n_jobs=-1)
model.fit(X_train, y_train)

# Make predictions
y_train_pred = model.predict(X_train)
y_val_pred = model.predict(X_val)

print("Model trained successfully!")

Training Random Forest model...
Model trained successfully!


---

## 6. Performance Metrics

**RMSE:** Average prediction error  
**MAE:** Average absolute error  
**R²:** Variance explained  
**Spearman's ρ:** Rank correlation

In [None]:
# Calculate metrics
train_rmse = np.sqrt(mean_squared_error(y_train, y_train_pred))
val_rmse = np.sqrt(mean_squared_error(y_val, y_val_pred))

train_mae = mean_absolute_error(y_train, y_train_pred)
val_mae = mean_absolute_error(y_val, y_val_pred)

train_r2 = r2_score(y_train, y_train_pred)
val_r2 = r2_score(y_val, y_val_pred)

train_spearman = spearmanr(y_train, y_train_pred).correlation
val_spearman = spearmanr(y_val, y_val_pred).correlation

# Visualize RMSE, MAE, R², and Spearman's ρ for training and validation sets
print("Training Set:")
print(f"  RMSE: {train_rmse:.4f}")
print(f"  MAE:  {train_mae:.4f}")
print(f"  R²:   {train_r2:.4f}")
print(f"  Spearman's ρ: {train_spearman:.4f}")
print("\nValidation Set:")
print(f"  RMSE: {val_rmse:.4f}")
print(f"  MAE:  {val_mae:.4f}")
print(f"  R²:   {val_r2:.4f}")
print(f"  Spearman's ρ: {val_spearman:.4f}")

In [None]:
# Visualize predictions vs actual values
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Training set
axes[0].scatter(
    y_train, y_train_pred, alpha=0.2, s=15, color="steelblue", edgecolors="none"
)
axes[0].plot(
    [y_train.min(), y_train.max()],
    [y_train.min(), y_train.max()],
    "r--",
    lw=2.5,
    label="Perfect Prediction",
)
axes[0].set_xlabel("Actual Binding Score", fontsize=12)
axes[0].set_ylabel("Predicted Binding Score", fontsize=12)
axes[0].set_title(f"Training Set (R² = {train_r2:.3f})", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Validation set
axes[1].scatter(y_val, y_val_pred, alpha=0.2, s=15, color="coral", edgecolors="none")
axes[1].plot(
    [y_val.min(), y_val.max()],
    [y_val.min(), y_val.max()],
    "r--",
    lw=2.5,
    label="Perfect Prediction",
)
axes[1].set_xlabel("Actual Binding Score", fontsize=12)
axes[1].set_ylabel("Predicted Binding Score", fontsize=12)
axes[1].set_title(f"Validation Set (R² = {val_r2:.3f})", fontsize=13, fontweight="bold")
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Analyze residuals (prediction errors)
train_residuals = y_train - y_train_pred
val_residuals = y_val - y_val_pred

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Training residuals
axes[0].scatter(
    y_train_pred, train_residuals, alpha=0.2, s=15, color="steelblue", edgecolors="none"
)
axes[0].axhline(y=0, color="r", linestyle="--", lw=2.5)
axes[0].set_xlabel("Predicted Binding Score", fontsize=12)
axes[0].set_ylabel("Residual (Actual - Predicted)", fontsize=12)
axes[0].set_title("Training Set Residuals", fontsize=13, fontweight="bold")
axes[0].grid(True, alpha=0.3)

# Validation residuals
axes[1].scatter(
    y_val_pred, val_residuals, alpha=0.2, s=15, color="coral", edgecolors="none"
)
axes[1].axhline(y=0, color="r", linestyle="--", lw=2.5)
axes[1].set_xlabel("Predicted Binding Score", fontsize=12)
axes[1].set_ylabel("Residual (Actual - Predicted)", fontsize=12)
axes[1].set_title("Validation Set Residuals", fontsize=13, fontweight="bold")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 7. Feature Importance

Which features drive predictions?
- Guides future feature engineering
- Identifies important amino acids/properties

In [31]:
# Get feature names
feature_names = list(features_list[0].keys())

# Get feature coefficients
coefficients = model.feature_importances_

# Sort by absolute value
feature_importance = sorted(
    zip(feature_names, np.abs(coefficients)), key=lambda x: x[1], reverse=True
)

In [None]:
# Visualize top features
top_n = 20
top_features = feature_importance[:top_n]
names, values = zip(*top_features)

plt.figure(figsize=(12, 8))
colors = [
    "steelblue" if "antigen" in name else "coral" if "heavy" in name else "lightgreen"
    for name in names
]
plt.barh(range(len(names)), values, color=colors)
plt.yticks(range(len(names)), names, fontsize=10)
plt.xlabel("Feature Importance", fontsize=12)
plt.title(f"Top {top_n} Most Important Features", fontsize=13, fontweight="bold")
plt.gca().invert_yaxis()
plt.grid(True, alpha=0.3, axis="x")
plt.tight_layout()
plt.show()

---

## 8. Next Steps

### Our Baseline Results
- **Validation R²:** 0.34 (explains 34% of variance)
- **Validation MAE:** ~1.35 binding score units
- Room for significant improvement

### Key Limitations
- **Features too simple:** Composition loses sequence order
- **No structural info:** Missing 3D binding interfaces
- **No interaction modeling:** Treats proteins independently

### How to Improve

**Better Features:**
- Protein language models (ESM-2, ProtBERT, AntiBERTy)
- k-mer features (di/tri-peptides)
- CDR region analysis
- Structural embeddings

**Better Models:**
- Gradient boosting (XGBoost, LightGBM, CatBoost)
- Deep learning with attention mechanisms
- Graph neural networks
- Ensemble methods

**Resources:**
- [ESM](https://github.com/facebookresearch/esm) - Protein language models
- [SAbDab](http://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/sabdab) - Antibody structure database
- [BioPython](https://biopython.org/) - Sequence analysis tools