In [None]:
import os
os.environ["WANDB_MODE"] = "offline"
import wandb
%load_ext autoreload
%autoreload 2

# Utils

In [None]:
from typing import List, Callable, Tuple
import matplotlib.pyplot as plt
from sfett.tensors import TTSFTuckerTensor
import torch.backends.opt_einsum as opt_einsum
from tqdm.notebook import tqdm
from sfett.tools import (
    generate_laplace_TT,
    tt_scalar_prod,
    henon_heiles,
)
from IPython.display import clear_output
from sfett.manifolds import TTSFTuckerManifold
from sfett.tensors import tt_sf_tucker_scalar_prod, TTSFTuckerTangentVector
import random

from functools import partial
from tqdm.notebook import tqdm

opt_einsum.strategy = "greedy"
import torch
import tntorch as tn
import numpy as np
import time

import logging
from datetime import datetime
from typing import Optional, List, Union

torch.set_default_dtype(torch.float64)
device = "cuda:0"
import scipy
import wandb

In [None]:
logger = logging.getLogger(__name__)
logger_name = str(datetime.now())
print(logger_name)
logging.basicConfig(filename=f".logs/Hilbert_exp_{logger_name}.log", level=logging.INFO)

In [None]:
def laplace_quad(
    X: TTSFTuckerTensor,
    Y: Optional[TTSFTuckerTensor],
    laplace_mat: List[torch.Tensor],
    potenital: Optional[TTSFTuckerTensor] = None,
) -> torch.Tensor:
    N = len(laplace_mat)
    d_s = X.shared_factors_amount
    d_t = N - d_s
    new_laplace_TT = []
    for i in range(N):
        a_factor = X.tucker_factors[i if i < d_t else -1]
        b_factor = a_factor if Y is None else Y.tucker_factors[i if i < d_t else -1]
        laplace_core = laplace_mat[i]

        new_laplace_TT.append(
            torch.einsum("abcd,br,cy->aryd", laplace_core, a_factor, b_factor)
        )
    res = tt_scalar_prod(
        new_laplace_TT, X.tt_cores, None if Y is None else Y.tt_cores
    ) + (
        0
        if potenital is None
        else tt_sf_tucker_scalar_prod(
            X, potenital.hadamard_product(X if Y is None else Y)
        )
    )
    return res


def steepest_desc(
    X: TTSFTuckerTensor,
    direction: TTSFTuckerTensor,
    bilinear_A: Callable[[TTSFTuckerTensor, TTSFTuckerTensor], torch.Tensor],
) -> float:
    _a, _c, _b = (
        bilinear_A(X, X),  # x.T A x
        bilinear_A(direction, direction),  # grad.T A grad
        bilinear_A(X, direction),  # x.T A grad
    )
    STAS = torch.Tensor([[_a, _b], [_b, _c]], device="cpu")
    _a = tt_sf_tucker_scalar_prod(X, X)
    _c = tt_sf_tucker_scalar_prod(direction, direction)
    _b = tt_sf_tucker_scalar_prod(X, direction)
    STBS = torch.Tensor([[_a, _b], [_b, _c]], device="cpu")
    _, eig_vectors = scipy.linalg.eigh(
        STAS,
        STBS,
    )
    alpha, beta = eig_vectors[0][0], eig_vectors[1][0]
    new_X = alpha * X + beta * direction
    delta = bilinear_A(X, X) - bilinear_A(new_X, new_X)
    assert delta >= 0 or abs(delta) <= 1e-8, f"{delta = } the nex step is not optimal!"
    return new_X


