# ESM Encoding for Unique Proteins in scope_onside_common_v3

This notebook encodes amino acid sequences for unique proteins identified by `target_uniprot_id` using the Hugging Face ESM2 model and real pretrained weights. It follows the logic shown in `esm-hugginface.py`, using masked mean pooling over token embeddings, excluding special and padding tokens.

Outline:
- Set up environment and GPU
- Load dataset (CSV/Parquet/JSON)
- Deduplicate by `target_uniprot_id`
- Configure and load the ESM model from Hugging Face
- Clean and validate sequences
- Batch tokenization and masked mean pooling
- Batched inference with sliding-window chunking for long sequences
- Persist embeddings with metadata (Parquet/NPZ)
- Caching and resume logic
- Quick sanity tests

## 1) Set Up Environment and GPU

- Installs are optional here; if missing, run the commands in the next cell.
- Detect CUDA and choose dtype: float16 on GPU, float32 otherwise.
- Print device info and CUDA memory if available.

In [None]:
import torch 

if torch.cuda.is_available():
    device = torch.device('cuda')
    dtype = torch.float16
else:
    device = torch.device('cpu')
    dtype = torch.float32

print('Device:', device)

In [12]:
import os
import math
import json
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

# Device and dtype selection
if torch.cuda.is_available():
    device = torch.device('cuda')
    dtype = torch.float16
else:
    device = torch.device('cpu')
    dtype = torch.float32

print('Device:', device)
if device.type == 'cuda':
    print('CUDA device count:', torch.cuda.device_count())
    print('Current device:', torch.cuda.current_device())
    print('Device name:', torch.cuda.get_device_name(0))
    free, total = torch.cuda.mem_get_info()
    print('CUDA memory (free/total GB):', round(free/1024**3, 2), '/', round(total/1024**3, 2))

Device: cuda
CUDA device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 3050 Ti Laptop GPU
CUDA memory (free/total GB): 0.0 / 4.0


## 2) Load scope_onside_common_v3 Dataset

- Set input path and optionally override column names
- Auto-detect file format and load
- Validate columns, drop missing, and log counts

In [13]:
# Parameters
INPUT_PATH = os.getenv('SCOPE_ONSIDE_PATH', r'd:\BracU\Thesis\esm-encode\scope_onside_common_v3.parquet')
ID_COL = os.getenv('PROTEIN_ID_COL', 'target_uniprot_id')
SEQ_COL = os.getenv('PROTEIN_SEQ_COL', None)  # try to auto-detect if None

print('Input path:', INPUT_PATH)
print('ID column:', ID_COL)
print('Sequence column (preset):', SEQ_COL)

# Auto-detect file extension and load
path = Path(INPUT_PATH)
if not path.exists():
    raise FileNotFoundError(f'Input file not found: {path}')

ext = path.suffix.lower()
if ext in ['.parquet', '.pq']:
    df = pd.read_parquet(path)
elif ext in ['.csv', '.tsv']:
    df = pd.read_csv(path) if ext == '.csv' else pd.read_csv(path, sep='\t')
elif ext in ['.json', '.jsonl']:
    df = pd.read_json(path, lines=ext == '.jsonl')
else:
    raise ValueError(f'Unsupported file extension: {ext}')

print('Loaded rows:', len(df))
print('Columns:', list(df.columns))

# Auto-detect sequence column if not provided
if SEQ_COL is None:
    candidate_cols = [
        'amino_acid_sequence', 'sequence', 'seq', 'protein_sequence', 'aa_sequence', 'AA_sequence'
    ]
    for c in candidate_cols:
        if c in df.columns:
            SEQ_COL = c
            break

if SEQ_COL is None:
    raise KeyError('Could not find a sequence column automatically. Set PROTEIN_SEQ_COL or edit SEQ_COL above.')

missing_cols = [c for c in [ID_COL, SEQ_COL] if c not in df.columns]
if missing_cols:
    raise KeyError(f'Missing required columns: {missing_cols}')

