In [None]:
# tiny_shakespeare_fractal_rg_lm.py
import os
import math
import time
import requests
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ================== 配置 ==================

@dataclass
class Config:
    data_url: str = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    data_path: str = "input.txt"

    batch_size: int = 64
    block_size: int = 128  # 上下文长度
    max_iters: int = 20000
    eval_interval: int = 500
    eval_iters: int = 50

    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 256
    d_ff: int = 1024

    lr: float = 3e-4
    dropout: float = 0.1

    use_fractal: bool = True   # 切换架构：True=分形RG注意力，False=普通Transformer


cfg = Config()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device, " (FractalRG:", cfg.use_fractal, ")")


# ================== 数据加载：tiny-shakespeare ==================

def download_tiny_shakespeare(cfg: Config):
    if os.path.exists(cfg.data_path):
        print(f"Found {cfg.data_path}, skip download.")
        return
    print("Downloading tiny-shakespeare...")
    r = requests.get(cfg.data_url)
    r.raise_for_status()
    with open(cfg.data_path, "w", encoding="utf-8") as f:
        f.write(r.text)
    print("Download finished.")


download_tiny_shakespeare(cfg)

with open(cfg.data_path, "r", encoding="utf-8") as f:
    text = f.read()

print("Data length:", len(text))

# 字符级 vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(ids):
    return "".join(itos[i] for i in ids)

# 整个语料编码为整数
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch(split: str, cfg: Config):
    """
    从 train/val 中随机采样一个 batch，形状 [B, T]
    """
    if split == "train":
        d = train_data
    else:
        d = val_data
    ix = torch.randint(len(d) - cfg.block_size - 1, (cfg.batch_size,))
    x = torch.stack([d[i : i + cfg.block_size] for i in ix])
    y = torch.stack([d[i + 1 : i + 1 + cfg.block_size] for i in ix])
    return x.to(device), y.to(device)


# ================== 一些 Attention 工具函数 ==================

def causal_mask(L: int, device: torch.device):
    """
    生成因果 mask：上三角为 True（要被屏蔽）
    """
    return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)