def grad_desc(
    bilinear_A: Callable[[TTSFTuckerTensor, TTSFTuckerTensor], torch.Tensor],
    X: TTSFTuckerTensor,
    lb_opt: float,
    max_iters: int = 100,
    lr: Union[float, Callable[[int], float]] = 1e-4 / 5,
    eps: float = 1e-6,
):
    X = X / X.norm()
    tt_ranks = X.tt_ranks
    sf_tucker_ranks = X.tucker_ranks
    obj_fun_grad = TTSFTuckerManifold().grad(partial(bilinear_A, Y=None))
    fun_log = [bilinear_A(X, X).cpu()]
    grad_log = []
    rel_errs = [torch.abs(fun_log[-1] - lb_opt) / lb_opt]
    iters = tqdm(range(max_iters))
    for _ in iters:
        h = obj_fun_grad(X)
        h = h.construct() - 2 * fun_log[-1] * X
        grad_log.append(h.norm().cpu())
        if isinstance(lr, float):
            X = X - (lr * h)
        else:
            try:
                X = lr(X, h, bilinear_A)
            except Exception as e:
                print(e)
                return fun_log, grad_log, rel_errs

        X.round(tt_ranks=tt_ranks, sf_tucker_ranks=sf_tucker_ranks)
        X = (1 / X.norm()) * X
        fun_log.append(bilinear_A(X, X).cpu())
        rel_errs.append(torch.abs(fun_log[-1] - lb_opt) / lb_opt)
        iters.set_description(
            f"{fun_log[-1]:.2f}" + f" {grad_log[-1]:.2f}" + f" {rel_errs[-1]:.2f}"
        )
        if (rel_errs[-1]) < eps:
            return fun_log, grad_log, rel_errs
    return fun_log, grad_log, rel_errs


def rr_scipy(
    S: List[TTSFTuckerTensor],
    bilinear_A: Callable[[TTSFTuckerTensor, TTSFTuckerTensor], torch.Tensor],
):
    vecs_am = len(S)
    device = S[0].device
    p = S[0].batch_size
    STBS = torch.zeros((vecs_am, vecs_am, p, p), device=device)
    STAS = torch.zeros((vecs_am, vecs_am, p, p), device=device)
    for i in range(vecs_am):
        for j in range(i, vecs_am):
            STBS[i, j] = S[i] @ S[j]
            STAS[i, j] = bilinear_A(S[i].clone(), S[j].clone())
            if i != j:
                STBS[j, i] = STBS[i, j].T
                STAS[j, i] = STAS[i, j].T
    STBS = STBS.transpose(1, 2).reshape(vecs_am * p, vecs_am * p)
    STAS = STAS.transpose(1, 2).reshape(vecs_am * p, vecs_am * p)
    try:
        eig_vals, eig_vectors = scipy.linalg.eigh(STAS.to("cpu"), STBS.to("cpu"))
    except Exception as e:
        print(torch.linalg.eigh(STAS).eigenvalues, torch.linalg.eigh(STBS).eigenvalues)
        raise e
    eig_vecs = torch.Tensor(eig_vectors[:, :p]).to(device)
    eig_vals = torch.Tensor(eig_vals[:p]).to(device)
    if p == 1:
        eig_vals = eig_vals.unsqueeze(-1)
    return eig_vecs, eig_vals


