In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ESM2 Protein Embedding Analysis\n",
    "\n",
    "**Author:** Soumith Paritala  \n",
    "**Purpose:** Generate protein embeddings using ESM2-650M with GPU acceleration  \n",
    "**Platform:** Google Colab (or any Jupyter environment with GPU)\n",
    "\n",
    "---\n",
    "\n",
    "## Overview\n",
    "\n",
    "This notebook demonstrates:\n",
    "1. Loading protein sequences from FASTA/text files\n",
    "2. Generating embeddings with ESM2 on GPU\n",
    "3. Analyzing protein similarities\n",
    "4. Visualizing relationships with PCA and t-SNE\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Check GPU Availability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader\n",
    "\n",
    "import torch\n",
    "print(f\"\\nPyTorch: {torch.__version__}\")\n",
    "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
    "    print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Upload Protein Sequence File"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.colab import files\n",
    "import os\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"UPLOAD YOUR PROTEIN SEQUENCE FILE\")\n",
    "print(\"=\" * 60)\n",
    "print(\"\\nSupported formats:\")\n",
    "print(\"  • FASTA (.fasta, .fa, .faa)\")\n",
    "print(\"  • Plain text (.txt)\")\n",
    "print(\"\\nClick 'Choose Files' button below...\\n\")\n",
    "\n",
    "uploaded = files.upload()\n",
    "input_file = list(uploaded.keys())[0]\n",
    "file_ext = os.path.splitext(input_file)[1].lower()\n",
    "\n",
    "print(f\"\\n✓ File uploaded: {input_file}\")\n",
    "print(f\"  Size: {os.path.getsize(input_file) / 1024:.1f} KB\")\n",
    "print(\"\\nFirst 15 lines:\")\n",
    "!head -n 15 {input_file}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Smart File Parser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def detect_file_format(filepath):\n",
    "    \"\"\"Detect if file is FASTA or plain text\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if line:\n",
    "                return 'fasta' if line.startswith('>') else 'plain'\n",
    "    return 'plain'\n",
    "\n",
    "def read_fasta_file(filepath):\n",
    "    \"\"\"Parse FASTA format\"\"\"\n",
    "    sequences = []\n",
    "    current_name = None\n",
    "    current_seq = []\n",
    "    \n",
    "    with open(filepath, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if not line:\n",
    "                continue\n",
    "            \n",
    "            if line.startswith('>'):\n",
    "                if current_name and current_seq:\n",
    "                    sequences.append((current_name, ''.join(current_seq).replace(' ', '')))\n",
    "                \n",
    "                if '|' in line:\n",
    "                    parts = line[1:].split('|')\n",
    "                    current_name = parts[2].split()[0] if len(parts) >= 3 else parts[0].split()[0]\n",
    "                else:\n",
    "                    current_name = line[1:].split()[0]\n",
    "                current_seq = []\n",
    "            else:\n",
    "                current_seq.append(line)\n",
    "        \n",
    "        if current_name and current_seq:\n",
    "            sequences.append((current_name, ''.join(current_seq).replace(' ', '')))\n",
    "    \n",
    "    return sequences\n",
    "\n",
    "def validate_protein_sequence(sequence):\n",
    "    \"\"\"Validate protein sequence\"\"\"\n",
    "    valid_aa = set('ACDEFGHIKLMNPQRSTVWY')\n",
    "    sequence = sequence.upper().replace(' ', '')\n",
    "    invalid_chars = set(sequence) - valid_aa\n",
    "    \n",
    "    if invalid_chars:\n",
    "        return False, sequence, f\"Invalid characters: {sorted(invalid_chars)}\"\n",
    "    if len(sequence) < 10:\n",
    "        return False, sequence, f\"Too short ({len(sequence)} aa)\"\n",
    "    if len(sequence) > 50000:\n",
    "        return False, sequence, f\"Too long ({len(sequence)} aa)\"\n",
    "    \n",
    "    return True, sequence, None\n",
    "\n",
    "# Parse file\n",
    "print(\"=\" * 60)\n",
    "print(\"PARSING FILE\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "file_format = detect_file_format(input_file)\n",
    "print(f\"Format: {file_format.upper()}\\n\")\n",
    "\n",
    "sequences = read_fasta_file(input_file)\n",
    "validated_sequences = []\n",
    "\n",
    "for name, seq in sequences:\n",
    "    is_valid, cleaned_seq, error = validate_protein_sequence(seq)\n",
    "    if is_valid:\n",
    "        validated_sequences.append((name, cleaned_seq))\n",
    "        print(f\"✓ {name}: {len(cleaned_seq)} aa\")\n",
    "    else:\n",
    "        print(f\"✗ {name}: {error}\")\n",
    "\n",
    "sequences = validated_sequences\n",
    "print(f\"\\n✓ Loaded {len(sequences)} valid sequences\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Load ESM2 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install fair-esm biopython -q\n",
    "\n",
    "import esm\n",
    "import numpy as np\n",
    "\n",
    "print(\"Loading ESM2-650M model...\")\n",
    "model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "print(f\"✓ Model loaded on {device}\")\n",
    "print(f\"  Layers: 33\")\n",
    "print(f\"  Embedding dim: 1280\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Generate Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"GENERATING EMBEDDINGS\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "batch_converter = alphabet.get_batch_converter()\n",
    "batch_labels, batch_strs, batch_tokens = batch_converter(sequences)\n",
    "batch_tokens = batch_tokens.to(device)\n",
    "\n",
    "start_time = time.time()\n",
    "with torch.no_grad():\n",
    "    results = model(batch_tokens, repr_layers=[33], return_contacts=False)\n",
    "\n",
    "embeddings = results[\"representations\"][33].cpu().numpy()\n",
    "elapsed = time.time() - start_time\n",
    "\n",
    "print(f\"✓ Complete in {elapsed:.1f}s\")\n",
    "print(f\"\\nShape: {embeddings.shape}\")\n",
    "print(f\"  {embeddings.shape[0]} proteins\")\n",
    "print(f\"  {embeddings.shape[2]} dimensions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"results\", exist_ok=True)\n",
    "\n",
    "!cp \"{input_file}\" \"results/input_sequences{file_ext}\"\n",
    "\n",
    "for i, (name, seq) in enumerate(sequences):\n",
    "    clean_name = name.replace(\"|\", \"_\").replace(\"/\", \"_\").replace(\":\", \"_\")\n",
    "    \n",
    "    np.save(f\"results/{clean_name}_embedding.npy\", embeddings[i])\n",
    "    \n",
    "    with open(f\"results/{clean_name}_info.txt\", \"w\") as f:\n",
    "        f.write(f\"Protein: {name}\\n\")\n",
    "        f.write(f\"Length: {len(seq)} aa\\n\\n\")\n",
    "        f.write(\"Sequence:\\n\")\n",
    "        for j in range(0, len(seq), 60):\n",
    "            f.write(seq[j:j+60] + \"\\n\")\n",
    "        f.write(f\"\\nEmbedding shape: {embeddings[i].shape}\\n\")\n",
    "        f.write(f\"Mean: {embeddings[i].mean():.6f}\\n\")\n",
    "        f.write(f\"Std: {embeddings[i].std():.6f}\\n\")\n",
    "    \n",
    "    print(f\"✓ {name}\")\n",
    "\n",
    "print(\"\\n✓ All saved to results/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: Analyze Similarities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine\n",
    "\n",
    "if len(sequences) >= 2:\n",
    "    print(\"=\" * 60)\n",
    "    print(\"PAIRWISE SIMILARITIES\")\n",
    "    print(\"=\" * 60)\n",
    "    \n",
    "    avg_embeddings = embeddings.mean(axis=1)\n",
    "    \n",
    "    for i in range(len(sequences)):\n",
    "        for j in range(i + 1, len(sequences)):\n",
    "            name1, name2 = sequences[i][0], sequences[j][0]\n",
    "            similarity = 1 - cosine(avg_embeddings[i], avg_embeddings[j])\n",
    "            print(f\"{name1[:25]:25s} ↔ {name2[:25]:25s}: {similarity:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 8: Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "if len(sequences) >= 2:\n",
    "    avg_embeddings = embeddings.mean(axis=1)\n",
    "    \n",
    "    if len(sequences) >= 3:\n",
    "        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))\n",
    "    else:\n",
    "        fig, ax1 = plt.subplots(1, 1, figsize=(10, 7))\n",
    "        ax2 = None\n",
    "    \n",
    "    # PCA\n",
    "    pca = PCA(n_components=2)\n",
    "    emb_pca = pca.fit_transform(avg_embeddings)\n",
    "    \n",
    "    colors = plt.cm.tab10(range(len(sequences)))\n",
    "    for i, (name, _) in enumerate(sequences):\n",
    "        ax1.scatter(emb_pca[i, 0], emb_pca[i, 1], s=400, alpha=0.7,\n",
    "                   color=colors[i], edgecolor='black', linewidth=1.5)\n",
    "        ax1.annotate(name[:20], (emb_pca[i, 0], emb_pca[i, 1]),\n",
    "                    xytext=(5, 5), textcoords='offset points', fontweight='bold')\n",
    "    \n",
    "    ax1.set_xlabel(f\"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)\", fontweight='bold')\n",
    "    ax1.set_ylabel(f\"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)\", fontweight='bold')\n",
    "    ax1.set_title(\"ESM2 Protein Embeddings - PCA\", fontweight='bold', fontsize=15)\n",
    "    ax1.grid(True, alpha=0.3)\n",
    "    \n",
    "    # t-SNE\n",
    "    if ax2 and len(sequences) >= 3:\n",
    "        tsne = TSNE(n_components=2, random_state=42, perplexity=min(len(sequences)-1, 5))\n",
    "        emb_tsne = tsne.fit_transform(avg_embeddings)\n",
    "        \n",
    "        for i, (name, _) in enumerate(sequences):\n",
    "            ax2.scatter(emb_tsne[i, 0], emb_tsne[i, 1], s=400, alpha=0.7,\n",
    "                       color=colors[i], edgecolor='black', linewidth=1.5)\n",
    "            ax2.annotate(name[:20], (emb_tsne[i, 0], emb_tsne[i, 1]),\n",
    "                        xytext=(5, 5), textcoords='offset points', fontweight='bold')\n",
    "        \n",
    "        ax2.set_xlabel(\"t-SNE Dimension 1\", fontweight='bold')\n",
    "        ax2.set_ylabel(\"t-SNE Dimension 2\", fontweight='bold')\n",
    "        ax2.set_title(\"ESM2 Protein Embeddings - t-SNE\", fontweight='bold', fontsize=15)\n",
    "        ax2.grid(True, alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"results/protein_embeddings_visualization.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    print(\"✓ Visualization saved\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 9: Create Summary Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"results/analysis_summary.txt\", \"w\") as f:\n",
    "    f.write(\"=\" * 70 + \"\\n\")\n",
    "    f.write(\"ESM2 PROTEIN EMBEDDING ANALYSIS\\n\")\n",
    "    f.write(\"=\" * 70 + \"\\n\\n\")\n",
    "    f.write(f\"Input: {input_file}\\n\")\n",
    "    f.write(f\"Proteins: {len(sequences)}\\n\")\n",
    "    f.write(f\"Model: ESM2-650M\\n\")\n",
    "    f.write(f\"Device: {device}\\n\\n\")\n",
    "    \n",
    "    for i, (name, seq) in enumerate(sequences, 1):\n",
    "        f.write(f\"{i}. {name}: {len(seq)} aa\\n\")\n",
    "    \n",
    "    if len(sequences) >= 2:\n",
    "        f.write(\"\\n\" + \"=\" * 70 + \"\\n\")\n",
    "        f.write(\"SIMILARITIES\\n\")\n",
    "        f.write(\"=\" * 70 + \"\\n\")\n",
    "        for i in range(len(sequences)):\n",
    "            for j in range(i + 1, len(sequences)):\n",
    "                similarity = 1 - cosine(avg_embeddings[i], avg_embeddings[j])\n",
    "                f.write(f\"{sequences[i][0][:30]:30s} ↔ {sequences[j][0][:30]:30s}: {similarity:.4f}\\n\")\n",
    "\n",
    "print(\"✓ Summary created\")\n",
    "!head -n 40 results/analysis_summary.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 10: Download Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!zip -r -q esm2_results.zip results/\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"DOWNLOAD READY\")\n",
    "print(\"=\" * 60)\n",
    "print(f\"Proteins: {len(sequences)}\")\n",
    "!ls -lh esm2_results.zip\n",
    "\n",
    "from google.colab import files\n",
    "files.download('esm2_results.zip')\n",
    "\n",
    "print(\"\\n✓ Download complete!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
```

---

## **FILE 4: test_data/proteins.fasta**

Location: `esm2/test_data/proteins.fasta`
```
>sp|P69905|HBA_HUMAN Hemoglobin subunit alpha OS=Homo sapiens OX=9606 GN=HBA1 PE=1 SV=2
MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGH
GKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEF
TPAVHASLDKFLASVSTVLTSKYR
>sp|P68871|HBB_HUMAN Hemoglobin subunit beta OS=Homo sapiens OX=9606 GN=HBB PE=1 SV=2
MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNP
KVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHH
FGKEFTPPVQAAYQKVVAGVANALAHKYH
>sp|P01308|INS_HUMAN Insulin OS=Homo sapiens OX=9606 GN=INS PE=1 SV=1
MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAE
DLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN
>sp|P04406|G3P_HUMAN Glyceraldehyde-3-phosphate dehydrogenase OS=Homo sapiens OX=9606 GN=GAPDH PE=1 SV=3
MVKVGVNGFGRIGRLVTRAAFNSGKVDIVAINDPFIDLNYMVYMFQYDSTHGKFHGTVK
AENGKLVINGKAITIFQERDPANIKWGDAGAEYVVESTGVFTTMEKAGAHLQGGAKRVI
ISAPSADAPMFVMGVNHEKYDNSLKIVSNASCTTNCLAPLAKVIHDHFGIVEGLMTTVH
AITATQKTVDGPSGKLWRDGRGAAQNIIPASTGAAKAVGKVIPELNGKLTGMAFRVPTA
NVSVVDLTCRLEKPAKYDDIKKVVKQASEGPLKGILGYTEHQVVSSDFNSDTHSSTFDA
GAGIALNDHFVKLISWYDNEFGYSNRVVDLMAHMASKE
>sp|P61626|LYSC_HUMAN Lysozyme C OS=Homo sapiens OX=9606 GN=LYZ PE=1 SV=1
MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNWVCAAKFESNFNTQ
ATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCAKKIVS
DGNGMNAWVAWRNRCKGTDVQAWIRGCRL