# ProtT5 Activation Extraction

This notebook loads the `Rostlab/prot_t5_xl_half_uniref50-enc` model, fetches a random sequence from UniRef50 (filtered by date and length), and extracts its hidden activations.

In [None]:
import torch
from transformers import T5EncoderModel, T5Tokenizer
import requests
import random
import re
import numpy as np
import time

# 1. GPU Check
if torch.cuda.is_available():
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Using Apple Silicon (MPS)")
    device = torch.device("mps")
else:
    raise RuntimeError("No GPU (CUDA or MPS) found! This script requires a valid accelerator for efficient inference.")

In [None]:
# 2. Load Model and Tokenizer
model_name = 'Rostlab/prot_t5_xl_half_uniref50-enc'
print(f"Loading model: {model_name}...")

tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
model = T5EncoderModel.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to(device)
model.eval()
print("Model loaded successfully.")

In [None]:
# 3. Fetch Random UniRef50 Sequence
def get_random_uniref50_sequence():
    # Filter: Modified before 2019, length <= 512
    # UniRef search API
    url = "https://rest.uniprot.org/uniref/search"
    query = "date_modified:[* TO 2019-01-01] AND length:[1 TO 512] identity:0.5"
    
    params = {
        'query': query,
        'format': 'json',
        'size': 50  # Fetch a batch
    }
    
    print("Fetching random sequence from UniRef50...")
    response = requests.get(url, params=params)
    response.raise_for_status()
    data = response.json()
    
    if 'results' not in data or not data['results']:
        raise ValueError("No results found for the query.")
    
    # Pick a random entry
    entry = random.choice(data['results'])
    
    # Extract sequence
    try:
        sequence = entry['representativeMember']['sequence']['value']
        accession = entry['id']
        print(f"Selected UniRef50 Entry: {accession}")
        print(f"Sequence Length: {len(sequence)}")
        return sequence
    except KeyError:
        # Fallback if structure is different
        print("Could not parse sequence from entry, dumping keys:", entry.keys())
        raise

sequence = get_random_uniref50_sequence()

In [None]:
# 4. Extract Activations

# Pre-processing (Regex replace UZOB -> X, add spaces)
processed_seq = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))

# Tokenize
ids = tokenizer.batch_encode_plus([processed_seq], add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

print(f"Input IDs shape: {input_ids.shape}")

# Forward Pass
print("Running inference...")
with torch.no_grad():
    output = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

# Extract Hidden States (excluding the last one)
# output.hidden_states is a tuple of (embedding_output, layer_1, ..., layer_N)
# We want all except the very last one (which is the final encoder output)
all_hidden_states = output.hidden_states[:-1]

# Stack them: (num_layers, batch_size, seq_len, hidden_dim)
stacked_activations = torch.stack(all_hidden_states)

# Remove batch dimension (since batch_size=1)
final_activations = stacked_activations.squeeze(1)

print(f"Extracted activations shape: {final_activations.shape}")
print("(Layers, Sequence_Length, Hidden_Dim)")
print("Note: Layers includes the initial embedding layer.")

# 5. Save
save_path = "random_protein_activations.pt"
torch.save(final_activations, save_path)
print(f"Saved activations to {save_path}")