# SANL — Spiral-Atlas Number Learner (Atlas Number Learning Model)

This notebook trains a neural model to learn **Dirichlet-character-like waves** over the
multiplicative group \((\mathbb{Z}/p\mathbb{Z})^\times\) using a **spiral modular atlas**.

Core ideas:

- Integers mod `p` are embedded as **harmonic points on a spiral/cylinder**.
- The network learns complex-valued “characters” \(\chi(n) \in S^1\) such that:
  - **Multiplicativity:**  \(\chi(ab) \approx \chi(a)\chi(b)\)
  - **Scale equivariance:** \(\chi(sn) \approx \chi(n)\) for \(s\) coprime to \(p\)
  - **Orthogonality:** different heads behave like distinct Dirichlet characters.

This is the **Atlas Number Learning** model: the SANL network operates directly inside
a number-theoretic geometry where arithmetic becomes wave mechanics instead of
symbolic manipulation.

In [None]:
# SANL — Spiral-Atlas Number Learner
# Dependencies: torch>=2.1, numpy

import math, time, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# ===========================
# Config
# ===========================
class Cfg:
    device         = "cuda" if torch.cuda.is_available() else "cpu"
    p              = 97                    # prime modulus
    K_harm         = 4                     # harmonics per modulus (cos/sin pairs)
    M_mods         = [3,5,7,11,13,17,19,23,
                      29,31,37,41,43,47,53,59,
                      61,67,71,73,79,83,89,97]  # includes p
    use_only_p     = True                  # True = use only modulus p as atlas (fast)
    dims_hidden    = 256                   # MLP hidden size
    n_chars        = 6                     # number of character heads
    epochs         = 120
    batch_pairs    = 8192                  # (a,b) pairs per step
    batch_scale    = 4096                  # (n,s) pairs for scale-invariance
    lr             = 3e-3
    wd             = 1e-4
    grad_accum     = 2
    amp            = True                  # mixed precision on CUDA
    cosine_lr      = True
    torch_compile  = False                 # set True on PyTorch 2.x for extra speed
    seed           = 42

    # Loss weights
    w_mult         = 1.0       # χ(ab) ≈ χ(a)χ(b)
    w_scale        = 0.5       # χ(sn) ≈ χ(n) for coprime s
    w_unit         = 1e-3      # |χ| ≈ 1
    w_ortho        = 1e-2      # characters ≈ orthogonal
    w_anchor       = 5e-3      # fix global phase drift via residue 1

    # Supervision mode (for later, if you want analytic χ supervision)
    supervision    = "self"    # "self" or "dlog"
    dlog_r_index   = 1         # which discrete-log character to anchor if using "dlog"

cfg = Cfg()

# Repro
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

<torch._C.Generator at 0x7f4b113a9010>

In [None]:
# ===========================
# Number-theory helpers
# ===========================

def is_coprime(a: int, b: int) -> bool:
    return math.gcd(a, b) == 1

