# 03 — ESM-1b Embedding Generation

In this notebook we compute high-dimensional protein embeddings using  
**ESM-1b (esm1b_t33_650M_UR50S)** — a state-of-the-art transformer model for proteins.

We use the cleaned dataset 
**data/processed/protein_families_small_clean.csv**


### Goals of this notebook

1. **Load the cleaned dataset** and extract FASTA-like sequences.
2. **Load ESM-1b model and alphabet** using the `fair-esm` library.
3. **Tokenize sequences and compute embeddings**:
   - Representation layer: **33**
   - Mean pooling over amino acids (excluding BOS/EOS and padding)
   - Sequence length limit enforced at 1000 AA (ESM-1b max context)
4. **Batch inference on GPU** with configurable batch size.
5. **Save results as artifacts**:
   - `artifacts/embeddings/esm_embeddings.npy`
   - `artifacts/embeddings/metadata.csv`
6. Ensure embeddings are reproducible and ready for downstream ML.

### Output of this notebook

After running this notebook, you will have:

- `X` — dense embedding matrix of shape `(N_proteins, 1280)`
- `metadata.csv` — protein IDs, family labels, organisms, lengths
- These artifacts are consumed in:

### → `04_train_and_eval.ipynb`
for classifier training, evaluation, MLflow logging and baselines.

This notebook includes **only embedding computation** — no ML models yet.
from:

In [2]:
import esm
import numpy as np
import pandas as pd
import sys
import torch

from pathlib import Path
from tqdm import tqdm

PROJECT_ROOT = Path("..").resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.config import PROCESSED_DIR, EMB_DIR, MIN_SEQ_LEN, MAX_SEQ_LEN

print("PROJECT_ROOT:", PROJECT_ROOT)
print("PROCESSED_DIR:", PROCESSED_DIR)
print("EMBEDDINGS_DIR:", EMB_DIR)

csv_path = PROCESSED_DIR / "protein_families_small_clean.csv"
df = pd.read_csv(csv_path)

print(df.head())
print(df.shape)
print(df["family"].value_counts())

PROJECT_ROOT: D:\ML\BioML\ESM
PROCESSED_DIR: D:\ML\BioML\ESM\data\processed
EMBEDDINGS_DIR: D:\ML\BioML\ESM\artifacts\embeddings
  uniprot_id                                       protein_name  \
