In [1]:
# sentencepiece ONLY WORKS WITH PYTHON 3.12 or smaller, not 3.13!
!pip install torch transformers tqdm pandas

Collecting torch
  Downloading torch-2.7.0-cp312-cp312-win_amd64.whl.metadata (29 kB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting pandas
  Using cached pandas-2.2.3-cp312-cp312-win_amd64.whl.metadata (19 kB)
Collecting filelock (from torch)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Using cached huggingface_hub-0.31.4-py3-none-any.whl.metadata (13 kB)
Collecting numpy>=1.17 (from transfo

In [None]:
!pip install sentencepiece
import sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp312-cp312-win_amd64.whl.metadata (8.3 kB)
Downloading sentencepiece-0.2.0-cp312-cp312-win_amd64.whl (991 kB)
   ---------------------------------------- 0.0/992.0 kB ? eta -:--:--
   ---------------------------------------- 992.0/992.0 kB 5.2 MB/s eta 0:00:00
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0


In [4]:
# Import necessary modules
import torch
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5EncoderModel
from tqdm import tqdm
from pathlib import Path
import re
import os
import sys

  from .autonotebook import tqdm as notebook_tqdm


## Load Data

In [5]:
# Define paths
CSV_PATH = "../data/results_with_sequence.csv"  # Path to the CSV file
PROC_DIR = Path("../data/processed")  # Directory to save processed embeddings
PROC_DIR.mkdir(parents=True, exist_ok=True)

In [6]:
# Add the project root to sys.path
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from data.dataloader import ProteinResidueDataset

# Load the dataset using the new data loader class
dataset = ProteinResidueDataset(CSV_PATH)
loader = DataLoader(dataset, batch_size=8, shuffle=False)

In [7]:
import pandas as pd

df = pd.read_csv(CSV_PATH)
df.head(10)

Unnamed: 0,accession,length,source_database,fragments,sequence
0,A0A003,340,unreviewed,"[{'start': 15, 'end': 249}]",MSSDTHGTDLADGDVLVTGAAGFIGSHLVTELRNSGRNVVAVDRRP...
1,A0A009GZV8,323,unreviewed,"[{'start': 3, 'end': 208}]",MNVLITGGTGFIGKQIAKEILKAGSLTLDDNKPQSIDKIILFDAFA...
2,A0A009H3J1,335,unreviewed,"[{'start': 2, 'end': 260}]",MILVTGGLGFIGSHIALSLMAQGQEVVIVDNLANSTLQTLERLEFI...
3,A0A009H7U9,338,unreviewed,"[{'start': 4, 'end': 263}]",MAKILVTGGAGYIGSHTCVELLNAGHEVIVFDNLSNSSEESLKRVQ...
4,A0A009HJQ2,301,unreviewed,"[{'start': 5, 'end': 220}]",MNKNVLITGASGFIGTHLIKFLLQKNYNVIAVTRQAGKASDHPALQ...
5,A0A009HLV6,216,unreviewed,"[{'start': 17, 'end': 193}]",MDNLNNAKKDNFSRKTILVTGAAGFIGSRLIVELLREGHQVIAALR...
6,A0A009HNL3,323,unreviewed,"[{'start': 3, 'end': 206}]",MNVLITGGTGFIGKQIAKEILKTGSLTLDGKQAKPIDKIILFDAFA...
7,A0A009HPX5,338,unreviewed,"[{'start': 4, 'end': 263}]",MAKILVTGGAGYIGSHTCVELLEAGHEVIVFDNLSNSSKESLNRVQ...
8,A0A009HQP5,301,unreviewed,"[{'start': 5, 'end': 220}]",MNKNVLITGASGFIGTHLIRFLLQKNYNVIAVTRQAGRESDHPALQ...
9,A0A009I037,271,unreviewed,"[{'start': 14, 'end': 195}]",MHILFIGYGKTSQRVAKQLFEKEHQITTISRSVKTDSYATHLVQDI...


## Load Embedding models

In [13]:
# Load the ProtT5 model and tokenizer
def load_prott5():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, legacy=True)
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    if device.type == "cuda":
        print("Moving model to GPU")
        model = model.half()
    else:
        print("Moving model to CPU - not using half precision")
    model = model.to(device)
    model.eval()
    return tokenizer, model, device

## Generate Embeddings

In [11]:
# Generate embeddings and save them
def generate_embeddings(loader, tokenizer, model, device):
    for batch in tqdm(loader, desc="Generating embeddings"):
        for i in range(len(batch['accession'])):
            accession = batch['accession'][i]
            sequence = batch['residue_labels'][i]
            length = batch['length'][i]

            # Preprocess sequence
            raw_seq = re.sub(r"[UZOB]", "X", sequence)
            if len(raw_seq) > 1022:
                print(f"Skipping {accession}: too long")
                continue

            seq = "<AA2fold> " + " ".join(list(raw_seq))
            tokens = tokenizer.batch_encode_plus(
                [seq], return_tensors="pt", padding=True, add_special_tokens=True
            ).to(device)

            # Generate embeddings
            with torch.no_grad():
                output = model(**tokens).last_hidden_state.float().cpu()

            # Save embeddings and labels
            emb = output[0, 1:length + 1]  # Remove prefix token and padding
            labels = batch['residue_labels'][i]
            torch.save(emb, PROC_DIR / f"{accession}_embedding.pt")
            torch.save(labels, PROC_DIR / f"{accession}_labels.pt")

In [None]:
# Main execution
tokenizer, model, device = load_prott5()
generate_embeddings(loader, tokenizer, model, device)
print("Embedding generation complete.")