# Precompute Per-Residue Embeddings using ProtT5

This notebook generates protein embeddings using ProtT5 optimized for **Google Colab T4 GPU**.

**Before starting:**
1. **Enable GPU**: Runtime → Change runtime type → T4 GPU
2. Run cells in order from top to bottom

**Memory optimization for T4 (15GB VRAM):**
- Uses half precision (fp16) to reduce memory usage
- Processes sequences efficiently
- Aggressive memory cleanup between sequences

## 1. Install Dependencies

In [None]:
!pip install -q transformers torch pandas numpy tqdm sentencepiece accelerate

## 2. Import Libraries and Configuration

In [None]:
import os
import gc
import shutil
from transformers import T5Tokenizer, T5EncoderModel
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from google.colab import files

# Configuration optimized for T4 GPU
OUTPUT_DIR = "embeddings"
MODEL_NAME = "Rostlab/prot_t5_xl_half_uniref50-enc"  # Half precision model for T4
MAX_LENGTH = 1000  # Reduced for T4 memory constraints
USE_FP16 = True  # Use half precision to save memory

print("✓ Configuration loaded")

## 3. Check GPU Availability

In [None]:
# Check GPU availability
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
    print(f"✓ GPU detected: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    DEVICE = torch.device('cpu')
    print("⚠ No GPU detected! This will be very slow. Please enable GPU in Runtime settings.")
    USE_FP16 = False

print(f"\nUsing device: {DEVICE}")
print(f"Half precision (FP16): {USE_FP16}")

## 4. Define Helper Function

In [None]:
def seq_to_spaced(seq: str) -> str:
    """Convert sequence to space-separated letters for ProtT5."""
    return " ".join(list(seq.strip()))

print("✓ Helper function defined")

## 5. Upload Your CSV File

Your CSV must have columns: `id` and `sequence`

In [None]:
print("Please upload your CSV file (must have 'id' and 'sequence' columns):")
uploaded = files.upload()
INPUT_CSV = list(uploaded.keys())[0]
print(f"\n✓ Uploaded file: {INPUT_CSV}")

## 6. Load and Validate Data

In [None]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load data
df = pd.read_csv(INPUT_CSV)
print(f"✓ Loaded {len(df)} sequences")
print(f"✓ Columns: {df.columns.tolist()}")
print("\nFirst 3 rows:")
display(df.head(3))

# Validate columns
if 'id' not in df.columns or 'sequence' not in df.columns:
    raise ValueError("CSV must contain 'id' and 'sequence' columns!")

# Check sequence lengths
seq_lengths = df['sequence'].str.len()
print(f"\nSequence length statistics:")
print(f"  Min: {seq_lengths.min()}")
print(f"  Max: {seq_lengths.max()}")
print(f"  Mean: {seq_lengths.mean():.1f}")
print(f"  Median: {seq_lengths.median():.1f}")
n_long = (seq_lengths > MAX_LENGTH).sum()
if n_long > 0:
    print(f"\n⚠ Warning: {n_long} sequences will be truncated to {MAX_LENGTH} residues")

## 7. Load ProtT5 Model

This may take 2-3 minutes on first run...

In [None]:
print(f"Loading model: {MODEL_NAME}")
print("This may take 2-3 minutes on first run...\n")

try:
    # Load tokenizer
    tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
    print("✓ Tokenizer loaded")
    
    # Load model with memory optimization
    model = T5EncoderModel.from_pretrained(MODEL_NAME)
    print("✓ Model downloaded")
    
    # Move to GPU and set to eval mode
    model = model.to(DEVICE)
    model.eval()
    print("✓ Model moved to device")
    
    # Use half precision if on GPU
    if USE_FP16:
        model = model.half()
        print("✓ Using half precision (FP16)")
    
    print(f"\n✓ Model loaded successfully!")
    print(f"✓ Model dtype: {next(model.parameters()).dtype}")
    
    # Check GPU memory after loading
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        print(f"✓ GPU memory allocated: {allocated:.2f} GB")
        print(f"✓ GPU memory reserved: {reserved:.2f} GB")
        
except Exception as e:
    print(f"\n✗ Error loading model: {e}")
    print("\nTroubleshooting:")
    print("1. Make sure GPU is enabled (Runtime → Change runtime type → T4 GPU)")
    print("2. Try restarting the runtime (Runtime → Restart runtime)")
    raise

## 8. Generate Embeddings

This is the main processing step. It may take several minutes depending on the number of sequences.

In [None]:
print(f"Processing {len(df)} sequences...\n")

failed_sequences = []

for idx, row in enumerate(tqdm(df.itertuples(index=False), total=len(df), desc="Processing")):
    try:
        seq = str(getattr(row, "sequence")).strip()
        sid = str(getattr(row, "id"))
        
        # Skip empty sequences
        if not seq or len(seq) == 0:
            print(f"\n⚠ Skipping empty sequence: {sid}")
            failed_sequences.append((sid, "empty sequence"))
            continue
        
        # Truncate if sequence is too long
        if len(seq) > MAX_LENGTH:
            seq = seq[:MAX_LENGTH]
        
        # Prepare input
        spaced_seq = seq_to_spaced(seq)
        ids = tokenizer.encode_plus(spaced_seq, add_special_tokens=True, padding="longest")
        input_ids = torch.tensor(ids['input_ids']).unsqueeze(0)
        attention_mask = torch.tensor(ids['attention_mask']).unsqueeze(0)
        
        # Move to device
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            emb = outputs.last_hidden_state.squeeze(0).cpu()
            
            # Convert back to float32 for saving if using fp16
            if USE_FP16:
                emb = emb.float()
            
            emb = emb.numpy()
        
        # Save embeddings
        output_path = os.path.join(OUTPUT_DIR, f"{sid}.npy")
        np.save(output_path, emb)
        
        # Aggressive memory cleanup
        del input_ids, attention_mask, outputs, emb
        
        # Clear cache every 10 sequences
        if (idx + 1) % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            
    except Exception as e:
        print(f"\n✗ Error processing sequence {sid}: {e}")
        failed_sequences.append((sid, str(e)))
        continue

# Final cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

print(f"\n✓ Embeddings saved to {OUTPUT_DIR}/")

# Report failures
if failed_sequences:
    print(f"\n⚠ Failed to process {len(failed_sequences)} sequences:")
    for sid, error in failed_sequences[:5]:  # Show first 5
        print(f"  - {sid}: {error}")
    if len(failed_sequences) > 5:
        print(f"  ... and {len(failed_sequences) - 5} more")

## 9. Verify Results

In [None]:
embedding_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.npy')]
print(f"✓ Generated {len(embedding_files)} embedding files")
print(f"✓ Success rate: {len(embedding_files)}/{len(df)} ({len(embedding_files)/len(df)*100:.1f}%)")

if embedding_files:
    sample_file = embedding_files[0]
    sample_emb = np.load(os.path.join(OUTPUT_DIR, sample_file))
    print(f"\nSample embedding: {sample_file}")
    print(f"  Shape: {sample_emb.shape}")
    print(f"  Dtype: {sample_emb.dtype}")
    print(f"  Size: {sample_emb.nbytes / 1024:.1f} KB")
    print(f"  Embedding dimension: {sample_emb.shape[-1]}")

## 10. Download Embeddings

In [None]:
print("Creating zip file...")
zip_filename = "embeddings"
shutil.make_archive(zip_filename, 'zip', OUTPUT_DIR)
zip_size = os.path.getsize(f"{zip_filename}.zip") / (1024 * 1024)
print(f"✓ Zip file created: {zip_filename}.zip ({zip_size:.1f} MB)")
print("\nDownloading...")
files.download(f"{zip_filename}.zip")
print("\n✓ DONE! All embeddings have been generated and downloaded.")