## üöÄ RTX 4090 Resource Estimates

### **Hardware Specs:**
- RTX 4090: 24 GB VRAM, 16,384 CUDA cores
- CUDA Compute Capability: 8.9
- Memory Bandwidth: 1,008 GB/s

### **Computational Requirements:**

#### **1. Model Loading**
- AntiBERTy (IgBert) model size: **~200 MB**
- GPU memory for model: **~500 MB** (with overhead)

#### **2. Embedding Generation (92,620 samples)**

**Memory Usage:**
- Batch size 16: **~2-3 GB VRAM**
- Batch size 32: **~4-5 GB VRAM**
- Batch size 64: **~8-10 GB VRAM** ‚ö†Ô∏è (watch for OOM)
- **Recommended: batch_size=32** (safe with 24GB VRAM)

**Processing Time Estimates:**
- **Heavy + Light chains**: 92,620 antibody pairs
  - Speed: ~200-300 sequences/second on RTX 4090
  - Time: **~5-8 minutes** (with batch_size=16-32)

- **Antigen sequences**: 92,620 antigens
  - Speed: ~200-300 sequences/second
  - Time: **~5-8 minutes** (with batch_size=16-32)

- **Total embedding extraction: ~10-16 minutes**

**Storage Requirements:**
- Embeddings shape: (92,620, 1024) = 512 antibody + 512 antigen
- Data type: float32 (4 bytes per value)
- Size: 92,620 √ó 1,024 √ó 4 bytes = **~380 MB**
- With overhead: **~500 MB disk space**

#### **3. Model Training**

**Random Forest (100 trees):**
- Training time: **~2-5 minutes** (CPU-based, uses all cores)
- Memory: **~2-4 GB RAM**

**PCA + LinearSVR:**
- Training time: **~1-2 minutes**
- Memory: **~1-2 GB RAM**

#### **4. Cross-Validation (5-fold, subset=5000)**
- Time per fold: **~30-60 seconds**
- Total: **~3-5 minutes**

#### **5. Hyperparameter Optimization (20 trials)**
- Time per trial: **~30-60 seconds**
- Total: **~10-20 minutes**

---

### **üìä TOTAL ESTIMATES:**

| Metric | Estimate |
|--------|----------|
| **Total Runtime** | **25-45 minutes** (full pipeline) |
| **Peak GPU Memory** | **4-6 GB VRAM** (batch_size=32) |
| **Peak RAM** | **8-12 GB** |
| **Disk Space (embeddings)** | **~500 MB** |
| **Disk Space (models)** | **~100 MB** |
| **Total Disk Space** | **~600 MB** |

---

### **‚ö° Optimization Tips:**

1. **Increase batch size to 32 or 64** - you have plenty of VRAM
   ```python
   batch_size=64  # Fast, safe on RTX 4090
   ```

2. **Use mixed precision (FP16)** for faster inference:
   ```python
   antiberty_model.half()  # Reduces memory by 50%
   ```

3. **Process full dataset** (no subsets needed):
   - CV on full 92k samples: add ~10-15 min
   - HPO with more trials (50-100): add ~20-40 min

4. **Save embeddings** to avoid recomputation:
   - First run: 25-45 min
   - Subsequent runs: **<5 min** (load embeddings from disk)

---

### **üéØ Expected with RTX 4090:**
- **Embedding extraction: 10-12 minutes** (batch_size=64)
- **Full pipeline: 30-35 minutes**
- **No OOM issues** (24GB is plenty)
- **Can run full dataset** (no need for subsets)

---

# BindHack: AntiBERTy Embeddings for Binding Prediction

**Antibody-specific language model for binding affinity prediction**

---

## The Plan

- Load antibody-antigen **data**
- Extract **AntiBERTy embeddings** (antibody-specific model)
- Train predictive **models**
- Evaluate and **validate** performance

---
```
Data ‚Üí AntiBERTy Embeddings ‚Üí Machine Learning ‚Üí Binding Predictions
```
---

## Why AntiBERTy?

