In [None]:
import itertools
import numpy as np
import pandas as pd
import tensorly as tl
from tensorly.cp_tensor import cp_to_tensor
import matplotlib.pyplot as plt

# ------------------------------------------------------------------
# 1) Parameters, Data Generation, and Alpha Schedule
# ------------------------------------------------------------------

# Fix random seed for reproducibility
np.random.seed(42)

# Tensor shape and rank
I, J, K_tensor = 15, 20, 25
rank = 10

# Create a random tensor
tensor = tl.tensor(np.random.uniform(0, 1, (I, J, K_tensor)))

# Alpha schedule endpoints
alpha_0 = 1e3
alpha_final = 1e15

# An error bound for early stopping
error_bound = 22.923548

# Define a small function to generate alpha values (linearly spaced)
def alpha_sequence(alpha_start, alpha_end, K):
    """
    Returns an array of length (K+1) linearly mapping from alpha_start to alpha_end.
    """
    ks = np.arange(K+1)
    alphas = alpha_start + (ks / K) * (alpha_end - alpha_start)
    return alphas

# ------------------------------------------------------------------
# 2) Key Functions for the Particle Swarm
# ------------------------------------------------------------------

def objective_function(particle, tensor, rank):
    """
    Computes the reconstruction error ||X - X_hat|| where X_hat is built
    from the factor matrices (A, B, C) in 'particle'. We assume unit weights.
    """
    A = particle['A']
    B = particle['B']
    C = particle['C']
    reconstructed_tensor = cp_to_tensor((np.ones(rank), [A, B, C]))
    return tl.norm(tensor - reconstructed_tensor)

def compute_consensus_point(particles, alpha, tensor, rank):
    """
    Boltzmann-weighted consensus:
    - Find min-energy among particles
    - For each particle p, weight ~ exp(-alpha * (E_p - E_min))
    - Weighted average of A, B, C across all particles
    """
    # Find the particle with minimum energy
    min_particle = min(particles, key=lambda p: objective_function(p, tensor, rank))
    min_energy = objective_function(min_particle, tensor, rank)

    numerator_A = np.zeros_like(min_particle['A'])
    numerator_B = np.zeros_like(min_particle['B'])
    numerator_C = np.zeros_like(min_particle['C'])
    denominator = 0.0

    for p in particles:
        energy_p = objective_function(p, tensor, rank)
        weight = np.exp(-alpha * (energy_p - min_energy))
        numerator_A += weight * p['A']
        numerator_B += weight * p['B']
        numerator_C += weight * p['C']
        denominator += weight

    consensus_A = numerator_A / denominator
    consensus_B = numerator_B / denominator
    consensus_C = numerator_C / denominator
    return {'A': consensus_A, 'B': consensus_B, 'C': consensus_C}

def project_matrices_to_ball(particles, consensus_point, eta):
    """
    Shrinks each particle's factor matrices closer to consensus by factor eta.
    """
    for p in particles:
        for key in ['A', 'B', 'C']:
            vec_particle = p[key].flatten()
            vec_consensus = consensus_point[key].flatten()

            distance = np.linalg.norm(vec_particle - vec_consensus)
            radius = eta * distance

            if distance > radius:
                direction = (vec_particle - vec_consensus) / distance
                vec_particle = vec_consensus + radius * direction

            p[key] = vec_particle.reshape(p[key].shape)
    return particles

