# CARMANIA Unguided LM Generation 

This notebook:

1. Loads the **CARMANIA** genomic language model.
2. Implements **LM scoring** .
3. Lets you **generate sequences from a prompt**, rank them by LM score,
   and **save the top sequences to a FASTA file**.

You only need to:
- Set your `prompt` and generation parameters in the last cell.
- Run the notebook **top to bottom** in a fresh session.


In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from typing import List, Tuple
import numpy as np
import torch.nn.functional as F
import os

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Model name
CARMANIA_MODEL_NAME = "MsAlEhR/carmania-160k-seqlen-human"

# Load tokenizer & model
carmania = AutoModel.from_pretrained(
    CARMANIA_MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
).to(DEVICE)

carmania_tokenizer = AutoTokenizer.from_pretrained(
    CARMANIA_MODEL_NAME,
    trust_remote_code=True,
    model_max_length=2000,
)
carmania.eval()

DNA_ALPHABET = set("ACGT")

print("CARMANIA and tokenizer loaded.")

Using device: cuda
CARMANIA and tokenizer loaded.


## Helper functions 

In [2]:
from typing import List

def prepare_batch_carmania(
    seqs: List[str],
    tokenizer,
    device: str = DEVICE,
    add_special_tokens: bool = True,
):
    """
    Tokenize + pad. Builds attention_mask if tokenizer doesn't provide one.
    
    Returns:
      input_ids      [B, L]
      attention_mask [B, L] (1 = real token, 0 = pad)
      token_lengths  list[int] (number of real tokens per seq)
    """
    enc = tokenizer(
        seqs,
        return_tensors="pt",
        padding=True,
        truncation=False,
        add_special_tokens=add_special_tokens,
    )

    input_ids = enc["input_ids"].to(device)

    if "attention_mask" in enc:
        attention_mask = enc["attention_mask"].to(device)
    else:
        pad_id = tokenizer.pad_token_id
        if pad_id is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
        else:
            attention_mask = (input_ids != pad_id).long().to(device)

    token_lengths = attention_mask.sum(dim=1).tolist()
    return input_ids, attention_mask, token_lengths

In [3]:
def logits_to_logprobs_carmania(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    trim_bos: bool = True,
) -> torch.Tensor:
    """
    Evo-style: log p(x_t | x_<t) per position.

    logits: [B, L, V]
    input_ids: [B, L]

    Returns:
      logprobs: [B, L'] where L' = L-1 if trim_bos else L
    """
    softmax_logprobs = torch.log_softmax(logits, dim=-1)   # [B, L, V]

    if trim_bos:
        # use prediction at pos t-1 to score token at pos t
        softmax_logprobs = softmax_logprobs[:, :-1, :]     # [B, L-1, V]
        target_ids = input_ids[:, 1:]                      # [B, L-1]
    else:
        target_ids = input_ids                             # [B, L]

    logprobs = torch.gather(
        softmax_logprobs,             # [B, L', V]
        dim=2,
        index=target_ids.unsqueeze(-1)  # [B, L', 1]
    ).squeeze(-1)                     # [B, L']

    return logprobs

In [4]:
def score_sequences_carmania(
    seqs: List[str],
    model,
    tokenizer,
    reduce_method: str = "mean",
    device: str = DEVICE,
) -> List[float]:
    """
    LM log-likelihood scoring for CARMANIA (Evo-style).

    reduce_method:
      'mean' -> average log p per token
      'sum'  -> total log p
    """
    input_ids, attention_mask, token_lengths = prepare_batch_carmania(
        seqs, tokenizer, device=device, add_special_tokens=True
    )

    with torch.inference_mode():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, L, V]

    # per-position log p(x_t | x_<t)
    logprobs = logits_to_logprobs_carmania(
        logits,
        input_ids,
        trim_bos=True,
    )  # [B, L-1]

    logprobs = logprobs.float().cpu().numpy()

    if reduce_method == "mean":
        reduce_fn = np.mean
    elif reduce_method == "sum":
        reduce_fn = np.sum
    else:
        raise ValueError(f"Invalid reduce_method {reduce_method}")

    scores = []
    for i, L in enumerate(token_lengths):
        effective_len = int(L) - 1  # because we trimmed BOS
        seq_logprobs = logprobs[i][:effective_len]
        scores.append(reduce_fn(seq_logprobs))

    return scores

## Sampling function

