Unlike the paper, the code calculates primes dynamically instead of using a lookup table.
This code generates Figure 2 and all details behind figure 1.

In [None]:
# Code contains static version of our method represented in class StaticPrime and the experiments showing its advantages.
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd
import numpy as np
from scipy.stats.qmc import Sobol
from scipy.interpolate import griddata
import random
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

def seed_everything(seed=42):
    """
    Seeds all sources of randomness to ensure reproducibility.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # Ensure deterministic behavior for cuDNN
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"Global seed set to: {seed}")

# ==========================================
# 1. HELPER FUNCTIONS & METHODS
# ==========================================

def get_first_n_primes(n):
    if n < 1: return torch.tensor([], device=device)
    limit = int(n * (math.log(n) + math.log(math.log(n)))) + 20 if n > 5 else 20
    sieve = torch.ones(limit, dtype=torch.bool, device=device)
    sieve[0:2] = False
    for i in range(2, int(math.isqrt(limit)) + 1):
        if sieve[i]: sieve[i*i : limit : i] = False
    primes = torch.nonzero(sieve).flatten()
    while len(primes) < n:
        limit *= 2
        sieve = torch.ones(limit, dtype=torch.bool, device=device)
        sieve[0:2] = False
        for i in range(2, int(math.isqrt(limit)) + 1):
            if sieve[i]: sieve[i*i : limit : i] = False
        primes = torch.nonzero(sieve).flatten()
    return primes[:n].float()

def get_welch_bound(N, d):
    if N <= d: return 0.0
    return math.sqrt((N - d) / (d * (N - 1)))

class Method_RandomGaussian(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, n_seq):
        # Generates (N, d) matrix and normalizes rows (dim=1) to unit vectors
        # Note: We use dim=1 because your code structure is (Sequence_Length, Dimension)
        return torch.nn.functional.normalize(torch.randn(n_seq, self.dim, device=device), p=2, dim=1)

class StaticPrime(nn.Module):
    def __init__(self, dim):
        super().__init__()
        primes = get_first_n_primes(dim // 2)                                       
        self.register_buffer("freqs", (torch.sqrt(primes)))
    def forward(self, n_seq):
        t = torch.arange(n_seq, device=self.freqs.device)
        angles = torch.outer(t, self.freqs) * 2 * math.pi
        return torch.cat([angles.cos(), angles.sin()], dim=-1)

# class StaticPrime(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         primes = get_first_n_primes(dim // 2)
#         # Store as float32 to save memory, but cast later
#         self.register_buffer("freqs", (torch.sqrt(primes))) 

#     def forward(self, n_seq):
#         t = torch.arange(n_seq, device=self.freqs.device).double() # Cast time to double
#         freqs_double = self.freqs.double() # Cast freqs to double
        
#         # High precision multiplication
#         angles = torch.outer(t, freqs_double) * 2 * math.pi
        
#         # Cast back to float for output
#         return torch.cat([angles.cos().float(), angles.sin().float()], dim=-1)
        
# Dictionary of methods
def get_models(d):
    return {
        "RandNorm Gaussian": Method_RandomGaussian(d),
        "Static Prime": StaticPrime(d),
    }

# ==========================================
# 1. EXPERIMENT 1: DISTRIBUTION (HASHING QUALITY)
# ==========================================

def run_exp2_distribution_3d():
    print("\n" + "="*60)
    print("RUNNING EXPERIMENT 1: 3D Distribution Analysis")
    print("="*60)
    
    # We fix N to a large number to see the distribution clearly
    # varying d to see how the 'spike' sharpens
    fixed_N = 10000 
    d_values = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
    
    bins = 100
    # We focus on the center -0.4 to 0.4 because that's where the "Needle" is
    hist_range = (-0.4, 0.4) 
    
    all_data = []
    tightness_metrics = []

    for d in d_values:
        models = get_models(d)
        print(f"Processing Distribution for d={d}...")
        
        for name, model in models.items():
            model.to(device)
            with torch.no_grad():
                vecs = model(fixed_N)
                # Normalize
                vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                
                # Calculate Gram Matrix (in chunks to avoid OOM if necessary)
                # For N=10000, float32, matrix is ~400MB, safe for modern GPU
                gram = torch.mm(vecs, vecs.t())
                
                # Mask diagonal (self-correlation is always 1.0)
                mask = ~torch.eye(fixed_N, device=device, dtype=torch.bool)
                vals = gram[mask]
                # Calculate Root Mean Square (RMS) of the correlations.
                # A lower RMS indicates the distribution is tighter around 0.
                rms_val = torch.sqrt(torch.mean(vals**2)).item()
                tightness_metrics.append({
                    "Method": name,
                    "Dimension": d,
                    "RMS": rms_val
                })                
                # Compute Histogram on GPU
                hist = torch.histc(vals, bins=bins, min=hist_range[0], max=hist_range[1])
                hist_cpu = hist.cpu().numpy()
                
                # Convert to density
                hist_density = hist_cpu / vals.numel()
                
                # Logarithm (avoid log(0))
                hist_log = np.log10(hist_density + 1e-9)
                
                # Generate Bin Centers
                bin_edges = np.linspace(hist_range[0], hist_range[1], bins + 1)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
                
                for b_idx, val in enumerate(hist_log):
                    all_data.append({
                        "Method": name,
                        "Dimension": d,
                        "Similarity": bin_centers[b_idx],
                        "LogDensity": val
                    })
                    
    df = pd.DataFrame(all_data)

    print("\n" + "="*60)
    print("DISTRIBUTION TIGHTNESS ANALYSIS")
    print("Calculation Method: Root Mean Square (RMS) of cross-correlations.")
    print("Lower RMS values = Tighter distribution (better orthogonality).")
    print("-" * 60)

    stats_df = pd.DataFrame(tightness_metrics)
    
    # Calculate the average RMS across all dimensions for each method
    summary = stats_df.groupby("Method")["RMS"].mean().sort_values()
    print(summary)
    
    # detailed comparison if we have exactly 2 methods active
    methods_list = summary.index.tolist()
    if len(methods_list) == 2:
        best_method = methods_list[0] # Sorted ascending, so first is lowest (best)
        worst_method = methods_list[1]
        val_best = summary[best_method]
        val_worst = summary[worst_method]
        
        # Calculate percentage reduction
        reduction = ((val_worst - val_best) / val_worst) * 100
        
        print("\nCONCLUSION:")
        print(f"The values for '{best_method}' are {reduction:.2f}% lower")
        print(f"than '{worst_method}' on average, indicating a tighter distribution.")
    
    print("="*60 + "\n")
        
    # --- PLOTTING ---
    methods = df["Method"].unique()
    
    for method in methods:
        subset = df[df["Method"] == method]
        
        # Prepare Grid for Interpolation
        x = subset["Similarity"].values
        y = subset["Dimension"].values
        z = subset["LogDensity"].values
        
        # Create dense mesh
        xi = np.linspace(x.min(), x.max(), 100)
        yi = np.linspace(y.min(), y.max(), 100) # Linear interpolation of D
        Xi, Yi = np.meshgrid(xi, yi)
        Zi = griddata((x, y), z, (Xi, Yi), method='cubic')
        
        # Plot
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        surf = ax.plot_surface(Xi, Yi, Zi, cmap='plasma', 
                               linewidth=0, antialiased=True, rcount=100, ccount=100, alpha=0.9)
        
        ax.set_title(f"EXP 1: Log-Distribution Landscape - {method}\n(N={fixed_N})", fontsize=15, weight='bold')
        ax.set_xlabel('\nCosine Similarity', fontsize=11)
        ax.set_ylabel('\nDimension (d)', fontsize=11)
        ax.set_zlabel('\nLog Density', fontsize=11)
        
        # Remove Panes
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.grid(False)
        
        ax.view_init(elev=30, azim=-60)
        plt.tight_layout()
        plt.show()

# ==========================================
# 2. EXPERIMENT 2: WELCH OPTIMALITY (A & B)
# ==========================================

def run_exp3_welch_3d():
    print("\n" + "="*60)
    print("RUNNING EXPERIMENT 2: Welch Optimality Grid Search")
    print("="*60)
    
    n_values = [100, 1000, 10000]
    d_values = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
    results = []
    
    total_steps = len(n_values) * len(d_values)
    step = 0
    
    for d in d_values:
        models = get_models(d)
        for n in n_values:
            wb = get_welch_bound(n, d)
            
            # If N <= d, Welch Bound is 0 mathematically, but practically coherence exists.
            # We handle this in plotting.
            
            for name, model in models.items():
                model.to(device)
                with torch.no_grad():
                    vecs = model(n)
                    # Calculating Max Coherence
                    vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                    gram = torch.mm(vecs, vecs.t())
                    mask = ~torch.eye(n, device=device, dtype=torch.bool)
                    vals = gram[mask]
                    
                    if vals.numel() > 0:
                        max_coh = vals.abs().max().item()
                    else:
                        max_coh = 0.0
                
                ratio = max_coh / wb if wb > 0 else np.nan # Avoid div by zero
                
                results.append({
                    "Method": name,
                    "N": n,
                    "Dimension": d,
                    "MaxCoh": max_coh,
                    "WelchBound": wb,
                    "Optimality": ratio
                })
            
            step += 1
            print(f"\rProcessing Grid: {step}/{total_steps}", end="")
            
    df = pd.DataFrame(results)
    
    methods = df["Method"].unique()
    
    # --- PLOT A: Raw Max Coherence + Welch Bound Surface ---
    print("\nGenerating Plot A (Raw Coherence vs Bound)...")
    for method in methods:
        subset = df[df["Method"] == method]
        
        # Data for Method
        x = subset["Dimension"].values
        y = subset["N"].values
        z_method = subset["MaxCoh"].values
        z_bound = subset["WelchBound"].values
        
        # Interpolation Grid
        xi = np.linspace(x.min(), x.max(), 100)
        yi = np.linspace(y.min(), y.max(), 100)
        Xi, Yi = np.meshgrid(xi, yi)
        
        # Interpolate both Method and Bound
        Zi_method = griddata((x, y), z_method, (Xi, Yi), method='cubic')
        Zi_bound = griddata((x, y), z_bound, (Xi, Yi), method='cubic')
        
        fig = plt.figure(figsize=(12, 9))
        ax = fig.add_subplot(111, projection='3d')
        
        # 1. Plot Method Surface (Solid, Colorful)
        surf = ax.plot_surface(Xi, Yi, Zi_method, cmap='viridis', 
                               linewidth=0, antialiased=True, alpha=0.9, label=method)
        
        # 2. Plot Welch Bound Surface (Wireframe, Black/Grey, Translucent)
        # We use wireframe so it doesn't obscure the method completely
        ax.plot_wireframe(Xi, Yi, Zi_bound, color='black', alpha=0.4, 
                          rstride=5, cstride=5, linewidth=0.8, label="Welch Bound")
        
        ax.set_title(f"EXP 2 - PLOT A: {method}\nMax Coherence vs Welch Bound (Lower is Better)", fontsize=15)
        ax.set_xlabel('\nDimension (d)', fontsize=11)
        ax.set_ylabel('\nSequence Length (N)', fontsize=11)
        ax.set_zlabel('\nMax Coherence', fontsize=11)
        
        # Clean look
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.grid(False)
        
        # Add legend proxy (Wireframe doesn't auto-legend well in 3D)
        fake2Dline = plt.Line2D([0],[0], linestyle="none", c='black', marker='o')
        ax.legend([fake2Dline], ['Welch Bound Surface'], numpoints = 1)

        ax.view_init(elev=25, azim=-130)
        plt.show()

    # --- PLOT B: Optimality Ratio (Actual / Bound) ---
    print("\nGenerating Plot B (Optimality Ratio)...")
    for method in methods:
        subset = df[df["Method"] == method]
        
        # Filter out NaNs (where N < d) for cleaner interpolation
        subset = subset.dropna(subset=['Optimality'])
        
        x = subset["Dimension"].values
        y = subset["N"].values
        z = subset["Optimality"].values
        
        # Interpolation
        xi = np.linspace(x.min(), x.max(), 100)
        yi = np.linspace(y.min(), y.max(), 100)
        Xi, Yi = np.meshgrid(xi, yi)
        Zi = griddata((x, y), z, (Xi, Yi), method='cubic')
        
        fig = plt.figure(figsize=(12, 9))
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot Surface
        surf = ax.plot_surface(Xi, Yi, Zi, cmap='inferno', 
                               linewidth=0, antialiased=True, rcount=100, ccount=100, alpha=0.9)
        
        # Add a flat plane at Z=1.0 (Physical Limit)
        plane = np.ones_like(Zi)
        ax.plot_surface(Xi, Yi, plane, color='cyan', alpha=0.2)
        
        ax.set_title(f"EXP 2 - PLOT B: {method}\nOptimality Ratio (1.0 = Perfect/Theoretical Limit)", fontsize=15)
        ax.set_xlabel('\nDimension (d)', fontsize=11)
        ax.set_ylabel('\nSequence Length (N)', fontsize=11)
        ax.set_zlabel('\nRatio (Actual / Bound)', fontsize=11)
        
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.grid(False)
        
        cbar = fig.colorbar(surf, ax=ax, shrink=0.5, aspect=12, pad=0.1)
        cbar.set_label('Optimality Ratio')
        
        ax.view_init(elev=25, azim=-130)
        plt.show()

# ==========================================
# 3. EXPERIMENT 3: RMS ERROR SURFACE
# ==========================================

def run_exp4_rms_error_3d():
    print("\n" + "="*60)
    print("RUNNING EXPERIMENT 3: RMS Error Analysis")
    print("="*60)
    
    # Grid definition
    # We use a range of N and d similar to previous experiments
    n_values = [100, 500, 1000, 2500, 5000]
    d_values = [16, 32, 64, 128, 256, 512, 1024]
    
    results = []
    total_steps = len(n_values) * len(d_values)
    step = 0
    
    print("Calculating RMS Errors...")
    
    for d in d_values:
        models = get_models(d)
        for n in n_values:
            
            for name, model in models.items():
                model.to(device)
                with torch.no_grad():
                    vecs = model(n)
                    vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                    
                    # Compute Gram Matrix
                    gram = torch.mm(vecs, vecs.t())
                    
                    # Mask the diagonal (we only care about cross-correlations)
                    mask = ~torch.eye(n, device=device, dtype=torch.bool)
                    off_diag = gram[mask]
                    
                    # Calculate RMS (Root Mean Square) of off-diagonal elements
                    # RMS = sqrt(mean(x^2))
                    rms_error = torch.sqrt(torch.mean(off_diag**2)).item()
                
                results.append({
                    "Method": name,
                    "N": n,
                    "Dimension": d,
                    "RMS": rms_error
                })
            
            step += 1
            print(f"\rProgress: {step}/{total_steps}", end="")
            
    df = pd.DataFrame(results)
    print("\nPlotting results...")

    # --- PLOTTING ---
    methods = df["Method"].unique()
    
    # Determine Global Z-Limits (Keep this from previous step)
    global_z_min = df["RMS"].min()
    global_z_max = df["RMS"].max()
    z_margin = (global_z_max - global_z_min) * 0.1
    z_lims = (max(0, global_z_min - z_margin), global_z_max + z_margin)

    print("\n" + "="*30)
    print("SUMMARY STATISTICS")
    print("="*30)

    for method in methods:
        subset = df[df["Method"] == method]
        
        overall_avg_rms = subset["RMS"].mean()
        print(f"Method: {method:15s} | Overall Avg RMS: {overall_avg_rms:.6f}")
        
        x = subset["Dimension"].values
        y = subset["N"].values
        z = subset["RMS"].values
        
        # Interpolation Grid
        xi = np.linspace(x.min(), x.max(), 100)
        yi = np.linspace(y.min(), y.max(), 100)
        Xi, Yi = np.meshgrid(xi, yi)
        Zi = griddata((x, y), z, (Xi, Yi), method='cubic')
        
        fig = plt.figure(figsize=(12, 9))
        ax = fig.add_subplot(111, projection='3d')
        
        surf = ax.plot_surface(Xi, Yi, Zi, cmap='coolwarm', 
                               linewidth=0, antialiased=True, rcount=100, ccount=100, alpha=0.9)
        
        # ==========================================
        # [UPDATE] ADD SCORE TO TITLE
        # ==========================================
        ax.set_title(f"EXP 3: {method}\nGlobal Avg RMS: {overall_avg_rms:.5f} (Lower is Better)", 
                     fontsize=15, weight='bold')
        
        ax.set_xlabel('\nDimension (d)', fontsize=11)
        ax.set_ylabel('\nSequence Length (N)', fontsize=11)
        ax.set_zlabel('\nRMS Error', fontsize=11)
        
        ax.set_zlim(z_lims)
        
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.grid(False)
        
        m = cm.ScalarMappable(cmap=cm.coolwarm)
        m.set_array(z)
        m.set_clim(global_z_min, global_z_max)
        cbar = plt.colorbar(m, ax=ax, shrink=0.5, aspect=12, pad=0.1)
        cbar.set_label('RMS Error (Shared Scale)')
        
        ax.view_init(elev=30, azim=-130)
        plt.show()

# ==========================================
# THIS PART PROVIDES ADDITIONAL VISUALIZATIONS FOR EXP 2: TRO CHARTS CORRESPOND TO EXPERIMENTS A AND B COMPARING BASELINE AND OUR METHOD.
# ==========================================

import matplotlib.patches as mpatches
import matplotlib.lines as mlines

def generate_aggregate_data():
    print("\n" + "="*60)
    print("GENERATING POPULATION DATA (Grid Sweep)")
    print("="*60)
    
    # A diverse mix of N and d to create a robust statistical distribution
    n_values = [100, 250, 500, 1000, 2000, 5000, 10_000]
    d_values = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
    
    data = []
    
    methods_map = {
        "RandNorm Gaussian": Method_RandomGaussian,
        "Static Prime": StaticPrime
    }
    
    total = len(n_values) * len(d_values) * len(methods_map)
    count = 0
    
    for d in d_values:
        # Init models
        models = {name: cls(d).to(device) for name, cls in methods_map.items()}
        
        for n in n_values:
            # Skip invalid combos
            if n <= d: continue
            
            wb = get_welch_bound(n, d)
            
            for name, model in models.items():
                with torch.no_grad():
                    vecs = model(n)
                    # Calculate Max Coherence efficiently
                    vecs = torch.nn.functional.normalize(vecs, p=2, dim=1)
                    gram = torch.mm(vecs, vecs.t())
                    mask = ~torch.eye(n, device=device, dtype=torch.bool)
                    vals = gram[mask]
                    max_coh = vals.abs().max().item() if vals.numel() > 0 else 0.0
                
                # Metric 1: Optimality Ratio (Ideal = 1.0)
                optimality = max_coh / wb if wb > 0 else 1.0
                
                # Metric 2: Residual / Distance (Ideal = 0.0)
                residual = max_coh - wb
                
                data.append({
                    "Method": name,
                    "N": n,
                    "Dimension": d,
                    "Optimality": optimality,
                    "Residual": residual
                })
                
                count += 1
                if count % 10 == 0:
                    print(f"\rSimulating: {count}/{total}", end="")
                    
    print("\nData generation complete.")
    return pd.DataFrame(data)

def create_custom_legend(palette, target_label):
    """
    Manually creates legend handles to ensure they always appear.
    """
    handles = []
    for name, color in palette.items():
        patch = mpatches.Patch(color=color, label=name, alpha=0.6)
        handles.append(patch)
    line = mlines.Line2D([], [], color='black', linestyle='--', 
                         linewidth=2, label=target_label)
    handles.append(line)
    return handles

def plot_final_optimality_distribution(df):
    """
    Plot 1: Optimality Ratios.
    """
    df_plot = df.copy()
    
    # --- FIX IS HERE ---
    # The key must match "RandNorm Gaussian" exactly as defined in generate_aggregate_data
    name_map = {
        "RandNorm Gaussian": "RandNorm Gaussian (Baseline)", 
        # "Gaussian JL": "Gaussian JL",
        # "Sobol": "Sobol (Quasi-Random)",
        # "Golden Uniform": "Golden Uniform",
        "Static Prime": "Static Prime (Ours)"
    }
    # -------------------
    
    df_plot["Method"] = df_plot["Method"].map(name_map)

    plt.figure(figsize=(12, 7))
    
    palette = {
        "RandNorm Gaussian (Baseline)": "#eb4034",
        # "Gaussian JL": "#95a5a6",
        # "Sobol (Quasi-Random)": "#3498db",
        # "Golden Uniform": "#e67e22",
        "Static Prime (Ours)": "#02f5f5", #"#2ecc71"
    }
    
    sns.kdeplot(
        data=df_plot, 
        x="Optimality", 
        hue="Method", 
        palette=palette, 
        fill=True, 
        alpha=0.3, 
        linewidth=2.5, 
        common_norm=False,
        legend=False 
    )
    
    plt.axvline(x=1.0, color='black', linestyle='--', linewidth=2)
    
    handles = create_custom_legend(palette, "Welch Bound (Ideal)")
    plt.legend(handles=handles, title="Method / Reference", loc='upper right', frameon=True)
    
    plt.title("Distribution of Optimality Ratios", fontsize=16, weight='bold')
    plt.xlabel("Optimality Ratio (Actual / Welch Bound)\nCloser to 1.0 is Better", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    
    sns.despine()
    plt.grid(True, axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_final_residual_distribution(df):
    """
    Plot 2: Residuals.
    """
    df_plot = df.copy()
    
    # --- FIX IS HERE ---
    name_map = {
        "RandNorm Gaussian": "RandNorm Gaussian (Baseline)",
        "Static Prime": "Static Prime (Ours)"
    }
    # -------------------
    
    df_plot["Method"] = df_plot["Method"].map(name_map)

    plt.figure(figsize=(12, 7))
    
    palette = {
        "RandNorm Gaussian (Baseline)": "#eb4034",
        "Static Prime (Ours)": "#02f5f5", # "#2ecc71"
    }
    
    sns.kdeplot(
        data=df_plot, 
        x="Residual", 
        hue="Method", 
        palette=palette, 
        fill=True, 
        alpha=0.3, 
        linewidth=2.5, 
        common_norm=False,
        legend=False
    )
    
    plt.axvline(x=0.0, color='black', linestyle='--', linewidth=2)
    
    handles = create_custom_legend(palette, "Zero Excess (Perfect)")
    plt.legend(handles=handles, title="Method / Reference", loc='upper right', frameon=True)
    
    plt.title("Distribution of Excess Coherence", fontsize=16, weight='bold')
    plt.xlabel("Excess Coherence (Max Coherence - Welch Bound)\nCloser to 0.0 is Better", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    
    sns.despine()
    plt.grid(True, axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()


  


# ==========================================
#  MAIN EXECUTION
# ==========================================

if __name__ == "__main__":
    seed_everything(42) 
    # Run 3D Distribution Analysis
    run_exp2_distribution_3d()
    
    # Run 3D Welch Analysis
    run_exp3_welch_3d()

    # Run 3D RMS Error Analysis
    run_exp4_rms_error_3d()

    # Use seaborn theme ########################
    sns.set_theme(style="white", context="talk")
    
    # Generate Data
    df_agg = generate_aggregate_data()
    
    # 1. Optimality Distribution
    plot_final_optimality_distribution(df_agg)
    
    # 2. Residual Distribution
    plot_final_residual_distribution(df_agg)  