In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import pickle
import pandas as pd
import re
from tqdm import tqdm

In [None]:
BATCH_SIZE = 8
DATAFRAME_PATH = "../data/custom_fragments2/datasets/random_equal_distribution/random_equal_val.csv"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

In [None]:
#df = pd.read_csv(DATAFRAME_PATH)
df = fragment_df
df.head()

In [None]:
len(df)

In [None]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model.to(DEVICE)
model.eval()

In [None]:
# Filter sequences by length and create a proper copy to avoid SettingWithCopyWarning
length_df = df[df["fragment_length"] <= 1024].copy()
length_df.head()

In [None]:
length_df["sequence_split"] = length_df["sequence"].apply(
    lambda seq: " ".join(list(re.sub(r"[UZOB]", "X", seq)))
)
length_df.head()

In [None]:
tokenized_df = pd.DataFrame(data=length_df[["source_accession", "is_fragment", "sequence_split"]])
tokenized_df.head()

In [None]:
tokenized_input = []

In [None]:
for i in tqdm(range(0, len(tokenized_df), BATCH_SIZE)):
    batch = tokenized_df.loc[i:i+BATCH_SIZE - 1]

    encoded = tokenizer.batch_encode_plus(
        batch["sequence_split"],
        add_special_tokens=True,
        padding="longest",
        return_tensors="pt"
    )

    output = []

    for acc_id, is_fragment, input_ids, attention_mask in zip(batch["source_accession"], batch["is_fragment"], encoded["input_ids"], encoded["attention_mask"]):
        output.append((acc_id, is_fragment, input_ids, attention_mask))

    tokenized_input.append(output)

In [None]:
tokenized_input[0]

In [None]:
# Initialize as a list (not dict) since we'll use append
all_embeddings = []

In [None]:
# Create embeddings directory if it doesn't exist
import os
os.makedirs("./true_fragment_embeddings", exist_ok=True)
print("Embeddings directory ready")

In [None]:
# Test cell - checking embedding generation for a single sequence
with torch.no_grad():
    for (batch_num, entries) in tqdm(enumerate(tokenized_input), total=len(tokenized_input)):
        for acc_id, is_fragment, input_ids, attention_mask in entries:
            # Add batch dimension [seq_len] -> [1, seq_len]
            embedding = model(
                input_ids=input_ids.unsqueeze(0).to(DEVICE),
                attention_mask=attention_mask.unsqueeze(0).to(DEVICE)
            ).last_hidden_state
            print(f"Accession ID: {acc_id}")
            print(f"Fragment: {is_fragment}")
            print(f"Embedding shape: {embedding.shape}")  # Should be [1, seq_len, 1024]
            print(f"Attention mask shape: {attention_mask.shape}")  # Should be [seq_len]
            break
        break

In [None]:
# Main embedding generation with mean pooling
with torch.no_grad():
    for (batch_num, entries) in tqdm(enumerate(tokenized_input), total=len(tokenized_input)):
        # Collect batch data
        acc_ids = []
        fragment_status = []
        batch_input_ids = []
        batch_attention_masks = []
        
        for (acc_id, is_fragment, input_ids, attention_mask) in entries:
            acc_ids.append(acc_id)
            fragment_status.append(is_fragment)
            batch_input_ids.append(input_ids)
            batch_attention_masks.append(attention_mask)
        
        # Stack into proper batches [batch_size, seq_len]
        batch_input_ids = torch.stack(batch_input_ids).to(DEVICE)
        batch_attention_masks = torch.stack(batch_attention_masks).to(DEVICE)
        
        # Get embeddings [batch_size, seq_len, embedding_dim]
        embeddings = model(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_masks
        ).last_hidden_state
        
        # Mean pooling: average over sequence length, weighted by attention mask
        # Expand attention mask to match embedding dimensions
        mask = batch_attention_masks.unsqueeze(-1).expand(embeddings.size()).float()
        
        # Sum embeddings where attention mask is 1
        summed = torch.sum(embeddings * mask, dim=1)
        # Count how many positions were summed (sequence lengths)
        lengths = torch.sum(mask, dim=1)
        # Average: [batch_size, embedding_dim]
        prot_embeds = summed / lengths
        
        # Store embeddings with their accession IDs
        for i, (acc_id, is_fragment) in enumerate(zip(acc_ids, fragment_status)):
            all_embeddings.append((acc_id, is_fragment, prot_embeds[i].cpu()))
        
        # Progress logging
        if batch_num % 100 == 0:
            print(f"Processed {batch_num * BATCH_SIZE} sequences...")
        
        # Periodic saving
        if batch_num % 1000 == 0 and batch_num > 0:
            with open(f"./true_fragment_embeddings/embeddings_{batch_num}.pkl", "wb") as p:
                pickle.dump(all_embeddings, p)
            print(f"Saved checkpoint at batch {batch_num}")

# Save final embeddings
with open(f"./true_fragment_embeddings/final_embeddings.pkl", "wb") as p:
    pickle.dump(all_embeddings, p)
    
print(f"\nComplete! Generated {len(all_embeddings)} embeddings.")

In [None]:
# Verification: Load and inspect embeddings
with open("./embeddings/final_embeddings.pkl", "rb") as p:
    loaded_embeddings = pickle.load(p)

print(f"Total embeddings: {len(loaded_embeddings)}")
print(f"\nFirst embedding:")
print(f"  Accession ID: {loaded_embeddings[2][0]}")
print(f"  Embedding shape: {loaded_embeddings[2][2].shape}")
print(f"  Embedding type: {type(loaded_embeddings[2][2])}")
print(f"  Is Fragment: {loaded_embeddings[2][1]}")
print(f"\nExpected: torch.Size([1024]) for ProtT5-XL")

In [None]:
df[df["source_accession"] == "A3Q8S8"]

In [None]:
bool(tokenized_df.loc[0]["is_fragment"]) == True

In [None]:
tokenized_df.loc[60607]["is_fragment"]

In [None]:
# create embedding df
csv_test = "\n".join(f"{acc_id},{",".join(str(v.item()) for v in embedding)}" for acc_id, embedding in loaded_embeddings)

with open("test.csv", "w") as f:
    f.write(csv_test)

In [None]:
len(csv_test)

In [None]:
header = "acc_id," + ",".join(str(i) for i in range(1024))
header

In [None]:
len(tokenized_df)

In [None]:
# add fragment label to list and store are pickle
annotated_embeddings = []

for i, (acc_id, embedding) in enumerate(loaded_embeddings):
    annotation = bool(tokenized_df.loc[i]["is_fragment"])
    annotated_embeddings.append((acc_id, embedding, annotation))

## generation of real fragments

In [None]:
fragment_df = pd.read_csv("./uniprotkb_reviewed_true_AND_fragment_tr_2025_12_01.tsv", sep="\t")
fragment_df

In [None]:
fragment_df["fragment_length"] = [len(seq) for seq in fragment_df["sequence"]]
fragment_df

In [None]:
fragment_df = fragment_df.rename(columns={"entry": "source_accession"})
fragment_df

In [None]:
fragment_df["is_fragment"] = True
fragment_df