def anisotropic_update(particles, consensus_point, lambda_, sigma, dt, tensor, rank):
    """
    Drift-diffusion update:
    - If consensus has lower energy, drift each particle toward it
    - Add noise ~ N(0,1), scaled by (A - A_consensus) * sqrt(dt)
    """
    consensus_loss = objective_function(consensus_point, tensor, rank)
    for p in particles:
        A, B, C = p['A'], p['B'], p['C']
        A_cons, B_cons, C_cons = consensus_point['A'], consensus_point['B'], consensus_point['C']

        current_loss = objective_function(p, tensor, rank)

        # Drift only if consensus is better
        if consensus_loss < current_loss:
            drift_A = -lambda_ * (A - A_cons) * dt
            drift_B = -lambda_ * (B - B_cons) * dt
            drift_C = -lambda_ * (C - C_cons) * dt
        else:
            drift_A = np.zeros_like(A)
            drift_B = np.zeros_like(B)
            drift_C = np.zeros_like(C)

        # Normal random noise
        B_A = np.random.normal(loc=0, scale=1, size=A.shape)
        B_B = np.random.normal(loc=0, scale=1, size=B.shape)
        B_C = np.random.normal(loc=0, scale=1, size=C.shape)

        diffusion_A = sigma * (A - A_cons) * B_A * np.sqrt(dt)
        diffusion_B = sigma * (B - B_cons) * B_B * np.sqrt(dt)
        diffusion_C = sigma * (C - C_cons) * B_C * np.sqrt(dt)

        p['A'] += drift_A + diffusion_A
        p['B'] += drift_B + diffusion_B
        p['C'] += drift_C + diffusion_C

    return particles

def resample_particles_around_consensus(particles, consensus_point, noise_scale=0.01):
    """
    After drift-diffusion, replace each particle's A,B,C with
    consensus + clipped noise in [-0.06, 0.06].
    """
    for p in particles:
        for key in ['A', 'B', 'C']:
            shape = consensus_point[key].shape
            # Generate noise with mean=0, scale=noise_scale, then clip
            noise = np.random.normal(loc=0.0, scale=noise_scale, size=shape)
            noise = np.clip(noise, -0.06, 0.06)
            p[key] = consensus_point[key] + noise
    return particles

# ------------------------------------------------------------------
# 3) Main Script With "While" Early Stopping
# ------------------------------------------------------------------

# We'll define a small parameter grid. If you want more combos, expand these lists.
dt_list = [0.01]
nu_list = [25000]
lambda_sigma_list = [(0.5, 5)]
K_list = [100000]   # We'll use K as the *maximum* iteration count
eta_list = [0.9]

param_grid = list(itertools.product(dt_list, nu_list, lambda_sigma_list, K_list, eta_list))

results = []