# Drop missing
before = len(df)
after_drop = df[[ID_COL, SEQ_COL]].dropna().shape[0]
df = df.dropna(subset=[ID_COL, SEQ_COL])
print(f'Dropped {before - len(df)} rows with null {ID_COL} or {SEQ_COL}. Remaining: {len(df)}')

Input path: d:\BracU\Thesis\esm-encode\scope_onside_common_v3.parquet
ID column: target_uniprot_id
Sequence column (preset): None
Loaded rows: 34741
Columns: ['drug_chembl_id', 'target_uniprot_id', 'label', 'smiles', 'sequence', 'molfile_3d', 'rxcui']
Dropped 0 rows with null target_uniprot_id or sequence. Remaining: 34741


## 3) Deduplicate Proteins by target_uniprot_id

- Group by ID
- Check for multiple sequences per ID and resolve conflicts deterministically
- Prepare unique DataFrame with ID, sequence, and length

In [14]:
# Resolve duplicates: keep the most frequent sequence per ID; if tie, keep first appearance

def resolve_duplicates(df: pd.DataFrame, id_col: str, seq_col: str) -> pd.DataFrame:
    # Count sequences per ID
    grp = df.groupby([id_col, seq_col], as_index=False).size().rename(columns={'size': 'count'})
    # For each ID, select sequence with max count; tie broken by first occurrence order in original df
    grp['rank'] = grp.groupby(id_col)['count'].rank(method='first', ascending=False)
    best = grp[grp['rank'] == 1].drop(columns=['rank']).copy()
    # Merge back to ensure stable order based on first occurrence in df
    first_order = df.drop_duplicates(subset=[id_col]).set_index(id_col).reset_index()[[id_col]]
    uniq = first_order.merge(best, on=id_col, how='left')
    uniq = uniq.rename(columns={seq_col: 'sequence'})
    uniq['seq_len'] = uniq['sequence'].str.len()
    return uniq

uniq_df = resolve_duplicates(df, ID_COL, SEQ_COL)
print('Unique proteins:', len(uniq_df))
print(uniq_df.head())

Unique proteins: 2385
  target_uniprot_id                                           sequence  count  \
0            O15245  MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...    150   
1            P08183  MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...    328   
2            P35367  MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...     67   
3            Q02763  MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA...     31   
4            Q12809  MPVRRGHVAPQNTFLDTIIRKFEGQSRKFIIANARVENCAVIYCND...    161   

   seq_len  
0      554  
1     1280  
2      487  
3     1124  
4     1159  


## 4) Configure and Load ESM Model from Hugging Face

- Parameterize model name (default: facebook/esm2_t33_650M_UR50D)
- Load tokenizer and model with pretrained weights
- Move to device, set eval, and record max positions and hidden size

In [15]:
MODEL_NAME = os.getenv('ESM_MODEL_NAME', 'facebook/esm2_t33_650M_UR50D')
BATCH_SIZE = int(os.getenv('ESM_BATCH_SIZE', '4'))  # keep small for memory; you can tune
CHUNK_OVERLAP = int(os.getenv('ESM_CHUNK_OVERLAP', '50'))  # tokens overlap for long seq chunking

print('Model:', MODEL_NAME)
print('Batch size:', BATCH_SIZE)

# Load tokenizer and model from pretrained weights
# Note: we use AutoModel to keep it generic and load true repo weights without hardcoding local files

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
model.to(device)

hidden_size = getattr(model.config, 'hidden_size', None) or getattr(model.config, 'd_model', None)
max_positions = getattr(model.config, 'max_position_embeddings', None)

print('Hidden size:', hidden_size)
print('Max positions:', max_positions)

Model: facebook/esm2_t33_650M_UR50D
Batch size: 4


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Hidden size: 1280
Max positions: 1026


## 5) Sequence Cleaning and Validation

- Normalize: strip, uppercase, map non-standard to 'X'
- Validate against tokenizer vocab where possible
- Record a cleaned flag and track sequences exceeding model max length