0     O00444  Serine/threonine-protein kinase PLK4 (EC 2.7.1...   
1     O00506  Serine/threonine-protein kinase 25 (EC 2.7.11....   
2     O00746  Nucleoside diphosphate kinase, mitochondrial (...   
3     O14757  Serine/threonine-protein kinase Chk1 (EC 2.7.1...   
4     O15111  Inhibitor of nuclear factor kappa-B kinase sub...   

               organism  length  \
0  Homo sapiens (Human)     970   
1  Homo sapiens (Human)     426   
2  Homo sapiens (Human)     187   
3  Homo sapiens (Human)     476   
4  Homo sapiens (Human)     745   

                                            sequence  family  
0  MATCIGEKIEDFKVGNLLGKGSFAGVYRAESIHTGLEVAIKMIDKK...  kinase  
1  MAHLRGFANQHSRVDPEELFTKLDRIGKGSFGEVYKGIDNHTKEVV...  kinase  
2  MGGLFWRSALRGLRCGPRAPGPSLLVRHGSGGPSWTRERTLVAVKP...  kinase  
3  MA

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Загружаем предобученную модель и алфавит
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
model = model.to(device)
model.eval()

batch_converter = alphabet.get_batch_converter()
padding_idx = alphabet.padding_idx
max_positions = model.embed_positions.num_embeddings 
max_positions

Using device: cuda


1026

In [5]:
def mean_pool_token_reprs(token_representations: torch.Tensor,
                          tokens: torch.Tensor,
                          padding_idx: int) -> torch.Tensor:
    """
    Усреднение по аминокислотам (без BOS/EOS, без pad).
    
    token_representations: (B, L, D)
    tokens: (B, L)
    return: (B, D)
    """
    pooled = []
    for i in range(tokens.size(0)):
        row = token_representations[i]   # (L, D)
        row_tokens = tokens[i]           # (L,)

        non_pad_idx = (row_tokens != padding_idx).nonzero(as_tuple=True)[0]

        if len(non_pad_idx) <= 2:
            vec = row[non_pad_idx].mean(dim=0)
        else:
            # [BOS] seq... [EOS]
            start = non_pad_idx[1]        # пропускаем BOS
            end = non_pad_idx[-2] + 1     # пропускаем EOS (slice [start:end)
            vec = row[start:end].mean(dim=0)

        pooled.append(vec)

    return torch.stack(pooled, dim=0)    # (B, D)

In [6]:
ids = df["uniprot_id"].tolist()
seqs = df["sequence"].tolist()

max_len_seq = max(len(s) for s in seqs)
print(f"Max sequence length in df_short: {max_len_seq}")

batch_size = 4  

all_embeddings = []

for i in tqdm(range(0, len(df_short), batch_size), desc="Embedding sequences"):
    batch_ids = ids[i:i + batch_size]
    batch_seqs = seqs[i:i + batch_size]
    batch_data = list(zip(batch_ids, batch_seqs))

    labels, strs, tokens = batch_converter(batch_data)
    # tokens.shape -> (B, L)
    if tokens.size(1) > max_positions:
        # защита от переполнения контекста
        raise ValueError(
            f"Token sequence length {tokens.size(1)} exceeds model max_positions {max_positions}"
        )

    tokens = tokens.to(device)

    with torch.no_grad():
        out = model(tokens, repr_layers=[33], return_contacts=False)
        token_reprs = out["representations"][33]  # (B, L, D)

    pooled = mean_pool_token_reprs(token_reprs, tokens, padding_idx)  # (B, D)
    all_embeddings.append(pooled.cpu().numpy())

# Собираем всё в одну матрицу
X = np.vstack(all_embeddings)
X.shape

Max sequence length in df_short: 999


Embedding sequences: 100%|██████████| 1066/1066 [13:00<00:00,  1.37it/s]


(4264, 1280)

In [9]:
# Убедимся, что размерности согласованы
assert X.shape[0] == len(df_short), "Mismatch between embeddings and dataframe length"

EMB_DIR.mkdir(parents=True, exist_ok=True)

emb_path = EMB_DIR / "esm1b_embeddings_small_maxlen1000.npy"
meta_path = EMB_DIR / "metadata_small_maxlen1000.csv"

np.save(emb_path, X)

# Сохраняем метаинформацию — чтобы потом удобно связывать с эмбеддингами
meta_cols = ["uniprot_id", "protein_name", "organism", "length", "family"]
df_short[meta_cols].to_csv(meta_path, index=False)

emb_path, meta_path, X.shape

(WindowsPath('D:/ML/BioML/ESM/artifacts/embeddings/esm1b_embeddings_small_maxlen1000.npy'),
 WindowsPath('D:/ML/BioML/ESM/artifacts/embeddings/metadata_small_maxlen1000.csv'),
 (4264, 1280))

In [10]:
# Быстрая проверка, что всё читается
X_loaded = np.load(emb_path)
meta_loaded = pd.read_csv(meta_path)

X_loaded.shape, meta_loaded.shape, meta_loaded["family"].value_counts()

((4264, 1280),
 (4264, 5),
 family
 kinase           500
 transporter      499
 ligase           495
 chaperone        490
 transcription    484
 hydrolase        445
 ion_channel      420
 receptor         418
 protease         356
 dna_binding      157
 Name: count, dtype: int64)

## Summary of this notebook

In this notebook we:

1. **Loaded the cleaned protein family dataset**
   - Source: `data/processed/protein_families_small_clean.csv`
   - Final size: `N` proteins across 10 functional families  
     (`kinase`, `transporter`, `ligase`, `chaperone`, `transcription`,
     `hydrolase`, `ion_channel`, `receptor`, `protease`, `dna_binding`).

2. **Loaded the ESM-1b model**
   - Model: `esm1b_t33_650M_UR50S` (from `fair-esm`)
   - Device: GPU (if available) / CPU fallback
   - Representation layer: **33**
   - Context limit respected by filtering sequences with  
     `MIN_SEQ_LEN ≤ length ≤ MAX_SEQ_LEN` (here: `50–1000 aa`).

3. **Computed sequence-level embeddings**
   - Tokenized sequences with the ESM alphabet and batch converter.
   - Extracted per-token representations from layer 33.
   - Applied **mean pooling over amino acids**, excluding:
     - padding tokens,
     - BOS/EOS special tokens.
   - Obtained a dense embedding matrix:
     - shape: `(N_proteins, 1280)`.

4. **Saved reusable artifacts**
   - Embeddings:
     - `artifacts/embeddings/esm1b_embeddings_small_maxlen1000.npy`
   - Metadata (linked 1:1 to embeddings):
     - `artifacts/embeddings/metadata_small_maxlen1000.csv`
     - columns: `uniprot_id, protein_name, organism, length, family`

These artifacts are now the **fixed input** for downstream ML steps:
train/validation/test splits, model training, evaluation and interpretation.

---

### Next step

Proceed to:

- `04_train_and_eval.ipynb`  

where we will:

- build train/val/test splits,
- train baseline and stronger classifiers on top of ESM embeddings,
- log results to MLflow for reproducibility.
