# Embedding Examples


This file contains examples of how to use different foundation models as embedders. To be used as reference to be plugged into existing pipelines.

## Common Code
This code can be ignored, and just loads in some of the dataset to demonstrate how to embed the sequences.


In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
import os
import csv
from pathlib import Path
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from itertools import combinations
from datetime import datetime

# Setup
project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"The device being used is: {device}")

# Constants
NUCLEOTIDE_MAP = {"P": 0, "A": 1, "T": 2, "C": 3, "G": 4}
NUCLEOTIDES = {v: k for k, v in NUCLEOTIDE_MAP.items()}
GAP_CHAR = '-'  # for internal use during column construction

# config file info
training_size = 1000
max_nt_num = 150
MAX_MSA_LEN = 50
MAX_N_SEQS = 3


  from .autonotebook import tqdm as notebook_tqdm


The device being used is: cpu


In [3]:
## -------- DIRECT COPY OF EXISITING CODE, FOR REFERENCE ONLY. DO NOT EDIT. -------- ##

def _count_inserted_gaps_from_sequences(start, solution):
    dash_start = sum(str(s).count('-') for s in start)
    dash_solution = sum(str(s).count('-') for s in solution)
    return max(0, dash_solution - dash_start)

def convert_column_major_solution(msa_string, n_seq):
    """
    Converts a column-major MSA string (down columns first) into
    a row-major list of aligned sequences.
    
    Args:
        msa_string (str): e.g. "AAACC---CGGTTTT"
        n_seq (int): number of sequences (rows)
    
    Returns:
        list[str]: e.g. ["ACGT-", "A-GT-", "AC-T-"]
    """
    if not msa_string or n_seq <= 0:
        return []

    # Split into chunks of n_seq (each chunk = one column)
    columns = [msa_string[i:i+n_seq] for i in range(0, len(msa_string), n_seq)]

    # Transpose columns -> rows
    seqs = [''.join(col[i] for col in columns) for i in range(n_seq)]
    return seqs

def convert_huggingface_to_samples(dataset, max_samples=None):
    samples = []
    for i, ex in enumerate(dataset):
        if max_samples and i >= max_samples:
            break

        unaligned_seqs = ex.get('unaligned_seqs', {})
        MSA = ex.get('MSA', "")

        if not unaligned_seqs or not MSA:
            continue

        start = [unaligned_seqs[k] for k in sorted(unaligned_seqs.keys())]
        n_seq = len(start)
        solution = convert_column_major_solution(MSA, n_seq)

        accepted_pairs = [(str(a), str(b)) for a, b in combinations(range(len(start)), 2)]
        n_gaps = _count_inserted_gaps_from_sequences(start, solution)

        sample = {
            'start': start,
            'solution': solution,
            'n_gaps': n_gaps,
            'moves': [-1] * n_gaps,  # keep list length equal to n_gaps as this is never actually used in the DQN
            'n_sequences': len(start),
            'idx': i
        }
        samples.append(sample)
    return samples


def filter_by_seq_length(example, max_len=MAX_MSA_LEN):
    """Keep only samples where every unaligned sequence is <= max_len."""
    if "unaligned_seqs" not in example:
        return False
    seqs = example["unaligned_seqs"].values() if isinstance(example["unaligned_seqs"], dict) else example["unaligned_seqs"]
    return all(len(seq) <= max_len for seq in seqs)

# --- Load and filter datasets ---
ds = load_dataset("dotan1111/MSA-nuc-3-seq", split="train")
ds = ds.filter(filter_by_seq_length)
train_samples = convert_huggingface_to_samples(ds, max_samples=training_size)
print(train_samples[0])