In [16]:
# Allowed amino acids per UniProt standard + X, B, Z, U, O
STANDARD_AA = set(list('ACDEFGHIKLMNPQRSTVWY'))
ALLOWED_AA = set(list('ACDEFGHIKLMNPQRSTVWYBXZOU'))


def clean_sequence(seq: str) -> Tuple[str, bool, int]:
    s = (seq or '').strip().upper()
    cleaned = False
    out_chars = []
    for ch in s:
        if ch in ALLOWED_AA:
            if ch not in STANDARD_AA:
                cleaned = True
                out_chars.append('X') if ch in {'B', 'Z'} else out_chars.append(ch if ch in STANDARD_AA else 'X')
            else:
                out_chars.append(ch)
        else:
            cleaned = True
            out_chars.append('X')
    return ''.join(out_chars), cleaned, len(s)

uniq_df['cleaned_seq'], uniq_df['was_cleaned'], uniq_df['orig_len'] = zip(*uniq_df['sequence'].map(clean_sequence))

# Length validation against model max positions (tokenizer adds special tokens)
if max_positions is None:
    # Fallback: most ESM2 variants support ~1022 tokens including specials
    max_positions = 1022

# We will mark sequences that likely need chunking if their raw length exceeds (max_positions - specials)
# HF tokenizer typically adds 2 specials (CLS/EOS)
specials = 2
len_limit = max_positions - specials
uniq_df['needs_chunk'] = uniq_df['cleaned_seq'].str.len() > len_limit

print('Sequences needing chunking:', int(uniq_df['needs_chunk'].sum()))
uniq_df.head()

Sequences needing chunking: 293


Unnamed: 0,target_uniprot_id,sequence,count,seq_len,cleaned_seq,was_cleaned,orig_len,needs_chunk
0,O15245,MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...,150,554,MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...,False,554,False
1,P08183,MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...,328,1280,MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...,False,1280,True
2,P35367,MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...,67,487,MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...,False,487,False
3,Q02763,MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA...,31,1124,MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA...,False,1124,True
4,Q12809,MPVRRGHVAPQNTFLDTIIRKFEGQSRKFIIANARVENCAVIYCND...,161,1159,MPVRRGHVAPQNTFLDTIIRKFEGQSRKFIIANARVENCAVIYCND...,False,1159,True


## 6) Batch Tokenization and Masked Mean Pooling

Implement helper functions to:
- tokenize a list of sequences with padding/truncation
- forward pass and masked mean pool over last_hidden_state, excluding padding and special tokens using attention_mask
- return pooled vectors and flags indicating any truncation

