In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, roc_auc_score
from torch.utils.data import TensorDataset, DataLoader, Dataset, WeightedRandomSampler
from typing import Sequence, Tuple


In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')   # Check wether gpu is available

### The Mandelbrot set
The Mandelbrot set is a two-dimensional set that is defined in the complex plane as the complex numbers $c$ for which the function $f_c(z) = z^2 + c $ does not diverge to infinity when iterated starting at $z=0$.

Interesting properties:
- A point c belongs to the Mandelbrot set iff $|z| \leq 2$ for all $n \geq 0$


### Visualization

In [None]:
@torch.no_grad()
def model_grid_tiled(model, device, xlim, ylim, res, tile=(512, 512), amp=True):
    model.eval()
    W, H = res
    tw, th = tile

    xs = np.linspace(xlim[0], xlim[1], W, endpoint=False, dtype=np.float32)
    ys = np.linspace(ylim[0], ylim[1], H, endpoint=False, dtype=np.float32)

    out = np.empty((H, W), dtype=np.float32)

    for y0 in range(0, H, th):
        y1 = min(y0 + th, H)
        Y = ys[y0:y1]

        for x0 in range(0, W, tw):
            x1 = min(x0 + tw, W)
            X = xs[x0:x1]

            XX, YY = np.meshgrid(X, Y)
            grid = np.stack([XX.reshape(-1), YY.reshape(-1)], axis=1)

            g = torch.from_numpy(grid).to(
                device,
                dtype=torch.float16 if (amp and device.type == "cuda") else torch.float32,
                non_blocking=True
            )

            if amp and device.type == "cuda":
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    v = model(g).squeeze(1)
            else:
                v = model(g).squeeze(1)

            out[y0:y1, x0:x1] = v.float().cpu().numpy().reshape((y1 - y0, x1 - x0))
            del XX, YY, grid, g, v

    return out


In [None]:
def compute_ylim_from_x(xlim, res, ycenter=0.0):
    """
    Keep square pixels in complex plane by matching step size in x and y.
    """
    W, H = res
    step = (xlim[1] - xlim[0]) / W
    y_half = step * H / 2
    return (ycenter - y_half, ycenter + y_half)

In [None]:
def plot_model_heatmap_tiled(
    model, device,
    xlim=(-2.4, 1.0),
    ycenter=0.0,
    res=(3840, 2160),
    tile=(512, 512),
    fname="render.png",
    title="Model",
    amp=False,
    gamma=0.85,
    qlo=0.01,
    qhi=0.99,
    cmap="inferno",
):
    ylim = compute_ylim_from_x(xlim, res, ycenter=ycenter)

    # render logits in float32
    pred = model_grid_tiled(model, device, xlim, ylim, res, tile=tile, amp=amp).astype(np.float32)

    pred = 1.0 / (1.0 + np.exp(-pred))   # sigmoid

    # robust contrast to avoid flattening + avoid amplifying tiny noise too much
    lo, hi = np.quantile(pred, [qlo, qhi])
    pred = (pred - lo) / (hi - lo + 1e-8)
    pred = np.clip(pred, 0.0, 1.0)

    # mild gamma (too aggressive gamma makes grain visible)
    pred = pred ** gamma

    dpi = 300
    figsize = (res[0] / dpi, res[1] / dpi)
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    ax.imshow(
        pred,
        extent=[xlim[0], xlim[1], ylim[0], ylim[1]],
        origin="lower",
        interpolation="none",
        aspect="equal",
        cmap=cmap,
    )
    ax.set_axis_off()
    plt.subplots_adjust(0, 0, 1, 1, 0, 0)
    fig.savefig(fname, dpi=dpi, bbox_inches=None, pad_inches=0)
    plt.close(fig)
    print("Saved:", fname)

### Fourier Features

