In [None]:
# Run using docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3
!pip install transformers

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import csv
import json
from typing import Any, Dict, Optional, Tuple, List, Set
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from tqdm import tqdm

In [None]:
# Constants
CANONICAL_AA: Set[str] = set("ACDEFGHIKLMNPQRSTVWY")
DATA_FILE = "ps4_data.csv"
OUTPUT_FILE = "esm2_embeddings.csv"

In [None]:
# Load ESM2 model and tokenizer
model_name = "nvidia/esm2_t48_15B_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForMaskedLM.from_pretrained(
    model_name,
    trust_remote_code=True
).to(device)
model.eval()
model.config.output_hidden_states = True


In [34]:
class ProteinDataset(Dataset):
    def __init__(self, data: List[Dict], input_vocab: Dict[str, int], target_vocab: Dict[str, int]):
        self.data = data
        self.input_vocab = input_vocab
        self.target_vocab = target_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        input_seq = [self.input_vocab[c] for c in row['input']]
        target_seq = [self.target_vocab[c] for c in row['dssp8']]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)

In [None]:
# Read Data as a list of dicts
with open(DATA_FILE, 'r') as f:
    reader = csv.DictReader(f)
    data = list(reader)

# Get all unique characters for Amino Acids and DSSP8
target_chars = set()
for row in data:
    # Normalize input sequence: uppercase + map non-canonical to 'X'
    raw_seq = row["input"].upper()
    normalized_seq = [c if c in CANONICAL_AA else "X" for c in raw_seq]
    row["input"] = "".join(normalized_seq)
    target_chars.update(row["dssp8"].upper())

# Map the letters to ESM vocabulary
pad_token_id = tokenizer.pad_token_id
unk_token_id = tokenizer.unk_token_id
if pad_token_id is None or tokenizer.cls_token_id is None or tokenizer.eos_token_id is None:
    raise ValueError("Tokenizer is missing required special tokens.")
input_chars = sorted(CANONICAL_AA.union({"X"}))
input_vocab = {"<PAD>": pad_token_id}
for char in input_chars:
    token_id = tokenizer.convert_tokens_to_ids(char)
    if token_id == tokenizer.unk_token_id or token_id is None:
        token_id = unk_token_id
    input_vocab[char] = token_id

target_vocab = {"<PAD>": 0}
for idx, char in enumerate(sorted(target_chars), start=1):
    target_vocab[char] = idx

# Get Dataset
dataset = ProteinDataset(data, input_vocab, target_vocab)

BATCH_SIZE = 1
cls_idx = tokenizer.cls_token_id
eos_idx = tokenizer.eos_token_id
input_pad_idx = input_vocab["<PAD>"]
target_pad_idx = target_vocab["<PAD>"]

def collate_fn(batch):
    input_seqs, target_seqs = zip(*batch)
    lengths = torch.tensor([seq.size(0) for seq in input_seqs], dtype=torch.long)
    seqs_with_specials = []
    for seq in input_seqs:
        seqs_with_specials.append(
            torch.cat([
                torch.tensor([cls_idx], dtype=torch.long),
                seq,
                torch.tensor([eos_idx], dtype=torch.long),
            ])
        )
    padded_inputs = pad_sequence(seqs_with_specials, batch_first=True, padding_value=input_pad_idx)
    padded_targets = pad_sequence(target_seqs, batch_first=True, padding_value=target_pad_idx)
    return padded_inputs, padded_targets, lengths

# DataLoader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [None]:
import os

SHARD_DIR = "../data/esm2_token_embeddings_sharded"
os.makedirs(SHARD_DIR, exist_ok=True)

model_device = next(model.parameters()).device
seq_ptr = 0

# Sharding config
MAX_SHARD_BYTES = 4 * 1024**3  # 4 GiB
SAFETY_MARGIN = 0.9            # don't go all the way to 4 GiB
MAX_SHARD_BYTES = int(MAX_SHARD_BYTES * SAFETY_MARGIN)

shard_idx = 0
shard_chain_ids: List[str] = []
shard_lengths: List[int] = []
shard_embeddings: List[torch.Tensor] = []
current_shard_bytes = 0


def save_shard(shard_idx: int,
               shard_chain_ids: List[str],
               shard_lengths: List[int],
               shard_embeddings: List[torch.Tensor],
               approx_bytes: int) -> int:
    """Save current shard to disk and return next shard_idx."""
    if not shard_chain_ids:
        return shard_idx  # nothing to save

    out_path = os.path.join(
        SHARD_DIR, f"esm2_token_embeddings_shard_{shard_idx:03d}.pt"
    )
    obj = {
        "chain_ids": shard_chain_ids,                         # list[str], len M
        "lengths": torch.tensor(shard_lengths, dtype=torch.int32),  # (M,)
        "embeddings": shard_embeddings,                       # list of (L_i, D) float32 tensors
    }
    torch.save(obj, out_path)
    print(
        f"Saved shard {shard_idx} with {len(shard_chain_ids)} sequences "
        f"to {out_path}, approx {approx_bytes / 1024**2:.1f} MiB"
    )

    # Clear lists in-place
    shard_chain_ids.clear()
    shard_lengths.clear()
    shard_embeddings.clear()

    return shard_idx + 1


with torch.no_grad():
    for batch_inputs, _, batch_lengths in tqdm(dataloader, desc="Generating embeddings"):
        attention_mask = (batch_inputs != input_pad_idx).long().to(model_device)
        batch_inputs = batch_inputs.to(model_device)

        outputs = model(
            input_ids=batch_inputs,
            attention_mask=attention_mask,
        )
        # Last layer: (batch_size, seq_len_with_specials, hidden_dim)
        hidden = outputs.hidden_states[-1]

        batch_size = hidden.size(0)

        for b in range(batch_size):
            seq_len = batch_lengths[b].item()

            # (L, D) float32, still on GPU
            seq_hidden = hidden[b, 1:seq_len + 1, :]      # strip CLS/EOS
            seq_hidden_cpu = seq_hidden.detach().cpu().contiguous()  # (L, D), float32

            chain_id = data[seq_ptr]["chain_id"]

            # Approximate bytes this entry will add
            seq_bytes = seq_hidden_cpu.numel() * 4  # float32 = 4 bytes
            meta_bytes = len(chain_id) + 16         # tiny overhead fudge
            entry_bytes = seq_bytes + meta_bytes

            # If adding this would exceed shard budget, flush current shard first
            if current_shard_bytes + entry_bytes > MAX_SHARD_BYTES and shard_chain_ids:
                shard_idx = save_shard(
                    shard_idx,
                    shard_chain_ids,
                    shard_lengths,
                    shard_embeddings,
                    current_shard_bytes,
                )
                current_shard_bytes = 0

            # Add this sequence to current shard
            shard_chain_ids.append(chain_id)
            shard_lengths.append(seq_len)
            shard_embeddings.append(seq_hidden_cpu)
            current_shard_bytes += entry_bytes

            seq_ptr += 1

# Save any remaining sequences in the final shard
if shard_chain_ids:
    shard_idx = save_shard(
        shard_idx,
        shard_chain_ids,
        shard_lengths,
        shard_embeddings,
        current_shard_bytes,
    )

print(f"Saved embeddings for {seq_ptr} proteins into {shard_idx} shard file(s) in {SHARD_DIR}.")