def primitive_root(p: int) -> int:
    """Primitive root mod prime p."""
    assert p > 2 and isinstance(p, int)
    phi = p - 1
    fac, m, d = set(), phi, 2
    while d * d <= m:
        while m % d == 0:
            fac.add(d)
            m //= d
        d += 1
    if m > 1:
        fac.add(m)
    for g in range(2, p):
        for q in fac:
            if pow(g, phi // q, p) == 1:
                break
        else:
            return g
    raise RuntimeError("No primitive root?")

def dlog_table(p: int, g: int = None):
    """Discrete log table k such that g^k ≡ n (mod p) for all n != 0."""
    if g is None:
        g = primitive_root(p)
    tab = {1: 0}
    x = 1
    for k in range(1, p-1):
        x = (x * g) % p
        tab[x] = k
    return g, tab

def dl_character_factory(p: int, r: int, g: int = None, tab=None):
    """
    Analytic Dirichlet-like character via discrete log:
    χ_r(g^k) = exp(2πi r k / (p-1)).
    """
    g, tab = dlog_table(p, g)
    tw = 2*math.pi*r/(p-1)

    def chi(n: int):
        nm = n % p
        if nm == 0:
            return complex(0.0, 0.0)
        k = tab[nm]
        return complex(math.cos(tw*k), math.sin(tw*k))

    return chi

In [None]:
# ===========================
# Spiral Atlas
# ===========================

class SpiralAtlas:
    """
    Spiral/harmonic embedding over moduli.
    If use_only_p=True, we only use modulus p as the atlas (fast + clean).
    """

    def __init__(self, p, mods, K, use_only_p, device):
        if use_only_p:
            mods = [p]
        self.mods = mods
        self.p = p
        self.K = K
        self.device = device

        # Precompute harmonic banks for each modulus
        self.blocks = []
        for m in mods:
            r = torch.arange(1, m, device=device).float()   # residues 1..m-1
            theta = 2*math.pi*r/m                           # base phase
            ks = torch.arange(1, K+1, device=device).float().view(1, -1)
            TH = theta.view(-1, 1) * ks                     # [m-1, K]
            H = torch.cat([torch.cos(TH), torch.sin(TH)], dim=1)  # [m-1, 2K]
            self.blocks.append((m, H))

        # For dlog-based supervision (analytic characters)
        self.g = primitive_root(p)
        g, tab = dlog_table(p, self.g)
        self.residue_to_k = torch.full((p,), -1, device=device, dtype=torch.long)
        for res, k in tab.items():
            self.residue_to_k[res] = k

    def features(self, n: torch.Tensor) -> torch.Tensor:
        """
        n: [B] (ints) on device
        Returns: [B, D] harmonic feature vectors across all atlas moduli.
        """
        # We only really need mod p in the current config, but code is generic
        n = torch.remainder(n, self.mods[-1])
        feats = []
        for (m, H) in self.blocks:
            nm = torch.remainder(n, m)      # [B]
            mask = nm != 0
            idx  = nm[mask] - 1             # residues 1..m-1 map to 0..m-2
            block = torch.zeros((n.shape[0], H.shape[1]), device=self.device)
            if mask.any():
                block[mask] = H[idx]
            feats.append(block)
        return torch.cat(feats, dim=1)

In [None]:
# ===========================
# Model & helpers
# ===========================

class SANL(nn.Module):
    """
    Spiral-Atlas Number Learner:
    Takes harmonic features and outputs n_chars complex heads χ_j(n) in ℂ.
    """
    def __init__(self, in_dim: int, hidden: int, n_chars: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Linear(hidden, 2*n_chars)  # (Re,Im) per head
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, D]  ->  out: [B, H, 2], where H=n_chars, last dim is (Re, Im)
        """
        out = self.net(x)
        return out.view(x.shape[0], -1, 2)

def unit_project(z: torch.Tensor) -> torch.Tensor:
    """Project complex vectors z[...,2] onto the unit circle."""
    return z / torch.clamp(torch.norm(z, dim=-1, keepdim=True), min=1e-8)

def complex_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Complex multiply (a⊗b) in R^2 representation: [...,2]."""
    return torch.stack([
        a[..., 0]*b[..., 0] - a[..., 1]*b[..., 1],
        a[..., 0]*b[..., 1] + a[..., 1]*b[..., 0]
    ], dim=-1)

def complex_conj(a: torch.Tensor) -> torch.Tensor:
    """Complex conjugate."""
    return torch.stack([a[..., 0], -a[..., 1]], dim=-1)

In [None]:
# ===========================
# Losses
# ===========================

def multiplicativity_loss(ha, hb, hab):
    """
    ha, hb, hab: [B, H, 2]
    Enforce χ(ab) ≈ χ(a)χ(b).
    """
    return F.mse_loss(complex_mul(ha, hb), hab)

def scale_loss(hn, hsn):
    """
    hn, hsn: [B, H, 2]
    Enforce χ(sn) ≈ χ(n) for random coprime s.
    """
    return F.mse_loss(hn, hsn)

def unit_penalty(head):
    """
    head: [B, H, 2]
    Encourage |χ| ≈ 1.
    """
    mag = torch.norm(head, dim=-1)
    return ((mag - 1.0)**2).mean()

def ortho_penalty_over_group(model, atlas, p: int, sample: int = 2048):
    """
    Encourage heads to behave like orthogonal characters on the group (Z/pZ)^×.
    """
    with torch.no_grad():
        x = torch.randint(1, p, (sample,), device=cfg.device)  # units
    Fbank = atlas.features(x)                       # [S, D]
    out = unit_project(model(Fbank))               # [S, H, 2]
    a = out[..., 0] - 1j*out[..., 1]              # complex form: [S, H]
    G = (a.unsqueeze(2)*a.unsqueeze(1).conj()).mean(0)  # [H, H] Gram
    Id = torch.eye(out.shape[1], device=cfg.device)
    return (G.real - Id).abs().mean()

def anchor_phase_loss(model, atlas, p: int, residue: int = 1):
    """
    Fix mean phase at residue=1 to [1,0] in R^2, to break global phase drift.
    """
    with torch.no_grad():
        x = torch.full((256,), residue, device=cfg.device)
    Fx = atlas.features(x)
    out = unit_project(model(Fx)).mean(0)          # [H,2]
    target = torch.zeros_like(out)
    target[..., 0] = 1.0
    return F.mse_loss(out, target)

In [None]:
# ===========================
# Training batch builder
# ===========================

def build_batches(p: int, n_ab: int, n_ns: int, device: str):
    """
    Build:
      - (a,b,ab) samples for multiplicativity
      - (n,sn) samples for scale invariance (s coprime to p)
    """
    # (a,b) pairs
    a = torch.randint(1, p, (n_ab,), device=device)
    b = torch.randint(1, p, (n_ab,), device=device)
    ab = (a * b) % p
    ab[ab == 0] = 1  # avoid zero residue

    # (n,sn) for scale invariance
    n = torch.randint(1, p, (n_ns,), device=device)
    S = []
    while len(S) < n_ns:
        s = random.randrange(2, p-1)
        if math.gcd(s, p) == 1:
            S.append(s)
    s = torch.tensor(S, device=device)
    sn = (s * n) % p
    sn[sn == 0] = 1
    return a, b, ab, n, s, sn

# ===========================
# Model construction & training
# ===========================

def build():
    mods = [cfg.p] if cfg.use_only_p else cfg.M_mods
    atlas = SpiralAtlas(cfg.p, mods, cfg.K_harm, cfg.use_only_p, cfg.device)
    D = len(mods) * (2 * cfg.K_harm)
    model = SANL(D, cfg.dims_hidden, cfg.n_chars).to(cfg.device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, T_max=cfg.epochs, eta_min=cfg.lr * 0.1
    )
    return atlas, model, opt, sched

def train():
    atlas, model, opt, sched = build()
    scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and cfg.device == "cuda"))

    for ep in range(1, cfg.epochs + 1):
        t0 = time.time()
        a, b, ab, n, s, sn = build_batches(
            cfg.p, cfg.batch_pairs, cfg.batch_scale, cfg.device
        )

        with torch.cuda.amp.autocast(enabled=(cfg.amp and cfg.device == "cuda")):
            Fa, Fb, Fab = atlas.features(a), atlas.features(b), atlas.features(ab)
            Fn, Fsn     = atlas.features(n), atlas.features(sn)

            ha = unit_project(model(Fa))
            hb = unit_project(model(Fb))
            hab = unit_project(model(Fab))
            hn = unit_project(model(Fn))
            hsn = unit_project(model(Fsn))

            L_mult  = multiplicativity_loss(ha, hb, hab)
            L_scale = scale_loss(hn, hsn)
            L_unit  = unit_penalty(torch.cat([ha, hb, hab, hn, hsn], dim=0))
            L_ortho = ortho_penalty_over_group(model, atlas, cfg.p, sample=1024)
            L_anchor= anchor_phase_loss(model, atlas, cfg.p, residue=1)

            L = (cfg.w_mult * L_mult +
                 cfg.w_scale * L_scale +
                 cfg.w_unit * L_unit +
                 cfg.w_ortho * L_ortho +
                 cfg.w_anchor * L_anchor)

        scaler.scale(L / cfg.grad_accum).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad(set_to_none=True)
        sched.step()

        if ep % 10 == 0 or ep == 1:
            print(
                f"epoch {ep:03d} | loss={L.item():.6f} | "
                f"mult={L_mult.item():.4e} | scale={L_scale.item():.4e} | "
                f"unit={L_unit.item():.4e} | ortho={L_ortho.item():.4e} | "
                f"anchor={L_anchor.item():.4e} | dt={time.time()-t0:.2f}s"
            )

    print("\nDONE.")
    return atlas, model

