Unlike the paper, the code calculates primes dynamically instead of using a lookup table.
This code generates Figure 1.

In [None]:
# log-distribution landscape and rms error combined
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd
import numpy as np
from scipy.interpolate import griddata
import os
import random

# ==========================================
# 0. SETUP & SEEDING
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(42)

# ==========================================
# 1. MODEL DEFINITIONS
# ==========================================

def get_first_n_primes(n):
    if n < 1: return torch.tensor([], device=device)
    limit = int(n * (math.log(n) + math.log(math.log(n)))) + 20 if n > 5 else 20
    sieve = torch.ones(limit, dtype=torch.bool, device=device)
    sieve[0:2] = False
    for i in range(2, int(math.isqrt(limit)) + 1):
        if sieve[i]: sieve[i*i : limit : i] = False
    primes = torch.nonzero(sieve).flatten()
    while len(primes) < n:
        limit *= 2
        sieve = torch.ones(limit, dtype=torch.bool, device=device)
        sieve[0:2] = False
        for i in range(2, int(math.isqrt(limit)) + 1):
            if sieve[i]: sieve[i*i : limit : i] = False
        primes = torch.nonzero(sieve).flatten()
    return primes[:n].float()

class Method_RandomGaussian(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, n_seq):
        return torch.nn.functional.normalize(torch.randn(n_seq, self.dim, device=device), p=2, dim=1)

# class StaticPrime(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         primes = get_first_n_primes(dim // 2)
#         self.register_buffer("freqs", (torch.sqrt(primes))) 

#     def forward(self, n_seq):
#         t = torch.arange(n_seq, device=self.freqs.device).double()
#         freqs_double = self.freqs.double()
#         angles = torch.outer(t, freqs_double) * 2 * math.pi
#         return torch.cat([angles.cos().float(), angles.sin().float()], dim=-1)

# def get_models(d):
#     return {
#         "RandNorm Gaussian": Method_RandomGaussian(d),
#         "Static Prime": StaticPrime(d),
#     }