def lobpcg(
    bilinear_A: Callable[[TTSFTuckerTensor, TTSFTuckerTensor], torch.Tensor],
    X: TTSFTuckerTensor,
    rayleigh_ritz: Callable[[List[TTSFTuckerTensor], Callable], List[torch.Tensor]],
    lb_opt: float,
    max_iters: int = 100,
    eps: float = 1e-6,
):
    rayleigh_ritz = rr_scipy
    obj_fun_grad = TTSFTuckerManifold().grad(partial(bilinear_A, Y=None))
    X = (1 / X.norm()) * X
    _, theta = rayleigh_ritz([X], bilinear_A)
    r_tt, r_t = X.tt_ranks, X.tucker_ranks
    fun_log = [bilinear_A(X, X).cpu()]
    rel_errs = [torch.abs(fun_log[-1] - lb_opt) / lb_opt]
    R: TTSFTuckerTensor = 0.5 * obj_fun_grad(X).construct() - theta * X
    grad_log = [R.norm().cpu()]
    P: Optional[TTSFTuckerTensor] = None
    iters = tqdm(range(max_iters))
    for _ in iters:
        start = time.time()
        H = R
        S = [X, H] + ([] if P is None else [P])
        try:
            C, theta = rayleigh_ritz(S, bilinear_A)
        except Exception as e:
            print(e)
            return fun_log, grad_log, rel_errs
        direction: TTSFTuckerTensor = C[1] * H
        if P is not None:
            direction += C[2] * P
        X = C[0] * X + direction
        P = direction

        X.round(tt_ranks=r_tt, sf_tucker_ranks=r_t)
        X = (1 / X.norm()) * X

        R = 0.5 * obj_fun_grad(X).construct() - theta * X
        projector = TTSFTuckerManifold().grad(partial(tt_sf_tucker_scalar_prod, B=P))
        P = projector(X)
        P = P - tt_sf_tucker_scalar_prod(X, P.construct()) * TTSFTuckerTangentVector(X)
        P = P.construct()

        t = time.time() - start
        fun_log.append(bilinear_A(X, X).cpu())
        grad_log.append(R.norm().cpu())
        rel_errs.append(torch.abs(fun_log[-1] - lb_opt) / lb_opt)
        wandb.log(
            {
                "obj": fun_log[-1],
                "grad_norm": grad_log[-1],
                "rel_err": rel_errs[-1],
                "iter_time": t,
            }
        )
        iters.set_description(
            f"{fun_log[-1]:.2f}" + f" {grad_log[-1]:.2f}" + f" {rel_errs[-1]:.2f}"
        )
        if (rel_errs[-1]) < eps:
            return fun_log, grad_log, rel_errs
    return fun_log, grad_log, rel_errs


def compute_exp_laplace_no_pot(
    laplace_mat: List[torch.Tensor],
    max_tt_rank: int,
    max_sf_tucker_rank: int,
    opt_val: float,
    step: int = 2,
):
    device = laplace_mat[0].device
    n = laplace_mat[0].shape[1]
    vanilla_laplace_bilinear = partial(laplace_quad, laplace_mat=laplace_mat)
    N = len(laplace_mat)
    ranks = [(1, 1)]
    rank_to_his = {}
    rel_err_tol = 1e-6
    for rank in tqdm(ranks):
        max_tt_rank, max_sf_tucker_rank = rank
        tt_cores_shapes = (
            [[1, max_sf_tucker_rank, max_tt_rank]]
            + [[max_tt_rank, max_sf_tucker_rank, max_tt_rank]] * (N - 2)
            + [[max_tt_rank, max_sf_tucker_rank, 1]]
        )
        tt_cores = [torch.ones(sh, device=device) for sh in tt_cores_shapes]
        tucker_factors = [torch.ones(n, max_sf_tucker_rank, device=device)] * N

        lobpcg_time_to_ds = {}
        lobpcg_rel_err_to_ds = {}
        clear_output()
        for _d_s in range(1, N + 1, step):
            _d_t = N - _d_s
            t = TTSFTuckerTensor(
                tt_cores=[c.clone() for c in tt_cores],
                device=device,
                tucker_factors=[f.clone() for f in tucker_factors[:_d_t]]
                + [tucker_factors[-1].clone()],
                shared_factors_amount=_d_s,
            )
            t /= t.norm()
            t.orthogonalize(-1)

            res = 0
            wandb.init(
                project=...,
                name=f"laplace_no_pot_sh={_d_s}_r={rank}",
                config={
                    "shared_factors": _d_s,
                    "rank": rank,
                    "name": "laplace_no_potenital",
                },
            )
            start = time.time()
            lobpcg_fun_log, _, lobpcg_rel_errs = lobpcg(
                vanilla_laplace_bilinear,
                t.clone(),
                rr_scipy,
                lb_opt=opt_val,
                max_iters=700,
                eps=rel_err_tol,
            )
            end = time.time()
            wandb.finish()
            res += end - start
            lobpcg_time_to_ds[_d_s] = res
            lobpcg_rel_err_to_ds[_d_s] = lobpcg_rel_errs
        rank_to_his[rank] = (lobpcg_rel_err_to_ds, lobpcg_time_to_ds)
    return rank_to_his


