In [None]:
import os, json, math, random
from pathlib import Path
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ROOT = Path("/workspace/mnist_ascii_dataset")
TRAIN_MANIFEST = ROOT / "train_manifest.jsonl"
TEST_MANIFEST  = ROOT / "test_manifest.jsonl"

# Model params
W = 20
H = 10
L = W * H

EMB_DIM   = 64
N_LAYERS  = 4
N_HEADS   = 2
FF_DIM    = 256
BETAS     = (1e-4, 2e-2)
TIMESTEPS = 400

BATCH_SIZE   = 32          # consider vram
EPOCHS       = 20
LR           = 1e-3
GRAD_CLIP    = 1.0
GRAD_ACCUM   = 4           # effective batch = BATCH_SIZE * GRAD_ACCUM

CFG_P_DROP   = 0.1         # classifier-free guidance drop prob
CFG_NULL_ID  = 10          # null class id
CFG_SCALE    = 3.0         # guidance at sampling

LAMBDA_CE    = 0.5         # auxiliary CE loss on decoded tokens

SEED = 42
torch.manual_seed(SEED); random.seed(SEED)

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
try:
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    torch.backends.cuda.enable_math_sdp(False)
except:
    pass
    


In [None]:
import unicodedata

def load_manifest(path: Path) -> List[Dict]:
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

# ...existing code...
# ...existing code...
def normalize_ascii_grid(s: str, W: int, H: int, shift_left: bool = False) -> str:
    """
    Normalize raw ASCII-art text into a fixed W*H grid (row-major string).
    """
    s = unicodedata.normalize("NFKC", s)

    subst = {
        '\u00A0': ' ',  '\u2002': ' ', '\u2003': ' ', '\u2004': ' ', '\u2005': ' ',
        '\u2006': ' ', '\u2007': ' ', '\u2008': ' ', '\u2009': ' ', '\u200A': ' ',
        '\u202F': ' ', '\u205F': ' ', '\u3000': ' ',
        '\u2212': '-', '\u2013': '-', '\u2014': '-', '\u2010': '-', '\u2011': '-',
        '\uFF0E': '.', '\u2024': '.', '\uFF0B': '+', '\uFF03': '#', '\uFF0D': '-',
        '\uFF1A': ':', '\uFF1D': '=', '\uFF20': '@', '\uFF0A': '*', '\uFF05': '%'
    }

    out_chars = []
    for ch in s:
        if ch == '\r':
            continue
        if ch == '\n':
            out_chars.append('\n')
            continue
        ch = subst.get(ch, ch)
        if 32 <= ord(ch) <= 126 and ch != '\t':
            out_chars.append(ch)
        else:
            out_chars.append(' ')
    s = "".join(out_chars)

    lines = s.splitlines()

    while lines and not lines[-1].strip():
        lines.pop()

    if shift_left:
        non_blank = [ln for ln in lines if ln.strip()]
        if non_blank:
            min_lead = min(len(ln) - len(ln.lstrip(' ')) for ln in non_blank)
            if min_lead > 0:
                lines = [ln[min_lead:] if len(ln) > min_lead else "" for ln in lines]

    fixed = [(ln[:W] + " " * max(0, W - len(ln))) for ln in lines]
    final_lines = fixed[:H] + [" " * W] * max(0, H - len(fixed))
    grid = "".join(final_lines)

    return grid
# ...existing code...
# ...existing code...



