In [None]:
import os
os.environ["WANDB_MODE"] = "offline"
import wandb
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Optional, List, Tuple, Callable
import torch.backends.opt_einsum as opt_einsum
import matplotlib.pyplot as plt
from sfett.tensors import TTSFTuckerTensor
from IPython.display import clear_output
from tqdm.notebook import tqdm
from sfett.manifolds import TTSFTuckerManifold
from sfett.tensors import tt_sf_tucker_scalar_prod

from functools import partial
from tqdm.notebook import tqdm

import torch
import json
import tntorch as tn
import numpy as np
import logging
import pandas as pd
from logging import Logger
from datetime import datetime

torch.set_default_dtype(torch.float64)
opt_einsum.strategy = "greedy"
device = "cuda:0"

In [None]:
logger = logging.getLogger(__name__)
logger_name = str(datetime.now())
print(logger_name)
logging.basicConfig(filename=f".logs/Hilbert_exp_{logger_name}.log", level=logging.INFO)

In [None]:
def tune_frob_norm(
    target: TTSFTuckerTensor,
    X: TTSFTuckerTensor,
    max_iters: int = 300,
) -> TTSFTuckerTensor:
    t_norm = target.norm()

    def rel_err(target, _X):
        return (target - _X).norm() / t_norm

    def objective(_X):
        R = target - _X
        return tt_sf_tucker_scalar_prod(R, R)

    tt_ranks = X.tt_ranks
    tucker_ranks = X.tucker_ranks
    iters = tqdm(range(max_iters))
    fun_log = []
    grad_log = []
    fun_log.append(rel_err(target, X).cpu())
    new_X = X
    for _ in iters:
        X = new_X
        R = target - X
        g = -2 * R
        projector = TTSFTuckerManifold().grad(partial(tt_sf_tucker_scalar_prod, B=g))
        g = projector(X).construct()
        g_norm = g.norm()
        lr = -1 * tt_sf_tucker_scalar_prod(g / g_norm, R / g_norm)
        grad_log.append(g_norm.cpu())
        new_X = X - lr * g
        new_X.round(tt_ranks=tt_ranks, sf_tucker_ranks=tucker_ranks)
        new_err = rel_err(target, new_X).cpu()
        iters.set_description(f"{lr}" + f" {fun_log[-1]}" + f" {grad_log[-1]:.2f}")
        if fun_log[-1] < new_err:
            break
        else:
            fun_log.append(new_err)
    clear_output()
    plt.plot(fun_log, label="obj")
    plt.yscale("log")
    plt.legend()
    plt.show()
    plt.plot(grad_log, label="grad_norm")
    plt.yscale("log")
    plt.legend()
    plt.show()
    return X


def compute_exp(
    ground_truth_tt: List[torch.Tensor],
    max_ranks: List[Tuple],
    logger: Logger,
    tuner: Optional[Callable] = None,
    step: int = 2,
    name="Hilbert",
) -> Tuple[dict, dict]:
    def rel_err(target, X):
        return ((target - X).norm() / target.norm()).cpu()

    logger.info("=========EXP_STARTED===========")
    errs_to_ds = {}
    params_am_to_ds = {}
    N = len(ground_truth_tt)
    shape = np.asarray([core.shape[1] for core in ground_truth_tt])
    for _d_s in tqdm(range(1, N + 1, step)):
        _d_t = N - _d_s
        tucker_factors = [torch.eye(n=i, device=device) for i in shape[:_d_t]]
        if _d_s:
            tucker_factors.append(torch.eye(shape[_d_t], device=device))
        ground_truth_sfett = TTSFTuckerTensor(
            shared_factors_amount=_d_s,
            tt_cores=ground_truth_tt,
            tucker_factors=tucker_factors,
        )
        errs_to_ds[_d_s] = []
        params_am_to_ds[_d_s] = []
        wandb.init(
            project=...,
            name=name + f"_tune={tuner is None}_ds={_d_s}",
            config={"shared_factors": _d_s, "tuner": (tuner is not None), "name": name},
        )
        log = {}
        for max_r_t, max_r_tt in max_ranks:
            ground_truth_sfett_rounded = ground_truth_sfett.clone()
            ground_truth_sfett_rounded.round(sf_tucker_ranks=max_r_t, tt_ranks=max_r_tt)

            logger.info(f"====={max_r_t, max_r_tt,_d_s}=====")
            logger.info(f"tt:{ground_truth_sfett_rounded.tt_ranks}")
            logger.info(f"t:{ground_truth_sfett_rounded.tucker_ranks}")
            if tuner is not None:
                logger.info(
                    f"before_tune:{rel_err(ground_truth_sfett, ground_truth_sfett_rounded)}"
                )

                ground_truth_sfett_rounded = tuner(
                    target=ground_truth_sfett, X=ground_truth_sfett_rounded
                )
                logger.info(
                    f"after_tune:{rel_err(ground_truth_sfett, ground_truth_sfett_rounded)}"
                )

            errs_to_ds[_d_s].append(
                rel_err(ground_truth_sfett, ground_truth_sfett_rounded)
            )
            log["params_am"] = ground_truth_sfett_rounded.params_amount
            log["rel_err"] = errs_to_ds[_d_s][-1].item()
            wandb.log(log)

            params_am_to_ds[_d_s].append(ground_truth_sfett_rounded.params_amount)

    wandb.finish()
    return errs_to_ds, params_am_to_ds


