In [1]:
# sentencepiece ONLY WORKS WITH PYTHON 3.12 or smaller, not 3.13!
#!pip install torch transformers tqdm pandas

In [2]:
#!pip install sentencepiece
import sentencepiece

In [1]:
# Import necessary modules
import torch
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5EncoderModel
from tqdm import tqdm
from pathlib import Path
import re
import os
import sys

  from .autonotebook import tqdm as notebook_tqdm


## Load Data

In [2]:
# Define paths
CSV_PATH = "../data/results_with_sequence.csv"  # Path to the CSV file
PROC_DIR = Path("../data/processed")  # Directory to save processed embeddings
PROC_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
from torch.nn.utils.rnn import pad_sequence

# Custom collate function for padding
def collate_fn(batch):
    accessions = [item['accession'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    lengths = torch.tensor([item['length'] for item in batch], dtype=torch.long)
    
    # Pad residue labels to the same length
    residue_labels = pad_sequence([item['residue_labels'] for item in batch], batch_first=True, padding_value=0)
    
    return {
        'accession': accessions,
        'label': labels,
        'residue_labels': residue_labels,
        'length': lengths
    }

In [4]:
# Add the project root to sys.path
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from data.dataloader import ProteinResidueDataset

# Load the dataset using the new data class
dataset = ProteinResidueDataset(CSV_PATH)
loader = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [5]:
import pandas as pd

df = pd.read_csv(CSV_PATH)
df.head(10)

Unnamed: 0,accession,length,source_database,fragments,sequence
0,A0A003,340,unreviewed,"[{'start': 15, 'end': 249}]",MSSDTHGTDLADGDVLVTGAAGFIGSHLVTELRNSGRNVVAVDRRP...
1,A0A009GZV8,323,unreviewed,"[{'start': 3, 'end': 208}]",MNVLITGGTGFIGKQIAKEILKAGSLTLDDNKPQSIDKIILFDAFA...
2,A0A009H3J1,335,unreviewed,"[{'start': 2, 'end': 260}]",MILVTGGLGFIGSHIALSLMAQGQEVVIVDNLANSTLQTLERLEFI...
3,A0A009H7U9,338,unreviewed,"[{'start': 4, 'end': 263}]",MAKILVTGGAGYIGSHTCVELLNAGHEVIVFDNLSNSSEESLKRVQ...
4,A0A009HJQ2,301,unreviewed,"[{'start': 5, 'end': 220}]",MNKNVLITGASGFIGTHLIKFLLQKNYNVIAVTRQAGKASDHPALQ...
5,A0A009HLV6,216,unreviewed,"[{'start': 17, 'end': 193}]",MDNLNNAKKDNFSRKTILVTGAAGFIGSRLIVELLREGHQVIAALR...
6,A0A009HNL3,323,unreviewed,"[{'start': 3, 'end': 206}]",MNVLITGGTGFIGKQIAKEILKTGSLTLDGKQAKPIDKIILFDAFA...
7,A0A009HPX5,338,unreviewed,"[{'start': 4, 'end': 263}]",MAKILVTGGAGYIGSHTCVELLEAGHEVIVFDNLSNSSKESLNRVQ...
8,A0A009HQP5,301,unreviewed,"[{'start': 5, 'end': 220}]",MNKNVLITGASGFIGTHLIRFLLQKNYNVIAVTRQAGRESDHPALQ...
9,A0A009I037,271,unreviewed,"[{'start': 14, 'end': 195}]",MHILFIGYGKTSQRVAKQLFEKEHQITTISRSVKTDSYATHLVQDI...


In [6]:
df["sequence"].iloc[0] 

'MSSDTHGTDLADGDVLVTGAAGFIGSHLVTELRNSGRNVVAVDRRPLPDDLESTSPPFTGSLREIRGDLNSLNLVDCLKNISTVFHLAALPGVRPSWTQFPEYLRCNVLATQRLMEACVQAGVERVVVASSSSVYGGADGVMSEDDLPRPLSPYGVTKLAAERLALAFAARGDAELSVGALRFFTVYGPGQRPDMFISRLIRATLRGEPVEIYGDGTQLRDFTHVSDVVRALMLTASVRDRGSAVLNIGTGSAVSVNEVVSMTAELTGLRPCTAYGSARIGDVRSTTADVRQAQSVLGFTARTGLREGLATQIEWTRRSLSGAEQDTVPVGGSSVSVPRL'

## Load Embedding models

In [9]:
# Load the ProtT5 model and tokenizer
def load_prott5():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, legacy=True)
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    if device.type == "cuda":
        print("Moving model to GPU")
        model = model.half()
    else:
        print("Moving model to CPU - not using half precision")
    model = model.to(device)
    model.eval()
    return tokenizer, model, device

## Generate Embeddings

In [None]:
def generate_embeddings(loader, tokenizer, model, device):
    for batch in tqdm(loader, desc="Generating embeddings"):
        for i in range(len(batch['accession'])):
            accession = batch['accession'][i]
            sequence = batch['residue_labels'][i]
            length = batch['length'][i].item()  # Convert tensor to int

            # Convert the sequence tensor to a string
            raw_seq = "".join(map(str, sequence[:length].tolist()))
            raw_seq = re.sub(r"[UZOB]", "X", raw_seq)  # Replace invalid characters

            if len(raw_seq) > 1022:
                print(f"Skipping {accession}: too long")
                continue

            seq = "<AA2fold> " + " ".join(list(raw_seq))
            tokens = tokenizer.batch_encode_plus(
                [seq], return_tensors="pt", padding=True, add_special_tokens=True
            ).to(device)

            # Generate embeddings
            with torch.no_grad():
                output = model(**tokens).last_hidden_state.float().cpu()

            # Save embeddings and labels
            emb = output[0, 1:length + 1]  # Remove prefix token and padding
            labels = sequence[:length]  # Use only the original length
            torch.save(emb, PROC_DIR / f"{accession}_embedding.pt")
            torch.save(labels, PROC_DIR / f"{accession}_labels.pt")

In [10]:
# Main execution
tokenizer, model, device = load_prott5()
generate_embeddings(loader, tokenizer, model, device)
print("Embedding generation complete.")

Moving model to CPU - not using half precision


Generating embeddings:   2%|▏         | 31/1250 [1:45:47<31:42:57, 93.66s/it]   

Skipping A0A015KJ60: too long


Generating embeddings:   3%|▎         | 41/1250 [1:59:29<27:24:09, 81.60s/it]

Skipping A0A016VF32: too long


Generating embeddings:  12%|█▏        | 153/1250 [5:32:58<14:50:22, 48.70s/it]  

Skipping A0A060SI51: too long


Generating embeddings:  18%|█▊        | 231/1250 [7:13:01<22:28:35, 79.41s/it]

Skipping A0A067P674: too long


Generating embeddings:  19%|█▉        | 237/1250 [7:21:05<23:10:25, 82.35s/it]

Skipping A0A067TQY0: too long


Generating embeddings:  31%|███▏      | 392/1250 [10:52:09<20:00:21, 83.94s/it]

Skipping A0A081ACZ4: too long
Skipping A0A081ACZ5: too long


Generating embeddings:  32%|███▏      | 397/1250 [10:58:57<19:56:08, 84.14s/it]

Skipping A0A081CJY1: too long


Generating embeddings:  32%|███▏      | 404/1250 [11:10:22<23:23:49, 99.56s/it] 


KeyboardInterrupt: 

## Load Emeddings

In [7]:
import torch

def load_embedding_and_labels(accession, proc_dir=PROC_DIR):
    emb_path = proc_dir / f"{accession}_embedding.pt"
    labels_path = proc_dir / f"{accession}_labels.pt"
    embedding = torch.load(emb_path)
    labels = torch.load(labels_path)
    return embedding, labels

# Example usage for the first accession in df
example_accession = df["accession"].iloc[0]
embedding, labels = load_embedding_and_labels(example_accession)
print(f"Embedding shape: {embedding.shape}")
print(f"Labels shape: {labels.shape}")

Embedding shape: torch.Size([1, 1024])
Labels shape: torch.Size([340])
