In [12]:
import pandas as pd

df = pd.read_csv("../data/domains-and-seqs-merged.csv")

In [13]:
import pandas as pd

MIN_DOMAINS_PER_HOMOLOGY = 10
MAX_DOMAINS_PER_HOMOLOGY = 200

HOMOLOGY_GROUPS = 100
SAMPLES_PER_GROUP = 10

# Define hierarchy columns — this full path defines a unique homology group
hierarchy = ['class', 'architecture', 'topology', 'homology']

# Step 1 & 2: Filter groups where the number of domain_id entries is at least 100
filtered_df = df.groupby(hierarchy).filter(lambda x: MIN_DOMAINS_PER_HOMOLOGY <= len(x) <= MAX_DOMAINS_PER_HOMOLOGY)

# Step 3: Get unique full-path homology groups
unique_homology_paths = filtered_df[hierarchy].drop_duplicates()

# Randomly sample 100 unique homology groups (based on full path)
sampled_paths = unique_homology_paths.sample(n=HOMOLOGY_GROUPS, random_state=42)

# Step 4: Retain only rows that belong to the sampled groups
sampled_df = pd.merge(sampled_paths, filtered_df, on=hierarchy)

# Within each sampled group, randomly choose 100 domain_id entries
subset = sampled_df.groupby(hierarchy).apply(lambda x: x.sample(n=SAMPLES_PER_GROUP, random_state=42)).reset_index(drop=True)

# final_df is your result
subset.to_csv("../data/subset.csv", index=False)


  subset = sampled_df.groupby(hierarchy).apply(lambda x: x.sample(n=SAMPLES_PER_GROUP, random_state=42)).reset_index(drop=True)


In [14]:
import torch
from transformers import T5Tokenizer, T5EncoderModel

# Load ProtT5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model = model.eval()

# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [18]:
# Dictionary to store embeddings
all_embeddings = {}

# Process each sequence
for index, row in subset.head(20).iterrows():
    sequence = row["sequence"]
    sequence = sequence.replace('U', 'X').replace('Z', 'X').replace('O', 'X')
    ids = tokenizer.batch_encode_plus([sequence], add_special_tokens=True, padding=True, return_tensors="pt")
    input_ids = ids['input_ids'].to(device)
    attention_mask = ids['attention_mask'].to(device)

    with torch.no_grad():
        embedding = model(input_ids=input_ids, attention_mask=attention_mask)

    # Average over tokens to get a single vector per sequence
    sequence_embedding = embedding.last_hidden_state.mean(dim=1).squeeze().cpu()

    # Store in dictionary
    all_embeddings[index] = sequence_embedding

    print(f"Processed: {index}")

# Save all embeddings to one file
torch.save(all_embeddings, "../data/all_embeddings.pt")
print("All embeddings saved to all_embeddings.pt")


Processed: 0
Processed: 1
Processed: 2
Processed: 3
Processed: 4
Processed: 5
Processed: 6
Processed: 7
Processed: 8
Processed: 9
Processed: 10
Processed: 11
Processed: 12
Processed: 13
Processed: 14
Processed: 15
Processed: 16
Processed: 17
Processed: 18
Processed: 19
All embeddings saved to all_embeddings.pt
