# Step 2: Generate ESM-2 Embeddings (Google Colab - GPU Required)

**Purpose**: Generate ESM-2 embeddings for UniProt terpene synthase sequences

**Runtime**: ~30-60 minutes for 5000 sequences on Colab T4 GPU

**Instructions**:
1. Upload `uniprot_tps_sequences.fasta` from Step 1
2. Run all cells
3. Download `uniprot_tps_embeddings.npy` for local prediction

---


## Setup: Install Dependencies


In [None]:
# Install required packages
!pip install -q fair-esm torch transformers biopython tqdm


## Verify GPU Access


In [None]:
import torch

if torch.cuda.is_available():
    print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("✗ WARNING: No GPU detected. This will be very slow!")
    print("  Go to Runtime > Change runtime type > GPU")


## Upload FASTA File

Click the file icon (📁) on the left sidebar and upload `uniprot_tps_sequences.fasta`


In [None]:
# Check if FASTA file is uploaded
import os

fasta_file = 'uniprot_tps_sequences.fasta'

if os.path.exists(fasta_file):
    print(f"✓ Found FASTA file: {fasta_file}")
    print(f"  File size: {os.path.getsize(fasta_file) / 1e6:.2f} MB")
else:
    print(f"✗ FASTA file not found: {fasta_file}")
    print("  Please upload the file using the file browser")


## Load ESM-2 Model


In [None]:
import torch
from transformers import AutoTokenizer, EsmModel

print("Loading ESM-2 model (650M parameters)...")
print("This may take 2-3 minutes...")

model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

print(f"✓ Model loaded on {device}")


## Parse FASTA File


In [None]:
from Bio import SeqIO
from typing import List, Tuple

def parse_fasta(fasta_file: str) -> List[Tuple[str, str]]:
    """Parse FASTA file and return list of (header, sequence) tuples."""
    sequences = []
    for record in SeqIO.parse(fasta_file, 'fasta'):
        header = record.id
        sequence = str(record.seq)
        sequences.append((header, sequence))
    return sequences

print(f"Parsing {fasta_file}...")
sequences = parse_fasta(fasta_file)
print(f"✓ Loaded {len(sequences)} sequences")
print(f"  Average length: {sum(len(seq) for _, seq in sequences) / len(sequences):.0f} amino acids")


## Generate Embeddings


In [None]:
import numpy as np
from tqdm.auto import tqdm

def generate_embeddings(sequences: List[Tuple[str, str]], 
                       model, 
                       tokenizer, 
                       device,
                       batch_size: int = 1) -> np.ndarray:
    """
    Generate ESM-2 embeddings for sequences.
    
    Returns:
        embeddings: numpy array of shape (n_sequences, 1280)
    """
    embeddings = []
    
    print(f"Generating embeddings for {len(sequences)} sequences...")
    print(f"Using batch size: {batch_size}")
    
    with torch.no_grad():
        for header, sequence in tqdm(sequences, desc="Processing sequences"):
            try:
                # Tokenize
                inputs = tokenizer(
                    sequence,
                    return_tensors="pt",
                    padding=False,
                    truncation=True,
                    max_length=1024
                )
                
                # Move to device
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                # Generate embedding
                outputs = model(**inputs)
                
                # Extract mean pooled embedding
                # Shape: (1, seq_len, 1280) -> (1280,)
                embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
                
                embeddings.append(embedding)
                
            except Exception as e:
                print(f"\nError processing {header}: {e}")
                print(f"Sequence length: {len(sequence)}")
                # Use zero embedding as fallback
                embeddings.append(np.zeros(1280))
    
    # Convert to numpy array
    embeddings = np.array(embeddings)
    print(f"✓ Generated embeddings shape: {embeddings.shape}")
    
    return embeddings

# Generate embeddings
embeddings = generate_embeddings(sequences, model, tokenizer, device, batch_size=1)


In [None]:
output_file = 'uniprot_tps_embeddings.npy'

print(f"Saving embeddings to {output_file}...")
np.save(output_file, embeddings)

print(f"✓ Embeddings saved!")
print(f"  File: {output_file}")
print(f"  Size: {os.path.getsize(output_file) / 1e6:.2f} MB")
print(f"  Shape: {embeddings.shape}")

print("\n" + "="*60)
print("EMBEDDING GENERATION COMPLETE!")
print("="*60)
print("Next steps:")
print("1. Download 'uniprot_tps_embeddings.npy' (right-click > Download)")
print("2. Move it to your local 'data/' directory")
print("3. Run step3_predict_germacrene.py locally")
print("="*60)


## (Optional) Verify Embeddings


In [None]:
# Quick sanity check
print("Embedding statistics:")
print(f"  Mean: {embeddings.mean():.4f}")
print(f"  Std: {embeddings.std():.4f}")
print(f"  Min: {embeddings.min():.4f}")
print(f"  Max: {embeddings.max():.4f}")
print(f"  Contains NaN: {np.isnan(embeddings).any()}")
print(f"  Contains Inf: {np.isinf(embeddings).any()}")

if np.isnan(embeddings).any() or np.isinf(embeddings).any():
    print("\n⚠️  WARNING: Embeddings contain NaN or Inf values!")
else:
    print("\n✓ Embeddings look good!")
