In [30]:
import numpy as np
import tensorly as tl
from tensorly.cp_tensor import cp_to_tensor

# Suppose these parameters:
alpha_0 = 1e3  # starting alpha
alpha_K = 1e15  # final alpha
K = 2000      # number of iterations

def alpha_sequence(alpha_0, alpha_K, K):
    ks = np.arange(K+1)  # k = 0, 1, ..., K
    alphas = alpha_0 + (ks / K) * (alpha_K - alpha_0)
    return alphas

# Generate the alpha values:
alpha_values = alpha_sequence(alpha_0, alpha_K, K)

# Define the objective function, consensus, updates, etc.
def objective_function(particle, tensor, rank):
    A = particle['A']
    B = particle['B']
    C = particle['C']
    reconstructed_tensor = cp_to_tensor((np.ones(rank), [A, B, C]))
    error = tl.norm(tensor - reconstructed_tensor)
    return error

def compute_consensus_point(particles, alpha, tensor, rank):
    # Find the particle with min error
    min_energy_particle = min(particles, key=lambda p: objective_function(p, tensor, rank))
    min_energy = objective_function(min_energy_particle, tensor, rank)

    numerator_A = np.zeros_like(particles[0]['A'])
    numerator_B = np.zeros_like(particles[0]['B'])
    numerator_C = np.zeros_like(particles[0]['C'])
    denominator = 0.0

    for particle in particles:
        energy = objective_function(particle, tensor, rank)
        weight = np.exp(-alpha * (energy - min_energy))
        numerator_A += weight * particle['A']
        numerator_B += weight * particle['B']
        numerator_C += weight * particle['C']
        denominator += weight

    consensus_A = numerator_A / denominator
    consensus_B = numerator_B / denominator
    consensus_C = numerator_C / denominator
    return {'A': consensus_A, 'B': consensus_B, 'C': consensus_C}

def anisotropic_update(particles, consensus_point, lambda_, sigma, dt, tensor, rank):
    consensus_point_loss = objective_function(consensus_point, tensor, rank)
    dt_scalar = dt[0] if isinstance(dt, list) else dt  # ensure dt is a float

    for particle in particles:
        A, B, C = particle['A'], particle['B'], particle['C']
        A_consensus, B_consensus, C_consensus = consensus_point['A'], consensus_point['B'], consensus_point['C']

        if consensus_point_loss < objective_function(particle, tensor, rank):
            drift_A = (-lambda_) * (A - A_consensus) * dt_scalar
            drift_B = (-lambda_) * (B - B_consensus) * dt_scalar
            drift_C = (-lambda_) * (C - C_consensus) * dt_scalar
        else:
            drift_A = np.zeros_like(A)
            drift_B = np.zeros_like(B)
            drift_C = np.zeros_like(C)

        # Use normal distribution
        B_A = np.random.normal(loc=0, scale=1, size=A.shape)
        B_B = np.random.normal(loc=0, scale=1, size=B.shape)
        B_C = np.random.normal(loc=0, scale=1, size=C.shape)

        # Anisotropic diffusion
        diffusion_A = sigma * (A - A_consensus) * B_A * np.sqrt(dt_scalar)
        diffusion_B = sigma * (B - B_consensus) * B_B * np.sqrt(dt_scalar)
        diffusion_C = sigma * (C - C_consensus) * B_C * np.sqrt(dt_scalar)

        particle['A'] += drift_A + diffusion_A
        particle['B'] += drift_B + diffusion_B
        particle['C'] += drift_C + diffusion_C

    return particles

# Example setup
np.random.seed(42)
I, J, K_tensor = 20, 25, 30
rank = 10
tensor = tl.tensor(np.random.random((I, J, K_tensor)))

# sigma = S * sqrt(2*lambda) where S=5? then: 11 = 5 * 4 => lambda = 4 ?
lambda_ = 180
sigma = 10
dt = 0.01
nu = 1000
n_iterations = K  # Using K=100 from above for consistency
particles = []

for _ in range(nu):
    A = np.random.randn(I, rank)
    B = np.random.randn(J, rank)
    C = np.random.randn(K_tensor, rank)
    particles.append({'A': A, 'B': B, 'C': C})

# Main iteration loop
for iteration in range(n_iterations+1):
    alpha_current = alpha_values[iteration]
    consensus_point = compute_consensus_point(particles, alpha=alpha_current, tensor=tensor, rank=rank)
    consensus_tensor = cp_to_tensor((np.ones(rank), [consensus_point['A'], consensus_point['B'], consensus_point['C']]))
    consensus_error = tl.norm(tensor - consensus_tensor)
    print(f"Iteration {iteration}/{n_iterations}, alpha={alpha_current}, Consensus Error={consensus_error}")
    if iteration < n_iterations:
        particles = anisotropic_update(particles, consensus_point, lambda_=lambda_, sigma=sigma, dt=dt, tensor=tensor, rank=rank)


Iteration 0/2000, alpha=1000.0, Consensus Error=278.98220617942013
Iteration 1/2000, alpha=500000000999.5, Consensus Error=278.98220617942013
Iteration 2/2000, alpha=1000000000999.0, Consensus Error=278.98220617942013
Iteration 3/2000, alpha=1500000000998.5, Consensus Error=278.98220617942013
Iteration 4/2000, alpha=2000000000998.0, Consensus Error=278.98220617942013
Iteration 5/2000, alpha=2500000000997.5, Consensus Error=278.98220617942013
Iteration 6/2000, alpha=3000000000997.0, Consensus Error=278.98220617942013
Iteration 7/2000, alpha=3500000000996.5, Consensus Error=278.98220617942013
Iteration 8/2000, alpha=4000000000996.0, Consensus Error=278.98220617942013
Iteration 9/2000, alpha=4500000000995.5, Consensus Error=278.98220617942013
Iteration 10/2000, alpha=5000000000995.0, Consensus Error=278.98220617942013
Iteration 11/2000, alpha=5500000000994.5, Consensus Error=278.98220617942013
Iteration 12/2000, alpha=6000000000994.0, Consensus Error=278.98220617942013
Iteration 13/2000, 

KeyboardInterrupt: 