In [None]:
"""
ESM-2 650M Protein Embedding Tool – Two-column Fast Format
Designed for Google Colab
"""

#@title 1. Install required packages
!pip install fair-esm biopython -q
print("Packages installed successfully.")

#@title 2. Import libraries
import torch
import esm
import numpy as np
from google.colab import files
import os
import re
from typing import List, Tuple
import gc
from tqdm import tqdm

print("Libraries loaded successfully.")

#@title 3. Load ESM-2 650M model
print("Loading ESM-2 650M model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model = model.to(device)
model.eval()
batch_converter = alphabet.get_batch_converter()

print("ESM-2 650M model loaded.")

#@title 4. Define helper functions
def parse_protein_file(file_path: str) -> List[Tuple[str, str, str]]:
    """
    Parse protein file.
    Supported format:
        - Four-column format: index<TAB>UniProt<TAB>Name<TAB>Sequence
    Returns:
        List of (index, protein_id, sequence)
    """
    proteins = []
    with open(file_path, 'r') as f:
        lines = f.readlines()

    # Skip header if detected
    start_idx = 0
    if lines and lines[0].strip():
        first_parts = lines[0].strip().split('\t')
        if len(first_parts) >= 2:
            try:
                int(first_parts[0])  # header detection
            except ValueError:
                print(f"Skipping header: {lines[0].strip()[:100]}")
                start_idx = 1

    # Parse entries
    for i, line in enumerate(lines[start_idx:], 1):
        line = line.strip()
        if line:
            parts = line.split('\t')

            if len(parts) == 4:  # expected four columns
                idx = parts[0]
                protein_id = parts[1]
                sequence = parts[3]
            else:
                print(f"Skipping malformed line: {line[:50]}")
                continue

            # Clean sequence (keep standard amino acids only)
            sequence = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper())

            if sequence:
                proteins.append((idx, protein_id, sequence))

    return proteins

def get_esm_embeddings(sequences: List[Tuple[str, str]], batch_size: int = 1):
    """
    Generate ESM embeddings for given sequences.
    """
    embeddings = {}

    for i in tqdm(range(0, len(sequences), batch_size), desc="Generating embeddings"):
        batch = sequences[i:i + batch_size]

        batch_labels, batch_strs, batch_tokens = batch_converter(
            [(label, seq) for label, seq in batch]
        )
        batch_tokens = batch_tokens.to(device)

        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_representations = results["representations"][33]

            for j, (label, seq) in enumerate(batch):
                seq_embedding = token_representations[j, 1:len(seq)+1].mean(0)
                embeddings[label] = seq_embedding.cpu().numpy()

        del batch_tokens
        torch.cuda.empty_cache()

    return embeddings

def save_embeddings(embeddings: dict, output_dir: str = "protein_embeddings"):
    """
    Save embeddings as .npy files.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Save each protein embedding
    for protein_id, embedding in embeddings.items():
        file_path = os.path.join(output_dir, f"{protein_id}.npy")
        np.save(file_path, embedding)

    # Save all embeddings
    all_embeddings = np.stack(list(embeddings.values()))
    all_ids = list(embeddings.keys())

    np.save(os.path.join(output_dir, "all_embeddings.npy"), all_embeddings)

    with open(os.path.join(output_dir, "protein_ids.txt"), 'w') as f:
        for pid in all_ids:
            f.write(f"{pid}\n")

    return output_dir

#@title 5. Main workflow

def process_protein_file(file_path: str, batch_size: int = 1):
    """
    Process the uploaded protein file and generate embeddings.
    """
    print(f"\nProcessing file: {file_path}")

    # 1. Parse file
    print("Parsing protein sequences...")
    proteins = parse_protein_file(file_path)

    if not proteins:
        print("No valid protein sequences found.")
        return None, {}

    print(f"Found {len(proteins)} protein sequences.")

    # Prepare sequences
    sequences = [(protein_id, seq) for idx, protein_id, seq in proteins]

    # Preview
    print("\nSequence samples:")
    for idx, protein_id, seq in proteins[:5]:
        preview = seq[:30] + "..." if len(seq) > 30 else seq
        print(f"  ID: {idx} | Length: {len(seq)} | Sequence: {preview}")
    if len(proteins) > 5:
        print(f"  ... {len(proteins)-5} more sequences")

    # 2. Generate embeddings
    print(f"\nGenerating embeddings (batch size: {batch_size})...")
    embeddings = get_esm_embeddings(sequences, batch_size)

    # 3. Save embeddings
    print("\nSaving embeddings...")
    output_dir = save_embeddings(embeddings)

    # 4. Zip output
    print("\nCreating archive...")
    zip_filename = "protein_embeddings.zip"
    !zip -r {zip_filename} {output_dir} -q

    print("\nProcessing complete.")
    print(f"Embedding dimension: {list(embeddings.values())[0].shape}")
    print(f"Output directory: {output_dir}/")

    return zip_filename, embeddings

#@title 6. Upload and process file

print("=" * 50)
print("ESM-2 650M Protein Embedding Tool")
print("=" * 50)
print("\nFile requirements:")
print("• Four-column format: index<TAB>UniProt<TAB>Name<TAB>Sequence")
print("\nExamples:")
print("1\tP12345\tAlpha\tMKLLIALSLGALV...")
print("2\tP67890\tBeta\tMASNFTQFVLVDNG...")
print("-" * 50)
print("\nUpload your protein sequence file (TXT):")

uploaded = files.upload()

if uploaded:
    filename = list(uploaded.keys())[0]
    print(f"\nFile uploaded: {filename}")

    batch_size = 1

    try:
        zip_file, embeddings = process_protein_file(filename, batch_size)

        if zip_file:
            print("\nPreparing download...")
            files.download(zip_file)
            print("Download started.")

            print("\nSummary:")
            print(f"  Proteins processed: {len(embeddings)}")
            print(f"  Embedding dimension: {list(embeddings.values())[0].shape[0]}")
            print("  Output files:")
            print("    • all_embeddings.npy")
            print("    • protein_ids.txt")
            print("    • Individual .npy files per protein")

    except Exception as e:
        print(f"\nError encountered: {str(e)}")
        print("Please check input file format.")
        import traceback
        traceback.print_exc()
else:
    print("\nNo file uploaded.")

#@title 7. Optional: View embedding sample
if 'embeddings' in locals() and embeddings:
    print("\nEmbedding sample:")
    first_key = list(embeddings.keys())[0]
    first_embedding = embeddings[first_key]
    print(f"Protein ID: {first_key}")
    print(f"Embedding shape: {first_embedding.shape}")
    print(f"First 10 values: {first_embedding[:10]}")

    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 4))

    # Distribution
    plt.subplot(1, 2, 1)
    plt.hist(first_embedding, bins=50, edgecolor='black', alpha=0.7)
    plt.title(f'Value Distribution (ID: {first_key})')
    plt.xlabel('Value')
    plt.ylabel('Frequency')

    # First 100 dimensions
    plt.subplot(1, 2, 2)
    plt.plot(first_embedding[:100], alpha=0.7)
    plt.title('First 100 Dimensions')
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

print("\nDone.")
