Unlike the paper, the code calculates primes dynamically instead of using a lookup table.
This code generates Figure 3.

In [None]:
# Code contains dynamic version of our method and it's backward and forward capabilities and adventages. 
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.decomposition import PCA
import random
import os


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ==========================================
# 1. HELPER FUNCTIONS
# ==========================================

def get_first_n_primes(n):
    if n < 1: return torch.tensor([], device=device)
    if n == 1: return torch.tensor([2.0], device=device)
    if n < 6:
        limit = 13
    else:
        log_n = math.log(n)
        limit = int(n * (log_n + math.log(log_n))) + 3
    sieve = np.ones(limit // 2, dtype=bool)
    sieve[0] = False
    cross_limit = int((math.isqrt(limit) - 1) / 2)
    for i in range(1, cross_limit + 1):
        if sieve[i]:
            start_idx = 2 * i * (i + 1) + i
            step = 2 * i + 1
            sieve[start_idx::step] = False
    prime_indices = np.nonzero(sieve)[0]
    primes = 2 * prime_indices + 1
    result = np.empty(len(primes) + 1, dtype=np.float32)
    result[0] = 2.0
    result[1:] = primes
    return torch.from_numpy(result[:n]).to(device)

def calculate_metrics(vectors, batch_size=2000):
    n, d = vectors.shape
    vectors = torch.nn.functional.normalize(vectors, p=2, dim=1)
    sum_sq = 0
    max_coh = 0.0
    count = 0
    
    gram = torch.mm(vectors, vectors.t())
    mask = ~torch.eye(n, dtype=bool, device=vectors.device)
    off_diag = gram[mask]
    
    if off_diag.numel() > 0:
        sum_sq = (off_diag ** 2).sum().item()
        max_coh = off_diag.abs().max().item()
        count = off_diag.numel()
        
    rms = math.sqrt(sum_sq / count) if count > 0 else 0
    return rms, max_coh

# ==========================================
# 2. DATA GENERATORS
# ==========================================

def generate_spiral(n_points=1000, noise=0.0):
    n = torch.sqrt(torch.rand(n_points)) * 780 * (2 * math.pi / 360)
    d1x = -torch.cos(n) * n + torch.randn(n_points) * noise
    d1y = torch.sin(n) * n + torch.randn(n_points) * noise
    return torch.stack([d1x, d1y], dim=1).to(device)

def generate_circles(n_points=1000, noise=0.0):
    linspace = torch.linspace(0, 2*math.pi, n_points // 2)
    x1 = torch.cos(linspace) * 5 + torch.randn(n_points // 2) * noise
    y1 = torch.sin(linspace) * 5 + torch.randn(n_points // 2) * noise
    x2 = torch.cos(linspace) * 10 + torch.randn(n_points // 2) * noise
    y2 = torch.sin(linspace) * 10 + torch.randn(n_points // 2) * noise
    c1 = torch.stack([x1, y1], dim=1)
    c2 = torch.stack([x2, y2], dim=1)
    return torch.cat([c1, c2], dim=0).to(device)

# ==========================================
# 3. MODEL CLASS
# ==========================================

# class DynamicPrime(nn.Module):
#     def __init__(self, input_dim, output_dim, scaling_factor=0.01):
#         super().__init__()
#         if output_dim % 2 != 0: raise ValueError("Output dimension must be even.") 
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.half_dim = output_dim // 2
#         self.scaling_factor = scaling_factor
        
#         num_primes_needed = self.half_dim * input_dim
#         primes = get_first_n_primes(num_primes_needed)
#         self.register_buffer("weights", torch.sqrt(primes).reshape(self.half_dim, input_dim))

#     def forward(self, x):
#         self.projection = torch.mm(x, self.weights.t()) * 2 * math.pi * self.scaling_factor
#         return torch.cat([torch.cos(self.projection), torch.sin(self.projection)], dim=-1)

#     def reverse(self, y):
#         cos_part = y[:, :self.half_dim]
#         sin_part = y[:, self.half_dim:]
#         recovered_phases = torch.atan2(sin_part, cos_part) 
        
#         effective_W_T = self.weights.t() * 2 * math.pi * self.scaling_factor
#         W_inv = torch.linalg.pinv(effective_W_T)
        
#         x_hat = torch.mm(recovered_phases, W_inv) 
#         return x_hat

class DynamicPrime(nn.Module):
    def __init__(self, input_dim, output_dim, scaling_factor=0.01):
        super().__init__()
        if output_dim % 2 != 0: raise ValueError("Output dimension must be even.") 
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.half_dim = output_dim // 2
        self.scaling_factor = scaling_factor
        
        num_primes_needed = self.half_dim * input_dim
        primes = get_first_n_primes(num_primes_needed)
        
        # Store as float32 to save memory (Standard PyTorch default)
        # We will cast to double only during calculation
        self.register_buffer("weights", torch.sqrt(primes).reshape(self.half_dim, input_dim).float())

    def forward(self, x):
        # 1. Cast Input and Weights to Double
        x_double = x.double() 
        weights_double = self.weights.double()
        
        # 2. High Precision Projection
        # Performing matrix multiplication and scaling in float64 prevents 
        # phase information loss when values become large.
        projection = torch.mm(x_double, weights_double.t()) * (2 * math.pi * self.scaling_factor)
        
        # 3. Cast back to float (or input type) for output to save memory in next layers
        # Note: cos/sin are computed on double precision inputs, then cast down.
        return torch.cat([torch.cos(projection), torch.sin(projection)], dim=-1).type_as(x)

    def reverse(self, y):
        # 1. Cast input to double to preserve phase accuracy during atan2
        y_double = y.double()
        
        cos_part = y_double[:, :self.half_dim]
        sin_part = y_double[:, self.half_dim:]
        recovered_phases = torch.atan2(sin_part, cos_part) 
        
        # 2. Recalculate Transformation Matrix in Double
        # Inverting matrices is very sensitive to precision. Doing this in double 
        # drastically reduces reconstruction error.
        weights_double = self.weights.double()
        effective_W_T = weights_double.t() * (2 * math.pi * self.scaling_factor)
        
        # 3. High Precision Pseudo-Inverse
        W_inv = torch.linalg.pinv(effective_W_T)
        
        # 4. Reconstruct and cast back to float
        x_hat = torch.mm(recovered_phases, W_inv) 
        return x_hat.float()

# ==========================================
# 4. EXPERIMENT EXECUTION
# ==========================================

def run_experiment():
    n_points = 1000
    noise_levels = [0.0, 0.5, 1.5] 
    patterns = ['Spiral', 'Circles']
    l_scl = 0.007
    h_scl = 1.0
    scecial_scl = 0.02
    
    model_low = DynamicPrime(input_dim=2, output_dim=4, scaling_factor=l_scl).to(device)
    model_high = DynamicPrime(input_dim=2, output_dim=4, scaling_factor=h_scl).to(device)
    model_high_128 = DynamicPrime(input_dim=2, output_dim=128, scaling_factor=scecial_scl).to(device)
    
    rows = len(patterns) * len(noise_levels)
    cols = 7 

    fig, axes = plt.subplots(rows, cols, figsize=(28, 4.0 * rows))
    plt.subplots_adjust(hspace=0.4, wspace=0.3)
    
    row_idx = 0
    
    for pattern_name in patterns:
        for noise in noise_levels:
            # [Data Generation and Model Processing code remains the same...]
            if pattern_name == 'Spiral':
                data = generate_spiral(n_points, noise)
            else:
                data = generate_circles(n_points, noise)
            
            X_np = data.cpu().numpy()
            
            # ... (Low model processing) ...
            with torch.no_grad():
                y_low = model_low(data)
                rms_low, coh_low = calculate_metrics(y_low)
                x_recon_low = model_low.reverse(y_low)
                mse_low = torch.nn.functional.mse_loss(x_recon_low, data).item()
            pca_low = PCA(n_components=2, random_state=42)
            y_low_viz = pca_low.fit_transform(y_low.cpu().numpy())
            x_low_recon_np = x_recon_low.cpu().numpy()

            # ... (High Dim 4 model processing) ...
            with torch.no_grad():
                y_high = model_high(data)
                rms_high, coh_high = calculate_metrics(y_high)
                x_recon_high = model_high.reverse(y_high)
                mse_high = torch.nn.functional.mse_loss(x_recon_high, data).item()
            pca_high = PCA(n_components=2, random_state=42)
            y_high_viz = pca_high.fit_transform(y_high.cpu().numpy())
            x_high_recon_np = x_recon_high.cpu().numpy()

            # ... (High Dim 128 model processing) ...
            with torch.no_grad():
                y_128 = model_high_128(data)
                rms_128, coh_128 = calculate_metrics(y_128)
                x_recon_128 = model_high_128.reverse(y_128)
                mse_128 = torch.nn.functional.mse_loss(x_recon_128, data).item()
            pca_128 = PCA(n_components=2, random_state=42)
            y_128_viz = pca_128.fit_transform(y_128.cpu().numpy())
            x_128_recon_np = x_recon_128.cpu().numpy()

            # ==========================================
            # PLOTTING
            # ==========================================
            ax = axes[row_idx]

            for sub_ax in ax:
                sub_ax.set_box_aspect(1)
            
            # 1. Original
            ax[0].scatter(X_np[:,0], X_np[:,1], s=2, c='blue', alpha=0.5)
            ax[0].set_title(f"INPUT\n{pattern_name} (Noise={noise})")
            ax[0].axis('equal') # This keeps data square, set_box_aspect keeps frame square

            # 2. Low Scale Latent
            ax[1].scatter(y_low_viz[:,0], y_low_viz[:,1], s=2, c='red', alpha=0.5)
            ax[1].set_title(f"LATENT (Scale {l_scl})\nRMS: {rms_low:.4f}, MaxCoh: {coh_low:.4f}")
            ax[1].set_xticks([]); ax[1].set_yticks([])

            # 3. Low Scale Recon
            ax[2].scatter(x_low_recon_np[:,0], x_low_recon_np[:,1], s=2, c='red', alpha=0.5)
            ax[2].set_title(f"RECON (Scale {l_scl})\nMSE: {mse_low:.6f}")
            ax[2].axis('equal')

            # 4. High Scale Latent (Dim 4)
            ax[3].scatter(y_high_viz[:,0], y_high_viz[:,1], s=2, c='purple', alpha=0.5)
            ax[3].set_title(f"LATENT (High, Dim 4)\nRMS: {rms_high:.4f}, MaxCoh: {coh_high:.4f}")
            ax[3].set_xticks([]); ax[3].set_yticks([])

            # 5. High Scale Recon (Dim 4)
            ax[4].scatter(x_high_recon_np[:,0], x_high_recon_np[:,1], s=2, c='purple', alpha=0.5)
            ax[4].set_title(f"RECON (High, Dim 4)\nMSE: {mse_high:.6f}")
            ax[4].axis('equal')

            # 6. High Scale Latent (Dim 128)
            ax[5].scatter(y_128_viz[:,0], y_128_viz[:,1], s=2, c='magenta', alpha=0.5)
            ax[5].set_title(f"LATENT (High, Dim 128)\nRMS: {rms_128:.4f}, MaxCoh: {coh_128:.4f}")
            ax[5].set_xticks([]); ax[5].set_yticks([])

            # 7. High Scale Recon (Dim 128)
            ax[6].scatter(x_128_recon_np[:,0], x_128_recon_np[:,1], s=2, c='magenta', alpha=0.5)
            ax[6].set_title(f"RECON (High, Dim 128)\nMSE: {mse_128:.6f}")
            ax[6].axis('equal')
            
            row_idx += 1

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    seed_everything(42)
    sns.set_theme(style="whitegrid")
    run_experiment()