In [5]:
@torch.no_grad()
def sample_carmania(
    prompt_seq: str,
    num_samples: int = 4,
    max_new_tokens: int = 64,
    temperature: float = 1.0,
    top_p: float = 0.7,
):
    """
    Autoregressive sampler for CARMANIA using logits directly (no .generate()).

    - Takes a DNA prompt string.
    - Repeats it num_samples times as batch.
    - Iteratively samples next tokens from logits[:, -1, :].
    - Returns:
        sequences: cleaned DNA strings (prompt + continuation)
        scores: mean log-prob of sampled tokens (per sequence)
    """
    enc = carmania_tokenizer(prompt_seq, return_tensors="pt", add_special_tokens=True)
    input_ids = enc["input_ids"].to(DEVICE)
    input_ids = input_ids.repeat(num_samples, 1)

    step_logprobs = []  # list of [B]

    for _ in range(max_new_tokens):
        outputs = carmania(input_ids=input_ids)
        logits = outputs.logits              # (B, L, V)

        # work in float32 to avoid fp16 overflow issues
        next_logits = logits[:, -1, :].float()   # (B, V)

        # temperature
        if temperature is not None and temperature > 0.0:
            next_logits = next_logits / temperature

        # top-p (nucleus) sampling
        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
            probs = torch.softmax(sorted_logits, dim=-1)
            cdf = probs.cumsum(-1)

            mask = cdf > top_p
            mask[..., 1:] = mask[..., :-1]
            mask[..., 0] = False

            rm = torch.zeros_like(next_logits, dtype=torch.bool)
            rm.scatter_(1, sorted_idx, mask)
            next_logits = next_logits.masked_fill(rm, -1e9)

        probs = torch.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, 1)  # (B, 1)

        # log-prob of chosen token
        step_lp = torch.log(probs.gather(1, next_token).squeeze(1) + 1e-12)  # (B,)
        step_logprobs.append(step_lp)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    # decode sequences
    raw_seqs = carmania_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    cleaned = []
    for s in raw_seqs:
        s_up = s.upper()
        only_dna = ''.join(c for c in s_up if c in DNA_ALPHABET)
        cleaned.append(only_dna if only_dna else s_up)

    # mean log-prob per sequence
    logprob_tensor = torch.stack(step_logprobs, dim=0)  # (T, B)
    scores = logprob_tensor.mean(dim=0).cpu().tolist()

    return cleaned, scores

## Generate sequences, rank them, and save top ones to FASTA

In [6]:
def generate_and_save_fasta(
    prompt: str,
    num_samples: int = 32,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 0.9,
    top_k_to_save: int = 10,
    fasta_path: str = "carmania_top_sequences.fasta",
):
    """
    High-level helper:
      1) Sample `num_samples` sequences from CARMANIA given `prompt`.
      2) Score all sequences with LM log-likelihood (Evo-style).
      3) Rank by LM score (descending).
      4) Save top_k_to_save sequences to a FASTA file.

    Returns:
      ranked: list of (sequence, score) sorted by score desc.
      fasta_path: path to the written FASTA file.
    """
    print(f"Sampling {num_samples} sequences from prompt: {prompt!r}")
    seqs, sample_scores = sample_carmania(
        prompt_seq=prompt,
        num_samples=num_samples,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    print("Scoring sequences with CARMANIA LM...")
    lm_scores = score_sequences_carmania(
        seqs,
        model=carmania,
        tokenizer=carmania_tokenizer,
        reduce_method="mean",
    )

    ranked = sorted(
        zip(seqs, lm_scores),
        key=lambda x: x[1],
        reverse=True,
    )

    # Save top K to FASTA
    top_k = min(top_k_to_save, len(ranked))
    with open(fasta_path, "w") as f:
        for i in range(top_k):
            seq, score = ranked[i]
            header = f">carmania_seq_{i+1}_score_{score:.4f}"
            f.write(header + "\n")
            f.write(seq + "\n")

    print(f"Saved top {top_k} sequences to {os.path.abspath(fasta_path)}")
    return ranked, fasta_path

## Example: run generation + save FASTA for this session

In [7]:
# <<< EDIT THESE PARAMETERS AS YOU LIKE >>>
prompt = "CTTTCTGTCCCGCCCTTCCTCTGACTGTGTCTTGATT"
num_samples = 32
max_new_tokens = 128
temperature = 1.0
top_p = 0.9
top_k_to_save = 10
fasta_path = "carmania_top_sequences.fasta"

ranked, fasta_file = generate_and_save_fasta(
    prompt=prompt,
    num_samples=num_samples,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    top_k_to_save=top_k_to_save,
    fasta_path=fasta_path,
)

# Show top 3 in the notebook
for i in range(min(3, len(ranked))):
    seq, score = ranked[i]
    print(f"Rank {i+1} | score={score:.4f} | length={len(seq)}")
    print(seq[:200] + ('...' if len(seq) > 200 else ''))
    print()

Sampling 32 sequences from prompt: 'CTTTCTGTCCCGCCCTTCCTCTGACTGTGTCTTGATT'
Scoring sequences with CARMANIA LM...
Saved top 10 sequences to /scratch/home/sr3622/Firm-DTI/Firm-DTI2/carmania_top_sequences.fasta
Rank 1 | score=-0.5021 | length=165
CTTTCTGTCCCGCCCTTCCTCTGACTGTGTCTTGATTTTCTTTTTTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCTCTTCCCTTCCCTTCTCTTCTCTTCCCTTCTCTTCCCTTCCCTTCTCTTC

Rank 2 | score=-0.5134 | length=165
CTTTCTGTCCCGCCCTTCCTCTGACTGTGTCTTGATTCTTTTTCCCTCCCTCCTTCTCTCTCTCTCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCCCTCC

Rank 3 | score=-0.5654 | length=165
CTTTCTGTCCCGCCCTTCCTCTGACTGTGTCTTGATTTTTATTTTTATTTTTTTTTGAGATAAAGTCTTGCTCTGTCACCCAGGCTGGAGTGCAGTGGCACGATCTCAGCTCACTGCAACCTCCACCTCCCAGGTTCAAGTGATTCTTGTGCCTCAGCCTCCTGA

