In [None]:
############ Imports #########################
# Standard library imports
from time import time

# Data processing imports
import pandas as pd
from tqdm import tqdm

# PyTorch imports
import torch
from torch.utils.data import DataLoader, TensorDataset

# Hugging Face Transformers imports
from transformers import AutoTokenizer

# Local import
from ablangpaired_model import AbLangPairedConfig, AbLangPaired

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
############ Helper Methods (run me) ####################
def tokenize_data(df: pd.DataFrame, model_config: AbLangPairedConfig) -> TensorDataset:
    """
    Prepare antibody sequences for input to the AbLang1Embedder model.
    
    Args:
        df: DataFrame containing antibody sequences with HC_AA and LC_AA columns
        model_config: AbLangPairedConfig that tells where to load the heavy and light tokenizers from
        
    Returns:
        TensorDataset with encoded sequences ready for model input
    """   
    # Configure sequence truncation parameters
    max_tokenizer_length = 160
    num_tokens_always_added = 2
    max_length = max_tokenizer_length - num_tokens_always_added - 1
    
    # Filter out sequences that are too long or contain stop codons
    df = df[(df["HC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa))) & 
            (df["LC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa)))]
    
    # Load tokenizers for heavy and light chains
    heavy_tokenizer = AutoTokenizer.from_pretrained(model_config.heavy_model_id, revision=model_config.heavy_revision)
    light_tokenizer = AutoTokenizer.from_pretrained(model_config.light_model_id, revision=model_config.light_revision)

    # Format sequences for tokenization (add spaces between amino acids)
    df.loc[:, "PREPARED_HC_SEQ"] = df["HC_AA"].apply(lambda x: " ".join(list(x)))
    df.loc[:, "PREPARED_LC_SEQ"] = df["LC_AA"].apply(lambda x: " ".join(list(x)))

    # Tokenize heavy chain sequences
    print("About to encode heavies")
    t1 = time()
    h_train_tokens = heavy_tokenizer.batch_encode_plus(
        df["PREPARED_HC_SEQ"].tolist(), 
        add_special_tokens=True, 
        padding='longest', 
        return_tensors="pt",
        truncation=True,
        return_special_tokens_mask=True
    )
    print(f"That took {time() - t1} seconds")
    
    # Tokenize light chain sequences
    print("About to encode lights")
    t2 = time()
    l_train_tokens = light_tokenizer.batch_encode_plus(
        df["PREPARED_LC_SEQ"].tolist(), 
        add_special_tokens=True, 
        padding='longest', 
        return_tensors="pt",
        truncation=True,
        return_special_tokens_mask=True
    )
    print(f"That took {time() - t2} seconds")
    
    # Handle special token 24 (replace with token 23 and set attention mask to False)
    # I'm not sure why the [UNK] token gives errors but I have seen it happen
    # Here we convert it to a [MASK] token which AbLang handles well
    matches = torch.where(h_train_tokens['input_ids'] == 24)
    if len(matches[0]) > 0:
        h_train_tokens['input_ids'][matches] = 23
        h_train_tokens['attention_mask'][matches] = False
    
    matches = torch.where(l_train_tokens['input_ids'] == 24)
    if len(matches[0]) > 0:
        l_train_tokens['input_ids'][matches] = 23
        l_train_tokens['attention_mask'][matches] = False
    
  
    # Create TensorDataset for model input
    dataset = TensorDataset(
        h_train_tokens['input_ids'].to(torch.int16), 
        l_train_tokens['input_ids'].to(torch.int16),
        h_train_tokens['attention_mask'].to(torch.bool),
        l_train_tokens['attention_mask'].to(torch.bool)
    )
    
    return dataset


def embed_dataloader(dataloader, model, device) -> torch.Tensor:
    """
    Generate embeddings for all antibodies in the dataloader.
    
    Args:
        dataloader: DataLoader containing tokenized antibody sequences
        model: Trained Embedder model
        device: Device to run inference on (CPU or GPU)
        
    Returns:
        Tensor containing embeddings for all antibodies
    """
    model.to(device)
    model.eval()
    
    # Preallocate tensor for all embeddings
    num_embeddings = len(dataloader.dataset)
    embedding_dim = 1536
    all_embeds = torch.zeros((num_embeddings, embedding_dim), dtype=torch.float32)
    
    # Generate embeddings batch by batch
    current_batch_index = 0
    print("Now Embedding Antibodies")
    with torch.no_grad():
        for htoks, ltoks, hmasks, lmasks in tqdm(dataloader):
            # Move tensors to device
            htoks = htoks.to(device)
            hmasks = hmasks.to(device)
            ltoks = ltoks.to(device) 
            lmasks = lmasks.to(device)
            
            # Forward pass to get embeddings
            embeds = model(
                h_input_ids=htoks, 
                h_attention_mask=hmasks, 
                l_input_ids=ltoks, 
                l_attention_mask=lmasks
            )
            
            # Store embeddings in preallocated tensor
            batch_size = embeds.size(0)
            all_embeds[current_batch_index:current_batch_index + batch_size] = embeds.detach().cpu()
            current_batch_index += batch_size
            
            # Clean up GPU memory
            del htoks, hmasks, ltoks, lmasks, embeds
            torch.cuda.empty_cache()
    
    return all_embeds

# AbLangPDB Inference

In [2]:
df_old = pd.read_pickle("paper_pdb_dataset.pd")
df_old.head()

Unnamed: 0,Column1,AG_CLUSTER,AG_AA,CLAN,PFAM_PLUS,NAME_x,HV,HJ,CDRH3,HC_AA,...,NN_GUESS_PFAM,NN_GUESS_COS_SIM,NN_GUESS_ACTUAL_LABEL,NN_GUESS_RANK,CORRECT_NN_EP_GUESS,CORRECT_NN_PFAM_GUESS,PLOTTING_AG,t-SNE_1,t-SNE_2,EMBEDDING
0,1,0551_0777_1026_1171,MDLTVEPNLHSLITSTTHKWIFVGGKGGVGKTTSSCSIAIQMALSQ...,CL0023,ArsA_ATPase,4XWO_C_D_A,IGHV3-74,IGHJ4,ARGRWYRRALDY,EVQLVESGGGLVQPGGSLRLSCAASGFNLYYYSIHWVRQAPGKGLE...,...,Paxillin,0.83395,-1.0,539.0,False,False,Other,-24.314289,-52.9277,"[0.030508561059832573, 0.004053947050124407, -..."
1,9,0519_0739_0984_1128,MQPTGREGSRALSRRYLRRLLLLLLLLLLRQPVTRAETTPGAPRAL...,CL0036,Peptidase_M19,6VGR_C_D_A,IGHV1-69,IGHJ4,ARRAAAYYSNPEWFAY,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYWIEWVRQAPGQGLE...,...,GP41,0.552818,-1.0,861.0,True,False,Other,-26.720898,15.49514,"[-0.030367737635970116, -0.004938304889947176,..."
2,11,0478_0695_0936_1076,ENAIKKTKNQENQLTLLPIKSTEEEKDDIKNGKDIKKEIDNDKENI...,,Rh5,6RCU_B_C_A,IGHV1-69,IGHJ3,ARDKHSWSYAFDI,EVQLVQSGAEVKKPGSSVKVSCKASGGTFSNYAINWVRQAPGQGLE...,...,HRM,0.549503,-1.0,388.0,False,False,Other,-17.266788,18.418531,"[-0.04985538497567177, -0.01655494049191475, 0..."
3,17,0527_0747_0992_1136,ASWSHPQFEKSGGGGGLVPRGSGIQDLSDNYENLSKLLTRYSTLNT...,,SabA_adhesion,7ZQT_C_D_A,IGHV3-48,IGHJ1,ARLNGWAGSGLDH,QVQLVQSGGGIGQPGGSLRLACEASGFTFNLFEMAWVRQAPGQSLE...,...,Bet_v_1,0.355318,-1.0,225.0,True,False,Other,3.680409,-11.97096,"[-0.038855042308568954, -0.007398952264338732,..."
4,19,0092_0170_0301_0355,ARGTNVTRECCLEYFKGAIPLRKLKTWYQTSEDCSRDAIVFVTVQG...,CL0730,IL8,5WK3_Q_P_A,IGHV5-51,IGHJ4,ARVGPADVWDSFDY,EVQLVQSGAEVKKPGESLKISCKGSGYSFTSYWIGWVRQMPGKGLE...,...,Tubulin-binding,0.437684,-1.0,674.0,False,False,Other,14.008639,6.846615,"[-0.041530925780534744, -0.031332969665527344,..."


# Set up Model

In [3]:
modelf = "ablangpdb_model.safetensors" # Path to model checkpoint
# The AbLangPairedConfig defaults to pull from Hugging Face to dowload AbLang_heavy and AbLang_light
# You could instead specify a local path
model_config = AbLangPairedConfig(checkpoint_filename=modelf)
model = AbLangPaired(model_config, device)



# Embed 1 Antibody

## Tokenize the 1 sequence

In [5]:
# Define an example antibody
data = {
    'HC_AA': ["EVQLVESGGGLVQPGGSLRLSCAASGFNLYYYSIHWVRQAPGKGLEWVASISPYSSSTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARGRWYRRALDYWGQGTLVTVSS"],
    'LC_AA': ["DIQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQYPYYSSLITFGQGTKVEIK"]
}
df = pd.DataFrame(data)

# Pre-process sequences by adding spaces between amino acids
df["PREPARED_HC_SEQ"] = df["HC_AA"].apply(lambda x: " ".join(list(x)))
df["PREPARED_LC_SEQ"] = df["LC_AA"].apply(lambda x: " ".join(list(x)))

heavy_tokenizer = AutoTokenizer.from_pretrained(model_config.heavy_model_id, revision=model_config.heavy_revision)
light_tokenizer = AutoTokenizer.from_pretrained(model_config.light_model_id, revision=model_config.light_revision)

h_tokens = heavy_tokenizer(df["PREPARED_HC_SEQ"].tolist(), padding='longest', return_tensors="pt")
l_tokens = light_tokenizer(df["PREPARED_LC_SEQ"].tolist(), padding='longest', return_tensors="pt")



## Embed it

In [7]:
with torch.no_grad():
    embedding = model(
        h_input_ids=h_tokens['input_ids'].to(device),
        h_attention_mask=h_tokens['attention_mask'].to(device),
        l_input_ids=l_tokens['input_ids'].to(device),
        l_attention_mask=l_tokens['attention_mask'].to(device)
    )

correct_embedding = torch.tensor(df_old.loc[0, "EMBEDDING"])

if torch.allclose(correct_embedding.cpu(), embedding[0].cpu(), atol=1e-6):
    print("\n🎉 SUCCESS! Embeddings are identical. The model was loaded correctly.")
else:
    print("\n🚨 FAILURE! Embeddings do not match.")


🎉 SUCCESS! Embeddings are identical. The model was loaded correctly.


# Embed many antibodies

In [5]:
all_old_embeddings = torch.tensor(df_old["EMBEDDING"].to_list())

In [13]:
df = pd.read_pickle("paper_pdb_dataset.pd")
batch_size = 256

# Prepare data for embedding
tokenized_dataset = tokenize_data(df, model_config)

# Create dataloader
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)

all_embeds = embed_dataloader(dataloader, model, device)

print("\n🔎 Comparing embeddings...")
# Do same filter as in `tokenize_dataset`
df_old2 = df_old[(df_old["HC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa))) & 
            (df_old["LC_AA"].apply(lambda aa: (len(aa) < 157) & ("*" not in aa)))]
all_old_embeddings = torch.tensor(df_old2["EMBEDDING"].to_list())


if torch.allclose(all_old_embeddings.cpu(), all_embeds.cpu(), atol=1e-6):
    print("\n🎉 SUCCESS! Embeddings are identical. The model was loaded correctly.")
else:
    print("\n🚨 FAILURE! Embeddings do not match.")



About to encode heavies
That took 0.31325197219848633 seconds
About to encode lights
That took 0.31325197219848633 seconds
About to encode lights
That took 0.21178865432739258 seconds
That took 0.21178865432739258 seconds
Now Embedding Antibodies
Now Embedding Antibodies


100%|██████████| 8/8 [00:26<00:00,  3.26s/it]




🔎 Comparing embeddings...

🎉 SUCCESS! Embeddings are identical. The model was loaded correctly.