In [None]:
class FourierFeatures(nn.Module):
    """
        Gaussian Fourier Features
    """
    def __init__(self, in_dim=2, num_feats=256, sigma=10.0):
        super().__init__()
        B = torch.randn(in_dim, num_feats)
        self.register_buffer("B", B)

    def forward(self, x):
        proj = 2 * np.pi * x @ self.B
        return torch.cat([proj.sin(), proj.cos()], dim=-1)

### Creating a dataset

In [None]:
def mandelbrot_grid_dataset(nx=750, ny=750, xlim=(-2.0, 1.0), ylim =(-1.5, 1.5), max_iter=1000):
    xs = np.linspace(xlim[0], xlim[1], nx)
    ys = np.linspace(ylim[0], ylim[1], ny)
    X, Y = np.meshgrid(xs, ys, indexing="xy")
    
    C = X + 1j * Y
    Z = np.zeros_like(C)
    mask = np.ones(C.shape, dtype=bool)

    for _ in range(max_iter):
        Z[mask] = Z[mask] * Z[mask] + C[mask]
        # Update mask for those that remain within radius 2
        mask[mask] = (np.abs(Z[mask]) <= 2.0)

    # If after max_iter still True, it's considered inside the set
    in_set = mask
    return X, Y, in_set

In [None]:
def smooth_escape(x: float, y: float, max_iter: int = 1000) -> float:
    c = complex(x, y)
    z = 0j
    for n in range(max_iter):
        z = z*z + c
        r2 = z.real*z.real + z.imag*z.imag
        if r2 > 4.0:
            r = math.sqrt(r2)
            mu = n + 1 - math.log(math.log(r)) / math.log(2.0)  # smooth
            # log-scale to spread small mu
            v = math.log1p(mu) / math.log1p(max_iter)
            return float(np.clip(v, 0.0, 1.0))
    return 1.0

In [None]:
def build_multiscale_dataset(
    n_points=1000000,
    xlim=(-2.0, 1.0),
    ylim=(-1.5, 1.5),
    zoom_centers: Sequence[Tuple[float, float]] = (
        (-0.743643887037151, 0.131825904205330),  # seahorse valley
        (-0.7435669,          0.1314023),         # nearby spiral
    ),
    zoom_scale_range=(0.01, 1.0),  # fraction of full window size, log-uniform
    p_global=0.3,                  # probability of sampling from full domain
    max_iter=1000,
    random_state=42,
):
    """
    Returns:
        X: (N, 2) float32 array of coordinates
        y: (N,)  float32 array of labels in {0,1}
    """
    rng = np.random.default_rng(random_state)
    N = n_points

    full_width  = xlim[1] - xlim[0]
    full_height = ylim[1] - ylim[0]

    # Decide which samples are global vs zoomed
    is_global = rng.random(N) < p_global
    idx_global = np.where(is_global)[0]
    idx_zoom   = np.where(~is_global)[0]

    X = np.empty((N, 2), dtype=np.float32)

    #Global samples (uniform over whole domain)
    if len(idx_global) > 0:
        xs = rng.uniform(xlim[0], xlim[1], size=len(idx_global))
        ys = rng.uniform(ylim[0], ylim[1], size=len(idx_global))
        X[idx_global, 0] = xs
        X[idx_global, 1] = ys

    # Zoomed samples (multi-scale, multi-center)
    if len(idx_zoom) > 0:
        # pick a center for each zoom sample
        centers_idx = rng.integers(0, len(zoom_centers), size=len(idx_zoom))
        centers = np.array(zoom_centers)[centers_idx]  # (M, 2)

        # log-uniform scales in [zoom_scale_range[0], zoom_scale_range[1]]
        z_min, z_max = zoom_scale_range
        log_scales = rng.uniform(np.log10(z_min), np.log10(z_max), size=len(idx_zoom))
        scales = 10 ** log_scales  # (M,)

        widths  = full_width  * scales
        heights = full_height * scales

        # sample x,y uniformly in each random window
        xs = centers[:, 0] + (rng.random(len(idx_zoom)) - 0.5) * widths
        ys = centers[:, 1] + (rng.random(len(idx_zoom)) - 0.5) * heights

        X[idx_zoom, 0] = xs
        X[idx_zoom, 1] = ys
    
    print(f"Multiscale dataset: {N} points "
          f"({len(idx_global)} global, {len(idx_zoom)} zoomed)")
    # print("Positive ratio:", float(y.mean()))

    return X


