<a href="https://colab.research.google.com/github/Jerryleerbay/leetcode/blob/main/Copy_of_Eso_LMs_Sampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Esoteric Language Models

This Colab demonstrates sample generation with our Eso-LM (B) models released on HuggingFace, trained on OpenWebText for 250K steps.

Our codebase contains the sampler as well. Use our codebase instead if you need to generate a large number of samples.

Currently, neither this notebook nor the codebase supports sampling of our Eso-LM (A) models.

📖 paper: https://arxiv.org/abs/2506.01928

🏕 code: https://github.com/s-sahoo/Eso-LMs

📑 Blog: https://s-sahoo.com/Eso-LMs/

🤗 Huggingface: [Eso-LMs](https://huggingface.co/collections/sahoo-diffusion/eso-lms-6838e86cb2c49f45302f0092)

In [None]:
# try running the cells below first before running this cell
# if cells below run successfully, there's no need to run this cell
! pip install numpy==2.0.2
! pip install torch==2.6.0+cu124
! pip install transformers==4.52.2

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0+cu124)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0+cu124)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0+cu124)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0+cu124)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0+cu124)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0+cu124)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidi

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [None]:
hf_model = AutoModelForMaskedLM.from_pretrained(
    'sahoo-diffusion/Eso-LM-B-alpha-0_25', trust_remote_code=True)
# hf_model = AutoModelForMaskedLM.from_pretrained(
#     'sahoo-diffusion/Eso-LM-B-alpha-1', trust_remote_code=True)

config.json:   0%|          | 0.00/500 [00:00<?, ?B/s]

config.py:   0%|          | 0.00/685 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/sahoo-diffusion/Eso-LM-B-alpha-0_25:
- config.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.py:   0%|          | 0.00/35.6k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/sahoo-diffusion/Eso-LM-B-alpha-0_25:
- model.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/679M [00:00<?, ?B/s]

**NOTE:** Sampling from the model is slow in this Colab because the T4 GPU doesn't support `bfloat16`. For faster sampling, please use our [GitHub code](https://github.com/s-sahoo/Eso-LMs).

## Sampler implementation

Along with helper classes and functions

In [None]:
#@title Log-linear noise schedule class
# Copied from https://github.com/jdeschena/sdtt/blob/bbc54d5b3c5fcffd79602cff17ed34dde1f3eff6/src/sdtt/core/sampling/utils.py#L10
class LogLinear(torch.nn.Module):
  def __init__(self, alpha_0=1):
    super().__init__()
    self.eps = 1e-3  # To be consistent with SEDD: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion/blob/0605786da5ccb5747545e26d66fdf477187598b6/noise_lib.py#L56
    self.alpha_0 = alpha_0

  def forward(self, t):
    t = (1 - self.eps) * t
    alpha_t = self.alpha_0 * (1 - t)
    dalpha_t = - self.alpha_0 * (1 - self.eps)
    return dalpha_t, alpha_t

In [None]:
#@title Helper functions

