In [1]:
import os, time, json, math
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from PIL import Image
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap

In [2]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')   # Check wether gpu is available

### The Mandelbrot set
The Mandelbrot set is a two-dimensional set that is defined in the complex plane as the complex numbers $c$ for which the function $f_c(z) = z^2 + c $ does not diverge to infinity when iterated starting at $z=0$.

Interesting properties:
- A point c belongs to the Mandelbrot set iff $|z| \leq 2$ for all $n \geq 0$


### Rendering

In [3]:
@torch.no_grad()
def model_grid_tiled(model, device, xlim, ylim, res, tile=(512, 512), amp=True):
    model.eval()
    W, H = res
    tw, th = tile

    xs = np.linspace(xlim[0], xlim[1], W, endpoint=False, dtype=np.float32)
    ys = np.linspace(ylim[0], ylim[1], H, endpoint=False, dtype=np.float32)

    out = np.empty((H, W), dtype=np.float32)

    for y0 in range(0, H, th):
        y1 = min(y0 + th, H)
        Y = ys[y0:y1]

        for x0 in range(0, W, tw):
            x1 = min(x0 + tw, W)
            X = xs[x0:x1]

            XX, YY = np.meshgrid(X, Y)
            grid = np.stack([XX.reshape(-1), YY.reshape(-1)], axis=1)

            g = torch.from_numpy(grid).to(
                device,
                dtype=torch.float16 if (amp and device.type == "cuda") else torch.float32,
                non_blocking=True
            )

            if amp and device.type == "cuda":
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    v = model(g).squeeze(1)
            else:
                v = model(g).squeeze(1)

            out[y0:y1, x0:x1] = v.float().cpu().numpy().reshape((y1 - y0, x1 - x0))
            del XX, YY, grid, g, v

    return out


In [4]:
def compute_ylim_from_x(xlim, res, ycenter=0.0):
    """
    Keep square pixels in complex plane by matching step size in x and y.
    """
    W, H = res
    step = (xlim[1] - xlim[0]) / W
    y_half = step * H / 2
    return (ycenter - y_half, ycenter + y_half)

In [5]:
def fractal_palette(name):
    palettes = {
        # ðŸ”¥ Cinematic fire (classic, hard to beat)
        "fire": [
            (0.02, 0.02, 0.05),
            (0.10, 0.02, 0.20),
            (0.40, 0.05, 0.30),
            (0.80, 0.20, 0.10),
            (0.98, 0.80, 0.30),
        ],

        # ðŸŒŒ Cosmic / nebula
        "cosmic": [
            (0.01, 0.01, 0.04),
            (0.05, 0.02, 0.20),
            (0.20, 0.10, 0.60),
            (0.60, 0.40, 0.90),
            (0.95, 0.85, 0.98),
        ],

        # ðŸŒŠ Deep ocean
        "ocean": [
            (0.01, 0.02, 0.05),
            (0.02, 0.10, 0.20),
            (0.05, 0.40, 0.50),
            (0.30, 0.80, 0.70),
            (0.90, 0.95, 0.85),
        ],

        # ðŸŒˆ Synthwave / neon
        "synthwave": [
            (0.02, 0.00, 0.08),
            (0.20, 0.00, 0.40),
            (0.60, 0.10, 0.80),
            (0.90, 0.30, 0.60),
            (1.00, 0.90, 0.30),
        ],

        # ðŸ–¤ Ink / poster
        "ink": [
            (0.00, 0.00, 0.00),
            (0.10, 0.10, 0.10),
            (0.40, 0.40, 0.40),
            (0.85, 0.85, 0.85),
        ],
    }

    return LinearSegmentedColormap.from_list(
        f"fract_{name}", palettes[name], N=2048
    )



def glow(img, strength=0.25, radius=3, threshold=0.6):
    """
    img: [0,1] float32
    threshold: only values above this emit glow
    """
    src = img.copy()

    # mask bright regions only
    mask = np.clip((src - threshold) / (1.0 - threshold), 0.0, 1.0)
    glow_src = src * mask

    out = glow_src
    for _ in range(radius):
        out = (
            np.roll(out, 1, 0) + np.roll(out, -1, 0) +
            np.roll(out, 1, 1) + np.roll(out, -1, 1) + out
        ) / 5.0

    return np.clip(img + strength * out, 0.0, 1.0)