def compute_exp_laplace_pot(
    laplace_mat: List[torch.Tensor],
    pot_mat: List[torch.Tensor],
    max_tt_rank: int,
    max_sf_tucker_rank: int,
    opt_val: float,
    step: int = 2,
):
    device = laplace_mat[0].device
    n = laplace_mat[0].shape[1]
    vanilla_laplace_bilinear = partial(laplace_quad, laplace_mat=laplace_mat)
    N = len(laplace_mat)
    ranks = [(1, 1)]
    rank_to_his = {}
    rel_err_tol = 1e-6
    for rank in tqdm(ranks):
        print(rank)
        max_tt_rank, max_sf_tucker_rank = rank
        tt_cores_shapes = (
            [[1, max_sf_tucker_rank, max_tt_rank]]
            + [[max_tt_rank, max_sf_tucker_rank, max_tt_rank]] * (N - 2)
            + [[max_tt_rank, max_sf_tucker_rank, 1]]
        )
        tt_cores = [torch.ones(sh, device=device) for sh in tt_cores_shapes]
        tucker_factors = [torch.ones(n, max_sf_tucker_rank, device=device)] * N

        lobpcg_time_to_ds = {}
        lobpcg_rel_err_to_ds = {}
        clear_output()
        for _d_s in range(1, N + 1, step):
            _d_t = N - _d_s
            t = TTSFTuckerTensor(
                tt_cores=[c.clone() for c in tt_cores],
                device=device,
                tucker_factors=[f.clone() for f in tucker_factors[:_d_t]]
                + [tucker_factors[-1].clone()],
                shared_factors_amount=_d_s,
            )
            t /= t.norm()
            t.orthogonalize(-1)
            henon_heiles_sfett = TTSFTuckerTensor(
                shared_factors_amount=_d_s,
                tt_cores=[c.clone() for c in pot_mat],
                tucker_factors=[torch.eye(n, device=device) for _ in range(_d_t + 1)],
            )
            henon_heiles_sfett.round()
            hh_laplace_bilinear = partial(
                laplace_quad,
                laplace_mat=[c.clone() for c in laplace_mat],
                potenital=henon_heiles_sfett,
            )

            res = 0
            wandb.init(
                project=...,
                name=f"laplace_hh_sh={_d_s}_r={rank}",
                config={"shared_factors": _d_s, "rank": rank, "name": "laplace_hh"},
            )
            start = time.time()
            lobpcg_fun_log, _, lobpcg_rel_errs = lobpcg(
                hh_laplace_bilinear,
                t.clone(),
                rr_scipy,
                lb_opt=opt_val,
                max_iters=700,
                eps=rel_err_tol,
            )
            end = time.time()
            wandb.finish()
            res += end - start
            lobpcg_time_to_ds[_d_s] = res
            lobpcg_rel_err_to_ds[_d_s] = lobpcg_rel_errs
        rank_to_his[rank] = (lobpcg_rel_err_to_ds, lobpcg_time_to_ds)
    return rank_to_his


def compute_time_laplace_no_pot(
    laplace_mat: List[torch.Tensor],
    max_tt_rank: int,
    max_sf_tucker_rank: int,
    opt_val: float,
    step=2,
):
    device = laplace_mat[0].device
    n = laplace_mat[0].shape[1]
    vanilla_laplace_bilinear = partial(laplace_quad, laplace_mat=laplace_mat)
    N = len(laplace_mat)
    rank_to_his = {}
    rel_err_tol = 1e-6

    tt_cores_shapes = (
        [[1, max_sf_tucker_rank, max_tt_rank]]
        + [[max_tt_rank, max_sf_tucker_rank, max_tt_rank]] * (N - 2)
        + [[max_tt_rank, max_sf_tucker_rank, 1]]
    )
    tt_cores = [torch.randn(sh, device=device) for sh in tt_cores_shapes]
    tucker_factors = [torch.randn(n, max_sf_tucker_rank, device=device)] * N

    lobpcg_time_to_ds = {}
    lobpcg_rel_err_to_ds = {}
    for _d_s in range(1, N + 1, step):
        _d_t = N - _d_s
        t = TTSFTuckerTensor(
            tt_cores=[c.clone() for c in tt_cores],
            device=device,
            tucker_factors=[f.clone() for f in tucker_factors[:_d_t]]
            + [tucker_factors[-1].clone()],
            shared_factors_amount=_d_s,
        )
        t /= t.norm()
        t.orthogonalize(-1)

        wandb.init(
            project=...,
            name=f"laplace_no_pot_sh={_d_s}_r={max(max_tt_rank, max_sf_tucker_rank)}_n={n}",
            config={
                "shared_factors": _d_s,
                "rank": [max_tt_rank, max_sf_tucker_rank],
                "name": "laplace_no_potenital",
                "n": n,
            },
        )
        lobpcg_fun_log, _, lobpcg_rel_errs = lobpcg(
            vanilla_laplace_bilinear,
            t.clone(),
            rr_scipy,
            lb_opt=opt_val,
            max_iters=70000,
            eps=rel_err_tol,
        )
        wandb.finish()
        lobpcg_time_to_ds[_d_s] = 0
        lobpcg_rel_err_to_ds[_d_s] = lobpcg_rel_errs
    rank_to_his[(max_tt_rank, max_sf_tucker_rank)] = (
        lobpcg_rel_err_to_ds,
        lobpcg_time_to_ds,
    )
    return rank_to_his


