In [None]:
from dataclasses import dataclass, field
from typing import Iterator

from src.utils import get_vocabulary

In [None]:
def is_palindrome(words: list[str]) -> bool:
    """Check if word sequence forms palindrome"""
    s = "".join(words)
    return s == s[::-1]


def find_palindrome_centers(words: list[str]) -> list[tuple[str, int]]:
    """
    Find all potential palindrome centers in words, excluding trivial cases.
    Returns list of (word, position) tuples where position could be center of palindrome.
    Only returns positions where there's a true palindrome opportunity (entire left substring
    matches reverse of right substring).

    Time complexity: O(N * L) where N is number of words and L is average word length
    Space complexity: O(K) where K is number of valid palindrome centers found
    """
    results = []

    for word in words:
        length = len(word)
        if length < 3:  # Skip very short words
            continue

        # For each position (excluding first two and last two characters)
        for pos in range(2, length - 2):
            # Skip center position of word
            if pos == length // 2:
                continue

            # Get entire left and right substrings
            left = word[:pos]
            right = word[pos + 1 :]

            # Check if left matches reverse of right (up to shorter length)
            min_length = min(len(left), len(right))
            if min_length > 0 and left[-min_length:] == right[:min_length][::-1]:
                results.append((word, pos))

    return results


@dataclass
class PalindromeFinder:
    vocabulary: set[str] = field(default_factory=set)
    prefix_cache: dict[str, set[str]] = field(default_factory=dict)
    suffix_cache: dict[str, set[str]] = field(default_factory=dict)

    def __post_init__(self) -> None:
        """Initialize caches after instance creation"""
        self._build_caches()

    def _build_caches(self) -> None:
        """Build prefix and suffix caches for efficient word lookup"""
        for word in self.vocabulary:
            for i in range(1, len(word) + 1):
                prefix, suffix = word[:i], word[-i:]
                self.prefix_cache.setdefault(prefix, set()).add(word)
                self.suffix_cache.setdefault(suffix, set()).add(word)

    @staticmethod
    def is_palindrome(words: list[str]) -> bool:
        """Check if word sequence forms palindrome"""
        s = "".join(words)
        return s == s[::-1]

    def find_matches(self, pattern: str, match_start: bool = True) -> set[str]:
        """Find all words that start/end with pattern"""
        return (
            self.prefix_cache.get(pattern, set())
            if match_start
            else self.suffix_cache.get(pattern, set())
        )

    def find_mismatch(self, words: list[str], center_pos: int) -> tuple[str, bool]:
        """
        Find what needs to be matched and on which side.
        Returns (unmatched_portion, needs_right_match)
        """
        s = "".join(words)
        left, right = s[:center_pos], s[center_pos + 1 :]

        # Find length of matching portion
        match_len = 0
        for i in range(min(len(left), len(right))):
            if left[-(i + 1)] != right[i]:
                break
            match_len = i + 1

        # Get unmatched portions
        left_unmatched = left[:-match_len] if match_len else left
        right_unmatched = right[match_len:] if match_len else right

        # Return longer unmatched portion and whether we need to match on right
        return (
            (left_unmatched, True)
            if len(left_unmatched) >= len(right_unmatched)
            else (right_unmatched, False)
        )

    def grow_palindromes(
        self, words: list[str], center_pos: int, depth: int = 5
    ) -> Iterator[str]:
        """
        Recursively grow palindromes from initial words.
        Yields valid palindromes as space-separated strings.
        """
        if depth <= 0:
            return

        if self.is_palindrome(words):
            yield " ".join(words)
            return

        mismatch, needs_right = self.find_mismatch(words, center_pos)
        if not mismatch:
            return

        # Find matching words for the reversed mismatch pattern
        pattern = mismatch[::-1]
        matches = self.find_matches(pattern, match_start=needs_right)

        for word in matches:
            new_words = words + [word] if needs_right else [word] + words
            new_center = center_pos if needs_right else center_pos + len(word)

            if self.is_palindrome(new_words):
                yield " ".join(new_words)
            yield from self.grow_palindromes(new_words, new_center, depth - 1)


def filter_palindrome(palindrome: str, min_avg_length: int = 5) -> bool:
    """
    Filter palindromes based on criteria:
    - No repeated words
    - Minimum average word length
    - First word doesn't contain palindrome prefix
    """
    words = palindrome.split()

    # Check for repeated words
    if len(set(words)) != len(words):
        return False

    # Check average word length
    if sum(len(word) for word in words) / len(words) < min_avg_length:
        return False

    # Check first word for palindrome prefix
    first_word = words[0]
    return not any(
        PalindromeFinder.is_palindrome([first_word[:i]]) for i in range(2, 7)
    )

In [None]:
vocabulary = get_vocabulary(top_n=50000)
print(vocabulary)

results = find_palindrome_centers(vocabulary)
print(results)
print(f"{len(results)} potential palindrome centers: {results[:5]}...")

finder = PalindromeFinder(vocabulary)
print("Finding palindromes...")

for palindrome in finder.grow_palindromes(["be"], center_pos=0, depth=6):
    if filter_palindrome(palindrome):
        print(palindrome)

## Diffusion Language Models


