In [4]:
import numpy as np
from scipy.sparse.linalg import svds
import matplotlib.pyplot as plt

def tucker_sketch(tensor, sketch_size, target_rank, max_iter=50, tol=1e-5):
    # Obtain dimensions of the input tensor
    dims = tensor.shape

    # Randomly sample the columns of each mode using TensorSketch
    sketch_indices = [np.random.choice(dims[i], size=sketch_size) for i in range(len(dims))]
    sketches = [tensor.take(sketch_indices[i], axis=i) for i in range(len(dims))]

    # Compute the SVD of each mode using the sketches
    factors = []
    for i in range(len(dims)):
        U, S, V = svds(sketches[i], k=target_rank)
        factors.append(U)

    # Compute the core tensor from the full tensor and the factor matrices
    core = tensor.copy()
    for i in range(len(dims)):
        core = np.tensordot(core, factors[i], axes=(i, 1))
        core = np.moveaxis(core, -1, 0)

    # Initialize the Tucker factors
    for i in range(len(dims)):
        factors[i] = np.random.randn(dims[i], target_rank)

    # Initialize the fitness and convergence history
    fit_hist = []
    conv_hist = []

    # Tucker ALS iteration
    for it in range(max_iter):
        # Update each factor matrix
        for i in range(len(dims)):
            tensor_mode = np.rollaxis(core, i)
            tensor_mode = tensor_mode.reshape((target_rank, -1)).T
            kr_prod = np.ones((dims[i], target_rank))
            for j in range(len(dims)):
                if j != i:
                    kr_prod = np.multiply(kr_prod, factors[j][sketch_indices[j], :])
            mat = kr_prod.T @ kr_prod
            vec = kr_prod.T @ tensor_mode[:, :target_rank]
            factors[i][sketch_indices[i], :] = np.linalg.solve(mat, vec).T

        # Compute the new core tensor
        core = tensor.copy()
        for i in range(len(dims)):
            core = np.tensordot(core, factors[i], axes=(i, 1))
            core = np.moveaxis(core, -1, 0)

        # Compute the new fitness and convergence
        fit = 1 - np.linalg.norm(tensor - core) / np.linalg.norm(tensor)
        fit_hist.append(fit)
        conv = np.abs(fit - fit_hist[-2]) if it > 0 else np.inf
        conv_hist.append(conv)

        # Check for convergence
        if conv < tol:
            break
    # Compute the Tucker decomposition
    return core, factors, fit_hist


In [5]:
   # Plot the fitness vs iterations plot if requested
T=np.random.randn(2,2,2,2)
sketch=(1,1,2,2)
ranks=(1,2,2,2)
core, factors, fit_hist = tucker_sketch(T,sketch,ranks)
plt.plot(fit_hist)
plt.xlabel('Iterations')
plt.ylabel('Fitness')
plt.title('Fitness vs Iterations Plot')
plt.show()


ValueError: array must have ndim <= 2