# CARMANIA + Enformer: Chromatin-Guided Sequence Generation

This notebook builds a simple pipeline where:

- **CARMANIA** is used as a genomic language model to generate candidate DNA sequences from a short input prompt.
- **Enformer** is used as an oracle to predict chromatin-related signals (e.g., accessibility and other regulatory tracks) along each candidate sequence.
- A **scalar score** is computed from Enformer’s output (by averaging predicted signal near the center of the sequence), and used to rank the generated sequences.

In other words, CARMANIA proposes DNA sequences, and Enformer provides a regulatory-style score so we can preferentially keep sequences that are predicted to have higher central chromatin “openness.”

> **Note:** Both models are large. Use a GPU runtime (e.g., Google Colab GPU).


In [None]:

!pip install -q enformer-pytorch


## Imports & Device

In [None]:

import torch
from transformers import AutoModel, AutoTokenizer
from enformer_pytorch import from_pretrained as enformer_from_pretrained, seq_indices_to_one_hot

from typing import List, Tuple

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


## Load CARMANIA (DNA LM)

Model: `MsAlEhR/carmania-160k-seqlen-human`

In [2]:

CARMANIA_MODEL_NAME = "MsAlEhR/carmania-160k-seqlen-human"

# CARMANIA README recommends AutoModel + trust_remote_code=True
# The custom class inherits from PreTrainedModel, so .generate() should be available.
carmania = AutoModel.from_pretrained(
    CARMANIA_MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
).to(DEVICE)

carmania_tokenizer = AutoTokenizer.from_pretrained(
    CARMANIA_MODEL_NAME,
    trust_remote_code=True,
    model_max_length=160_000,
)

print("Loaded CARMANIA.")

# quick sanity check: tokenize a tiny sequence
test_inputs = carmania_tokenizer("ACGTAGGCTA", return_tensors="pt").to(DEVICE)
with torch.no_grad():
    test_outputs = carmania(**test_inputs)
print("CARMANIA forward pass OK, type:", type(test_outputs))


## Load Enformer (oracle)

We use `EleutherAI/enformer-official-rough` via `enformer-pytorch`.

In [1]:

enformer = enformer_from_pretrained(
    "EleutherAI/enformer-official-rough"
).to(DEVICE)

enformer.eval()
print("Loaded Enformer.")

# quick sanity check on dummy indices
with torch.no_grad():
    dummy_seq = torch.randint(0, 5, (1, 196_608), device=DEVICE)  # A,C,G,T,N indices
    dummy_onehot = seq_indices_to_one_hot(dummy_seq)              # (1, 196608, 5)
    dummy_out = enformer(dummy_onehot)
    print("Enformer output keys:", dummy_out.keys())
    print("Human head shape:", dummy_out["human"].shape)


## DNA ⇄ Enformer Utilities

In [None]:

DNA_ALPHABET = "ACGTN"
char_to_idx = {c: i for i, c in enumerate(DNA_ALPHABET)}  # A:0, C:1, G:2, T:3, N:4

def dna_to_indices(seq: str) -> torch.Tensor:
    """Convert a DNA string (A,C,G,T,N) to indices tensor of shape (1, L)."""
    seq = seq.upper()
    idxs = [char_to_idx.get(ch, char_to_idx["N"]) for ch in seq]
    return torch.tensor([idxs], dtype=torch.long)

def center_pad_to_enformer_context(seq: str, context_length: int = 196_608) -> torch.Tensor:
    """Embed seq into the center of a context_length window with 'N' padding."""
    L = len(seq)
    if L > context_length:
        # center crop
        start = (L - context_length) // 2
        seq = seq[start:start + context_length]
        L = len(seq)

    pad_total = context_length - L
    pad_left = pad_total // 2
    pad_right = pad_total - pad_left

    padded_seq = "N" * pad_left + seq + "N" * pad_right
    assert len(padded_seq) == context_length

    return dna_to_indices(padded_seq)