https://colab.research.google.com/drive/1506OVHAzm7jIleB7oHGoSbuQuytaFWvu#scrollTo=ewd0Ky7VOhqq
https:// github.com/ML-GSAI/LLaDA/tree/main
https://ml-gsai.github.io/LLaDA-demo/
https://huggingface.co/GSAI-ML/LLaDA-8B-Base
https://github.com/hamishivi/tess-2


In [2]:
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt

# import F
import torch.nn.functional as F

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [3]:
model_name = "GSAI-ML/LLaDA-8B-Base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(device)

print(f"Model loaded: {model.__class__.__name__}")
print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

print("Model config:")
print(model.config)

A new version of the following files was downloaded from https://huggingface.co/GSAI-ML/LLaDA-8B-Base:
- configuration_llada.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/GSAI-ML/LLaDA-8B-Base:
- modeling_llada.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards: 100%|██████████| 6/6 [03:16<00:00, 32.71s/it]
Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
prompt = "Hi!"
num_steps = 1
print(f"\nGenerating from prompt: '{prompt}'")

# Tokenize
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"Input shape: {input_ids.shape}")

# Based on convert_to_simplex from sdlm/utils.py
# This creates a one-hot encoding and scales it
simplex_value = 5.0  # From SDLM code examples
vocab_size = model.config.vocab_size
simplex = 2 * simplex_value * torch.nn.functional.one_hot(input_ids, vocab_size).float() - simplex_value
simplex = simplex.to(device)
print(f"Simplex shape: {simplex.shape}")

# Create a span mask (what tokens to generate)
# In SDLM, this would indicate which tokens to diffuse/generate
# True = positions to generate, False = fixed input
span_mask = torch.zeros_like(input_ids, dtype=torch.bool)
prefix_len = input_ids.shape[1]  # Use input as prefix

# Let's add some positions for generated text
max_new_tokens = 10
# Extend input_ids, simplex and span_mask
padding = torch.zeros(1, max_new_tokens, dtype=input_ids.dtype).to(device)
full_input_ids = torch.cat([input_ids, padding], dim=1)

# Extend simplex with random noise for new positions
noise_shape = (1, max_new_tokens, vocab_size)
random_simplex = (torch.randn(noise_shape).to(device) * simplex_value)
full_simplex = torch.cat([simplex, random_simplex], dim=1)

# Set span mask: False for input, True for positions to generate
full_span_mask = torch.zeros_like(full_input_ids, dtype=torch.bool)
full_span_mask[:, prefix_len:] = True

print(f"Full input shape: {full_input_ids.shape}")
print(f"Full simplex shape: {full_simplex.shape}")
print(f"Full span mask shape: {full_span_mask.shape}")
print(f"Positions to generate (span_mask=True): {full_span_mask.sum().item()}")

In [None]:
results = []

# Start diffusion process - from high noise to low noise
for step in range(num_steps, -1, -1):
    # Calculate timestep (noise level)
    timestep = torch.ones_like(full_input_ids).float() * (step / num_steps)
    timestep = timestep.to(device)
    print(f"\nStep {num_steps-step}/{num_steps}: Timestep={timestep[0, 0].item():.2f}")

    # In SDLM, there would be a self-conditioning step here using previous predictions
    # For simplicity, we'll skip that

    # Forward pass through model
    with torch.no_grad():
        # Similar to warp_timesteps in sdlm models
        print(f"  Running model inference...")

        # We'd need the actual timestep embedding for full implementation
        # For now, we'll assume the model knows what to do with the timestep

        # Turn simplex values into probabilities with softmax
        probs = torch.nn.functional.softmax(full_simplex, dim=-1)

        # Similar to model.vocab_to_hidden_dim_embed in SDLM
        inputs_embeds = torch.matmul(probs, model.get_input_embeddings().weight)

        # For actual TESS2, we'd need proper timestep embedding
        # Here we do a simple approach
        timestep_embed = timestep.unsqueeze(-1) * 0.1  # Simple scaling

        # Apply span mask - use original word embeddings for input tokens
        original_embeds = model.get_input_embeddings()(full_input_ids)

        # Combine embeddings based on span mask
        combined_embeds = torch.where(
            full_span_mask.unsqueeze(-1),
            inputs_embeds + timestep_embed,
            original_embeds
        )

        # Run model
        outputs = model(inputs_embeds=combined_embeds)
        hidden_states = outputs.last_hidden_state

        # Project back to vocabulary space
        logits = torch.matmul(hidden_states, model.get_input_embeddings().weight.transpose(0, 1))

        # Sample from logits (from sdlm/inference/inference_utils.py)
        selected_tokens = torch.argmax(logits, dim=-1)

        # Update the simplex values for generated positions
        new_simplex = 2 * simplex_value * torch.nn.functional.one_hot(
            selected_tokens, vocab_size
        ).float() - simplex_value

        # Only update the positions we're generating
        full_simplex = torch.where(
            full_span_mask.unsqueeze(-1),
            (1 - 0.1) * full_simplex + 0.1 * new_simplex,  # Simple update rule
            full_simplex
        )

        # Update the input_ids with newly selected tokens
        full_input_ids = torch.where(
            full_span_mask,
            selected_tokens,
            full_input_ids
        )

        # Print current state
        generated_text = tokenizer.decode(full_input_ids[0], skip_special_tokens=True)
        print(f"  Generated so far: {generated_text}")
        results.append(generated_text)

print("\nFinal generation:", results[-1])