# AbLangPDB1 Inference Examples

This notebook demonstrates how to use AbLangPDB1 to generate embeddings for antibody sequences. AbLangPDB1 creates 1536-dimensional embeddings where antibodies targeting similar epitopes cluster together.

## Prerequisites

Make sure you have:
1. Downloaded the model weights: `ablangpdb_model.safetensors`
2. Installed required packages: `torch`, `pandas`, `transformers`, `safetensors`

```bash
# Download model weights
curl -L "https://huggingface.co/clint-holt/AbLangPDB1/resolve/main/ablangpdb_model.safetensors?download=true" -o ablangpdb_model.safetensors

# Install dependencies
pip install torch pandas transformers safetensors
```

## 1. Setup and Imports

In [None]:
# Standard library imports
from time import time
import os

# Data processing imports
import pandas as pd
from tqdm import tqdm

# PyTorch imports
import torch
from torch.utils.data import DataLoader, TensorDataset

# Hugging Face Transformers imports
from transformers import AutoTokenizer

# Local import
from ablangpaired_model import AbLangPairedConfig, AbLangPaired

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Helper Functions

These functions help with tokenization and batch processing of antibody sequences.

In [None]:
def tokenize_data(df: pd.DataFrame, model_config: AbLangPairedConfig) -> TensorDataset:
    """
    Prepare antibody sequences for input to the AbLangPDB1 model.
    
    Args:
        df: DataFrame containing antibody sequences with HC_AA and LC_AA columns
        model_config: AbLangPairedConfig that tells where to load the tokenizers from
        
    Returns:
        TensorDataset with encoded sequences ready for model input
    """   
    # Filter out sequences that are too long or contain stop codons
    # AbLang tokenizers work best with sequences under 157 amino acids
    df = df[(df["HC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa))) & 
            (df["LC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa)))]
    
    if len(df) == 0:
        raise ValueError("No valid sequences found after filtering")
    
    # Load tokenizers for heavy and light chains
    print("Loading tokenizers...")
    heavy_tokenizer = AutoTokenizer.from_pretrained(model_config.heavy_model_id, revision=model_config.heavy_revision)
    light_tokenizer = AutoTokenizer.from_pretrained(model_config.light_model_id, revision=model_config.light_revision)

    # Format sequences for tokenization (add spaces between amino acids)
    df.loc[:, "PREPARED_HC_SEQ"] = df["HC_AA"].apply(lambda x: " ".join(list(x)))
    df.loc[:, "PREPARED_LC_SEQ"] = df["LC_AA"].apply(lambda x: " ".join(list(x)))

    # Tokenize heavy chain sequences
    print("Tokenizing heavy chain sequences...")
    h_train_tokens = heavy_tokenizer.batch_encode_plus(
        df["PREPARED_HC_SEQ"].tolist(), 
        add_special_tokens=True, 
        padding='longest', 
        return_tensors="pt",
        truncation=True,
        return_special_tokens_mask=True
    )
    
    # Tokenize light chain sequences
    print("Tokenizing light chain sequences...")
    l_train_tokens = light_tokenizer.batch_encode_plus(
        df["PREPARED_LC_SEQ"].tolist(), 
        add_special_tokens=True, 
        padding='longest', 
        return_tensors="pt",
        truncation=True,
        return_special_tokens_mask=True
    )
    
    # Handle unknown tokens by converting them to mask tokens
    # This prevents errors during inference
    for tokens_dict in [h_train_tokens, l_train_tokens]:
        matches = torch.where(tokens_dict['input_ids'] == 24)  # UNK token
        if len(matches[0]) > 0:
            tokens_dict['input_ids'][matches] = 23  # MASK token
            tokens_dict['attention_mask'][matches] = False
    
    # Create TensorDataset for model input
    dataset = TensorDataset(
        h_train_tokens['input_ids'].to(torch.int16), 
        l_train_tokens['input_ids'].to(torch.int16),
        h_train_tokens['attention_mask'].to(torch.bool),
        l_train_tokens['attention_mask'].to(torch.bool)
    )
    
    print(f"Created dataset with {len(dataset)} sequences")
    return dataset


def embed_dataloader(dataloader, model, device) -> torch.Tensor:
    """
    Generate embeddings for all antibodies in the dataloader.
    
    Args:
        dataloader: DataLoader containing tokenized antibody sequences
        model: Trained AbLangPDB1 model
        device: Device to run inference on (CPU or GPU)
        
    Returns:
        Tensor containing embeddings for all antibodies (shape: N x 1536)
    """
    model.to(device)
    model.eval()
    
    # Preallocate tensor for all embeddings
    num_embeddings = len(dataloader.dataset)
    embedding_dim = 1536
    all_embeds = torch.zeros((num_embeddings, embedding_dim), dtype=torch.float32)
    
    # Generate embeddings batch by batch
    current_batch_index = 0
    print("Generating embeddings...")
    
    with torch.no_grad():
        for htoks, ltoks, hmasks, lmasks in tqdm(dataloader, desc="Processing batches"):
            # Move tensors to device
            htoks = htoks.to(device)
            hmasks = hmasks.to(device)
            ltoks = ltoks.to(device) 
            lmasks = lmasks.to(device)
            
            # Forward pass to get embeddings
            embeds = model(
                h_input_ids=htoks, 
                h_attention_mask=hmasks, 
                l_input_ids=ltoks, 
                l_attention_mask=lmasks
            )
            
            # Store embeddings in preallocated tensor
            batch_size = embeds.size(0)
            all_embeds[current_batch_index:current_batch_index + batch_size] = embeds.detach().cpu()
            current_batch_index += batch_size
            
            # Clean up GPU memory
            del htoks, hmasks, ltoks, lmasks, embeds
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    print(f"Generated {all_embeds.shape[0]} embeddings of dimension {all_embeds.shape[1]}")
    return all_embeds

## 3. Load Model

Load the AbLangPDB1 model and tokenizers.

In [None]:
# Check if model weights exist
model_path = "ablangpdb_model.safetensors"
if not os.path.exists(model_path):
    print(f"❌ Model weights not found: {model_path}")
    print("\n📥 Please download the model weights first:")
    print('curl -L "https://huggingface.co/clint-holt/AbLangPDB1/resolve/main/ablangpdb_model.safetensors?download=true" -o ablangpdb_model.safetensors')
else:
    print(f"✅ Found model weights: {model_path}")

# Load model configuration and model
print("\n🔄 Loading model...")
model_config = AbLangPairedConfig(checkpoint_filename=model_path)
model = AbLangPaired(model_config, device)
model.eval()
print("✅ Model loaded successfully!")

## 4. Single Antibody Example

Generate an embedding for a single antibody sequence.

In [None]:
# Example antibody sequences (SARS-CoV-2 neutralizing antibody)
example_data = {
    'HC_AA': ["EVQLVESGGGLVQPGGSLRLSCAASGFNLYYYSIHWVRQAPGKGLEWVASISPYSSSTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARGRWYRRALDYWGQGTLVTVSS"],
    'LC_AA': ["DIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQYPYYSSLITFGQGTKVEIK"]
}

df_single = pd.DataFrame(example_data)
print("Example antibody:")
print(f"Heavy chain: {df_single['HC_AA'][0][:50]}...")
print(f"Light chain: {df_single['LC_AA'][0][:50]}...")
print(f"Heavy chain length: {len(df_single['HC_AA'][0])} amino acids")
print(f"Light chain length: {len(df_single['LC_AA'][0])} amino acids")

In [None]:
# Load tokenizers
print("Loading tokenizers for single sequence example...")
heavy_tokenizer = AutoTokenizer.from_pretrained(model_config.heavy_model_id, revision=model_config.heavy_revision)
light_tokenizer = AutoTokenizer.from_pretrained(model_config.light_model_id, revision=model_config.light_revision)

# Preprocess sequences (add spaces between amino acids)
df_single["PREPARED_HC_SEQ"] = df_single["HC_AA"].apply(lambda x: " ".join(list(x)))
df_single["PREPARED_LC_SEQ"] = df_single["LC_AA"].apply(lambda x: " ".join(list(x)))

# Tokenize sequences
h_tokens = heavy_tokenizer(df_single["PREPARED_HC_SEQ"].tolist(), padding='longest', return_tensors="pt")
l_tokens = light_tokenizer(df_single["PREPARED_LC_SEQ"].tolist(), padding='longest', return_tensors="pt")

print(f"Heavy chain tokens shape: {h_tokens['input_ids'].shape}")
print(f"Light chain tokens shape: {l_tokens['input_ids'].shape}")

In [None]:
# Generate embedding
print("Generating embedding for single antibody...")
with torch.no_grad():
    embedding = model(
        h_input_ids=h_tokens['input_ids'].to(device),
        h_attention_mask=h_tokens['attention_mask'].to(device),
        l_input_ids=l_tokens['input_ids'].to(device),
        l_attention_mask=l_tokens['attention_mask'].to(device)
    )

print(f"\n✅ Generated embedding shape: {embedding.shape}")
print(f"📊 Embedding dimension: {embedding.shape[1]}")
print(f"🔢 First 5 embedding values: {embedding[0][:5].tolist()}")
print(f"📈 Embedding norm: {torch.norm(embedding[0]).item():.4f}")

## 5. Multiple Antibodies Example

Process multiple antibody sequences efficiently using batch processing.

In [None]:
# Create a dataset with multiple example antibodies
multi_antibody_data = {
    'HC_AA': [
        # Example 1: SARS-CoV-2 antibody
        "EVQLVESGGGLVQPGGSLRLSCAASGFNLYYYSIHWVRQAPGKGLEWVASISPYSSSTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARGRWYRRALDYWGQGTLVTVSS",
        # Example 2: Another antibody
        "QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYWIEWVRQAPGQGLEWMGIIYPILSEGSTKYYNEKFKDRATLSADTSTSTAYMELSSLTSEDTAVYYCARGGAYYGSGYYAMDYWGQGTLVTVSS",
        # Example 3: Third antibody
        "EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARYHGGDAMDYWGQGTLVTVSS",
        # Example 4: Fourth antibody
        "QVQLQQSGPGLVKPSQTLSLTCAISGDSVSSNSAAWNWIRQSPSRGLEWLGRTYYRSKWYNDYAVSVKSRITINPDTSKNQFSLQLNSVTPEDTAVYYCARYDILTGYCTNGVCYAMDYWGQGTLVTVSS"
    ],
    'LC_AA': [
        # Light chains corresponding to heavy chains above
        "DIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQYPYYSSLITFGQGTKVEIK",
        "EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGSSPLTFGAGTKVEIK",
        "DIQMTQSPSSLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKLLIYKASSLESGVPSRFSGSGSGTEFTLTISSLQPDDFATYYCQQYNSYSYTFGQGTKVEIK",
        "DIVMTQTPKFLLVSAGDRVTITCRASQGISSALAWYQQKPGQAPRLLIYDASSRATGIPARFSGSGSGTDFTLTISRLEPEDFAVYYCQQFNSYPLTFGAGTKLELK"
    ],
    'Antibody_ID': ['AB001', 'AB002', 'AB003', 'AB004'],
    'Description': [
        'SARS-CoV-2 neutralizing antibody',
        'Example antibody 2', 
        'Example antibody 3',
        'Example antibody 4'
    ]
}

df_multi = pd.DataFrame(multi_antibody_data)
print(f"Created dataset with {len(df_multi)} antibodies:")
for i, row in df_multi.iterrows():
    print(f"  {row['Antibody_ID']}: {row['Description']} (HC: {len(row['HC_AA'])} aa, LC: {len(row['LC_AA'])} aa)")

In [None]:
# Process multiple antibodies using batch processing
batch_size = 2  # Small batch size for this example

# Tokenize the dataset
tokenized_dataset = tokenize_data(df_multi, model_config)

# Create dataloader
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)

# Generate embeddings
embeddings_batch = embed_dataloader(dataloader, model, device)

print(f"\n✅ Generated batch embeddings shape: {embeddings_batch.shape}")
print(f"📊 Number of antibodies processed: {embeddings_batch.shape[0]}")
print(f"📏 Embedding dimension: {embeddings_batch.shape[1]}")

## 6. Analyze Embeddings

Calculate similarities between antibody embeddings to see which ones might target similar epitopes.

In [None]:
# Calculate pairwise cosine similarities
print("Calculating pairwise similarities...")
similarities = torch.cosine_similarity(embeddings_batch.unsqueeze(1), embeddings_batch.unsqueeze(0), dim=2)

print("\n📊 Cosine similarity matrix:")
print("    ", end="")
for j in range(len(df_multi)):
    print(f"  {df_multi.iloc[j]['Antibody_ID']:>6}", end="")
print()

for i in range(similarities.shape[0]):
    print(f"{df_multi.iloc[i]['Antibody_ID']:>6}: ", end="")
    for j in range(similarities.shape[1]):
        print(f"{similarities[i][j]:.3f}  ", end="")
    print()

# Find most similar pairs (excluding self-similarity)
print("\n🔍 Most similar antibody pairs:")
for i in range(len(df_multi)):
    for j in range(i+1, len(df_multi)):
        sim = similarities[i][j].item()
        id1, id2 = df_multi.iloc[i]['Antibody_ID'], df_multi.iloc[j]['Antibody_ID']
        print(f"  {id1} ↔ {id2}: {sim:.3f}")

## 7. Save Results

Save the embeddings and results for further analysis.

In [None]:
# Add embeddings to the dataframe
df_multi['EMBEDDING'] = [emb.numpy() for emb in embeddings_batch]

# Save to various formats
print("Saving results...")

# Save as CSV (embeddings will be serialized)
df_results = df_multi.copy()
df_results['EMBEDDING_STR'] = df_results['EMBEDDING'].apply(lambda x: ','.join(map(str, x)))
df_results.drop('EMBEDDING', axis=1).to_csv('antibody_results.csv', index=False)
print("✅ Saved results to antibody_results.csv")

# Save as pickle (preserves numpy arrays)
df_multi.to_pickle('antibody_results.pkl')
print("✅ Saved results to antibody_results.pkl")

# Save similarity matrix
similarity_df = pd.DataFrame(
    similarities.numpy(), 
    index=df_multi['Antibody_ID'], 
    columns=df_multi['Antibody_ID']
)
similarity_df.to_csv('antibody_similarities.csv')
print("✅ Saved similarity matrix to antibody_similarities.csv")

print("\n📁 Generated files:")
print("  • antibody_results.csv - Antibody sequences and metadata")
print("  • antibody_results.pkl - Complete results with embeddings")
print("  • antibody_similarities.csv - Similarity matrix")

## 8. Usage Tips

### Performance Optimization
- Use larger batch sizes (256-512) for better GPU utilization
- Process sequences in similar length groups to minimize padding
- Use `torch.cuda.empty_cache()` if running into memory issues

### Interpretation
- **High similarity (>0.8)**: Likely target similar epitopes
- **Medium similarity (0.6-0.8)**: May share some epitope characteristics
- **Low similarity (<0.6)**: Likely target different epitopes

### Common Issues
- **Sequence length**: AbLang works best with sequences <157 amino acids
- **Stop codons**: Remove sequences containing '*' characters
- **Memory**: Reduce batch size if encountering GPU memory errors

### Next Steps
- Explore the benchmarking suite in `/benchmarking/` directory
- See the main README.md for more applications
- Check out the paper for technical details: https://doi.org/10.1101/2025.02.25.640114