In [None]:
# Run using docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3
!pip install transformers

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import csv
import json
from typing import Any, Dict, Optional, Tuple, List, Set
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import os

In [None]:
# Constants
DATA_FILE = "test.tsv"
OUTPUT_FILE = "esm2_embeddings_test.pt"
MODEL_NAME = "nvidia/esm2_t48_15B_UR50D"

In [None]:
def parse_test_tsv(file_path):
    sequences = defaultdict(list)
    
    # 3-letter to 1-letter map
    aa_map = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
        'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
        'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
        'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
    }

    print(f"Reading {file_path}...")
    with open(file_path, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        header = next(reader) # Skip header "id"
        
        for row in reader:
            if not row: continue
            # Format: 3JRN_LYS_8
            parts = row[0].split('_')
            if len(parts) != 3:
                continue
                
            pdb_id = parts[0]
            aa_3 = parts[1]
            pos = int(parts[2])
            
            aa_1 = aa_map.get(aa_3, 'X')
            sequences[pdb_id].append((pos, aa_1))
            
    # Reconstruct sequences
    final_data = []
    print(f"Reconstructing sequences for {len(sequences)} proteins...")
    for pdb_id, residue_list in sequences.items():
        # Sort by position
        residue_list.sort(key=lambda x: x[0])
        
        # Concatenate
        seq = "".join([x[1] for x in residue_list])
        final_data.append({"chain_id": pdb_id, "input": seq})
        
    return final_data

In [None]:
class ProteinTestDataset(Dataset):
    def __init__(self, data: List[Dict], input_vocab: Dict[str, int]):
        self.data = data
        self.input_vocab = input_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        input_seq = [self.input_vocab.get(c, self.input_vocab.get("<UNK>")) for c in row['input']]
        return torch.tensor(input_seq, dtype=torch.long), row['chain_id']

In [None]:
# Load ESM2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = AutoModelForMaskedLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
).to(device)
model.eval()
model.config.output_hidden_states = True

In [None]:
# Prepare Vocabulary
CANONICAL_AA = set("ACDEFGHIKLMNPQRSTVWY")
pad_token_id = tokenizer.pad_token_id
unk_token_id = tokenizer.unk_token_id
cls_token_id = tokenizer.cls_token_id
eos_token_id = tokenizer.eos_token_id

if pad_token_id is None:
    raise ValueError("Tokenizer is missing required special tokens.")

input_chars = sorted(CANONICAL_AA.union({"X"}))
input_vocab = {"<PAD>": pad_token_id, "<UNK>": unk_token_id}
for char in input_chars:
    token_id = tokenizer.convert_tokens_to_ids(char)
    if token_id == unk_token_id or token_id is None:
        token_id = unk_token_id
    input_vocab[char] = token_id

# Load Data
data = parse_test_tsv(DATA_FILE)
dataset = ProteinTestDataset(data, input_vocab)

BATCH_SIZE = 1 # Keep small for safety on big model
input_pad_idx = input_vocab["<PAD>"]

def collate_fn(batch):
    input_seqs, chain_ids = zip(*batch)
    lengths = torch.tensor([seq.size(0) for seq in input_seqs], dtype=torch.long)
    seqs_with_specials = []
    for seq in input_seqs:
        seqs_with_specials.append(
            torch.cat([
                torch.tensor([cls_token_id], dtype=torch.long),
                seq,
                torch.tensor([eos_token_id], dtype=torch.long),
            ])
        )
    padded_inputs = pad_sequence(seqs_with_specials, batch_first=True, padding_value=input_pad_idx)
    return padded_inputs, lengths, chain_ids

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [None]:
all_chain_ids = []
all_lengths = []
all_embeddings = []

print("Starting embedding generation...")
with torch.no_grad():
    for batch_inputs, batch_lengths, batch_chain_ids in tqdm(dataloader, desc="Generating embeddings"):
        attention_mask = (batch_inputs != input_pad_idx).long().to(device)
        batch_inputs = batch_inputs.to(device)

        outputs = model(
            input_ids=batch_inputs,
            attention_mask=attention_mask,
        )
        # Last layer: (batch_size, seq_len_with_specials, hidden_dim)
        hidden = outputs.hidden_states[-1]
        
        batch_size = hidden.size(0)

        for b in range(batch_size):
            seq_len = batch_lengths[b].item()
            
            # (L, D) float32, still on GPU
            # Strip CLS (index 0) and EOS (index seq_len + 1)
            seq_hidden = hidden[b, 1:seq_len + 1, :]
            seq_hidden_cpu = seq_hidden.detach().cpu().contiguous() # (L, D), float32

            all_chain_ids.append(batch_chain_ids[b])
            all_lengths.append(seq_len)
            all_embeddings.append(seq_hidden_cpu)

# Save to file
print(f"Saving {len(all_chain_ids)} embeddings to {OUTPUT_FILE}...")
obj = {
    "chain_ids": all_chain_ids,
    "lengths": torch.tensor(all_lengths, dtype=torch.int32),
    "embeddings": all_embeddings,
}
torch.save(obj, OUTPUT_FILE)
print("Done!")