In [6]:
def plot_model_heatmap_tiled(
    model, device,
    xlim=(-2.4, 1.0),
    ycenter=0.0,
    res=(3840, 2160),
    tile=(512, 512),
    fname="render.png",
    title="Model",
    amp=False,
    gamma=0.85,
    qlo=0.01,
    qhi=0.99,
    cmap_custom="synthwave",
):
    ylim = compute_ylim_from_x(xlim, res, ycenter=ycenter)

    # render logits (or raw regression output) in float32
    pred = model_grid_tiled(model, device, xlim, ylim, res, tile=tile, amp=amp).astype(np.float32)

    pred = 1.0 / (1.0 + np.exp(-pred))   # sigmoid -> [0,1]

    # robust contrast to avoid flattening + avoid amplifying tiny noise too much
    lo, hi = np.quantile(pred, [qlo, qhi])
    pred = (pred - lo) / (hi - lo + 1e-8)
    pred = np.clip(pred, 0.0, 1.0)

    # mild gamma (too aggressive gamma makes grain visible)
    pred = pred ** gamma

    # add glow
    pred_glow = glow(pred, strength=0.30, radius=4, threshold=0.5)

    dpi = 300
    figsize = (res[0] / dpi, res[1] / dpi)
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    cmap=fractal_palette(cmap_custom)
    ax.imshow(
        pred_glow,
        extent=[xlim[0], xlim[1], ylim[0], ylim[1]],
        origin="lower",
        interpolation="none",
        aspect="equal",
        cmap=cmap,
    )
    ax.set_axis_off()
    plt.subplots_adjust(0, 0, 1, 1, 0, 0)
    fig.savefig(fname, dpi=dpi, bbox_inches=None, pad_inches=0)
    plt.close(fig)
    print("Saved:", fname)

### Fourier Features

In [7]:
class FourierFeatures(nn.Module):
    """
        Gaussian Fourier Features
    """
    def __init__(self, in_dim=2, num_feats=256, sigma=10.0):
        super().__init__()
        B = torch.randn(in_dim, num_feats) * sigma
        self.register_buffer("B", B)

    def forward(self, x):
        proj = 2 * np.pi * x @ self.B
        return torch.cat([proj.sin(), proj.cos()], dim=-1)

In [8]:
class MultiScaleGaussianFourierFeatures(nn.Module):
    def __init__(self, in_dim=2, num_feats=512, sigmas=(2.0, 6.0, 10.0), seed=0):
        super().__init__()
        # split features across scales
        k = len(sigmas)
        per = [num_feats // k] * k
        per[0] += num_feats - sum(per)

        Bs = []
        g = torch.Generator()
        g.manual_seed(seed)
        for s, m in zip(sigmas, per):
            B = torch.randn(in_dim, m, generator=g) * s
            Bs.append(B)

        self.register_buffer("B", torch.cat(Bs, dim=1))

    def forward(self, x):
        proj = (2 * torch.pi) * (x @ self.B)
        return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)

### Creating a dataset

In [9]:
def smooth_escape(x: float, y: float, max_iter: int = 1000) -> float:
    c = complex(x, y)
    z = 0j
    for n in range(max_iter):
        z = z*z + c
        r2 = z.real*z.real + z.imag*z.imag
        if r2 > 4.0:
            r = math.sqrt(r2)
            mu = n + 1 - math.log(math.log(r)) / math.log(2.0)  # smooth
            # log-scale to spread small mu
            v = math.log1p(mu) / math.log1p(max_iter)
            return float(np.clip(v, 0.0, 1.0))
    return 1.0

In [10]:
def sample_uniform(n, xlim, ylim, seed=0):
    rng = np.random.default_rng(seed)
    xs = rng.uniform(xlim[0], xlim[1], n)
    ys = rng.uniform(ylim[0], ylim[1], n)
    return np.stack([xs, ys], axis=1).astype(np.float32)

In [11]:
def build_boundary_biased_dataset(
    n_total=800_000,
    frac_boundary=0.7,
    xlim=(-2.4, 1.0),
    res_for_ylim=(3840, 2160),
    ycenter=0.0,
    max_iter=1000,
    band=(0.05, 0.98),
    seed=0,
):
    """
    Stable alternative to loss-reweighting:
    - Mix of uniform samples + boundary-band samples.
    - 'band' selects points with target in (low, high), which tends to concentrate near boundary.
    """
    rng = np.random.default_rng(seed)
    ylim = compute_ylim_from_x(xlim, res_for_ylim, ycenter=ycenter)

    n_boundary = int(n_total * frac_boundary)
    n_uniform  = n_total - n_boundary

    # Uniform set
    Xu = sample_uniform(n_uniform, xlim, ylim, seed=seed)

    # Boundary pool: oversample, then filter by band
    pool_factor = 20
    pool = sample_uniform(n_boundary * pool_factor, xlim, ylim, seed=seed + 1)

    yp = np.empty((pool.shape[0],), dtype=np.float32)
    for i, (x, y) in enumerate(pool):
        yp[i] = smooth_escape(float(x), float(y), max_iter=max_iter)

    mask = (yp > band[0]) & (yp < band[1])
    Xb = pool[mask]
    yb = yp[mask]

    if len(Xb) < n_boundary:
        # If band too strict, relax it automatically
        keep = min(len(Xb), n_boundary)
        print(f"[warn] Boundary band too strict; got {len(Xb)} boundary points, using {keep}.")
        Xb = Xb[:keep]
        yb = yb[:keep]
        n_boundary = keep
        n_uniform = n_total - n_boundary
        Xu = sample_uniform(n_uniform, xlim, ylim, seed=seed)

    else:
        Xb = Xb[:n_boundary]
        yb = yb[:n_boundary]

    yu = np.empty((Xu.shape[0],), dtype=np.float32)
    for i, (x, y) in enumerate(Xu):
        yu[i] = smooth_escape(float(x), float(y), max_iter=max_iter)

    X = np.concatenate([Xu, Xb], axis=0).astype(np.float32)
    y = np.concatenate([yu, yb], axis=0).astype(np.float32)

    # Shuffle once
    perm = rng.permutation(X.shape[0])
    return X[perm], y[perm], ylim