# quick check
s = "ACGT" * 10
idxs = center_pad_to_enformer_context(s)
print("Padded indices shape:", idxs.shape)


Padded indices shape: torch.Size([1, 196608])


## Enformer Chromatin Heuristic

In [None]:

@torch.no_grad()
def enformer_chromatin_open_score(seq: str, center_radius_bins: int = 8) -> float:
    """
    A Sample heuristic:
    - Put seq in 196,608 bp context
    - Run Enformer (human head)
    - Take center +/- center_radius_bins bins
    - Average across bins & tracks
    Returns: Python float score.
    """
    idxs = center_pad_to_enformer_context(seq)        # (1, 196608)
    idxs = idxs.to(DEVICE)
    one_hot = seq_indices_to_one_hot(idxs)            # (1, 196608, 5)
    one_hot = one_hot.to(DEVICE)

    outputs = enformer(one_hot)                       # dict with 'human', 'mouse'
    human_pred = outputs["human"]                     # (1, target_len, 5313)

    B = human_pred.shape[1]
    center = B // 2
    start = max(0, center - center_radius_bins)
    end = min(B, center + center_radius_bins + 1)

    center_slice = human_pred[:, start:end, :]        # (1, bins, tracks)
    score = center_slice.mean().item()
    return float(score)

# quick sanity check
test_score = enformer_chromatin_open_score("ACGT" * 50)
print("Test Enformer heuristic score:", test_score)


Test Enformer heuristic score: 0.5320137143135071


## CARMANIA Sampling Helper