**AntiBERTy is specifically designed for antibodies:**
- Pre-trained on 558M antibody sequences (vs general proteins)
- Understands CDR regions and antibody-specific patterns
- Smaller and faster than general protein models
- Better for antibody binding prediction tasks

**Model**: `Exscientia/IgBert` (AntiBERTy implementation)
- Embedding dimension: 512 per sequence
- Max sequence length: 512 tokens

---

## Setup & Imports

In [None]:
# Data manipulation
import polars as pl
import numpy as np

# Visualization
import seaborn as sns
import matplotlib.pyplot as plt

# Machine learning
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import LinearSVR
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy.stats import spearmanr

# Deep learning for embeddings
import torch
from transformers import AutoTokenizer, EsmModel
from tqdm.auto import tqdm

# Hyperparameter optimization
import optuna

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

print("All imports loaded")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load Data

In [None]:
# Load training data from AbiBench dataset
train_df = pl.read_csv("../data/train_data.csv.gz")

print(f"Dataset shape: {train_df.shape}")
print(f"\nColumns: {train_df.columns}")
print("\nFirst few rows:")
train_df.head()

In [None]:
# Check binding score distribution
print("Binding Score Statistics:")
print(train_df["binding_score"].describe())

# Visualize distribution
plt.figure(figsize=(10, 5))
plt.hist(
    train_df["binding_score"].to_numpy(),
    bins=50,
    edgecolor="black",
    alpha=0.7,
    color="steelblue",
)
plt.axvline(
    train_df["binding_score"].mean(),
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Mean: {train_df['binding_score'].mean():.2f}",
)
plt.xlabel("Binding Score (-ŒîG)", fontsize=12)
plt.ylabel("Frequency", fontsize=12)
plt.title("Distribution of Binding Scores", fontsize=14, fontweight="bold")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Load AntiBERTy Model

Loading the antibody-specific transformer model.

In [None]:
# Load AntiBERTy model (IgBert implementation)
antiberty_model_name = "Exscientia/IgBert"
print(f"Loading {antiberty_model_name}...")

antiberty_tokenizer = AutoTokenizer.from_pretrained(antiberty_model_name)
antiberty_model = EsmModel.from_pretrained(antiberty_model_name)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
antiberty_model = antiberty_model.to(device)
antiberty_model.eval()

print(f"\n‚úì AntiBERTy loaded on {device}")
print(f"  Embedding dimension: 512 (per sequence)")
print(f"  Max sequence length: 512 tokens")
print(f"\n‚ö†Ô∏è  Note: Antibody H+L chains combined can exceed 512 tokens and will be truncated")

# Print GPU memory info if available
if torch.cuda.is_available():
    print(f"\nüìä GPU Memory Status:")
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"   Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

## Embedding Extraction Function

This function extracts AntiBERTy embeddings with batching for efficiency.

In [None]:
def print_gpu_memory():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU Memory: {allocated:.2f}GB allocated / {reserved:.2f}GB reserved / {total:.2f}GB total")
        return allocated, reserved, total
    return 0, 0, 0

