In [1]:
!pip install antiberty
!pip install pandas torch

Looking in links: /path/to/your/local/wheel/directory
Looking in links: /path/to/your/local/wheel/directory


In [1]:
import os
import pandas as pd
import torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from antiberty import AntiBERTyRunner
from tqdm.notebook import tqdm  # Classic terminal/CLI progress bar
import pickle
import numpy as np
input_dir = "../cdr3_outputs/h/"
output_dir = "../cdr3_outputs/embedded/"

os.makedirs(output_dir, exist_ok=True)


In [None]:

runner = AntiBERTyRunner()
runner.model.eval()
if torch.cuda.is_available():
    runner.model.cuda()


def embed_unique_cdr3s(cdr3_list):
    seen = {}
    #gpu mem usage ~batch_size^2 KB
    batch_size = 1500
    total_batches = (len(cdr3_list) + batch_size - 1) // batch_size

    with tqdm(total=total_batches, desc="🧬 Embedding unique CDR3s") as pbar:
        for i in range(0, len(cdr3_list), batch_size):
            batch = cdr3_list[i:i+batch_size]
            to_embed = [cdr3 for cdr3 in batch if cdr3 not in seen]
            if not to_embed:
                pbar.update(1)
                continue
            
            with torch.no_grad():
                reps = runner.embed(to_embed)
            for cdr3, rep in zip(to_embed, reps):
                pooled = rep[1:-1].mean(dim=0).cpu().numpy()
                seen[cdr3] = pooled
            pbar.update(1)
    return seen


In [None]:
# Loop 1: embed and save embeddings as .pkl
tsv_files = [f for f in os.listdir(input_dir) if f.endswith(".csv")] 
for filename in tqdm(tsv_files, total=len(tsv_files), desc="📄 Embedding + Pickle Save"):
    file_path = os.path.join(input_dir, filename)
    try:
        df = pd.read_csv(file_path)
    except Exception as e:
        print(f"Failed to read {filename}: {e}")
        continue

    if 'cdr3_aa' not in df.columns or 'stage' not in df.columns:
        print(f"[!] Skipping {filename} - missing required columns.")
        continue

    unique_cdr3s_in_file = list(set(df['cdr3_aa'].dropna()))
    embeddings = embed_unique_cdr3s(unique_cdr3s_in_file)

    # Save the embeddings as a pickle file
    pkl_name = filename.replace(".csv", "_embeddings.pkl")
    pkl_path = os.path.join(output_dir, pkl_name)
    with open(pkl_path, "wb") as f:
        pickle.dump(embeddings, f)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    import gc

📄 Embedding + Pickle Save:   0%|          | 0/1 [00:00<?, ?it/s]

🧬 Embedding unique CDR3s:   0%|          | 0/99 [00:00<?, ?it/s]

In [3]:
# Loop 2: load pickles and write full CSVs
tsv_files = [f for f in os.listdir(input_dir) if f.endswith(".csv")] 
for filename in tqdm(tsv_files, total=len(tsv_files), desc="📄 Writing CSVs"):
    file_path = os.path.join(input_dir, filename)
    try:
        print("1")
        df = pd.read_csv(file_path)
    except Exception as e:
        print(f"Failed to read {filename}: {e}")
        continue
        
    if 'cdr3_aa' not in df.columns or 'stage' not in df.columns:
        print(f"[!] Skipping {filename} - missing required columns.")
        continue
    print("1")
    pkl_name = filename.replace(".csv", "_embeddings.pkl")
    pkl_path = os.path.join(output_dir, pkl_name)
    if not os.path.exists(pkl_path):
        print(f"[!] Missing pickle for {filename}, skipping")
        continue
    print("1")
    with open(pkl_path, "rb") as f:
        embeddings = pickle.load(f)
    print("1")
    emb_rows = []
    for idx, cdr3 in enumerate(set(df['cdr3_aa'].dropna())):
        if cdr3 not in embeddings:
            continue
        emb = embeddings[cdr3]

        if not isinstance(emb, np.ndarray):
            print(f"[!] {cdr3} embedding is not ndarray, skipping")
            continue
        if emb.shape != (512,):
            print(f"[!] {cdr3} has wrong shape: {emb.shape}, skipping")
            continue

        emb_rows.append([cdr3] + emb.tolist())

        if idx % 100 == 0:
            print(f"  processed {idx} CDR3s")

    print("1")
    output_path = os.path.join(output_dir, filename.replace(".csv", "_with_antiberty.csv"))
    chunk_size = 10000
    for i in range(0, len(emb_rows), chunk_size):
        chunk = emb_rows[i:i + chunk_size]
        emb_df_chunk = pd.DataFrame(chunk, columns=['cdr3_aa'] + [f'dim{i}' for i in range(512)])
        merged_chunk = pd.merge(df[['cdr3_aa', 'stage']], emb_df_chunk, on='cdr3_aa', how='inner')

        mode = 'w' if i == 0 else 'a'
        header = (i == 0)
        merged_chunk.to_csv(output_path, mode=mode, header=header, index=False)
        print(f"✅ Wrote rows {i} to {i+chunk_size}")


    print("1")



📄 Writing CSVs:   0%|          | 0/1 [00:00<?, ?it/s]

1
1
1
1
  processed 0 CDR3s
  processed 100 CDR3s
  processed 200 CDR3s
  processed 300 CDR3s
  processed 400 CDR3s
  processed 500 CDR3s
  processed 600 CDR3s
  processed 700 CDR3s
  processed 800 CDR3s
  processed 900 CDR3s
  processed 1000 CDR3s
  processed 1100 CDR3s
  processed 1200 CDR3s
  processed 1300 CDR3s
  processed 1400 CDR3s
  processed 1500 CDR3s
  processed 1600 CDR3s
  processed 1700 CDR3s
  processed 1800 CDR3s
  processed 1900 CDR3s
  processed 2000 CDR3s
  processed 2100 CDR3s
  processed 2200 CDR3s
  processed 2300 CDR3s
  processed 2400 CDR3s
  processed 2500 CDR3s
  processed 2600 CDR3s
  processed 2700 CDR3s
  processed 2800 CDR3s
  processed 2900 CDR3s
  processed 3000 CDR3s
  processed 3100 CDR3s
  processed 3200 CDR3s
  processed 3300 CDR3s
  processed 3400 CDR3s
  processed 3500 CDR3s
  processed 3600 CDR3s
  processed 3700 CDR3s
  processed 3800 CDR3s
  processed 3900 CDR3s
  processed 4000 CDR3s
  processed 4100 CDR3s
  processed 4200 CDR3s
  processed 430

In [19]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()