Unlike the paper, the code calculates primes dynamically instead of using a lookup table.
This code generates Figures 4 and 5.

In [None]:
# Representation and Classification abilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random
import os

# ==========================================
# SETUP
# ==========================================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

# ==========================================
# HELPER FUNCTIONS & GENERATORS
# ==========================================
def get_first_n_primes(n):
    if n < 1: return torch.tensor([], device=device)
    if n == 1: return torch.tensor([2.0], device=device)
    if n < 6:
        limit = 13
    else:
        log_n = math.log(n)
        limit = int(n * (log_n + math.log(log_n))) + 3
    sieve = np.ones(limit // 2, dtype=bool)
    sieve[0] = False
    cross_limit = int((math.isqrt(limit) - 1) / 2)
    for i in range(1, cross_limit + 1):
        if sieve[i]:
            start_idx = 2 * i * (i + 1) + i
            step = 2 * i + 1
            sieve[start_idx::step] = False
    prime_indices = np.nonzero(sieve)[0]
    primes = 2 * prime_indices + 1
    result = np.empty(len(primes) + 1, dtype=np.float32)
    result[0] = 2.0
    result[1:] = primes
    return torch.from_numpy(result[:n]).to(device)

def generate_spiral(n_points=1000, noise=0.0):
    n = torch.sqrt(torch.rand(n_points)) * 780 * (2 * math.pi / 360)
    d1x = -torch.cos(n) * n + torch.randn(n_points) * noise
    d1y = torch.sin(n) * n + torch.randn(n_points) * noise
    return torch.stack([d1x, d1y], dim=1).to(device)

def generate_circles(n_points=1000, noise=0.0):
    linspace = torch.linspace(0, 2*math.pi, n_points // 2)
    x1 = torch.cos(linspace) * 5 + torch.randn(n_points // 2) * noise
    y1 = torch.sin(linspace) * 5 + torch.randn(n_points // 2) * noise
    x2 = torch.cos(linspace) * 10 + torch.randn(n_points // 2) * noise
    y2 = torch.sin(linspace) * 10 + torch.randn(n_points // 2) * noise
    c1 = torch.stack([x1, y1], dim=1)
    c2 = torch.stack([x2, y2], dim=1)
    return torch.cat([c1, c2], dim=0).to(device)

# ==========================================
# MODEL CLASS
# ==========================================
# class DynamicPrime(nn.Module):
#     def __init__(self, input_dim, output_dim, scaling_factor=0.01):
#         super().__init__()
#         if output_dim % 2 != 0: raise ValueError("Output dimension must be even.") 
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.half_dim = output_dim // 2
#         self.scaling_factor = scaling_factor
        
#         num_primes_needed = self.half_dim * input_dim
#         primes = get_first_n_primes(num_primes_needed)
#         self.register_buffer("weights", torch.sqrt(primes).reshape(self.half_dim, input_dim))

#     def forward(self, x):
#         self.projection = torch.mm(x, self.weights.t()) * 2 * math.pi * self.scaling_factor
#         return torch.cat([torch.cos(self.projection), torch.sin(self.projection)], dim=-1)

#     def reverse(self, y):
#         cos_part = y[:, :self.half_dim]
#         sin_part = y[:, self.half_dim:]
#         recovered_phases = torch.atan2(sin_part, cos_part) 
        
#         effective_W_T = self.weights.t() * 2 * math.pi * self.scaling_factor
#         W_inv = torch.linalg.pinv(effective_W_T)
        
#         x_hat = torch.mm(recovered_phases, W_inv) 
#         return x_hat

class DynamicPrime(nn.Module):
    def __init__(self, input_dim, output_dim, scaling_factor=0.01):
        super().__init__()
        if output_dim % 2 != 0: raise ValueError("Output dimension must be even.") 
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.half_dim = output_dim // 2
        self.scaling_factor = scaling_factor
        
        num_primes_needed = self.half_dim * input_dim
        primes = get_first_n_primes(num_primes_needed)
        
        # Store as float32 to save memory (Standard PyTorch default)
        # We will cast to double only during calculation
        self.register_buffer("weights", torch.sqrt(primes).reshape(self.half_dim, input_dim).float())

    def forward(self, x):
        # 1. Cast Input and Weights to Double
        x_double = x.double() 
        weights_double = self.weights.double()
        
        # 2. High Precision Projection
        # Performing matrix multiplication and scaling in float64 prevents 
        # phase information loss when values become large.
        projection = torch.mm(x_double, weights_double.t()) * (2 * math.pi * self.scaling_factor)
        
        # 3. Cast back to float (or input type) for output to save memory in next layers
        # Note: cos/sin are computed on double precision inputs, then cast down.
        return torch.cat([torch.cos(projection), torch.sin(projection)], dim=-1).type_as(x)

    def reverse(self, y):
        # 1. Cast input to double to preserve phase accuracy during atan2
        y_double = y.double()
        
        cos_part = y_double[:, :self.half_dim]
        sin_part = y_double[:, self.half_dim:]
        recovered_phases = torch.atan2(sin_part, cos_part) 
        
        # 2. Recalculate Transformation Matrix in Double
        # Inverting matrices is very sensitive to precision. Doing this in double 
        # drastically reduces reconstruction error.
        weights_double = self.weights.double()
        effective_W_T = weights_double.t() * (2 * math.pi * self.scaling_factor)
        
        # 3. High Precision Pseudo-Inverse
        W_inv = torch.linalg.pinv(effective_W_T)
        
        # 4. Reconstruct and cast back to float
        x_hat = torch.mm(recovered_phases, W_inv) 
        return x_hat.float()
        
# ==========================================
# EXPERIMENT (3 SCALES)
# ==========================================

def run_extended_experiment():
    print("Running Extended Similarity Experiment (Scales: 0.007, 0.02, 1.0)...")
    
    # 1. Define Scales including the Full Reconstruction case
    scales = [0.007, 0.02, 1.0] 
    dim = 4
    n_points = 1000
    
    # 2. Generate Data
    d1_spiral_clean = generate_spiral(n_points, noise=0.0)
    d2_spiral_noisy = generate_spiral(n_points, noise=1.5)
    d3_circle_clean = generate_circles(n_points, noise=0.0)
    d4_circle_noisy = generate_circles(n_points, noise=1.5)
    
    datasets = [d1_spiral_clean, d2_spiral_noisy, d3_circle_clean, d4_circle_noisy]
    labels = ["Sp(C)", "Sp(N)", "Ci(C)", "Ci(N)"]
    full_labels = ["Spiral (Clean)", "Spiral (Noisy)", "Circles (Clean)", "Circles (Noisy)"]
    
    # 3. Setup Visualization
    # Increased height to accommodate 3 rows
    fig = plt.figure(figsize=(20, 18))
    subfigs = fig.subfigures(3, 1, hspace=0.15)

    for idx, scl in enumerate(scales):
        # Initialize Model
        model = DynamicPrime(input_dim=2, output_dim=dim, scaling_factor=scl).to(device)
        
        # Get Reconstructions
        reconstructions = []
        with torch.no_grad():
            for data in datasets:
                latent = model(data)
                recon = model.reverse(latent)
                reconstructions.append(recon)
        
        # Calculate Cosine Similarity Matrix
        n_ds = len(datasets)
        sim_matrix = np.zeros((n_ds, n_ds))
        for i in range(n_ds):
            for j in range(n_ds):
                vec_i = reconstructions[i].view(1, -1)
                vec_j = reconstructions[j].view(1, -1)
                sim_matrix[i, j] = F.cosine_similarity(vec_i, vec_j).item()

        # Plotting Setup for this Row
        axs = subfigs[idx].subplots(1, 4, gridspec_kw={'width_ratios': [1, 1, 1, 1]})
        subfigs[idx].subplots_adjust(top=0.8) 
        
        # Determine Row Title based on Scale
        if scl < 0.1:
            desc = "Linearized Regime"
        else:
            desc = "Full Reconstruction Regime"
            
        subfigs[idx].suptitle(f"SCALE: {scl} ({desc}) - Dim: {dim}", fontsize=16, fontweight='bold', color='darkblue')

        # -----------------------------
        # A. Heatmap
        # -----------------------------
        sns.heatmap(sim_matrix, annot=True, fmt=".3f", cmap="Blues", 
                    xticklabels=labels, yticklabels=labels,
                    ax=axs[0], cbar=False, vmin=0, vmax=1)
        axs[0].set_title("Cosine Similarity Matrix")
        
        # -----------------------------
        # Helper for Overlays
        # -----------------------------
        def plot_overlay(ax, idx1, idx2, color1, color2, title_prefix):
            d1 = reconstructions[idx1].cpu().numpy()
            d2 = reconstructions[idx2].cpu().numpy()
            
            # Using alpha to show overlap clearly
            ax.scatter(d1[:,0], d1[:,1], s=2, c=color1, alpha=0.4, label=full_labels[idx1])
            ax.scatter(d2[:,0], d2[:,1], s=2, c=color2, alpha=0.4, label=full_labels[idx2])
            
            sim_val = sim_matrix[idx1, idx2]
            ax.set_title(f"{title_prefix}\nSim: {sim_val:.4f}")
            
            # Smart Legend placement
            ax.legend(loc='upper right', fontsize='x-small', markerscale=3)
            ax.axis('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        # -----------------------------
        # B. Same Class (Spiral Clean vs Noisy)
        # -----------------------------
        plot_overlay(axs[1], 0, 1, 'blue', 'cyan', "MATCH: Same Figure")

        # -----------------------------
        # C. Different Class (Spiral vs Circle - Clean)
        # -----------------------------
        plot_overlay(axs[2], 0, 2, 'blue', 'red', "MISMATCH: Distinct Figures")
        
        # -----------------------------
        # D. Different Class (Spiral vs Circle - Noisy)
        # -----------------------------
        plot_overlay(axs[3], 1, 3, 'cyan', 'orange', "MISMATCH: Distinct (Noisy)")

    plt.show()

if __name__ == "__main__":
    seed_everything(42)
    sns.set_theme(style="whitegrid")
    run_extended_experiment()