for dt, nu, (lambda_, sigma), K, eta in param_grid:
    print(f"\nRunning with dt={dt}, nu={nu}, lambda={lambda_}, sigma={sigma}, "
          f"K(max)={K}, eta={eta}, error_bound={error_bound}")

    # 1) Initialize the swarm
    particles = []
    for _ in range(nu):
        A = np.random.uniform(0, 1, (I, rank))
        B = np.random.uniform(0, 1, (J, rank))
        C = np.random.uniform(0, 1, (K_tensor, rank))
        particles.append({'A': A, 'B': B, 'C': C})

    # 2) Create alpha schedule
    alpha_values = alpha_sequence(alpha_0, alpha_final, K)  # length = K+1

    # 3) While loop for early stopping
    iteration = 0
    max_iterations = K + 1  # We'll allow up to K inclusive (0..K), i.e. K+1 steps
    final_abs_error = None
    final_rel_error = None

    # -- Lists to store errors at each iteration for plotting --
    abs_errors = []
    rel_errors = []

    while iteration < max_iterations:
        # a) current alpha
        alpha_current = alpha_values[iteration]

        # b) consensus point
        consensus_point = compute_consensus_point(particles, alpha_current, tensor, rank)
        consensus_tensor = cp_to_tensor((np.ones(rank),
                                         [consensus_point['A'],
                                          consensus_point['B'],
                                          consensus_point['C']]))

        # c) compute errors
        consensus_abs_error = tl.norm(tensor - consensus_tensor)
        nonzero_mask = (tensor != 0)
        consensus_rel_error = (
            tl.norm(tensor[nonzero_mask] - consensus_tensor[nonzero_mask]) /
            tl.norm(tensor[nonzero_mask])
        )

        # -- Store these errors for plotting --
        abs_errors.append(consensus_abs_error)
        rel_errors.append(consensus_rel_error)

        # d) check if we've reached our error bound
        if consensus_abs_error < error_bound:
            print(f"Early stopping at iteration={iteration}, "
                  f"abs_error={consensus_abs_error:.4e} < {error_bound}")
            final_abs_error = consensus_abs_error
            final_rel_error = consensus_rel_error
            break

        # e) If not done and not the last iteration, do drift-diffusion + resample
        if iteration < max_iterations - 1:
            # 1) drift
            particles = anisotropic_update(
                particles,
                consensus_point,
                lambda_=lambda_,
                sigma=sigma,
                dt=dt,
                tensor=tensor,
                rank=rank
            )
            # 2) resample
            particles = resample_particles_around_consensus(
                particles,
                consensus_point,
                noise_scale=0.01
            )
            # 3) project
            particles = project_matrices_to_ball(particles, consensus_point, eta)

            print(f"iteration: {iteration}, abs error: {consensus_abs_error}, rel error: {consensus_rel_error}")
        iteration += 1

    # If we never broke out of the loop, compute final error at last iteration
    if final_abs_error is None:
        consensus_point = compute_consensus_point(
            particles,
            alpha_values[min(iteration, K)],  # clamp index in case iteration==K+1
            tensor,
            rank
        )
        consensus_tensor = cp_to_tensor((np.ones(rank),
                                         [consensus_point['A'],
                                          consensus_point['B'],
                                          consensus_point['C']]))

        final_abs_error = tl.norm(tensor - consensus_tensor)
        nonzero_mask = (tensor != 0)
        final_rel_error = (
            tl.norm(tensor[nonzero_mask] - consensus_tensor[nonzero_mask]) /
            tl.norm(tensor[nonzero_mask])
        )

    # Record final results
    results.append({
        'dt': dt,
        'nu': nu,
        'lambda': lambda_,
        'sigma': sigma,
        'eta': eta,
        'K_max': K,
        'iterations_used': iteration,
        'consensus_abs_error': final_abs_error,
        'consensus_rel_error': final_rel_error
    })

    # ------------------------------------------------------------------
    # 4) Plot the error vs iteration for this parameter setting
    # ------------------------------------------------------------------
    plt.figure(figsize=(8, 5))
    plt.plot(abs_errors, label='Absolute Error')
    plt.plot(rel_errors, label='Relative Error')
    plt.title(f'Error over Iterations (dt={dt}, nu={nu}, λ={lambda_}, σ={sigma}, η={eta})')
    plt.xlabel('Iteration')
    plt.ylabel('Error')
    plt.legend()
    plt.grid(True)
    plt.show()

# Convert results to a DataFrame and display
df_results = pd.DataFrame(results)
print("\nFinal Results:")
print(df_results)



Running with dt=0.01, nu=25000, lambda=0.5, sigma=5, K(max)=100000, eta=0.9, error_bound=22.923548
iteration: 0, abs error: 53.38088116953912, rel error: 1.0741605627539985
iteration: 1, abs error: 52.691603204639534, rel error: 1.0602905180779103
iteration: 2, abs error: 51.841275447751, rel error: 1.0431797375540035
iteration: 3, abs error: 51.1364452517995, rel error: 1.0289967420068766
iteration: 4, abs error: 50.54143391243752, rel error: 1.0170235841808724
iteration: 5, abs error: 49.92680929166027, rel error: 1.0046557567102143
iteration: 6, abs error: 49.26801738246593, rel error: 0.9913991698496404
iteration: 7, abs error: 48.65215399029614, rel error: 0.9790064151138834
iteration: 8, abs error: 47.99910133657148, rel error: 0.9658653168280581
iteration: 9, abs error: 47.391964580429104, rel error: 0.9536481644439443
iteration: 10, abs error: 46.72923509230901, rel error: 0.9403123433725048
iteration: 11, abs error: 46.14545381915276, rel error: 0.928565163349246
iteration: 1