In [17]:
def masked_mean_pool(last_hidden_state: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
    # last_hidden_state: [B, T, H]
    # input_ids: [B, T]
    # attention_mask: [B, T]
    # Exclude CLS/EOS tokens using input_ids and tokenizer special IDs
    cls_id = tokenizer.cls_token_id
    eos_id = tokenizer.eos_token_id
    special_mask = (input_ids != cls_id) & (input_ids != eos_id)
    valid_mask = attention_mask.bool() & special_mask
    # Expand to match hidden dim
    valid_mask_f = valid_mask.unsqueeze(-1)  # [B, T, 1]
    masked = last_hidden_state * valid_mask_f
    # Sum and divide by count
    sums = masked.sum(dim=1)  # [B, H]
    counts = valid_mask_f.sum(dim=1).clamp(min=1)  # [B, 1]
    pooled = sums / counts
    # A truncation flag can be inferred from tokenizer outputs if available
    return pooled, valid_mask.sum(dim=1)


def encode_batch(seq_list: List[str]) -> Tuple[torch.Tensor, np.ndarray]:
    # Tokenize with padding/truncation; rely on tokenizer/model limits
    enc = tokenizer(
        seq_list,
        return_tensors='pt',
        padding=True,
        truncation=True,
        add_special_tokens=True,
    ).to(device)

    # Choose autocast context based on device
    from contextlib import nullcontext
    autocast_ctx = torch.cuda.amp.autocast(dtype=dtype) if device.type == 'cuda' else nullcontext()

    with torch.no_grad():
        with autocast_ctx:
            outputs = model(**enc)
            hs = outputs.last_hidden_state  # [B, T, H]
    pooled, valid_counts = masked_mean_pool(hs, enc['input_ids'], enc['attention_mask'], tokenizer)
    pooled = pooled.detach().float().cpu()

    # Detect truncation: if any sequence length equals the model max (via tokenizer) and original > len_limit
    was_truncated = []
    for i, s in enumerate(seq_list):
        # Token length includes specials; check if attention_mask sum equals tokenizer.model_max_length or close to it
        attn_len = int(enc['attention_mask'][i].sum().item())
        trunc = attn_len >= (max_positions or attn_len) and len(s) > len_limit
        was_truncated.append(trunc)
    return pooled, np.array(was_truncated, dtype=bool)

## 7) Batched Inference with Chunking for Long Sequences

- Iterate in batches with a progress bar
- For sequences exceeding model max, use sliding-window chunking and average pooled embeddings
- Collect results as we go

In [18]:
def sliding_windows(seq: str, max_len: int, overlap: int) -> List[str]:
    if len(seq) <= max_len:
        return [seq]
    windows = []
    start = 0
    while start < len(seq):
        end = min(start + max_len, len(seq))
        windows.append(seq[start:end])
        if end == len(seq):
            break
        start = end - overlap
        if start < 0:
            start = 0
    return windows


def encode_sequences(ids: List[str], seqs: List[str], batch_size: int = BATCH_SIZE) -> Dict[str, Dict]:
    results = {}
    for i in tqdm(range(0, len(seqs), batch_size), desc='Encoding batches'):
        batch_ids = ids[i:i+batch_size]
        batch_seqs = seqs[i:i+batch_size]

        to_encode = []
        chunk_map = []  # (orig_index, num_chunks)
        for j, s in enumerate(batch_seqs):
            if len(s) > len_limit:
                chunks = sliding_windows(s, len_limit, CHUNK_OVERLAP)
                to_encode.extend(chunks)
                chunk_map.append((j, len(chunks)))
            else:
                to_encode.append(s)
                chunk_map.append((j, 1))

        # Encode all pieces in this batch (may be more than batch_size due to chunking)
        pooled, was_trunc = encode_batch(to_encode)

        # Aggregate back to original sequences
        idx = 0
        agg_vectors = []
        agg_trunc = []
        for (j, n_chunks) in chunk_map:
            vecs = pooled[idx:idx+n_chunks]
            flags = was_trunc[idx:idx+n_chunks]
            idx += n_chunks
            # Average chunk embeddings
            agg_vec = vecs.mean(dim=0)
            agg_vectors.append(agg_vec)
            agg_trunc.append(bool(flags.any()) or (n_chunks > 1))

        for bid, vec, trunc_flag, seq in zip(batch_ids, agg_vectors, agg_trunc, batch_seqs):
            results[bid] = {
                'embedding': vec.numpy(),
                'seq_len': len(seq),
                'was_truncated': trunc_flag,
            }
    return results

In [24]:
OUT_DIR = Path(os.getenv('ESM_OUT_DIR', r'd:\BracU\Thesis\esm-encode\esm_outputs'))
OUT_DIR.mkdir(parents=True, exist_ok=True)
PARQUET_PATH = OUT_DIR / 'esm2_embeddings.parquet'
NPZ_PATH = OUT_DIR / 'esm2_embeddings.npz'
NPZ_META_CSV = OUT_DIR / 'esm2_embeddings_meta.csv'
CHECKPOINT_PATH = OUT_DIR / 'checkpoint_ids.json'

print('Output dir:', OUT_DIR)

Output dir: d:\BracU\Thesis\esm-encode\esm_outputs


## 9) Caching and Resume Logic

- Load IDs from previous Parquet or checkpoint (if any)
- Skip already-encoded IDs
- Periodically checkpoint during long runs

In [25]:
# Simple checkpointing utilities

def load_completed_ids(parquet_path: Path, checkpoint_path: Path) -> set:
    completed = set()
    if parquet_path.exists():
        try:
            tmp = pd.read_parquet(parquet_path, columns=[ID_COL])
            completed.update(tmp[ID_COL].astype(str).tolist())
            print('Loaded completed IDs from Parquet:', len(completed))
        except Exception as e:
            print('Parquet exists but could not be read for checkpoints:', e)
    if checkpoint_path.exists():
        try:
            with open(checkpoint_path, 'r', encoding='utf-8') as f:
                completed.update(json.load(f))
            print('Loaded completed IDs from checkpoint:', len(completed))
        except Exception as e:
            print('Checkpoint exists but could not be read:', e)
    return completed


def save_checkpoint_ids(ids_done: set, checkpoint_path: Path):
    with open(checkpoint_path, 'w', encoding='utf-8') as f:
        json.dump(sorted(list(ids_done)), f)
    print('Checkpoint saved with', len(ids_done), 'IDs ->', checkpoint_path)

In [26]:
# Example of how to use checkpoints (for long runs, integrate into encode_sequences loop if desired)
completed_ids = load_completed_ids(PARQUET_PATH, CHECKPOINT_PATH)
if completed_ids:
    mask = ~uniq_df[ID_COL].astype(str).isin(completed_ids)
    to_process = uniq_df[mask].copy()
    print('Skipping', (~mask).sum(), 'already encoded IDs. To process:', len(to_process))
else:
    to_process = uniq_df.copy()

# If you want to use incremental saving, you can iterate batches manually
# For simplicity above we encoded all at once; for very large sets, adapt below pattern:
# ids = to_process[ID_COL].tolist()
# seqs = to_process['cleaned_seq'].tolist()
# batch_done = set(completed_ids)
# for i in range(0, len(seqs), BATCH_SIZE):
#     ... encode and append to Parquet (use pyarrow to append) ...
#     batch_done.update(batch_ids)
#     save_checkpoint_ids(batch_done, CHECKPOINT_PATH)

## 8) Persist Embeddings with Metadata

- Save to Parquet and optionally NPZ
- Parameterize output paths
- Log shapes and sizes

In [27]:
def save_parquet(ids: List[str], seq_lens: List[int], was_cleaned: List[bool], was_trunc: List[bool], embeds: np.ndarray, path: Path):
    table = pd.DataFrame({
        ID_COL: ids,
        'seq_len': seq_lens,
        'was_cleaned': was_cleaned,
        'was_truncated': was_trunc,
        'embedding': list(embeds.astype('float32')),
    })
    table.to_parquet(path, index=False)
    print('Saved Parquet:', path, 'rows:', len(table))


def save_npz_from_parquet(parquet_path: Path, npz_path: Path, meta_csv: Path):
    df_all = pd.read_parquet(parquet_path)
    # Convert list column to array
    embeds = np.stack(df_all['embedding'].to_list(), axis=0).astype('float32')
    np.savez_compressed(npz_path, embeddings=embeds)
    meta = df_all[[ID_COL, 'seq_len', 'was_cleaned', 'was_truncated']].copy()
    meta.to_csv(meta_csv, index=False)
    print('Saved NPZ:', npz_path, 'and meta CSV:', meta_csv)

# Run encoding for all sequences that still need processing
source_df = to_process if 'to_process' in globals() else uniq_df
ids = source_df[ID_COL].astype(str).tolist()
seqs = source_df['cleaned_seq'].tolist()
was_cleaned = source_df['was_cleaned'].astype(bool).tolist()
seq_lens = source_df['cleaned_seq'].str.len().tolist()

if len(ids) == 0:
    print('Nothing to process. All IDs appear to be encoded already.')
else:
    results = encode_sequences(ids, seqs, batch_size=BATCH_SIZE)

    # Assemble arrays in original order
    embeddings = np.stack([results[i]['embedding'] for i in ids], axis=0)
    was_truncated = [results[i]['was_truncated'] for i in ids]

    print('Embeddings shape (new batch):', embeddings.shape)

    # Prepare new table
    df_new = pd.DataFrame({
        ID_COL: ids,
        'seq_len': seq_lens,
        'was_cleaned': was_cleaned,
        'was_truncated': was_truncated,
        'embedding': list(embeddings.astype('float32')),
    })

    # Merge with existing Parquet if present
    if PARQUET_PATH.exists():
        try:
            df_old = pd.read_parquet(PARQUET_PATH)
            df_all = pd.concat([df_old, df_new], ignore_index=True)
            df_all = df_all.drop_duplicates(subset=[ID_COL], keep='last')
            df_all.to_parquet(PARQUET_PATH, index=False)
            print('Updated Parquet with merged results:', PARQUET_PATH, 'rows:', len(df_all))
        except Exception as e:
            print('Failed to merge with existing Parquet, writing new only:', e)
            df_new.to_parquet(PARQUET_PATH, index=False)
    else:
        df_new.to_parquet(PARQUET_PATH, index=False)
        print('Wrote new Parquet:', PARQUET_PATH, 'rows:', len(df_new))

    # Refresh NPZ + CSV metadata from the Parquet
    save_npz_from_parquet(PARQUET_PATH, NPZ_PATH, NPZ_META_CSV)

    # Update checkpoint with IDs present in Parquet
    try:
        all_ids = pd.read_parquet(PARQUET_PATH, columns=[ID_COL])[ID_COL].astype(str)
        save_checkpoint_ids(set(all_ids), CHECKPOINT_PATH)
    except Exception as e:
        print('Could not update checkpoint from Parquet:', e)

Encoding batches:   0%|          | 0/597 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  autocast_ctx = torch.cuda.amp.autocast(dtype=dtype) if device.type == 'cuda' else nullcontext()


Embeddings shape (new batch): (2385, 1280)
Wrote new Parquet: d:\BracU\Thesis\esm-encode\esm_outputs\esm2_embeddings.parquet rows: 2385
Saved NPZ: d:\BracU\Thesis\esm-encode\esm_outputs\esm2_embeddings.npz and meta CSV: d:\BracU\Thesis\esm-encode\esm_outputs\esm2_embeddings_meta.csv
Checkpoint saved with 2385 IDs -> d:\BracU\Thesis\esm-encode\esm_outputs\checkpoint_ids.json


## 10) Quick Sanity Tests

Run a few assertions on a small subset to verify embedding shape, non-zero norms, and similarity between identical sequences.

In [28]:
# Take 2-3 sequences (or repeat one) and test shapes
subset = uniq_df.head(2).copy()
if len(subset) == 1:
    subset = pd.concat([subset, subset], ignore_index=True)

ids_s = subset[ID_COL].tolist()
seqs_s = subset['cleaned_seq'].tolist()

embs = encode_sequences(ids_s, seqs_s, batch_size=2)
arr = np.stack([embs[i]['embedding'] for i in ids_s])
print('Subset embeddings shape:', arr.shape)

# Check hidden size
assert arr.shape[1] == hidden_size, f"Expected hidden size {hidden_size}, got {arr.shape[1]}"

# Non-zero norms
norms = np.linalg.norm(arr, axis=1)
assert np.all(norms > 0), 'Found zero-norm embedding'

# Similarity test for identical sequences (if any)
if seqs_s[0] == seqs_s[1]:
    from numpy.linalg import norm
    cos = float(np.dot(arr[0], arr[1]) / (norm(arr[0]) * norm(arr[1]) + 1e-9))
    print('Cosine similarity between identical sequences:', cos)
    assert cos > 0.99, 'Identical sequences should produce near-identical embeddings'

Encoding batches:   0%|          | 0/1 [00:00<?, ?it/s]

  autocast_ctx = torch.cuda.amp.autocast(dtype=dtype) if device.type == 'cuda' else nullcontext()


Subset embeddings shape: (2, 1280)