def coarse_grain_1d(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, L, D] -> [B, L//2, D]
    简单 stride=2 平均池化 (RG coarse-grain)
    """
    B, L, D = x.shape
    if L % 2 != 0:
        x = x[:, :L - 1, :]
        L = L - 1
    x = x.view(B, L // 2, 2, D).mean(dim=2)
    return x


def upsample_to_length(x: torch.Tensor, target_len: int) -> torch.Tensor:
    """
    x: [B, Ls, D] -> [B, target_len, D] 线性插值上采样
    """
    B, Ls, D = x.shape
    if Ls == target_len:
        return x
    x_t = x.transpose(1, 2)  # [B,D,Ls]
    x_up = F.interpolate(x_t, size=target_len, mode="linear", align_corners=False)
    x_up = x_up.transpose(1, 2)  # [B,target_len,D]
    return x_up


# ================== Baseline: Causal Multihead Self-Attention ==================

class MultiHeadSelfAttentionCausal(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        """
        x: [B, T, D]
        """
        B, T, D = x.shape
        H = self.n_heads
        d_h = self.head_dim

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        def split(t):
            return t.view(B, T, H, d_h).transpose(1, 2)  # [B,H,T,d_h]

        q = split(q)
        k = split(k)
        v = split(v)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_h)  # [B,H,T,T]

        # 因果 mask
        mask = causal_mask(T, x.device)  # [T,T]
        scores = scores.masked_fill(mask.view(1, 1, T, T), float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)  # [B,H,T,d_h]
        out = out.transpose(1, 2).contiguous().view(B, T, D)  # [B,T,D]
        out = self.o_proj(out)
        return out


class VanillaBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionCausal(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor):
        h = self.ln1(x)
        h = self.attn(h)
        x = x + h

        h2 = self.ln2(x)
        h2 = self.ff(h2)
        x = x + h2
        return x, None  # 为了和 FractalRG 接口统一，返回第二项占位


# ================== Fractal-RG Multi-Scale Attention ==================

class FractalRGAttention(nn.Module):
    """
    因果、多头、多尺度距离衰减注意力：
    - 不做下采样 / 上采样，不改变序列长度；
    - 每个 head 有自己的可学习尺度 R_h；
    - logits_{h,ij} = qk / sqrt(d) - lambda_dist * |i-j| / R_h；
    - 加 causal mask 保证只能看过去。
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        num_scales: int = 1,      # 为兼容原调用，保留这个参数
        base_radius: float = 16.0,
        lambda_dist: float = 1.0,
        dropout_p: float = 0.1,
    ):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.base_radius = base_radius
        self.lambda_dist = lambda_dist

        # 共享 QKV 投影
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.attn_dropout = nn.Dropout(dropout_p)

        # 每个 head 一个 log_R_h，可学习
        # R_h = base_radius * exp(log_R_h[h])
        self.log_R = nn.Parameter(torch.zeros(n_heads))

    def forward(self, x: torch.Tensor):
        """
        x: [B, T, D]
        返回:
          out: [B, T, D]
          attn_list: [attn] 其中 attn: [B, H, T, T]
        """
        B, T, D = x.shape
        H = self.n_heads
        d_h = self.head_dim

        # Q,K,V
        q = self.q_proj(x)  # [B,T,D]
        k = self.k_proj(x)
        v = self.v_proj(x)

        def split_heads(t):
            # [B,T,D] -> [B,H,T,d_h]
            return t.view(B, T, H, d_h).transpose(1, 2)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # 点积注意力 scores: [B,H,T,T]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_h)

        # 距离矩阵 |i-j|: [T,T]
        idx = torch.arange(T, device=x.device)
        dist = (idx[:, None] - idx[None, :]).abs().float()  # [T,T]

        # 每个 head 的半径 R_h: [H]
        R_h = self.base_radius * torch.exp(self.log_R)      # [H]
        # reshape 方便广播: [1,H,1,1]
        R_h_broadcast = R_h.view(1, H, 1, 1)

        # 距离惩罚：scores_{h,ij} -= lambda * |i-j| / R_h
        scores = scores - self.lambda_dist * dist.view(1, 1, T, T) / (R_h_broadcast + 1e-6)

        # 因果 mask：只能看过去
        mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(mask.view(1, 1, T, T), float("-inf"))

        # softmax + dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)

        # [B,H,T,T] @ [B,H,T,d_h] -> [B,H,T,d_h]
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, D)  # [B,T,D]
        out = self.out_proj(out)

        # 为了兼容原来 all_attn.append(attn_scales) 的用法，这里返回 [attn]
        return out, [attn]



class FractalRGBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        num_scales: int = 1,      # 为兼容原调用，保留这个参数
        base_radius: float = 16.0,
        lambda_dist: float = 1.0,
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
    ):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = FractalRGAttention(
            d_model=d_model,
            n_heads=n_heads,
            num_scales=num_scales,
            base_radius=base_radius,
            lambda_dist=lambda_dist,
            dropout_p=attn_dropout,
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(ff_dropout),
        )

    def forward(self, x: torch.Tensor):
        # 注意力子层
        h = self.ln1(x)
        h_attn, attn_list = self.attn(h)
        x = x + h_attn

        # FFN 子层
        h2 = self.ln2(x)
        h2 = self.ff(h2)
        x = x + h2
        return x, attn_list  # attn_list: [attn]，和原脚本中的 all_attn 兼容



# ================== 语言模型封装 ==================