def top_k_top_p_filtering(
  logits,
  top_k=0,
  top_p=0.0,
  filter_value=-float("Inf"),
  dim=-1):
  """Filter a distribution of logits using top-k/top-p (nucleus) filtering.
  Adapted from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317

  Args:
    logits (Tensor): Tensor of logits
    top_k (int, optional): Number of top values to keep.
        Deactivated if k is 0. Defaults to 0.
    top_p (float, optional): Cumulative mass to retain.
        Deactivated if p = 0. Defaults to 0.0.
    filter_value (float, optional): Fill value to replace
        the entries removed by top-k/top-p filtering.
        Defaults to -float('Inf').
    dim (int, optional): Dimension of the filtering. Defaults to -1.

  Returns:
      logits: Tensor whose axis `dim` was filtered.
  """
  if dim != -1:
    logits = torch.transpose(logits, dim, -1)

  assert top_k < logits.size(dim)
  if top_k > 0:
    # Remove all tokens with a probability less than
    # the last token of the top-k
    values, _ = torch.topk(logits, k=top_k, dim=-1)
    to_remove_mask = (
        logits < torch.min(values, dim=-1, keepdim=True)[0]
    )  # min returns a tuple (values, indices)
    logits[to_remove_mask] = filter_value

  if top_p > 0.0:
    sorted_logits, sorted_indices = torch.sort(
      logits, descending=True, dim=-1)
    cum_probs = torch.cumsum(
      torch.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > top_p
    # Ensures at least one token is kept
    sorted_indices_to_remove[..., 1:] = \
      sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask_to_remove = torch.empty_like(sorted_indices_to_remove)
    mask_to_remove.scatter_(dim=-1,
                            index=sorted_indices,
                            src=sorted_indices_to_remove)
    logits[mask_to_remove] = filter_value

  if dim != -1:
    logits = torch.transpose(logits, dim, -1)

  return logits

def get_reverse_indices(indices):
  """
  indices: LongTensor of shape [B, N] representing permutations
  returns: LongTensor of shape [B, N] representing the inverse permutations
  """
  B, N = indices.shape
  reverse_indices = torch.empty_like(indices)
  arange = torch.arange(N, device=indices.device).unsqueeze(0).expand(B, -1)
  reverse_indices.scatter_(1, indices, arange)
  return reverse_indices

In [None]:
#@title EsoLMBSampler class
class EsoLMBSampler:
  def __init__(self, backbone_wrapper, device):
    self.backbone = backbone_wrapper.backbone.to(device)
    self.tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
    self.device = device
    self.vocab_size = backbone_wrapper.config.vocab_size
    self.mask_index = backbone_wrapper.config.mask_index
    self.num_tokens = backbone_wrapper.config.model_length
    self.rotary_dim = (
        backbone_wrapper.config.hidden_size //
        backbone_wrapper.config.n_heads)
    self.neg_infinity = -1000000.0
    self.noise = None

  def _tokens_unmasked_per_step(self, num_steps):
    remaining_tokens = self.num_tokens
    num_tokens_to_unmask = []
    dt = 1 / num_steps
    # Assumes a log-linear schedule.
    for t in np.linspace(1, dt, num_steps):
      _, alpha_t = self.noise(t)
      _, alpha_s = self.noise(t - dt)
      n_unmask = np.random.binomial(
        remaining_tokens, (alpha_s - alpha_t) / (1 - alpha_t))
      if n_unmask != 0:
        num_tokens_to_unmask.append(n_unmask)
        remaining_tokens -= n_unmask
    if remaining_tokens != 0 and self.noise.alpha_0 == 1:
      num_tokens_to_unmask.append(remaining_tokens)
    return num_tokens_to_unmask

  def prior_sample(self, *batch_dims):
    return self.mask_index * torch.ones(
      * batch_dims, dtype=torch.int64, device=self.device)

  def _sort_rotary_cos_sin(self, rotary_cos_sin, sort_idx):
    # example cos shape: (1, 128, 3, 1, 32)
    # 128 for seq_len, 3 for qkv, 32 for head dim
    cos, sin = rotary_cos_sin
    bs = sort_idx.shape[0]
    cos = cos.expand(bs, -1, -1, -1, -1)
    sin = sin.expand(bs, -1, -1, -1, -1)
    cos = torch.gather(
      cos, dim=1,
      index=sort_idx[:, :, None, None, None].expand(
        -1, -1, 3, -1, self.rotary_dim)).contiguous()
    sin = torch.gather(
      sin, dim=1,
      index=sort_idx[:, :, None, None, None].expand(
        -1, -1, 3, -1, self.rotary_dim)).contiguous()
    return cos, sin

  def _diffusion_features(self, zt, sort_idx):
    x = self.backbone.vocab_embed(zt)
    rotary_cos_sin = self.backbone.rotary_emb(x)
    rotary_cos_sin = self._sort_rotary_cos_sin(
      rotary_cos_sin, sort_idx)
    return {'x': x, 'rotary': rotary_cos_sin}

  def _forward_sample(self, zt, sort_idx,
                      last_k_start, curr_k_start, curr_k_end):
    ones = torch.ones(zt.shape[0], device=zt.device)
    features = self._diffusion_features(zt=zt, sort_idx=sort_idx)
    zeros = torch.zeros(zt.shape[0], device=zt.device)
    t_cond = F.silu(self.backbone.sigma_map(zeros))

    x = features['x']
    rotary = features['rotary']

    x = x[:, last_k_start:curr_k_end, :]
    cos, sin = rotary
    rotary = (cos[:, :curr_k_end], sin[:, :curr_k_end])
    num_clean = curr_k_start - last_k_start
    num_clean_and_mask = curr_k_end - last_k_start

    with torch.amp.autocast('cuda', enabled=False):
      for i in range(len(self.backbone.blocks)):
        x = self.backbone.blocks[i](
          x, rotary, c=t_cond,
          attn_mask=None,
          kv_cache=True,
          num_clean=num_clean,
          num_clean_and_mask=num_clean_and_mask)
      x = self.backbone.output_layer(x, c=t_cond)

    x = x[:, num_clean:, :]

    return x

  @torch.no_grad()
  def generate_samples(self, num_samples,
                       alpha_0=0, num_diffusion_steps=1000,
                       p_nucleus=0.9, use_float64=True):
    """Generate samples from the model with KV caching enabled."""
    self.noise = LogLinear(alpha_0=alpha_0)

    unmask_k_tokens = self._tokens_unmasked_per_step(
        num_diffusion_steps)
    num_diffusion_tokens = sum(unmask_k_tokens)

    # shuffle diffusion tokens to be generated by diffusion
    # don't shuffle tokens to be generated sequentially
    sort_idx = torch.rand(
      num_samples, self.num_tokens).argsort(
        descending=False).to(self.device)
    sort_idx[:, num_diffusion_tokens:] = (
      sort_idx[:, num_diffusion_tokens:].sort().values)

    x = self.prior_sample(num_samples, self.num_tokens)
    x = torch.gather(x, dim=1, index=sort_idx)

    if len(unmask_k_tokens) != 0:
      unmask_k_tokens = unmask_k_tokens + [1] * (
        self.num_tokens - num_diffusion_tokens)
    else:
      unmask_k_tokens = [1] * self.num_tokens
    assert sum(unmask_k_tokens) == self.num_tokens
    noise = torch.distributions.Gumbel(0, 1).sample(
      (num_samples, self.num_tokens,
       self.vocab_size)).to(self.device)
    unmasked_tokens = 0
    self.backbone.reset_kv_cache()
    for i, k in enumerate(unmask_k_tokens):
      if i == 0:
        last_k_start = 0
      else:
        last_k_start = unmasked_tokens - unmask_k_tokens[i-1]
      log_p_x0 = self._forward_sample(
        zt=x,  # shape[1] is model.length
        sort_idx=sort_idx,  # shape[1] is model.length
        last_k_start=last_k_start,
        curr_k_start=unmasked_tokens,  # also last_k_end
        curr_k_end=unmasked_tokens+k)
      if use_float64:
        log_p_x0 = log_p_x0.to(torch.float64)
      log_p_x0[:, :, self.mask_index] = self.neg_infinity
      if p_nucleus < 1:
        # top_k_top_p_filtering takes in logits (normalized or
        # unnormalized) and returns logits (unnormalized)
        log_p_x0 = top_k_top_p_filtering(log_p_x0, top_p=p_nucleus)
      indices = slice(unmasked_tokens, unmasked_tokens + k)
      y = (log_p_x0 + noise[:, indices, :]).argmax(-1)
      x[:, indices] = y
      unmasked_tokens += k
    self.backbone.reset_kv_cache()
    sort_idx_reversed = get_reverse_indices(sort_idx)
    x = torch.gather(x, dim=1, index=sort_idx_reversed)
    return self.tokenizer.batch_decode(x)

In [None]:
sampler = EsoLMBSampler(hf_model, device='cuda')

## Generate samples

$\alpha_0$ is the proportion of tokens generated by diffusion.

$T$ is the number of diffusion steps. $T$ matters less when $\alpha_0^\text{eval}$ is small.

**NOTE**: $\alpha_0^\text{eval}$ used for generating samples can be different from $\alpha_0^\text{train}$ used for training.

$\alpha_0^\text{eval}=1$ with $T=1024$:

In [None]:
# if alpha_0=0.25 for training, gen ppl = 72.36
# if alpha_0=1 for training, gen ppl = 49.4
samples = sampler.generate_samples(
    num_samples=2, alpha_0=1, num_diffusion_steps=1024)
for sample in samples:
  print(sample)
  print('\n' * 5)

 The image in the image was blurry in the form of physical detail that is in between two objects to be blurred.”

The image says the story of what took place after the killing, and Lutmret, whom she described in court, said: “She said it was her God-given Instagram, and she said that it was a little more.“This was a long, die, the meaning: from day one like on a who, from and who is carrying a weapon, [the person carries the weapon], Haidulall told the court, who sought her death by calling for the picture back, “she believed she was actually posting an image of any object that could be the weapon.”

Online reaction

Abeh was shocked when asked about what she was saying, to hear the images in the comments. It seemed as though her friend Uday Saleh was unfamiliar with the question at all.

As the new article about the image, been posted by @Police is using social media media to contact the victim’s family. But that is clearly wrong, but I am now being interviewed by an expert at psychol

$\alpha_0^\text{eval}=0.0625$ with $T=16$ ($T$ matters less when $\alpha_0^\text{eval}$ is small):

In [None]:
# if alpha_0=0.25 for training, gen ppl = 23.95
# if alpha_0=1 for training, gen ppl = 31.33
samples = sampler.generate_samples(
    num_samples=2, alpha_0=0.0625, num_diffusion_steps=16)
for sample in samples:
  print(sample)
  print('\n' * 5)

We will seek congressional action to change his mind and follow the American people's guidance and their verdict."

Sen. John McCain (R-Ariz.), who traveled to Saudi Arabia earlier this week, told reporters he believes that Obama may jeopardize his presidential bid.

"Our people have come to tell us they are ready to go to war," he said. "If politicians and leaders of both our parties in Congress show they are willing to meet with their constituents, then that war has never been and will never be an option."

McCain, who is among a small group of senators who are rallying behind the president, was an early supporter of the Saudi Arabia effort.

Here's why:

(CBS News) Critics warned Obama that Riyadh's young generation may vote Democratic. White House spokesman Josh Moiffer said the administration had never considered the Saudi announcement to sway voters. "There is no question that Saudi Arabia is sending very young Saudi people to Iraq and Yemen and on to terrorist training camps," M