class StaticPrime(nn.Module):
    def __init__(self, dim):
        super().__init__()
        primes = get_first_n_primes(dim // 2)                                       
        self.register_buffer("freqs", (torch.sqrt(primes)))
    def forward(self, n_seq):
        t = torch.arange(n_seq, device=self.freqs.device)
        angles = torch.outer(t, self.freqs) * 2 * math.pi
        return torch.cat([angles.cos(), angles.sin()], dim=-1)
        
# ==========================================
# 2. DATA GENERATION & METRIC CALCULATION
# ==========================================

def calculate_tightness_score(n_seq=10000, d_vals=[16, 64, 256, 1024, 4096]):
    """
    Calculates the 'Tightness' (RMS of cross-correlations) for comparison text.
    Lower RMS is better.
    """
    scores = {}
    print("Calculating Comparative Scores...")
    
    # We aggregate the score across multiple dimensions to get a robust average
    for name in ["RandNorm Gaussian", "Static Prime"]:
        method_scores = []
        for d in d_vals:
            model = get_models(d)[name].to(device)
            with torch.no_grad():
                vecs = model(n_seq)
                vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                gram = torch.mm(vecs, vecs.t())
                mask = ~torch.eye(n_seq, device=device, dtype=torch.bool)
                vals = gram[mask]
                # RMS Calculation
                rms = torch.sqrt(torch.mean(vals**2)).item()
                method_scores.append(rms)
        scores[name] = np.mean(method_scores)
    
    baseline = scores["RandNorm Gaussian"]
    ours = scores["Static Prime"]
    
    # Calculate % improvement (Reduction in RMS error)
    improvement = ((baseline - ours) / baseline) * 100
    return scores, improvement

def generate_distribution_data():
    """Row 1 Data: Log-Distribution"""
    print("Generating Distribution Data...")
    fixed_N = 10000 
    d_values = [16, 64, 256, 1024, 4096] 
    bins = 100
    hist_range = (-0.4, 0.4)
    all_data = []

    for d in d_values:
        models = get_models(d)
        for name, model in models.items():
            model.to(device)
            with torch.no_grad():
                vecs = model(fixed_N)
                vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                gram = torch.mm(vecs, vecs.t())
                mask = ~torch.eye(fixed_N, device=device, dtype=torch.bool)
                vals = gram[mask]
                
                hist = torch.histc(vals, bins=bins, min=hist_range[0], max=hist_range[1])
                hist_cpu = hist.cpu().numpy()
                hist_density = hist_cpu / vals.numel()
                hist_log = np.log10(hist_density + 1e-9)
                
                bin_edges = np.linspace(hist_range[0], hist_range[1], bins + 1)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
                
                for b_idx, val in enumerate(hist_log):
                    all_data.append({
                        "Method": name,
                        "Dimension": d,
                        "Similarity": bin_centers[b_idx],
                        "LogDensity": val
                    })
    return pd.DataFrame(all_data)

def generate_rms_surface_data():
    """Row 2 Data: RMS Error Surface"""
    print("Generating RMS Surface Data...")
    n_values = [100, 1000, 5000]
    d_values = [16, 64, 128, 256, 512, 1024]
    results = []

    for d in d_values:
        models = get_models(d)
        for n in n_values:
            for name, model in models.items():
                model.to(device)
                with torch.no_grad():
                    vecs = model(n)
                    vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                    gram = torch.mm(vecs, vecs.t())
                    mask = ~torch.eye(n, device=device, dtype=torch.bool)
                    off_diag = gram[mask]
                    rms_error = torch.sqrt(torch.mean(off_diag**2)).item()
                
                results.append({
                    "Method": name,
                    "N": n,
                    "Dimension": d,
                    "RMS": rms_error
                })
    return pd.DataFrame(results)

# ==========================================
# 3. VISUALIZATION PIPELINE
# ==========================================

def plot_comparative_dashboard():
    # 1. Compute Metrics
    scores, tightness_improvement = calculate_tightness_score()
    
    # 2. Get Plotting Data
    df_dist = generate_distribution_data()
    df_rms = generate_rms_surface_data()

    # 3. Setup Figure (2 Rows x 2 Cols)
    fig = plt.figure(figsize=(18, 12)) 
    
    methods = ["RandNorm Gaussian", "Static Prime"]
    titles = ["Baseline (Gaussian)", "Ours (StaticPrime)"]
    
    # --- ROW 1: DISTRIBUTION (3D) ---
    z_min_dist = df_dist["LogDensity"].min()
    z_max_dist = df_dist["LogDensity"].max()
    
    for i, method in enumerate(methods):
        ax = fig.add_subplot(2, 2, i+1, projection='3d')
        subset = df_dist[df_dist["Method"] == method]
        
        x = subset["Similarity"].values
        y = subset["Dimension"].values
        z = subset["LogDensity"].values
        
        xi = np.linspace(x.min(), x.max(), 60)
        yi = np.linspace(y.min(), y.max(), 60)
        Xi, Yi = np.meshgrid(xi, yi)
        Zi = griddata((x, y), z, (Xi, Yi), method='cubic')
        
        ax.plot_surface(Xi, Yi, Zi, cmap='plasma', linewidth=0, antialiased=True, alpha=0.9)
        
        ax.set_title(f"{titles[i]}", fontsize=15, weight='bold')
        ax.set_xlabel('Cosine Similarity')
        ax.set_ylabel('Dimension (d)')
        ax.set_zlabel('Log Density')
        ax.set_zlim(z_min_dist, z_max_dist)
        ax.view_init(elev=30, azim=-60)
        ax.grid(False)
        ax.xaxis.pane.fill = ax.yaxis.pane.fill = ax.zaxis.pane.fill = False

    # --- ADD CENTRAL ANNOTATION FOR ROW 1 ---
    # Calculates relative improvements to display between plots
    fig.text(0.5, 0.75, 
             f"DISTRIBUTION TIGHTNESS\n(RMS Reduction)\n\nOurs is\n{tightness_improvement:.1f}%\nBetter\n➜", 
             ha='center', va='center', fontsize=14, weight='bold', color='darkred',
             bbox=dict(boxstyle="rarrow,pad=0.3", fc="white", ec="darkred", lw=2))

    # --- ROW 2: RMS ERROR SURFACE (3D) ---
    z_min_rms = df_rms["RMS"].min()
    z_max_rms = df_rms["RMS"].max()
    
    rms_improvements = {} # To calculate Row 2 improvement

    for i, method in enumerate(methods):
        ax = fig.add_subplot(2, 2, i+3, projection='3d')
        subset = df_rms[df_rms["Method"] == method]
        
        x = subset["Dimension"].values
        y = subset["N"].values
        z = subset["RMS"].values
        
        xi = np.linspace(x.min(), x.max(), 60)
        yi = np.linspace(y.min(), y.max(), 60)
        Xi, Yi = np.meshgrid(xi, yi)
        Zi = griddata((x, y), z, (Xi, Yi), method='cubic')
        
        norm = plt.Normalize(z_min_rms, z_max_rms)
        ax.plot_surface(Xi, Yi, Zi, cmap='coolwarm', norm=norm, linewidth=0, antialiased=True, alpha=0.9)
        
        avg_rms = subset["RMS"].mean()
        rms_improvements[method] = avg_rms # Store for calc
        
        ax.set_title(f"{titles[i]}\nAvg RMS Error: {avg_rms:.4f}", fontsize=15, weight='bold')
        ax.set_xlabel('Dimension (d)')
        ax.set_ylabel('Seq Length (N)')
        ax.set_zlabel('RMS Error')
        ax.set_zlim(z_min_rms, z_max_rms)
        ax.view_init(elev=30, azim=-130)
        ax.grid(False)
        ax.xaxis.pane.fill = ax.yaxis.pane.fill = ax.zaxis.pane.fill = False

    base_rms = rms_improvements["RandNorm Gaussian"]
    our_rms = rms_improvements["Static Prime"]
    rms_imp = ((base_rms - our_rms) / base_rms) * 100
    
    fig.text(0.5, 0.25, 
             f"GLOBAL ERROR\n(Lower is Better)\n\nOurs is\n{rms_imp:.1f}%\nLower\n➜", 
             ha='center', va='center', fontsize=14, weight='bold', color='navy',
             bbox=dict(boxstyle="rarrow,pad=0.3", fc="white", ec="navy", lw=2))

    plt.subplots_adjust(top=0.88, wspace=0.3, hspace=0.3)
    plt.suptitle(f"Comparison: Gaussian Baseline vs. Static Prime (Ours)\nAnalyzed over N=10,000 samples", 
                 fontsize=22, weight='bold', y=0.98)
    
    print(f"Rendering Plot... (Tightness Improvement: {tightness_improvement:.2f}%)")
    plt.show()

if __name__ == "__main__":
    plot_comparative_dashboard()