In [None]:
@torch.no_grad()
def sample_carmania(
    prompt_seq: str,
    num_samples: int = 4,
    max_new_tokens: int = 64,
    temperature: float = 1.0,
    top_p: float = 0.95,
) -> List[str]:
    """
    Autoregressive sampler for CARMANIA using its logits (no .generate()).

    - Takes a DNA prompt string.
    - Encodes with the CARMANIA tokenizer.
    - Repeats the prompt `num_samples` times as batch.
    - Iteratively samples the next token from logits[:, -1, :].
    - Returns decoded sequences (prompt + continuation), restricted to A/C/G/T.
    """
    # 1) Encode the prompt once
    enc = carmania_tokenizer(
        prompt_seq,
        return_tensors="pt",
        add_special_tokens=True,
    )
    input_ids = enc["input_ids"].to(DEVICE)          # (1, L)
    # Repeat for batch size = num_samples
    input_ids = input_ids.repeat(num_samples, 1)     # (B, L)

    # 2) Autoregressive loop
    for _ in range(max_new_tokens):
        outputs = carmania(input_ids=input_ids)      # CausalLMOutput
        logits = outputs.logits                      # (B, seq_len, vocab_size)
        next_token_logits = logits[:, -1, :]         # (B, vocab_size)

        # Temperature
        if temperature is not None and temperature > 0.0:
            next_token_logits = next_token_logits / temperature



        # Top-p (nucleus) filtering
        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            probs = torch.softmax(sorted_logits, dim=-1)
            cumulative_probs = probs.cumsum(dim=-1)

            # filter tokens with cumulative prob above top_p
            sorted_indices_to_remove = cumulative_probs > top_p
            # shift right so we always keep at least the first token
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            # scatter back to original ordering
            indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
            indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits = next_token_logits.masked_fill(indices_to_remove, float("-inf"))

        # Sample from the filtered distribution
        probs = torch.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1)   # (B, 1)

        # Append to input_ids
        input_ids = torch.cat([input_ids, next_tokens], dim=1)  # (B, L+1)

    # 3) Decode all sequences
    sequences = carmania_tokenizer.batch_decode(
        input_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    # 4) Restrict to DNA alphabet
    cleaned = []
    for s in sequences:
        s = s.upper()
        s = "".join(ch for ch in s if ch in DNA_ALPHABET)
        cleaned.append(s)

    return cleaned


# quick sanity check sampling
test_samples = sample_carmania("ACGTACGT", num_samples=2, max_new_tokens=16)
print("Sampled sequences:", test_samples)


Sampled sequences: ['ACGTACGTCGGACATCTTGAATAC', 'ACGTACGTTCTACCAATTACTAGT']


##  Guided Beam Search

In [None]:

def guided_beam_search(
    prompt_seq: str,
    num_beams: int = 4,
    fan_out: int = 4,
    num_steps: int = 4,
    max_new_tokens_per_step: int = 64,
) -> List[Tuple[str, float]]:
    """
     beam search:
    - CARMANIA as generator
    - Enformer as oracle for chromatin accessibility heuristic.
    """
    init_score = enformer_chromatin_open_score(prompt_seq)
    beams = [(prompt_seq, init_score)]
    print(f"Initial prompt score: {init_score:.4f}")

    for step in range(num_steps):
        print(f"\n=== Step {step+1}/{num_steps} ===")
        candidates: List[Tuple[str, float]] = []

        for seq, _ in beams:
            samples = sample_carmania(
                seq,
                num_samples=fan_out,
                max_new_tokens=max_new_tokens_per_step,
            )

            for s in samples:
                try:
                    score = enformer_chromatin_open_score(s)
                except RuntimeError as e:
                    print("Enformer error for candidate, skipping:", e)
                    continue
                candidates.append((s, score))

        candidates.extend(beams)
        candidates.sort(key=lambda x: x[1], reverse=True)
        beams = candidates[:num_beams]

        print("Top beams this step:")
        for i, (seq, score) in enumerate(beams):
            print(f"  Beam {i+1}: score = {score:.4f}, len = {len(seq)}")

    return beams


## Run the Guided Search

In [None]:

prompt = "ACGT" * 5

final_beams = guided_beam_search(
    prompt_seq=prompt,
    num_beams=3,
    fan_out=3,
    num_steps=10,
    max_new_tokens_per_step=64,
)

print("\n=== Final beams ===")
for i, (seq, score) in enumerate(final_beams):
    print(f"\nBeam {i+1}")
    print(f"Score: {score:.4f}")
    print(f"Length: {len(seq)}")
    preview = seq[:300] + ("..." if len(seq) > 300 else "")
    print(preview)


Initial prompt score: 0.5407

=== Step 1/10 ===
Top beams this step:
  Beam 1: score = 0.5951, len = 84
  Beam 2: score = 0.5407, len = 20
  Beam 3: score = 0.5048, len = 84

=== Step 2/10 ===
Top beams this step:
  Beam 1: score = 0.6756, len = 148
  Beam 2: score = 0.6117, len = 148
  Beam 3: score = 0.5951, len = 84

=== Step 3/10 ===
Top beams this step:
  Beam 1: score = 0.6756, len = 148
  Beam 2: score = 0.6516, len = 212
  Beam 3: score = 0.6429, len = 212

=== Step 4/10 ===
Top beams this step:
  Beam 1: score = 0.8394, len = 276
  Beam 2: score = 0.8118, len = 276
  Beam 3: score = 0.7553, len = 276

=== Step 5/10 ===
Top beams this step:
  Beam 1: score = 1.2665, len = 340
  Beam 2: score = 1.1864, len = 340
  Beam 3: score = 1.1265, len = 340

=== Step 6/10 ===
Top beams this step:
  Beam 1: score = 1.5185, len = 404
  Beam 2: score = 1.3278, len = 404
  Beam 3: score = 1.2665, len = 340

=== Step 7/10 ===
Top beams this step:
  Beam 1: score = 1.6594, len = 468
  Beam 2: s

The above result  shows the top three sequences returned by the chromatin-guided search.  
Each “Beam” corresponds to one candidate DNA sequence generated by CARMANIA and then scored by Enformer:

- **Score** is the Enformer-based chromatin activity heuristic: higher values indicate stronger predicted regulatory signal in the central region of the sequence.

- The sequences share a common prefix from the initial prompt, while the later positions reflect CARMANIA’s sampled continuation that Enformer predicts to be highly active.


In [None]:
s