def clear_gpu_cache():
    """Clear GPU cache and run garbage collection"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    import gc
    gc.collect()
    
# Test the function
print("Initial GPU state:")
print_gpu_memory()

## GPU Memory Monitoring Utilities

Helper functions to track GPU memory usage during embedding extraction.

In [None]:
def get_antiberty_embeddings_batch(
    sequences, model, tokenizer, device, batch_size=32, max_length=512
):
    """
    Extract AntiBERTy embeddings for sequences with batching.
    Returns mean-pooled embeddings (512-dim per sequence).

    Args:
        sequences: List of protein sequences (strings)
        model: AntiBERTy model
        tokenizer: AntiBERTy tokenizer
        device: torch device (cuda/cpu)
        batch_size: Number of sequences to process at once
        max_length: Maximum sequence length (512 for AntiBERTy)

    Returns:
        List of numpy arrays, each of shape (512,)
    """
    all_embeddings = []
    num_batches = (len(sequences) + batch_size - 1) // batch_size

    # Track truncation statistics
    truncated_count = 0
    max_seq_len = 0

    for i in tqdm(
        range(0, len(sequences), batch_size),
        total=num_batches,
        desc="AntiBERTy batches",
    ):
        batch = sequences[i : i + batch_size]

        # Tokenize with padding and truncation
        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Check for truncation
        for j, seq in enumerate(batch):
            seq_len = len(tokenizer.encode(seq, add_special_tokens=False))
            max_seq_len = max(max_seq_len, seq_len)
            if seq_len > max_length - 2:  # Account for special tokens
                truncated_count += 1

        # Get embeddings
        with torch.no_grad():
            outputs = model(**inputs)

            # Mean pool over sequence length (exclude [CLS] and [SEP])
            for j in range(len(batch)):
                mask = inputs["attention_mask"][j]
                seq_len = mask.sum().item()
                emb = outputs.last_hidden_state[j, 1 : seq_len - 1].mean(dim=0)
                all_embeddings.append(emb.cpu().numpy())
        
        # Clear GPU cache to prevent memory buildup
        del outputs, inputs
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if truncated_count > 0:
        print(
            f"\n‚ö†Ô∏è  Warning: {truncated_count}/{len(sequences)} sequences were truncated to {max_length} tokens"
        )
        print(f"   Maximum sequence length encountered: {max_seq_len} tokens")

    return all_embeddings

## Extract AntiBERTy Embeddings

**Strategy**: Concatenate heavy + light chains for complete antibody representation
- Heavy + Light chains form the binding site together
- Process antigens separately
- Concatenate all embeddings into final feature vector

In [None]:
print("Extracting AntiBERTy embeddings...")
print("\nüìå Strategy: Concatenate heavy + light chains for each antibody")
print("   (This captures the complete binding site structure)\n")

# Combine heavy and light chains with a space separator
antibody_seqs = [
    f"{h} {l}"
    for h, l in zip(
        train_df["heavy_chain_sequence"].to_list(),
        train_df["light_chain_sequence"].to_list(),
    )
]

# Process antibody pairs (heavy+light)
# RTX 4090 optimization: batch_size=32 is safe, but reduce to 16 if you see OOM errors
# Adjust based on your available GPU memory
BATCH_SIZE = 32  # Set to 16 if running out of memory

print("Processing antibody sequences (heavy + light chains):")
print(f"Using batch_size={BATCH_SIZE}")
print_gpu_memory()

antibody_embs = get_antiberty_embeddings_batch(
    antibody_seqs, antiberty_model, antiberty_tokenizer, device, batch_size=BATCH_SIZE
)

print("\n‚úì Antibody embeddings extracted")
print_gpu_memory()
clear_gpu_cache()

# Process antigens separately
print("\n" + "="*60)
print("Processing antigen sequences:")
print(f"Using batch_size={BATCH_SIZE}")
print_gpu_memory()

antigen_seqs = train_df["antigen_sequences"].to_list()
antigen_embs = get_antiberty_embeddings_batch(
    antigen_seqs, antiberty_model, antiberty_tokenizer, device, batch_size=BATCH_SIZE
)

print("\n‚úì Antigen embeddings extracted")
print_gpu_memory()
clear_gpu_cache()

# Concatenate antibody and antigen embeddings
print("\n" + "="*60)
print("Combining embeddings...")
X_antiberty = np.array(
    [np.concatenate([ab, ag]) for ab, ag in zip(antibody_embs, antigen_embs)]
)
y_antiberty = train_df["binding_score"].to_numpy()

print(f"\n‚úì AntiBERTy features: {X_antiberty.shape[1]} dimensions (512 antibody + 512 antigen)")
print(f"  Total samples: {X_antiberty.shape[0]}")
print(f"\nüìä Final GPU state:")
print_gpu_memory()

## Train/Validation Split

In [None]:
# Split into train/validation sets
X_train, X_val, y_train, y_val = train_test_split(
    X_antiberty, y_antiberty, test_size=0.2, random_state=42
)

print(f"Training set: {len(X_train)} samples")
print(f"Validation set: {len(X_val)} samples")
print(f"Feature dimensions: {X_train.shape[1]}")

## Model Training

Training multiple models to compare performance:
1. **Random Forest** - Ensemble tree-based model
2. **Linear SVR with PCA** - Linear model with dimensionality reduction

**Evaluation Metrics**:
- **R¬≤**: Coefficient of determination (higher is better)
- **MAE**: Mean Absolute Error (lower is better)
- **Spearman œÅ**: Rank correlation (higher is better, more robust to outliers)

In [None]:
print("Training Random Forest model...\n")

# Model 1: Random Forest
print("[1/2] Random Forest")
rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
rf_model.fit(X_train, y_train)
y_val_pred_rf = rf_model.predict(X_val)

r2_rf = r2_score(y_val, y_val_pred_rf)
mae_rf = mean_absolute_error(y_val, y_val_pred_rf)
spearman_rf = spearmanr(y_val, y_val_pred_rf).correlation

print(f"  Validation R¬≤: {r2_rf:.3f}")
print(f"  Validation MAE: {mae_rf:.3f}")
print(f"  Spearman œÅ: {spearman_rf:.3f}")

In [None]:
print("\nTraining Linear SVR with PCA...\n")

# Model 2: PCA + Linear SVR
print("[2/2] PCA + Linear SVR")
svr_model = make_pipeline(PCA(n_components=100), LinearSVR(max_iter=2000))
svr_model.fit(X_train, y_train)
y_val_pred_svr = svr_model.predict(X_val)

r2_svr = r2_score(y_val, y_val_pred_svr)
mae_svr = mean_absolute_error(y_val, y_val_pred_svr)
spearman_svr = spearmanr(y_val, y_val_pred_svr).correlation

print(f"  Validation R¬≤: {r2_svr:.3f}")
print(f"  Validation MAE: {mae_svr:.3f}")
print(f"  Spearman œÅ: {spearman_svr:.3f}")

In [None]:
# Comparison table
print("\n" + "=" * 60)
print(f"{'Metric':<20} {'Random Forest':<20} {'PCA + SVR':<20}")
print("=" * 60)
print(f"{'R¬≤':<20} {r2_rf:<20.3f} {r2_svr:<20.3f}")
print(f"{'MAE':<20} {mae_rf:<20.3f} {mae_svr:<20.3f}")
print(f"{'Spearman œÅ':<20} {spearman_rf:<20.3f} {spearman_svr:<20.3f}")
print("=" * 60)

# Highlight best model
best_spearman = max(spearman_rf, spearman_svr)
if spearman_rf == best_spearman:
    print("\nüèÜ Random Forest gives the best Spearman correlation!")
    best_model = "Random Forest"
else:
    print("\nüèÜ PCA + SVR gives the best Spearman correlation!")
    best_model = "PCA + SVR"

## Visualize Predictions

In [None]:
# Visualize prediction quality
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Random Forest
axes[0].scatter(y_val, y_val_pred_rf, alpha=0.3, s=10, color="mediumseagreen")
axes[0].plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--", lw=2)
axes[0].set_xlabel("Actual Binding Score", fontsize=12)
axes[0].set_ylabel("Predicted Binding Score", fontsize=12)
axes[0].set_title(
    f"Random Forest (R¬≤ = {r2_rf:.3f}, œÅ = {spearman_rf:.3f})",
    fontsize=13,
    fontweight="bold",
)
axes[0].grid(True, alpha=0.3)

# PCA + SVR
axes[1].scatter(y_val, y_val_pred_svr, alpha=0.3, s=10, color="steelblue")
axes[1].plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--", lw=2)
axes[1].set_xlabel("Actual Binding Score", fontsize=12)
axes[1].set_ylabel("Predicted Binding Score", fontsize=12)
axes[1].set_title(
    f"PCA + SVR (R¬≤ = {r2_svr:.3f}, œÅ = {spearman_svr:.3f})",
    fontsize=13,
    fontweight="bold",
)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nPerfect predictions would fall on the red diagonal line.")
print("Scatter around the line indicates prediction errors.")

## Cross-Validation

5-fold cross-validation for more robust performance estimation.

In [None]:
print("Running 5-fold cross-validation on AntiBERTy features...\n")

# Use a subset for faster CV (optional - remove for full dataset)
# For production, use the full dataset
use_subset = len(X_antiberty) > 10000
if use_subset:
    cv_size = 5000
    X_cv = X_antiberty[:cv_size]
    y_cv = y_antiberty[:cv_size]
    print(f"Using subset of {cv_size} samples for faster CV")
else:
    X_cv = X_antiberty
    y_cv = y_antiberty
    print(f"Using full dataset ({len(X_cv)} samples)")

kfold = KFold(n_splits=5, shuffle=True, random_state=42)
rf_cv = RandomForestRegressor(
    n_estimators=50, random_state=42, n_jobs=-1
)  # Fewer trees for speed

# Cross-validate with multiple metrics
cv_r2 = cross_val_score(rf_cv, X_cv, y_cv, cv=kfold, scoring="r2", n_jobs=-1)
cv_mae = -cross_val_score(
    rf_cv, X_cv, y_cv, cv=kfold, scoring="neg_mean_absolute_error", n_jobs=-1
)

print("\nCross-validation results (5 folds):")
print("\nR¬≤ scores per fold:")
for i, score in enumerate(cv_r2, 1):
    print(f"  Fold {i}: {score:.3f}")
print(f"\nMean R¬≤: {cv_r2.mean():.3f} ¬± {cv_r2.std():.3f}")

print("\nMAE scores per fold:")
for i, score in enumerate(cv_mae, 1):
    print(f"  Fold {i}: {score:.3f}")
print(f"\nMean MAE: {cv_mae.mean():.3f} ¬± {cv_mae.std():.3f}")

print("\nThe ¬± shows how stable our estimates are across different data splits")

## Hyperparameter Optimization with Optuna

Finding the best Random Forest hyperparameters using Bayesian optimization.

In [None]:
# Suppress Optuna's verbose output
optuna.logging.set_verbosity(optuna.logging.WARNING)

# Use smaller dataset for HPO (optional)
if len(X_antiberty) > 10000:
    hpo_size = 5000
    X_hpo_train, X_hpo_val, y_hpo_train, y_hpo_val = train_test_split(
        X_antiberty[:hpo_size], y_antiberty[:hpo_size], test_size=0.2, random_state=42
    )
    print(f"Using {hpo_size} samples for HPO")
else:
    X_hpo_train, X_hpo_val, y_hpo_train, y_hpo_val = train_test_split(
        X_antiberty, y_antiberty, test_size=0.2, random_state=42
    )
    print(f"Using full dataset for HPO")


def objective(trial):
    """
    Optuna objective function.
    Optuna will call this function many times with different hyperparameters.
    """
    # Suggest hyperparameters
    params = {
        "n_estimators": trial.suggest_int("n_estimators", 50, 200),
        "max_depth": trial.suggest_int("max_depth", 5, 30),
        "min_samples_split": trial.suggest_int("min_samples_split", 2, 20),
        "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 10),
        "random_state": 42,
        "n_jobs": -1,
    }

    # Train model with these hyperparameters
    model = RandomForestRegressor(**params)
    model.fit(X_hpo_train, y_hpo_train)

    # Evaluate
    y_pred = model.predict(X_hpo_val)
    mae = mean_absolute_error(y_hpo_val, y_pred)

    # Optuna minimizes the objective, so return MAE
    return mae


# Run optimization
print("\nRunning Optuna hyperparameter optimization...")
print("(20 trials - in production, use 100+ trials)\n")

study = optuna.create_study(direction="minimize", study_name="antiberty_rf_hpo")
study.optimize(objective, n_trials=20, show_progress_bar=True)

print("\n‚úì Optimization complete!")
print(f"\nBest MAE: {study.best_value:.3f}")
print("\nBest hyperparameters:")
for param, value in study.best_params.items():
    print(f"  {param}: {value}")

In [None]:
# Visualize optimization history
trial_numbers = [trial.number for trial in study.trials]
trial_values = [trial.value for trial in study.trials]
best_values = [min(trial_values[: i + 1]) for i in range(len(trial_values))]

plt.figure(figsize=(10, 5))
plt.plot(
    trial_numbers, trial_values, "o-", alpha=0.5, label="Trial MAE", color="steelblue"
)
plt.plot(trial_numbers, best_values, "r-", linewidth=2, label="Best MAE so far")
plt.xlabel("Trial Number", fontsize=12)
plt.ylabel("MAE", fontsize=12)
plt.title("Optuna Optimization Progress", fontsize=14, fontweight="bold")
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("The red line shows the best result found so far")

In [None]:
# Train final model with best hyperparameters
print("Training final model with optimized hyperparameters...\n")

best_rf = RandomForestRegressor(**study.best_params)
best_rf.fit(X_hpo_train, y_hpo_train)
y_hpo_pred = best_rf.predict(X_hpo_val)

final_r2 = r2_score(y_hpo_val, y_hpo_pred)
final_mae = mean_absolute_error(y_hpo_val, y_hpo_pred)
final_spearman = spearmanr(y_hpo_val, y_hpo_pred).correlation

print("Final optimized model performance:")
print(f"  R¬≤: {final_r2:.3f}")
print(f"  MAE: {final_mae:.3f}")
print(f"  Spearman œÅ: {final_spearman:.3f}")

## Save Embeddings (Optional)

Save the extracted embeddings for future use without recomputing.

In [None]:
# Save embeddings to disk
import os

output_dir = "../data/embeddings"
os.makedirs(output_dir, exist_ok=True)

# Save as numpy arrays
np.save(os.path.join(output_dir, "antiberty_embeddings.npy"), X_antiberty)
np.save(os.path.join(output_dir, "binding_scores.npy"), y_antiberty)

print(f"‚úì Embeddings saved to {output_dir}/")
print(f"  - antiberty_embeddings.npy: {X_antiberty.shape}")
print(f"  - binding_scores.npy: {y_antiberty.shape}")

## Summary

**What we accomplished:**
1. ‚úÖ Loaded antibody-antigen binding data
2. ‚úÖ Extracted AntiBERTy embeddings (antibody-specific)
3. ‚úÖ Trained and compared multiple ML models
4. ‚úÖ Performed cross-validation for robust estimates
5. ‚úÖ Optimized hyperparameters with Optuna
6. ‚úÖ Saved embeddings for future use

**Key Results:**
- AntiBERTy embeddings: 1024 dimensions (512 antibody + 512 antigen)
- Best model: {best_model}
- Performance metrics available above

**Next Steps:**
- Try different aggregation methods (CLS token vs mean pooling)
- Extract CDR-specific embeddings
- Combine with structural features
- Fine-tune AntiBERTy on binding data
- Ensemble with other models (ESM, structure-based)

---

## Notes on AntiBERTy for Server Deployment

**Memory Requirements:**
- Model size: ~200 MB
- GPU memory: ~2-4 GB for batch inference
- Recommended GPU: NVIDIA T4 or better

**Speed Considerations:**
- GPU: ~200-500 sequences/second
- CPU: ~10-50 sequences/second
- Adjust `batch_size` based on available GPU memory

**512 Token Limit:**
- Heavy chain: ~450 amino acids
- Light chain: ~220 amino acids
- Combined: ~670 amino acids ‚Üí **will be truncated**
- Consider processing chains separately if truncation is a concern

**Alternative Strategies:**
1. Process H and L chains separately, concatenate embeddings
2. Extract only CDR regions (most relevant for binding)
3. Use sliding window approach for long sequences
4. Compare with ESM (has 1024 token limit)

---