class FractalRGTransformerLM(nn.Module):
    def __init__(self, cfg: Config, vocab_size: int):
        super().__init__()
        self.cfg = cfg
        self.token_emb = nn.Embedding(vocab_size, cfg.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, cfg.block_size, cfg.n_embd))

        blocks = []
        for layer_idx in range(cfg.n_layer):
            if cfg.use_fractal:
                # 粗略设置一个 base_radius，让浅层更局部、深层更全局（可以随意改）
                base_R = cfg.block_size / (2 ** (layer_idx // 2 + 1))
                block = FractalRGBlock(
                    d_model=cfg.n_embd,
                    n_heads=cfg.n_head,
                    d_ff=cfg.d_ff,
                    num_scales=2,
                    base_radius=base_R,
                    lambda_dist=1.0,
                    attn_dropout=cfg.dropout,
                    ff_dropout=cfg.dropout,
                )
            else:
                block = VanillaBlock(
                    d_model=cfg.n_embd,
                    n_heads=cfg.n_head,
                    d_ff=cfg.d_ff,
                    dropout=cfg.dropout,
                )
            blocks.append(block)
        self.blocks = nn.ModuleList(blocks)
        self.ln_f = nn.LayerNorm(cfg.n_embd)
        self.lm_head = nn.Linear(cfg.n_embd, vocab_size)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        B, T = idx.shape
        assert T <= self.cfg.block_size

        tok_emb = self.token_emb(idx)                    # [B,T,C]
        pos_emb = self.pos_emb[:, :T, :]                 # [1,T,C]
        x = tok_emb + pos_emb                            # [B,T,C]

        all_attn = []
        for block in self.blocks:
            x, attn_scales = block(x)
            all_attn.append(attn_scales)

        x = self.ln_f(x)
        logits = self.lm_head(x)                         # [B,T,V]

        loss = None
        if targets is not None:
            B, T, V = logits.shape
            loss = F.cross_entropy(
                logits.view(B * T, V),
                targets.view(B * T),
            )

        return logits, loss, all_attn


# ================== 训练 / 评估 ==================

@torch.no_grad()
def estimate_loss(model: nn.Module, cfg: Config):
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(cfg.eval_iters):
            xb, yb = get_batch(split, cfg)
            _, loss, _ = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out


def main(cfg: Config):
    model = FractalRGTransformerLM(cfg, vocab_size).to(device)
    print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

    for iter in range(cfg.max_iters):
        # 周期性评估
        if iter % cfg.eval_interval == 0:
            t0 = time.time()
            losses = estimate_loss(model, cfg)
            dt = time.time() - t0
            print(
                f"step {iter}: train loss {losses['train']:.4f}, "
                f"val loss {losses['val']:.4f} (eval {dt:.1f}s)"
            )

        xb, yb = get_batch("train", cfg)
        _, loss, _ = model(xb, yb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if (iter + 1) % 100 == 0:
            print(f"iter {iter+1}, train loss {loss.item():.4f}")

    # 训练结束后保存模型
    ckpt_name = "fractal_rg_lm.pt" if cfg.use_fractal else "vanilla_lm.pt"
    torch.save(model.state_dict(), ckpt_name)
    print("Saved model to", ckpt_name)


if __name__ == "__main__":
    main(cfg)


Using device: cuda  (FractalRG: True )
Found input.txt, skip download.
Data length: 1115394
Vocab size: 65
Model parameters: 3.225681 M
step 0: train loss 4.3167, val loss 4.3171 (eval 10.5s)
iter 100, train loss 2.4267
iter 200, train loss 2.3366
iter 300, train loss 2.2545
iter 400, train loss 2.1598
iter 500, train loss 2.1300
step 500: train loss 2.0630, val loss 2.1490 (eval 4.8s)
iter 600, train loss 2.0386
iter 700, train loss 1.8763
iter 800, train loss 1.8665
iter 900, train loss 1.7962
iter 1000, train loss 1.7907
step 1000: train loss 1.6669, val loss 1.8270 (eval 4.8s)
iter 1100, train loss 1.7019
iter 1200, train loss 1.6610
iter 1300, train loss 1.6496
iter 1400, train loss 1.5791
iter 1500, train loss 1.6068
step 1500: train loss 1.5278, val loss 1.7070 (eval 4.8s)
iter 1600, train loss 1.5660
iter 1700, train loss 1.5841
iter 1800, train loss 1.5246
iter 1900, train loss 1.4994
iter 2000, train loss 1.4929
step 2000: train loss 1.4369, val loss 1.6348 (eval 4.8s)
iter 2

In [4]:
# tiny_shakespeare_fractal_rg_lm.py
import os
import math
import time
import requests
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ================== 配置 ==================

@dataclass
class Config:
    data_url: str = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    data_path: str = "input.txt"

    batch_size: int = 64
    block_size: int = 128  # 上下文长度
    max_iters: int = 20000
    eval_interval: int = 500
    eval_iters: int = 50

    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 256
    d_ff: int = 1024

    lr: float = 3e-4
    dropout: float = 0.1

    use_fractal: bool = False   # 切换架构：True=分形RG注意力，False=普通Transformer


cfg = Config()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device, " (FractalRG:", cfg.use_fractal, ")")


# ================== 数据加载：tiny-shakespeare ==================

def download_tiny_shakespeare(cfg: Config):
    if os.path.exists(cfg.data_path):
        print(f"Found {cfg.data_path}, skip download.")
        return
    print("Downloading tiny-shakespeare...")
    r = requests.get(cfg.data_url)
    r.raise_for_status()
    with open(cfg.data_path, "w", encoding="utf-8") as f:
        f.write(r.text)
    print("Download finished.")


download_tiny_shakespeare(cfg)

with open(cfg.data_path, "r", encoding="utf-8") as f:
    text = f.read()

print("Data length:", len(text))

# 字符级 vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(ids):
    return "".join(itos[i] for i in ids)

# 整个语料编码为整数
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch(split: str, cfg: Config):
    """
    从 train/val 中随机采样一个 batch，形状 [B, T]
    """
    if split == "train":
        d = train_data
    else:
        d = val_data
    ix = torch.randint(len(d) - cfg.block_size - 1, (cfg.batch_size,))
    x = torch.stack([d[i : i + cfg.block_size] for i in ix])
    y = torch.stack([d[i + 1 : i + 1 + cfg.block_size] for i in ix])
    return x.to(device), y.to(device)


# ================== 一些 Attention 工具函数 ==================

def causal_mask(L: int, device: torch.device):
    """
    生成因果 mask：上三角为 True（要被屏蔽）
    """
    return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)


def coarse_grain_1d(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, L, D] -> [B, L//2, D]
    简单 stride=2 平均池化 (RG coarse-grain)
    """
    B, L, D = x.shape
    if L % 2 != 0:
        x = x[:, :L - 1, :]
        L = L - 1
    x = x.view(B, L // 2, 2, D).mean(dim=2)
    return x


def upsample_to_length(x: torch.Tensor, target_len: int) -> torch.Tensor:
    """
    x: [B, Ls, D] -> [B, target_len, D] 线性插值上采样
    """
    B, Ls, D = x.shape
    if Ls == target_len:
        return x
    x_t = x.transpose(1, 2)  # [B,D,Ls]
    x_up = F.interpolate(x_t, size=target_len, mode="linear", align_corners=False)
    x_up = x_up.transpose(1, 2)  # [B,target_len,D]
    return x_up


# ================== Baseline: Causal Multihead Self-Attention ==================

class MultiHeadSelfAttentionCausal(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        """
        x: [B, T, D]
        """
        B, T, D = x.shape
        H = self.n_heads
        d_h = self.head_dim

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        def split(t):
            return t.view(B, T, H, d_h).transpose(1, 2)  # [B,H,T,d_h]

        q = split(q)
        k = split(k)
        v = split(v)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_h)  # [B,H,T,T]

        # 因果 mask
        mask = causal_mask(T, x.device)  # [T,T]
        scores = scores.masked_fill(mask.view(1, 1, T, T), float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)  # [B,H,T,d_h]
        out = out.transpose(1, 2).contiguous().view(B, T, D)  # [B,T,D]
        out = self.o_proj(out)
        return out


class VanillaBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionCausal(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor):
        h = self.ln1(x)
        h = self.attn(h)
        x = x + h

        h2 = self.ln2(x)
        h2 = self.ff(h2)
        x = x + h2
        return x, None  # 为了和 FractalRG 接口统一，返回第二项占位


# ================== Fractal-RG Multi-Scale Attention ==================

class FractalRGAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        num_scales: int = 2,
        base_radius: float = 16.0,
        lambda_dist: float = 1.0,
        dropout_p: float = 0.1,
    ):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.num_scales = num_scales
        self.base_radius = base_radius
        self.lambda_dist = lambda_dist

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout_p)

        # 尺度参数 log_R_s，可学习
        self.log_R = nn.Parameter(torch.zeros(num_scales))
        # 多尺度融合权重（softmax 后为 g_s）
        self.scale_logits = nn.Parameter(torch.zeros(num_scales))

    def _scaled_dot_attn(
        self,
        x: torch.Tensor,
        R_s: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        在单个尺度上计算 causal self-attention
        x: [B, Ls, D]
        """
        B, Ls, D = x.shape
        H = self.n_heads
        d_h = self.head_dim

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        def split(t):
            return t.view(B, Ls, H, d_h).transpose(1, 2)  # [B,H,Ls,d_h]

        q = split(q)
        k = split(k)
        v = split(v)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_h)  # [B,H,Ls,Ls]

        # 距离惩罚项
        idx = torch.arange(Ls, device=x.device)
        dist = (idx[:, None] - idx[None, :]).abs().float()  # [Ls,Ls]
        scores = scores - self.lambda_dist * dist / (R_s + 1e-6)

        # 因果 mask
        mask = causal_mask(Ls, x.device)  # [Ls,Ls]
        scores = scores.masked_fill(mask.view(1, 1, Ls, Ls), float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)

        h = torch.matmul(attn, v)  # [B,H,Ls,d_h]
        h = h.transpose(1, 2).contiguous().view(B, Ls, D)  # [B,Ls,D]
        return h, attn

    def forward(self, x: torch.Tensor):
        """
        x: [B, L, D]
        返回:
          out: [B, L, D]
          attn_scales: List[[B,H,Ls,Ls]]
        """
        B, L, D = x.shape

        # 构造多尺度表示
        xs: List[torch.Tensor] = [x]
        for _ in range(1, self.num_scales):
            xs.append(coarse_grain_1d(xs[-1]))

        scale_weights = F.softmax(self.scale_logits, dim=-1)  # g_s
        out = torch.zeros_like(x)
        attn_scales: List[torch.Tensor] = []

        for s, x_s in enumerate(xs):
            R_s = self.base_radius * torch.exp(self.log_R[s])  # 标量
            h_s, attn_s = self._scaled_dot_attn(x_s, R_s)
            attn_scales.append(attn_s)

            # 上采样回原长度并加权融合
            h_up = upsample_to_length(h_s, L)
            out = out + scale_weights[s] * h_up

        out = self.out_proj(out)
        return out, attn_scales


class FractalRGBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        num_scales: int = 2,
        base_radius: float = 16.0,
        lambda_dist: float = 1.0,
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
    ):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = FractalRGAttention(
            d_model=d_model,
            n_heads=n_heads,
            num_scales=num_scales,
            base_radius=base_radius,
            lambda_dist=lambda_dist,
            dropout_p=attn_dropout,
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(ff_dropout),
        )

    def forward(self, x: torch.Tensor):
        h = self.ln1(x)
        h_attn, attn_scales = self.attn(h)
        x = x + h_attn

        h2 = self.ln2(x)
        h2 = self.ff(h2)
        x = x + h2
        return x, attn_scales