def compute_time_laplace_pot(
    laplace_mat: List[torch.Tensor],
    pot_mat: List[torch.Tensor],
    max_tt_rank: int,
    max_sf_tucker_rank: int,
    opt_val: float,
    step: int = 1,
):
    device = laplace_mat[0].device
    n = laplace_mat[0].shape[1]
    vanilla_laplace_bilinear = partial(laplace_quad, laplace_mat=laplace_mat)
    N = len(laplace_mat)
    rank_to_his = {}
    rel_err_tol = 1e-6
    tt_cores_shapes = (
        [[1, max_sf_tucker_rank, max_tt_rank]]
        + [[max_tt_rank, max_sf_tucker_rank, max_tt_rank]] * (N - 2)
        + [[max_tt_rank, max_sf_tucker_rank, 1]]
    )
    tt_cores = [torch.randn(sh, device=device) for sh in tt_cores_shapes]
    tucker_factors = [torch.randn(n, max_sf_tucker_rank, device=device)] * N

    lobpcg_time_to_ds = {}
    lobpcg_rel_err_to_ds = {}
    for _d_s in range(1, N + 1, step):
        _d_t = N - _d_s
        t = TTSFTuckerTensor(
            tt_cores=[c.clone() for c in tt_cores],
            device=device,
            tucker_factors=[f.clone() for f in tucker_factors[:_d_t]]
            + [tucker_factors[-1].clone()],
            shared_factors_amount=_d_s,
        )
        logger.info(f"{ t.tt_ranks }, {t.tucker_ranks}")
        logger.info(f"{ t.params_amount }")
        n = t.tucker_factors[0].shape[0]
        t /= t.norm()
        t.orthogonalize(-1)
        henon_heiles_sfett = TTSFTuckerTensor(
            shared_factors_amount=_d_s,
            tt_cores=[c.clone() for c in pot_mat],
            tucker_factors=[torch.eye(n, device=device) for _ in range(_d_t + 1)],
        )
        henon_heiles_sfett.round()
        hh_laplace_bilinear = partial(
            laplace_quad,
            laplace_mat=[c.clone() for c in laplace_mat],
            potenital=henon_heiles_sfett,
        )

        wandb.init(
            project=...,
            name=f"laplace_hh_sh={_d_s}_r={max(max_tt_rank, max_sf_tucker_rank)}_n={n}",
            config={
                "shared_factors": _d_s,
                "rank": [max_tt_rank, max_sf_tucker_rank],
                "name": "laplace_hh_potenital",
                "n": n,
            },
        )
        lobpcg_fun_log, _, lobpcg_rel_errs = lobpcg(
            hh_laplace_bilinear,
            t.clone(),
            rr_scipy,
            lb_opt=opt_val,
            max_iters=70000,
            eps=rel_err_tol,
        )
        wandb.finish()
        lobpcg_rel_err_to_ds[_d_s] = lobpcg_rel_errs

    rank_to_his[(max_tt_rank, max_sf_tucker_rank)] = (
        lobpcg_rel_err_to_ds,
        lobpcg_time_to_ds,
    )
    return lobpcg_time_to_ds