In [None]:
# This will take some minutes on Colab GPU, depending on epochs and batch sizes.
torch.set_grad_enabled(True)

atlas, model = train()

  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and cfg.device == "cuda"))
  with torch.cuda.amp.autocast(enabled=(cfg.amp and cfg.device == "cuda")):


epoch 001 | loss=1.474181 | mult=1.2551e+00 | scale=4.1702e-01 | unit=1.5731e-15 | ortho=2.9069e-01 | anchor=1.5314e+00 | dt=0.60s
epoch 010 | loss=0.029065 | mult=1.9592e-02 | scale=2.7009e-03 | unit=1.4801e-15 | ortho=8.0426e-01 | anchor=1.5925e-02 | dt=0.33s
epoch 020 | loss=0.011463 | mult=2.6450e-03 | scale=1.0138e-03 | unit=1.5456e-15 | ortho=8.3023e-01 | anchor=1.6837e-03 | dt=0.39s
epoch 030 | loss=0.010793 | mult=2.2317e-03 | scale=4.9357e-04 | unit=1.6432e-15 | ortho=8.3028e-01 | anchor=2.3140e-03 | dt=0.38s
epoch 040 | loss=0.008866 | mult=4.3645e-04 | scale=1.9859e-04 | unit=1.6821e-15 | ortho=8.3288e-01 | anchor=3.7048e-04 | dt=0.39s
epoch 050 | loss=0.008689 | mult=3.0775e-04 | scale=1.0147e-04 | unit=1.6056e-15 | ortho=8.3297e-01 | anchor=1.2500e-04 | dt=0.47s
epoch 060 | loss=0.008469 | mult=1.0979e-04 | scale=5.3550e-05 | unit=1.3874e-15 | ortho=8.3325e-01 | anchor=3.8761e-05 | dt=0.33s
epoch 070 | loss=0.008405 | mult=5.6395e-05 | scale=3.0260e-05 | unit=1.5394e-15 | 