def Hilbert_fun(x, y, z, t, w, a, b, c, x1, x2, x3, x4):  # Input arguments are vectors
    return 1 / (
        1
        + 2 * x
        + 3 * y
        + 4 * z
        + 5 * t
        + 6 * w
        + 7 * a
        + 8 * b
        + 9 * c
        + 10 * x1
        + 11 * x2
        + 12 * x3
        + 13 * x4
    )  # Hilbert tensor


def create_exponent(n=32, alpha=-1):
    def my_exp(*args):
        res = args[0] / n
        for i in range(1, len(args)):
            res += args[i] / (n ** (i + 1))
        return torch.exp(alpha * res**2)

    return my_exp

In [None]:
l = 10
n = 32
alpha = -0.1
domain = [torch.range(0, n - 1, device=device)] * l
my_exp_tn = tn.cross(
    function=create_exponent(alpha=alpha, n=n),
    domain=domain,
    device=device,
    rmax=4,
)
my_exp_tn

In [None]:
l = 12
domain = [torch.linspace(0, 1, 512, device=device)] * l
hilbert_like_tensor = tn.cross(
    function=Hilbert_fun, domain=domain, device=device, eps=1e-14
)
hilbert_like_tensor, hilbert_like_tensor.norm()

In [None]:
max_rank = 16
min_rank = 1
max_ranks = [(i, j) for i in range(min_rank, max_rank, 1) for j in range(i, i + 2)]
print(max_ranks, len(max_ranks))

In [None]:
errs_to_ds, params_am_to_ds = compute_exp(
    ground_truth_tt=hilbert_like_tensor.clone().cores,
    max_ranks=max_ranks,
    logger=logger,
    tuner=None,
    step=2,
)

In [None]:
api = wandb.Api()
runs = ...
# Extract data for a specific run

# Convert to Pandas DataFrame for easy manipulation
fig, ax = plt.subplots()
for run in runs:
    config = json.loads(run.json_config)
    if "name" not in config or config["name"]["value"] != "Hilbert":
        continue
    _d_s = config["shared_factors"]["value"]
    hist = pd.DataFrame(run.history(pandas=True))
    ax.plot(
        hist["params_am"],
        hist["rel_err"],
        label=f"{('no' if _d_s == 1 or _d_s == 0 else _d_s)} shared factors",
    )
    ax.set_yscale("log")

ax.grid()
ax.set_ylabel("relative l2 error")
ax.set_xlabel("params amount")
ax.legend()
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
fig, ax = plt.subplots()

for _d_s, errs in errs_to_ds.items():
    ax.plot(
        params_am_to_ds[_d_s],
        errs,
        label=f"{('no' if _d_s == 1 or _d_s == 0 else _d_s)} shared factors",
    )
    ax.set_yscale("log")

ax.grid()
ax.set_ylabel("relative l2 error")
ax.set_xlabel("params amount")
ax.legend()
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
exp_errs_to_ds, exp_params_am_to_ds = compute_exp(
    ground_truth_tt=my_exp_tn.clone().cores,
    max_ranks=max_ranks,
    logger=logger,
    tuner=None,
    name="exp",
    step=2,
)

In [None]:
api = wandb.Api()
runs = ...
# Extract data for a specific run

# Convert to Pandas DataFrame for easy manipulation
fig, ax = plt.subplots()
for run in runs:
    config = json.loads(run.json_config)
    if "name" not in config or config["name"]["value"] != "exp":
        continue
    _d_s = config["shared_factors"]["value"]
    hist = pd.DataFrame(run.history(pandas=True))
    ax.plot(
        hist["params_am"],
        hist["rel_err"],
        label=f"{('no' if _d_s == 1 or _d_s == 0 else _d_s)} shared factors",
    )
    ax.set_yscale("log")

ax.grid()
ax.set_ylabel("relative l2 error")
ax.set_xlabel("params amount")
ax.legend()
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
fig, ax = plt.subplots()

for _d_s, errs in exp_errs_to_ds.items():
    ax.plot(
        exp_params_am_to_ds[_d_s],
        errs,
        label=f"{(1 if _d_s == 0 else _d_s)}-shared factors",
    )
    ax.set_yscale("log")

ax.grid()
ax.set_ylabel("relative l2 error")
ax.set_xlabel("params amount")
# ax.set_title(
#     r"Low rank approximation of $\exp(-\alpha x^2)$ tensor,"
#     + r"$x \in [0, 1]$"
#     + f" {alpha = }"
# )
ax.legend()
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()