# Generate stuff

In [None]:
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

core_am = 8
rank = 2
tucker_rank = 2
n = 128
A = 2
domain = [torch.linspace(-A, A, n, device=device)] * core_am

In [None]:
laplace_TT = generate_laplace_TT(
    n=n, d=core_am, device=device, alpha=(n + 1) ** 2 / ((2 * A) ** 2)
)

In [None]:
henon_heiles_tn = tn.cross(
    function=henon_heiles,
    domain=domain,
    function_arg="matrix",
    device=device,
    rmax=3,
)
henon_heiles_TT = henon_heiles_tn.cores
print(henon_heiles_tn)

In [None]:
name_to_desc = {
    "Laplace": r"${\langle Lx, x \rangle}$",
    "HH": r"${\langle Lx, x \rangle + \langle x, V \odot x \rangle}$",
}
name_to_grad_desc = {
    "Laplace": r"$\|P_{T_x M}\left(\nabla_x\langle Lx, x \rangle\right)  - 2\langle Lx, x \rangle \cdot x \|$",
    "HH": r"$\|P_{T_x M}\left(\nabla_x[\langle Lx, x \rangle  + \langle x, V \odot x\rangle]\right)  - 2\langle Lx, x \rangle \cdot x \|$",
    "LOBPCG": r"$\|P_{T_x M}\left(R\right)\|$",
}
name_to_opt = {"Laplace": 78.8974, "HH": 80.10747094}

# Run stuff

In [None]:
lob_his = compute_exp_laplace_no_pot(
    laplace_TT,
    max_tt_rank=rank,
    max_sf_tucker_rank=tucker_rank,
    opt_val=name_to_opt["Laplace"],
    step=1,
)

In [None]:
lob_his2 = compute_exp_laplace_pot(
    laplace_TT,
    henon_heiles_TT,
    max_tt_rank=rank,
    max_sf_tucker_rank=tucker_rank,
    opt_val=name_to_opt["HH"],
    step=1,
)

In [None]:
logger.info("=================NEW_EXP================")
logger.info(f"{datetime.now().timestamp()}")

dicts_laplace_hh_pot = {
    100: 80.14891053,
    200: 80.15103274,
    500: 80.15082747,
    700: 80.15066734,
    800: 80.15061884,
    900: 80.15059055,
    1000: 80.15059055,
}
for n in [900, 100]:

    logger.info(n)
    r_tt, r_sft = 5, 5
    laplace_TT = generate_laplace_TT(n=n, d=core_am, device=device, alpha=(n + 1) ** 2)
    domain = [torch.linspace(0, 1, n, device=device)] * core_am
    henon_heiles_tn = tn.cross(
        function=henon_heiles,
        domain=domain,
        function_arg="matrix",
        device=device,
        eps=1e-14,
        verbose=False,
    )
    henon_heiles_TT = henon_heiles_tn.cores
    res = compute_time_laplace_pot(
        laplace_TT,
        henon_heiles_TT,
        max_tt_rank=r_tt,
        max_sf_tucker_rank=r_sft,
        opt_val=dicts_laplace_hh_pot[n],
    )
    logger.info(res)

In [None]:
logger.info("=================NEW_EXP================")
logger.info(f"{datetime.now().timestamp()}")
dicts_laplace_no_pot = {
    i: 4 * np.sin(np.pi / (2 * (i + 1))) ** 2 * core_am * (i + 1) ** 2
    for i in range(100, 1100, 100)
}
print(dicts_laplace_no_pot)
for n in [900, 500, 100]:
    logger.info(n)
    r_tt, r_sft = 5, 5
    laplace_TT = generate_laplace_TT(n=n, d=core_am, device=device, alpha=(n + 1) ** 2)
    res = compute_time_laplace_no_pot(
        laplace_TT,
        max_tt_rank=r_tt,
        max_sf_tucker_rank=r_sft,
        opt_val=dicts_laplace_no_pot[n],
        step=1,
    )
    logger.info(res)