In [None]:
# === SANL Post-Training Demo ===
# What the model can do after training.

torch.set_grad_enabled(False)

def to_phase_deg(z: torch.Tensor):
    """Convert complex vectors [...,2] to phase (degrees)."""
    return torch.rad2deg(torch.atan2(z[..., 1], z[..., 0]))

def mult_score(model, atlas, p: int, trials: int = 2000) -> float:
    """
    Quick multiplicativity score:
    χ(ab) vs χ(a)χ(b) averaged over random units.
    """
    a = torch.randint(1, p, (trials,), device=cfg.device)
    b = torch.randint(1, p, (trials,), device=cfg.device)
    ab = (a * b) % p

    Fa, Fb, Fab = atlas.features(a), atlas.features(b), atlas.features(ab)
    ha = unit_project(model(Fa))   # [T,H,2]
    hb = unit_project(model(Fb))   # [T,H,2]
    hab = unit_project(model(Fab)) # [T,H,2]

    prod = complex_mul(ha, hb)
    diff = (prod - hab).pow(2).sum(dim=-1).mean().sqrt()  # RMS vector error
    # Map small error to score near 1.0
    return float(torch.exp(-10.0 * diff).mean().item())

def scale_probe(model, atlas, p: int, scales=(2,3,5,7,11,13,17,19,23), trials: int = 1024):
    """
    For random n and coprime s, measure cosine between χ(n) and χ(sn).
    Returns mean cosine per scale.
    """
    n = torch.randint(1, p, (trials,), device=cfg.device)
    Fn = atlas.features(n)
    hn = unit_project(model(Fn))   # [T,H,2]

    cos = []
    for s in scales:
        ns = (s * n) % p
        Fns = atlas.features(ns)
        hns = unit_project(model(Fns))
        # flatten over heads
        flat1 = hn.reshape(trials, -1)
        flat2 = hns.reshape(trials, -1)
        num = (flat1 * flat2).sum(dim=-1)
        den = torch.norm(flat1, dim=-1) * torch.norm(flat2, dim=-1) + 1e-9
        c = (num / den).mean().item()
        cos.append(c)
    return cos

print("\n[final] quick multiplicativity score:",
      mult_score(model, atlas, cfg.p, trials=5000))

print("[final] scale probe (mean cos per s):")
scales = (2,3,5,7,11,13,17,19,23)
cos_vals = scale_probe(model, atlas, cfg.p, scales=scales, trials=4096)
for s, c in zip(scales, cos_vals):
    print(f"  s={s:2d} -> cos={c:+.3f}")


[final] quick multiplicativity score: 0.9427040815353394
[final] scale probe (mean cos per s):
  s= 2 -> cos=+1.000
  s= 3 -> cos=+1.000
  s= 5 -> cos=+1.000
  s= 7 -> cos=+1.000
  s=11 -> cos=+1.000
  s=13 -> cos=+1.000
  s=17 -> cos=+1.000
  s=19 -> cos=+1.000
  s=23 -> cos=+1.000