# ================== 语言模型封装 ==================

class FractalRGTransformerLM(nn.Module):
    def __init__(self, cfg: Config, vocab_size: int):
        super().__init__()
        self.cfg = cfg
        self.token_emb = nn.Embedding(vocab_size, cfg.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, cfg.block_size, cfg.n_embd))

        blocks = []
        for layer_idx in range(cfg.n_layer):
            if cfg.use_fractal:
                # 粗略设置一个 base_radius，让浅层更局部、深层更全局（可以随意改）
                base_R = cfg.block_size / (2 ** (layer_idx // 2 + 1))
                block = FractalRGBlock(
                    d_model=cfg.n_embd,
                    n_heads=cfg.n_head,
                    d_ff=cfg.d_ff,
                    num_scales=2,
                    base_radius=base_R,
                    lambda_dist=1.0,
                    attn_dropout=cfg.dropout,
                    ff_dropout=cfg.dropout,
                )
            else:
                block = VanillaBlock(
                    d_model=cfg.n_embd,
                    n_heads=cfg.n_head,
                    d_ff=cfg.d_ff,
                    dropout=cfg.dropout,
                )
            blocks.append(block)
        self.blocks = nn.ModuleList(blocks)
        self.ln_f = nn.LayerNorm(cfg.n_embd)
        self.lm_head = nn.Linear(cfg.n_embd, vocab_size)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        B, T = idx.shape
        assert T <= self.cfg.block_size

        tok_emb = self.token_emb(idx)                    # [B,T,C]
        pos_emb = self.pos_emb[:, :T, :]                 # [1,T,C]
        x = tok_emb + pos_emb                            # [B,T,C]

        all_attn = []
        for block in self.blocks:
            x, attn_scales = block(x)
            all_attn.append(attn_scales)

        x = self.ln_f(x)
        logits = self.lm_head(x)                         # [B,T,V]

        loss = None
        if targets is not None:
            B, T, V = logits.shape
            loss = F.cross_entropy(
                logits.view(B * T, V),
                targets.view(B * T),
            )

        return logits, loss, all_attn


# ================== 训练 / 评估 ==================

@torch.no_grad()
def estimate_loss(model: nn.Module, cfg: Config):
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = []
        for _ in range(cfg.eval_iters):
            xb, yb = get_batch(split, cfg)
            _, loss, _ = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out


def main(cfg: Config):
    model = FractalRGTransformerLM(cfg, vocab_size).to(device)
    print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

    for iter in range(cfg.max_iters):
        # 周期性评估
        if iter % cfg.eval_interval == 0:
            t0 = time.time()
            losses = estimate_loss(model, cfg)
            dt = time.time() - t0
            print(
                f"step {iter}: train loss {losses['train']:.4f}, "
                f"val loss {losses['val']:.4f} (eval {dt:.1f}s)"
            )

        xb, yb = get_batch("train", cfg)
        _, loss, _ = model(xb, yb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if (iter + 1) % 100 == 0:
            print(f"iter {iter+1}, train loss {loss.item():.4f}")

    # 训练结束后保存模型
    ckpt_name = "fractal_rg_lm.pt" if cfg.use_fractal else "vanilla_lm.pt"
    torch.save(model.state_dict(), ckpt_name)
    print("Saved model to", ckpt_name)


if __name__ == "__main__":
    main(cfg)


Using device: cuda  (FractalRG: False )
Found input.txt, skip download.
Data length: 1115394
Vocab size: 65
Model parameters: 3.225665 M
step 0: train loss 4.4069, val loss 4.4038 (eval 4.7s)
iter 100, train loss 2.4556
iter 200, train loss 2.3880
iter 300, train loss 2.3867
iter 400, train loss 2.3386
iter 500, train loss 2.2266
step 500: train loss 2.2075, val loss 2.2900 (eval 4.7s)
iter 600, train loss 2.0654
iter 700, train loss 1.9703
iter 800, train loss 1.9031
iter 900, train loss 1.8170
iter 1000, train loss 1.7880
step 1000: train loss 1.7059, val loss 1.8506 (eval 4.7s)
iter 1100, train loss 1.7424
iter 1200, train loss 1.7094
iter 1300, train loss 1.6904
iter 1400, train loss 1.6049
iter 1500, train loss 1.6468
step 1500: train loss 1.5313, val loss 1.7219 (eval 4.7s)
iter 1600, train loss 1.5886
iter 1700, train loss 1.5505
iter 1800, train loss 1.5407
iter 1900, train loss 1.5547
iter 2000, train loss 1.4825
step 2000: train loss 1.4502, val loss 1.6499 (eval 4.7s)
iter 2