In [12]:
class IndexedTensorDataset(Dataset):
    def __init__(self, X, y):
        # X: numpy (N,2), y: numpy (N,)
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], idx


### Neural Network

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, dim: int, act: str = "silu", dropout: float = 0.0):
        super().__init__()
        activation = nn.ReLU if act.lower() == "relu" else nn.SiLU

        self.ln1 = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, dim)

        self.ln2 = nn.LayerNorm(dim)
        self.fc2 = nn.Linear(dim, dim)

        self.act = activation()
        self.drop = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()

        # small init for the last layer to start near-identity
        nn.init.zeros_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        h = self.ln1(x)
        h = self.act(self.fc1(h))
        h = self.drop(h)
        h = self.ln2(h)
        h = self.fc2(h)
        return x + h

In [14]:
class MLPFourierRes(nn.Module):
    def __init__(
        self,
        num_feats=256,
        sigma=5.0,
        hidden_dim=256,
        num_blocks=8,
        act="silu",
        dropout=0.0,
        out_dim=1,
    ):
        super().__init__()
        self.ff = MultiScaleGaussianFourierFeatures(
            2,
            num_feats=num_feats,
            sigmas=(2.0, 6.0, sigma),
            seed=0
        )

        self.in_proj = nn.Linear(2 * num_feats, hidden_dim)

        self.blocks = nn.Sequential(*[
            ResidualBlock(hidden_dim, act=act, dropout=dropout)
            for _ in range(num_blocks)
        ])

        self.out_ln = nn.LayerNorm(hidden_dim)
        activation = nn.ReLU if act.lower() == "relu" else nn.SiLU
        self.out_act = activation()

        self.out_proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.ff(x)
        x = self.in_proj(x)
        x = self.blocks(x)
        x = self.out_act(self.out_ln(x))
        return self.out_proj(x)  # still no sigmoid during training

### Training

In [15]:
@torch.no_grad()
def eval_loss(model, loader, device, criterion):
    model.eval()
    tot = 0.0
    n = 0
    for Xb, yb, _ in loader:
        Xb = Xb.to(device)
        yb = yb.to(device)
        pred = model(Xb)
        loss = criterion(pred, yb).mean()
        tot += float(loss.item()) * Xb.size(0)
        n += Xb.size(0)
    return tot / max(1, n)


In [16]:
def train_model(
    model,
    train_dataset,
    val_dataset,
    run_dir,
    epochs=50,
    batch_size=4096,
    lr=3e-4,
    weight_decay=1e-6,
    grad_clip=1.0,
    amp=True,
    render_every=10,
    render_res=(1920, 1080),
    xlim=(-2.4, 1.0),
    ycenter=0.0,
    tile=(512, 512),
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    # SmoothL1 is very stable for this.
    criterion = nn.SmoothL1Loss(reduction="none")

    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=(amp and device.type == "cuda"))

    metrics_path = run_dir / "metrics.csv"
    with open(metrics_path, "w") as f:
        f.write("epoch,train_loss,val_loss\n")

    for epoch in range(1, epochs + 1):
        model.train()
        tot = 0.0
        n = 0

        for Xb, yb, _ in train_loader:
            Xb = Xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=(amp and device.type == "cuda")):
                pred = model(Xb)
                loss = criterion(pred, yb).mean()

            scaler.scale(loss).backward()

            if grad_clip is not None:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            scaler.step(opt)
            scaler.update()

            tot += float(loss.item()) * Xb.size(0)
            n += Xb.size(0)

        train_loss = tot / max(1, n)
        val_loss = eval_loss(model, val_loader, device, criterion)

        scheduler.step()
        with open(metrics_path, "a") as f:
            f.write(f"{epoch},{train_loss:.8f},{val_loss:.8f}\n")

        print(f"Epoch {epoch:03d} | train {train_loss:.6f} | val {val_loss:.6f}")

        # checkpoint
        if epoch == epochs or (epoch % render_every == 0):
            ckpt_path = run_dir / "ckpt" / f"model_epoch_{epoch:03d}.pt"
            torch.save(model.state_dict(), ckpt_path)

            # render a preview image
            out_img = run_dir / "images" / f"render_epoch_{epoch:03d}.png"
            plot_model_heatmap_tiled(
                model, device,
                xlim=xlim, ycenter=ycenter, res=render_res, tile=tile,
                fname=str(out_img),
                title=f"Model (epoch {epoch})",
                amp=amp
            )

    return model