2025-11-23 14:28:27 | INFO | httpx | HTTP Request: HEAD https://huggingface.co/datasets/dotan1111/MSA-nuc-3-seq/resolve/main/README.md "HTTP/1.1 307 Temporary Redirect"
2025-11-23 14:28:27 | INFO | httpx | HTTP Request: HEAD https://huggingface.co/api/resolve-cache/datasets/dotan1111/MSA-nuc-3-seq/0acd8be3fd2b222e4d3f88f4fd370310c2878120/README.md "HTTP/1.1 200 OK"
2025-11-23 14:28:27 | INFO | httpx | HTTP Request: HEAD https://huggingface.co/datasets/dotan1111/MSA-nuc-3-seq/resolve/0acd8be3fd2b222e4d3f88f4fd370310c2878120/MSA-nuc-3-seq.py "HTTP/1.1 404 Not Found"
2025-11-23 14:28:27 | INFO | httpx | HTTP Request: HEAD https://s3.amazonaws.com/datasets.huggingface.co/datasets/datasets/dotan1111/MSA-nuc-3-seq/dotan1111/MSA-nuc-3-seq.py "HTTP/1.1 404 Not Found"
2025-11-23 14:28:27 | INFO | httpx | HTTP Request: GET https://huggingface.co/api/datasets/dotan1111/MSA-nuc-3-seq/revision/0acd8be3fd2b222e4d3f88f4fd370310c2878120 "HTTP/1.1 200 OK"
2025-11-23 14:28:27 | INFO | httpx | HTTP Reque

{'start': ['TACTACAGTTCTTAAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGTGAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'solution': ['TACTACAGTTCTT--AAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGT--GAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'n_gaps': 4, 'moves': [-1, -1, -1, -1], 'n_sequences': 3, 'idx': 0}


In [4]:
train_samples

[{'start': ['TACTACAGTTCTTAAAAATAATCTATTAAAATTTTTTTGCT',
   'TAGTACGATTCGTGAAAATAATCTGTTAAAATTCTTTTTCT',
   'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'],
  'solution': ['TACTACAGTTCTT--AAAAATAATCTATTAAAATTTTTTTGCT',
   'TAGTACGATTCGT--GAAAATAATCTGTTAAAATTCTTTTTCT',
   'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'],
  'n_gaps': 4,
  'moves': [-1, -1, -1, -1],
  'n_sequences': 3,
  'idx': 0},
 {'start': ['AGAAAAAAACGTAGTATTTTTTATATGATTACCCTGAT',
   'AGAATAGGATGCAATATTTACTCCATGATTACCTTCCAACT',
   'AGTATAGGATGGAATATTTACGCCATAATCACCTTCAGAT'],
  'solution': ['AGAAAAAAACGTAGTATTTTTTATATGATTACCCT---GAT',
   'AGAATAGGATGCAATATTTACTCCATGATTACCTTCCAACT',
   'AGTATAGGATGGAATATTTACGCCATAATCACCTT-CAGAT'],
  'n_gaps': 4,
  'moves': [-1, -1, -1, -1],
  'n_sequences': 3,
  'idx': 1},
 {'start': ['TTACGTTTACAATTTCATACAGCTTAATAGCGTGGGGGA',
   'TTACGTTTGTCAATTCATACAATTTAGTAGCATTTTAGGA',
   'TTACGTTTATCGATTCACACTGTTTAGTAGCATTTAAAA'],
  'solution': ['TTACGTTTACAATTTCATACAGCTTAATAGC-GTGGGGGA',
   'TTACGTTT

# ERNIE-RNA

In [14]:
for sample in train_samples[:5]:
    print(sample)

{'start': ['TACTACAGTTCTTAAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGTGAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'solution': ['TACTACAGTTCTT--AAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGT--GAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'], 'n_gaps': 4, 'moves': [-1, -1, -1, -1], 'n_sequences': 3, 'idx': 0}
{'start': ['AGAAAAAAACGTAGTATTTTTTATATGATTACCCTGAT', 'AGAATAGGATGCAATATTTACTCCATGATTACCTTCCAACT', 'AGTATAGGATGGAATATTTACGCCATAATCACCTTCAGAT'], 'solution': ['AGAAAAAAACGTAGTATTTTTTATATGATTACCCT---GAT', 'AGAATAGGATGCAATATTTACTCCATGATTACCTTCCAACT', 'AGTATAGGATGGAATATTTACGCCATAATCACCTT-CAGAT'], 'n_gaps': 4, 'moves': [-1, -1, -1, -1], 'n_sequences': 3, 'idx': 1}
{'start': ['TTACGTTTACAATTTCATACAGCTTAATAGCGTGGGGGA', 'TTACGTTTGTCAATTCATACAATTTAGTAGCATTTTAGGA', 'TTACGTTTATCGATTCACACTGTTTAGTAGCATTTAAAA'], 'solution': ['TTACGTTTACAATTTCATACAGCTTAATAGC-GTGGGGGA', 'TTACGTTTGTCAATTCATACAATTTAGTAGCATTTTAGGA', 'TTACGTTTATCGATTCACACTGTT

In [None]:
import argparse
import sys
from pathlib import Path

ernie_rna_path = Path.cwd() / "ERNIE_RNA"
sys.path.insert(0, str(ernie_rna_path))

from ERNIE_RNA.extract_embedding import extract_embedding_of_ernierna, load_pretrained_ernierna, ErnieRNAOnestage
import torch
    
torch.serialization.add_safe_globals([argparse.Namespace])


starts = [sample['start'] for sample in train_samples[:5]]

model_pretrained = load_pretrained_ernierna(
    './pretrained/ERNIE-RNA_pretrain.pt', { "data": '../ERNIE_RNA/src/dict/' }
)
my_model = ErnieRNAOnestage(model_pretrained.encoder).to(device)

embeddings_list = []
embeddings_2_list = []

for s in starts:
    print("Processing sequence:", s, len(s))
    embeddings = extract_embedding_of_ernierna(s, arg_overrides={ "data": '../ERNIE_RNA/src/dict/' }, pretrained_model_path='./pretrained/ERNIE-RNA_pretrain.pt', device=str(device), model=my_model, if_cls=True) # True/False for CLS or token embeddings
    embeddings_2 = extract_embedding_of_ernierna(s, arg_overrides={ "data": '../ERNIE_RNA/src/dict/' }, pretrained_model_path='./pretrained/ERNIE-RNA_pretrain.pt', device=str(device), model=my_model, if_cls=False) # True/False for CLS or token embeddings

    embeddings_list.append(embeddings)
    embeddings_2_list.append(embeddings_2)

Processing sequence: ['TACTACAGTTCTTAAAAATAATCTATTAAAATTTTTTTGCT', 'TAGTACGATTCGTGAAAATAATCTGTTAAAATTCTTTTTCT', 'TTATACAATTTTTTGAGGATTAATCTGTTGAAATTATTGTTCT'] 3
Model Loading Done!!!
Model Loading Done!!!
Processing sequence: ['AGAAAAAAACGTAGTATTTTTTATATGATTACCCTGAT', 'AGAATAGGATGCAATATTTACTCCATGATTACCTTCCAACT', 'AGTATAGGATGGAATATTTACGCCATAATCACCTTCAGAT'] 3
Model Loading Done!!!
Model Loading Done!!!
Processing sequence: ['TTACGTTTACAATTTCATACAGCTTAATAGCGTGGGGGA', 'TTACGTTTGTCAATTCATACAATTTAGTAGCATTTTAGGA', 'TTACGTTTATCGATTCACACTGTTTAGTAGCATTTAAAA'] 3
Model Loading Done!!!
Model Loading Done!!!
Processing sequence: ['ATCTCCTTTAAATAATAGAAAATCACGGATCTGTCG', 'ATTTCCCTTGAATAATAAAGCCTCAGATCTG', 'ATTTCCCTTAAATAGTAGAAGATCACGGATCTGTCCG'] 3
Model Loading Done!!!
Model Loading Done!!!
Processing sequence: ['CTCCATTGCCAAAATGTTCCAAAAAATTTAAGGATTGTCG', 'CTTCGTTAACAAAATGTACCAAAGAATATAGAGATTTTCA', 'CTACGTTTACATAATATACCGAAAAGTTTAAAGAACTGCA'] 3
Model Loading Done!!!
Model Loading Done!!!


**RESULTS**: When we use if_cls=True, we get a CLS embeddings i.e. should be the full sequence embeddings. If we need the individual token embeddings, we get that with if_cls=False

In [25]:
for list1, list2 in zip(embeddings_list, embeddings_2_list):
    for emb1, emb2 in zip(list1, list2):
        print(emb1.shape, emb2.shape)

(12, 768) (12, 45, 768)
(12, 768) (12, 45, 768)
(12, 768) (12, 45, 768)
(12, 768) (12, 43, 768)
(12, 768) (12, 43, 768)
(12, 768) (12, 43, 768)
(12, 768) (12, 42, 768)
(12, 768) (12, 42, 768)
(12, 768) (12, 42, 768)
(12, 768) (12, 39, 768)
(12, 768) (12, 39, 768)
(12, 768) (12, 39, 768)
(12, 768) (12, 42, 768)
(12, 768) (12, 42, 768)
(12, 768) (12, 42, 768)