class AsciiDigits(Dataset):
    def __init__(self, manifest_path: Path, W: int, H: int, build_vocab_from: Path):
        self.recs = load_manifest(manifest_path)
        self.vocab = self._build_vocab(build_vocab_from, W, H)
        self.stoi = {ch:i for i,ch in enumerate(self.vocab)}
        self.itos = {i:ch for ch,i in self.stoi.items()}
        
    def _build_vocab(self, manifest_path: Path, W: int, H: int) -> List[str]:
        chars = set([" "])
        for i, entry in enumerate(load_manifest(manifest_path)):
            txt = Path(entry["ascii_txt_path"]).read_text(encoding="utf-8", errors="ignore")
            grid = normalize_ascii_grid(txt, W, H, shift_left=True)
            # print(grid)
            chars.update(set(grid))
        return sorted(list(chars))

    def __len__(self): return len(self.recs)

    def __getitem__(self, idx):
        j = self.recs[idx]
        txt = Path(j["ascii_txt_path"]).read_text(encoding="utf-8", errors="ignore")
        # if idx == 0:
        #     print(txt)
        grid = normalize_ascii_grid(txt, W, H, shift_left=True)   # length L
        # if idx == 0:
        #     print(grid)
        # Map to token ids (unknowns -> space)
        # print(grid)
        ids = torch.tensor([ self.stoi.get(ch, self.stoi[" "]) for ch in grid], dtype=torch.long)
        label = torch.tensor(int(j["label"]), dtype=torch.long)
        return ids, label


train_ds = AsciiDigits(TRAIN_MANIFEST, W, H, build_vocab_from=TRAIN_MANIFEST)
test_ds  = AsciiDigits(TEST_MANIFEST,  W, H, build_vocab_from=TRAIN_MANIFEST)  # build vocab from train

VOCAB = train_ds.vocab
V = len(VOCAB)
print(f"Vocab size: {V}  (example: {VOCAB[:20]}...) | Sequence length L={L}")
tensor, label = train_ds[20]
print(tensor)
print(label)

ids, lbl = train_ds[20]
print("Label:", lbl)
print("Unique id values:", torch.unique(ids))
print("Decoded (first 200 chars):")
print("".join(train_ds.itos[i.item()] for i in ids[:200]).replace(" ", "·"))

raw_path = Path(train_ds.recs[20]["ascii_txt_path"])
raw = raw_path.read_text(encoding="utf-8", errors="ignore")
print("RAW UNIQUE CODEPOINTS (sample 20):")
print(sorted({(repr(c), hex(ord(c))) for c in raw if c not in '\n\r'}))


Vocab size: 10  (example: [' ', '#', '%', '*', '+', '-', '.', ':', '=', '@']...) | Sequence length L=200
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 6, 5, 0, 0, 0, 0, 0, 0, 0, 0, 8, 2, 5, 0, 0, 0, 0, 0,
        0, 0, 8, 9, 5, 0, 0, 0, 0, 0, 0, 0, 7, 2, 9, 3, 0, 0, 0, 0, 0, 0, 1, 9,
        4, 0, 0, 0, 0, 0, 0, 0, 0, 7, 9, 9, 4, 0, 6, 7, 4, 3, 9, 9, 5, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 5, 9, 9, 9, 9, 9, 9, 2, 9, 9, 4, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 3, 3, 8, 5, 6, 0, 0, 9, 9, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 4, 9, 9, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 6, 8, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
tensor(4)
Label: tensor(4)
Unique id values: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Decoded (first 200 chars):
······························.-········=%-·······=@-·······:%@*······#@+········:@@+·.:+*@@-·········-@@@@@@%@

In [3]:
def make_beta_schedule(timesteps: int, beta_start: float, beta_end: float):
    betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
    alphas = 1.0 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    return betas, alphas, alpha_bar

BETAS_T, ALPHAS_T, ALPHAS_BAR_T = make_beta_schedule(TIMESTEPS, *BETAS)
BETAS_T  = BETAS_T.to(device)
ALPHAS_T = ALPHAS_T.to(device)
ALPHAS_BAR_T = ALPHAS_BAR_T.to(device)

def t_to_alpha_bar(t: torch.Tensor) -> torch.Tensor:
    # t in [0, T-1], integer
    return ALPHAS_BAR_T[t]

def sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Standard transformer-style time embeddings.
    timesteps: (B,)
    returns: (B, dim)
    """
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(
        torch.linspace(math.log(1e-4), math.log(1.0), half, device=device)
    )
    # Shape: (B, half)
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0,1))
    return emb


In [4]:
import math, torch.nn as nn, torch.nn.functional as F
from torch.utils.checkpoint import checkpoint_sequential

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(1))
    def forward(self, x):  # (L,B,D)
        L = x.size(0)
        return x + self.pe[:L]

def sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1e-4), math.log(1.0), half, device=timesteps.device))
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2 == 1: emb = F.pad(emb, (0,1))
    return emb

class Denoiser(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, n_layers: int, n_heads: int, ff_dim: int, seq_len: int, num_classes: int=11, use_ckpt=True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, emb_dim)
        self.cls_emb   = nn.Embedding(num_classes, emb_dim)   # includes null class at id=CFG_NULL_ID
        self.time_mlp  = nn.Sequential(nn.Linear(emb_dim, ff_dim), nn.SiLU(), nn.Linear(ff_dim, emb_dim))
        self.pos = PositionalEncoding(emb_dim, seq_len)

        layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, dim_feedforward=ff_dim,
                                           batch_first=False, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers, enable_nested_tensor=False)
        self.out = nn.Linear(emb_dim, emb_dim)  # predict noise in embedding space
        self.use_ckpt = use_ckpt

        nn.init.normal_(self.token_emb.weight, std=0.02)
        nn.init.normal_(self.cls_emb.weight, std=0.02)

    def forward(self, x_t: torch.Tensor, t: torch.Tensor, y: torch.Tensor):
        B,L,D = x_t.shape
        t_emb = self.time_mlp(sinusoidal_embedding(t, D))   # (B,D)
        c_emb = self.cls_emb(y)                             # (B,D)
        h = x_t + (t_emb + c_emb).unsqueeze(1)              # (B,L,D)
        h = h.transpose(0,1)                                 # (L,B,D)
        h = self.pos(h)
        if self.training and self.use_ckpt:
            h = checkpoint_sequential(self.encoder.layers, len(self.encoder.layers), h)
        else:
            h = self.encoder(h)
        h = h.transpose(0,1)                                 # (B,L,D)
        return self.out(h)


In [5]:

def collate(batch):
    ids, labels = zip(*batch)
    ids = torch.stack(ids, dim=0)        # (B, L)
    labels = torch.stack(labels, dim=0)  # (B,)
    return ids, labels

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)


In [None]:
model = Denoiser(vocab_size=len(train_ds.vocab), emb_dim=EMB_DIM, n_layers=N_LAYERS,
                 n_heads=N_HEADS, ff_dim=FF_DIM, seq_len=L, num_classes=11, use_ckpt=True).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

SPACE_ID = train_ds.stoi[" "]
Wt = model.token_emb.weight.t()   # (D,V)

def position_weights(ids):
    w = torch.full_like(ids, 0.3, dtype=torch.float32, device=ids.device)
    return torch.where(ids != SPACE_ID, 1.0, w)

def get_token_embeddings(ids):     # (B,L) -> (B,L,D)
    return model.token_emb(ids.to(device))

def decode_from_embeddings(e: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
    # e has shape (B, L, D)
    logits = e @ model.token_emb.weight.t()

    # apply temperature to the logits
    # higher temperature makes the distribution flatter
    # lower temperature makes it sharper (like argmax)
    if temperature > 0:
        logits /= temperature
        probs = F.softmax(logits, dim=-1)
        B, L, V = probs.shape
        sampled_ids = torch.multinomial(probs.view(B * L, V), num_samples=1)
        return sampled_ids.view(B, L)
    else:
        return logits.argmax(dim=-1)

def train_one_epoch(epoch: int):
    model.train()
    total, n = 0.0, 0
    opt.zero_grad(set_to_none=True)

    for step, (ids, labels) in enumerate(train_loader):
        ids = ids.to(device); labels = labels.to(device)
        e0 = get_token_embeddings(ids)
        B = e0.size(0)
        t = torch.randint(0, TIMESTEPS, (B,), device=device, dtype=torch.long)
        eps = torch.randn_like(e0)
        x_t = torch.sqrt(ALPHAS_BAR_T[t]).view(-1,1,1)*e0 + torch.sqrt(1-ALPHAS_BAR_T[t]).view(-1,1,1)*eps

        # classifier-free: drop some labels to null class
        y_in = labels.clone()
        drop = torch.rand(B, device=device) < CFG_P_DROP
        y_in[drop] = CFG_NULL_ID

        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            eps_hat = model(x_t, t, y_in)
            loss_mse = F.mse_loss(eps_hat, eps)

            # x0_hat for CE aux loss
            a_bar = ALPHAS_BAR_T[t].view(-1,1,1)
            x0_hat = (x_t - torch.sqrt(1 - a_bar) * eps_hat) / torch.sqrt(a_bar + 1e-8)

            logits = (x0_hat @ Wt).reshape(-1, len(train_ds.vocab))   # (B*L, V)
            ce = F.cross_entropy(logits, ids.reshape(-1), ignore_index=SPACE_ID, reduction='none')
            ce = ce.view(B, -1)
            # w = position_weights(ids)
            # loss_ce = (ce * w).mean()
            loss_ce = ce.mean() 

            loss = loss_mse + LAMBDA_CE * loss_ce

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

        if (step + 1) % GRAD_ACCUM == 0:
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

        total += float(loss.detach()) * B
        n += B

    return total / max(1, n)

@torch.no_grad()
def evaluate(n_batches: int = 50):
    model.eval()
    total, n = 0.0, 0
    for i, (ids, labels) in enumerate(test_loader):
        if i >= n_batches: break
        ids = ids.to(device); labels = labels.to(device)
        e0 = get_token_embeddings(ids)
        B = e0.size(0)
        t = torch.randint(0, TIMESTEPS, (B,), device=device, dtype=torch.long)
        eps = torch.randn_like(e0)
        x_t = torch.sqrt(ALPHAS_BAR_T[t]).view(-1,1,1)*e0 + torch.sqrt(1-ALPHAS_BAR_T[t]).view(-1,1,1)*eps
        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            eps_hat = model(x_t, t, labels)
            loss = F.mse_loss(eps_hat, eps)
        total += float(loss.detach()) * B
        n += B
    return total / max(1, n)


In [None]:
for epoch in range(1, EPOCHS+1):
    tr = train_one_epoch(epoch)
    va = evaluate()
    print(f"Epoch {epoch}/{EPOCHS} | train {tr:.4f} | val {va:.4f}")

CKPT = ROOT / f"ascii_diffusion_e{EPOCHS}_w{W}_h{H}.pt"
torch.save({
    "model": model.state_dict(),
    "vocab": VOCAB,
    "config": dict(W=W, H=H, L=L, emb_dim=EMB_DIM, steps=TIMESTEPS, betas=BETAS),
}, CKPT)
print("Saved:", CKPT)





Epoch 1/20 | train 1.1055 | val 0.5174
Epoch 2/20 | train 0.7558 | val 0.4023
Epoch 3/20 | train 0.6711 | val 0.3350
Epoch 4/20 | train 0.6289 | val 0.3012
Epoch 5/20 | train 0.5976 | val 0.2734
Epoch 6/20 | train 0.5716 | val 0.2534
Epoch 7/20 | train 0.5540 | val 0.2398
Epoch 8/20 | train 0.5384 | val 0.2252
Epoch 9/20 | train 0.5248 | val 0.2133
Epoch 10/20 | train 0.5125 | val 0.2079
Epoch 11/20 | train 0.5022 | val 0.2111
Epoch 12/20 | train 0.4956 | val 0.1978
Epoch 13/20 | train 0.4876 | val 0.1953
Epoch 14/20 | train 0.4803 | val 0.1938
Epoch 15/20 | train 0.4748 | val 0.1894
Epoch 16/20 | train 0.4668 | val 0.1831
Epoch 17/20 | train 0.4610 | val 0.1819
Epoch 18/20 | train 0.4578 | val 0.1862
Epoch 19/20 | train 0.4509 | val 0.1832
Epoch 20/20 | train 0.4494 | val 0.1806
Saved: /workspace/mnist_ascii_dataset/ascii_diffusion_e20_w20_h10.pt


In [None]:
ckpt = torch.load(CKPT, map_location=device)
model.load_state_dict(ckpt["model"])
VOCAB = ckpt["vocab"]
W = ckpt["config"]["W"]; H = ckpt["config"]["H"]; L = ckpt["config"]["L"]

@torch.no_grad()
def p_sample_loop(digit: int, steps: int = TIMESTEPS) -> torch.Tensor:
    model.eval()
    B = 1
    x = torch.randn(B, L, EMB_DIM, device=device)
    y_cond = torch.tensor([digit], device=device, dtype=torch.long)
    y_null = torch.tensor([CFG_NULL_ID], device=device, dtype=torch.long)

    for t_int in reversed(range(steps)):
        t = torch.tensor([t_int], device=device, dtype=torch.long)

        eps_uncond = model(x, t, y_null)
        eps_cond   = model(x, t, y_cond)
        eps_hat = eps_uncond + CFG_SCALE * (eps_cond - eps_uncond)

        alpha_t = ALPHAS_T[t_int]
        alpha_bar_t = ALPHAS_BAR_T[t_int]
        beta_t = BETAS_T[t_int]
        alpha_bar_tm1 = ALPHAS_BAR_T[t_int-1] if t_int > 0 else torch.tensor(1.0, device=device)

        # posterior mean
        mean = (1.0 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t + 1e-8)) * eps_hat)
        # posterior variance
        posterior_var = beta_t * (1 - alpha_bar_tm1) / (1 - alpha_bar_t + 1e-8)
        if t_int > 0:
            x = mean + torch.sqrt(posterior_var.clamp(min=1e-20)) * torch.randn_like(x)
        else:
            x = mean
    return x

def ids_to_ascii(ids_1d: torch.Tensor) -> str:
    s = "".join(VOCAB[i] for i in ids_1d.tolist())
    return "\n".join(s[i:i+W] for i in range(0, len(s), W))

def sample_ascii(digit: int, steps: int = TIMESTEPS, temperature: float = 1.0) -> str:
    e = p_sample_loop(digit, steps=steps)      # (1, L, D)
    ids = decode_from_embeddings(e, temperature=temperature).squeeze(0)  # (L,)
    return ids_to_ascii(ids)



In [9]:
print(sample_ascii(5))

 :.@.:.+.-..-- .  .:
: -==..+-...-.+: .::
:.#..++.:.@.#@@@-.-:
..-. .=.-+:%..#%@-::
 .-.+#: #+ : =:=.=..
.@- :=.:: %@+...-.%=
:..:-=..@::%@::.*..%
%:@.-%#:#=:.::-.#..-
-+-+--+.:+..+.. *+:#
# -*.%.#:.. :.@+:...


In [10]:
tensor, label = train_ds[50]
print(tensor)
print(label)


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 5,
        8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 7, 6, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 1, 7, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 3, 3, 1, 4, 4, 8, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 2, 4, 5,
        5, 8, 3, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 8, 4, 4, 5, 7, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
tensor(3)


In [11]:
from collections import Counter
w_counts, h_counts = Counter(), Counter()
for rec in train_ds.recs:  # or slice for a sample
    s = Path(rec["ascii_txt_path"]).read_text(encoding="utf-8", errors="ignore")
    ls = s.splitlines()
    h_counts[len(ls)] += 1
    w_counts[max((len(ln) for ln in ls), default=0)] += 1

print("Top widths:", w_counts.most_common(5))
print("Top heights:", h_counts.most_common(5))


Top widths: [(20, 10000)]
Top heights: [(10, 10000)]