### Running

In [17]:
def make_run_dir(base="runs", tag=""):
    """ Make directory to track experiment """
    ts = time.strftime("%Y-%m-%d_%H-%M-%S")
    name = ts + (f"_{tag}" if tag else "")
    run_dir = Path(base) / name
    run_dir.mkdir(parents=True, exist_ok=False)
    (run_dir / "images").mkdir(exist_ok=True)
    (run_dir / "ckpt").mkdir(exist_ok=True)
    return run_dir

In [18]:
cfg = {
    "xlim": (-2.4, 1.0),
    "res_for_ylim": (3840, 2160),
    "ycenter": 0.0,
    "max_iter_labels": 1000,

    "dataset_n_total": 1000000,
    "dataset_frac_boundary": 0.7,
    "boundary_band": (0.35, 0.95),
    "seed": 0,

    "model_num_feats": 512,
    "model_sigma": 10.0,
    "model_hidden_dim": 512,
    "model_hidden_layers": 20,
    "model_act": "silu",

    "train_epochs": 100,
    "train_batch_size": 4096,
    "train_lr": 3e-4,
    "train_weight_decay": 1e-6,
    "train_grad_clip": 1.0,
    "train_amp": True,

    "preview_every": 1,
    "preview_res": (1920, 1080),
    "preview_tile": (512, 512),

    "final_res": (3840, 2160),
    "final_tile": (512, 512),
}

run_dir = make_run_dir(tag="mandelbrot_competitive")
with open(run_dir / "config.json", "w") as f:
    json.dump(cfg, f, indent=2)

print("Run dir:", run_dir)

# Build dataset
X, y, ylim = build_boundary_biased_dataset(
    n_total=cfg["dataset_n_total"],
    frac_boundary=cfg["dataset_frac_boundary"],
    xlim=cfg["xlim"],
    res_for_ylim=cfg["res_for_ylim"],
    ycenter=cfg["ycenter"],
    max_iter=cfg["max_iter_labels"],
    band=cfg["boundary_band"],
    seed=cfg["seed"],
)

print("y stats:",
        "min", float(y.min()),
        "max", float(y.max()),
        "mean", float(y.mean()),
        "p50", float(np.quantile(y, 0.50)),
        "p99", float(np.quantile(y, 0.99)))

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15, random_state=cfg["seed"], shuffle=True)

train_ds = IndexedTensorDataset(X_train, y_train)
val_ds   = IndexedTensorDataset(X_val,   y_val)

# Model
model = MLPFourierRes(
    num_feats=cfg["model_num_feats"],
    sigma=cfg["model_sigma"],
    hidden_dim=cfg["model_hidden_dim"],
    num_blocks=cfg["model_hidden_layers"],
    act=cfg["model_act"],
    dropout=0.0,
)

# Train
model = train_model(
    model,
    train_ds,
    val_ds,
    run_dir=run_dir,
    epochs=cfg["train_epochs"],
    batch_size=cfg["train_batch_size"],
    lr=cfg["train_lr"],
    weight_decay=cfg["train_weight_decay"],
    grad_clip=cfg["train_grad_clip"],
    amp=cfg["train_amp"],
    render_every=cfg["preview_every"],
    render_res=cfg["preview_res"],
    xlim=cfg["xlim"],
    ycenter=cfg["ycenter"],
    tile=cfg["preview_tile"],
)

# Final 4K render
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
final_path = run_dir / "images" / "final_4k.png"
plot_model_heatmap_tiled(
    model, device,
    xlim=cfg["xlim"],
    ycenter=cfg["ycenter"],
    res=cfg["final_res"],
    tile=cfg["final_tile"],
    fname=str(final_path),
    title="Final 4K Model Render",
    amp=cfg["train_amp"],
)

Run dir: runs/2025-12-31_17-40-22_mandelbrot_competitive


KeyboardInterrupt: 

In [None]:
# Save model
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/fourier_mlp_final.pt")