In [None]:
class IndexedTensorDataset(Dataset):
    def __init__(self, X, y):
        # X: numpy (N,2), y: numpy (N,)
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], idx


### Neural Network

In [None]:
class NeuralNetFourierFeatures(nn.Module):
    def __init__(self, num_hidden_layers=3, num_feats=256, hidden_dim=32):
        super().__init__()
        self.ff = FourierFeatures(in_dim=2, num_feats=num_feats)

        layers = [
            nn.Linear(2*num_feats, hidden_dim),
            nn.ReLU(),
        ]

        for _ in range(num_hidden_layers):
            layers.append(
                nn.Linear(hidden_dim, hidden_dim)
            )
            layers.append(
                nn.ReLU()
            )

        # Output
        layers.append(
            nn.Linear(hidden_dim, 1)
        )
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        x = self.ff(x)
        x = self.network(x)
        return x

### Training Loop

In [None]:
def train_regression_with_reweighting(
    model,
    train_dataset,
    val_dataset,
    num_epoch=100,
    batch_size=4096,
    lr=1e-3,
    alpha=1.0,      # how strongly to emphasize hard examples
    eps=1e-4,       # avoid zero weights
    visualize=False,
    visualize_epochs=5,
    save_checkpoints=True,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model into GPU
    model.to(device)

    N = len(train_dataset)
    weights=torch.ones(N, dtype=torch.float32)
    criterion = nn.L1Loss(reduction="none")
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    train_losses = []
    val_losses = []
    saved_ckpts = []

    for epoch in range(num_epoch):
        sampler = WeightedRandomSampler(weights=weights, num_samples=N, replacement=True)
        loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
        model.train() # Set the model to training mode

        loss_per_sample = torch.zeros(N, dtype=torch.float32)
        count_per_sample = torch.zeros(N, dtype=torch.float32)
        total_loss = 0
        total_count = 0

        for Xb, yb, idxb in loader:
            Xb = Xb.to(device)
            yb = yb.to(device)

            # idxb = idxb.view(-1)
                        
            optimizer.zero_grad()
            preds = model(Xb).float()
            yb = yb.float()

            if preds.dim() == 1:
                preds = preds.unsqueeze(1)     # (B,) -> (B,1)
            if yb.dim() == 1:
                yb = yb.unsqueeze(1)           # (B,) -> (B,1)
            
            # print(f"Preds shape: {preds.shape}, yb shape: {yb.shape}")
            # print(f"Preds shape squeeze: {preds.squeeze(1).shape}, yb shape: {yb.shape}")
            
            per_loss = criterion(preds, yb).view(-1)
            
            loss = per_loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
            optimizer.step()

            # print(per_loss.shape, idxb.shape, loss_per_sample[idxb].shape)

            loss_per_sample[idxb] += per_loss.detach().cpu()
            count_per_sample[idxb] += 1.0

            total_loss  += loss.item() * Xb.size(0)
            total_count += Xb.size(0)

        avg_train_loss = total_loss / total_count

        # normalizing per-sample losses
        mask = count_per_sample > 0
        loss_per_sample[mask] /= count_per_sample[mask]

        # update sample weights
        new_weights = (loss_per_sample + eps) ** alpha
        new_weights /= new_weights.sum()
        weights = new_weights

        # Validation
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        model.eval()
        with torch.no_grad():
            total_v = 0
            count_v = 0
            for Xv, yv, _ in val_loader:
                Xv = Xv.to(device)
                yv = yv.to(device)

                predv = model(Xv).float()

                if predv.dim() == 1:
                    predv= predv.unsqueeze(1)
                if yv.dim() == 1:
                    yv = yv.unsqueeze(1)
                
                loss_v = torch.abs(predv - yv).mean().view(-1)

                total_v += loss_v.item() * Xv.size(0)
                count_v += Xv.size(0)
            avg_val_loss = total_v / count_v

        print(f"Epoch: {epoch+1}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        # every visualize_epochs epochs:
        if visualize and (epoch + 1) % visualize_epochs == 0:
            plot_model_heatmap_tiled(model, device, res=(1920, 840), title=f"Model epoch {epoch}", amp=True)
        if save_checkpoints and ((epoch + 1) % visualize_epochs == 0 or (epoch + 1) == num_epoch):
            ckpt_path = os.path.join("images/checkpoints", f"ckpt_epoch_{epoch+1:03d}.pt")
            torch.save(model.state_dict(), ckpt_path)
            saved_ckpts.append(ckpt_path)

    # Plot the learning curves
    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label="Training loss")
    plt.plot(val_losses, label="Validation loss")
    plt.title("Learning Curves")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    return model

### Testing Loop

In [None]:
def test(model, X_test, y_test, threshold=0.5):
    # Convert data to Pytorch tensors
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

    model.to(device)
    # Evaluation mode
    model.eval() 

    with torch.no_grad():
        # Load into GPU
        X_test_tensor = X_test_tensor.to(device)
        y_prob_tensor = model(X_test_tensor)
        
        # Convert probabilities to binary predictions (0 or 1)
        y_pred = (y_prob_tensor > threshold).float()

        y_true_np = y_test_tensor.cpu().numpy()
        y_pred_np = y_pred.cpu().numpy()
        y_prob_np = y_prob_tensor.cpu().numpy()

        # Calculate metrics
        auc = roc_auc_score(y_true_np, y_prob_np)
        print(f"Test AUC: {auc:.4f}\n")
        
        accuracy = accuracy_score(y_true_np, y_pred_np)
        print(f"Test Accuracy: {accuracy:.4f}\n")
        
        print("Classification Report:")
        print(classification_report(y_true_np, y_pred_np))
        
        print("Confusion Matrix:")
        print(confusion_matrix(y_true_np, y_pred_np))

        f1 = f1_score(y_true_np, y_pred_np, pos_label=1)
        tn, fp, fn, tp = confusion_matrix(y_true_np, y_pred_np).ravel()

        report = {
            "accuracy": accuracy,
            "f1_score_positive": f1,
            "auc": auc,
            "true_positives": tp,
            "false_positives": fp,
            "true_negatives": tn,
            "false_negatives": fn
        }
    
    return report

### Visualization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

xlim, ylim = (-2.0, 1.0), (-1.5, 1.5)

X = build_multiscale_dataset(
    n_points=500000,
    xlim=xlim,
    ylim=ylim,
    zoom_scale_range=(0.01, 1.0),  # from 1% to full window
    p_global=0.3,
)

y = np.array(
    [smooth_escape(x, y, max_iter=200) for x, y in X],
    dtype=np.float32
)

# Stratified split
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X, y, test_size=0.20, stratify=y, random_state=42, shuffle=True
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.50, stratify=y_tmp, random_state=42, shuffle=True
)

# Create dataset
train_dataset = IndexedTensorDataset(X_train, y_train)
val_dataset   = IndexedTensorDataset(X_val,   y_val)
test_dataset  = IndexedTensorDataset(X_test,  y_test)

# Model
model = NeuralNetFourierFeatures(num_hidden_layers=4, hidden_dim=32)

model = train_regression_with_reweighting(
    model,
    train_dataset,
    val_dataset,
    num_epoch=25,      # e.g. 100â€“200
    batch_size=4096,
    lr=1e-3,
    alpha=1.0,          # emphasis on hard examples
    eps=1e-4,
)

In [None]:
# Save model
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/fourier_mlp_final.pt")
