In [None]:
import torch
import torch
from torch import Tensor
import math
import sys
sys.path.append('../')  
import numpy as np
import plotly.graph_objects as go
from dataclasses import dataclass
import einops
from plotly.subplots import make_subplots
from model import MLP_ELU_convex
import hydra
import plotly.figure_factory as ff
from metrics import RiemannianMetric, DiagonalRiemannianMetric, Method2Metric, Method3Metric
from utils.toy_dataset import GaussianMixture
from train_EBM_geodesic import get_metrics_dict


This is a notebook to demonstrate the plotting of metric tensors quantities over some 2D domain, with those quantities being the eigenvalue magnitudes 

In [None]:

    # %%
DEVICE: str = "cuda:1"

NB_GAUSSIANS = 200
RADIUS = 8
mean_ = (torch.linspace(0, 180, NB_GAUSSIANS + 1)[0:-1] * math.pi / 180)
MEAN = RADIUS * torch.stack([torch.cos(mean_), torch.sin(mean_)], dim=1)
COVAR = torch.tensor([[1., 0], [0, 1.]]).unsqueeze(0).repeat(len(MEAN), 1, 1)

x_p, y_p = torch.meshgrid(torch.linspace(-10, 10, 100), torch.linspace(-2.5, 10, 62), indexing='xy')
pos = torch.cat([x_p.flatten().unsqueeze(1), y_p.flatten().unsqueeze(1)], dim=1).to(DEVICE)

## Defining the mixture
weight_1 = (torch.ones(NB_GAUSSIANS) / NB_GAUSSIANS)
mixture_1 = GaussianMixture(center_data=MEAN, covar=COVAR, weight=weight_1).to(DEVICE)
mixture_1 = mixture_1.to(DEVICE)

## compute the energy landscape
energy_landscape_1: Tensor = mixture_1.energy(pos)
print(f"{energy_landscape_1.shape=}")

In [None]:
# %%
num_samples: int = int(1000)
sample_dataset = mixture_1.sample(num_samples).to(DEVICE)
reference_samples = mixture_1.sample(num_samples)
## ebm-based metric
loaded: dict = torch.load("../tutorial/EBM_mixture1.pth", weights_only=False)

ebm = loaded['type']()
ebm.load_state_dict(loaded['weight'])
ebm.to(DEVICE)

target_metric_cfg: dict = {
"_target_": "train_EBM_geodesic.Method2Metric",
      "a_num": 20.0,
      "b_num": 0.01,
      "eps": 1e-6,
      "eta": 0,
      "mu": 1.0,
      "alpha_fn_choice": "linear"
}

metric_dict: dict[str, RiemannianMetric] = get_metrics_dict(
        mixture_1,
        pos,
        ebm,
        reference_samples,
        DEVICE)

metric_dict[target_metric_cfg["_target_"]] = hydra.utils.instantiate(target_metric_cfg, ebm=ebm)

In [None]:
metric_dict.keys()


In [None]:

pos.requires_grad_(True)


metric_gouts: dict[str, Tensor] = {}

for metric_str, metric_fn in metric_dict.items():
    print(f"{metric_str}, {metric_fn}")
    output_tnsr: Tensor = metric_fn.g(x_t=pos)
    metric_gouts[metric_str] = output_tnsr    

In [None]:
# m2_metric_is_diagonal: bool = torch.all(metric_gouts["train_EBM_geodesic.Method2Metric"][:, 1, 0] == 0) and torch.all(metric_gouts["train_EBM_geodesic.Method2Metric"][:, 0, 1] == 0)
# m2_metric_is_diagonal

In [None]:
metric_gouts["train_EBM_geodesic.Method2Metric"]

In [None]:
for metric_name, tnsr in metric_gouts.items():
    print(f"Metric: {metric_name}, Tensor shape: {tnsr.shape}")


In [None]:
@dataclass
class AnimationData:
    color_tuple: tuple
    size: int
    symbol: str
    latex_label: str

normed_m2_color: Tensor = torch.tensor([147, 112, 219]) 
normed_m3_color: Tensor = torch.tensor([128, 0, 128])
m2_color: tuple = tuple(normed_m2_color.tolist())
m3_color: tuple = tuple(normed_m3_color.tolist())

In [None]:
dico_color: dict[str, AnimationData] = {
    "conf_ebm_logp": AnimationData(color_tuple=(66, 133, 244), size=150, symbol='.', latex_label=r"$\mathbf{G}_{E_{\theta}}$"),
    "conf_ebm_invp": AnimationData(color_tuple=(52, 168, 83), size=150, symbol='.', latex_label=r"$\mathbf{G}_{1/p_{\theta}}$"),
    "diag_rbf_invp": AnimationData(color_tuple=(234, 67, 53), size=150, symbol='.', latex_label=r"$\mathbf{G}_{RBF}$"),
    "diag_land_invp": AnimationData(color_tuple=(251, 188, 5), size=150, symbol='.', latex_label=r"$\mathbf{G}_{LAND}$"),
    "conf_true_logp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='+', latex_label=r"$\mathbf{G}_{E_{\mathcal{M}}}$"),
    "conf_true_invp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='x', latex_label=r"$\mathbf{G}_{1/p_{\mathcal{M}}}$"),
    "train_EBM_geodesic.Method2Metric": AnimationData(color_tuple=m2_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{M2}$"),
    "train_EBM_geodesic.Method3Metric": AnimationData(color_tuple=m3_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{M3}$"),
}

In [None]:
raw_value_titles: list[str] = [f"{dico_color[metric].latex_label} Norm" for metric in dico_color.keys()]
eigvalue_titles: list[str] = [f"Metric: {dico_color[metric].latex_label} - Eigenvalues" for metric in dico_color.keys()]
all_subplot_titles: list[str] = []
all_subplot_titles.extend(raw_value_titles) 
all_subplot_titles.extend(eigvalue_titles)

fig: go.Figure = make_subplots(rows=2, cols=len(metric_gouts), subplot_titles=all_subplot_titles)


In [None]:
# adding raw value heatmaps
x_tnsr_detach: Tensor = pos[:, 0].detach().cpu().numpy()
y_tnsr_detach: Tensor = pos[:, 1].detach().cpu().numpy()
# TODO: add aserts here 
B: int = int(x_tnsr_detach.shape[0])


# Adding frobenius norm 
for i, (metric_str, tnsr) in enumerate(metric_gouts.items()):
    print(f"{i=}")
    print(f"{metric_str=}")
    print(f"{tnsr.shape=}")

    if tnsr.shape != (B,):
        dims: int = len(tnsr.shape)

        z_tnsr: Tensor = None
        if dims == 1:
            z_tnsr = tnsr**2
        if dims == 2:
            z_tnsr = torch.linalg.norm(tnsr, dim=1)
        elif dims==3:
            z_tnsr = torch.linalg.norm(tnsr, dim=(1,2))
        else:
            raise ValueError(f"Warning! Found anomalous shape of {tnsr.shape}")
        
        assert z_tnsr.shape == (B,), f"Error! Expected tensor of shape {(B,)}, but found one of shape {z_tnsr.shape}"
    else:
        z_tnsr = tnsr

    # break

    fig.add_trace(
        go.Heatmap(
            x=x_tnsr_detach,
            y=y_tnsr_detach,
            z = z_tnsr.detach().cpu().numpy()
        ),
        row=1,
        col=i+1
    )

# Adding Aggregation of matrices
for i, (metric_str, tnsr) in enumerate(metric_gouts.items()):
    print(f"{i=}")
    print(f"{metric_str=}")
    print(f"{tnsr.shape=}")

    if tnsr.shape != (B,):
        dims: int = len(tnsr.shape)
        z_tnsr: Tensor = None
        if dims == 1:
            z_tnsr = tnsr
        if dims == 2:
            z_tnsr = tnsr.mean(dim=1)
        elif dims == 3:
            eigvals, eigvecs = torch.linalg.eig(tnsr)
            print(f"{eigvals.shape=}")
            z_tnsr: Tensor = torch.abs(eigvals.mean(dim=1))
        else:
            raise ValueError(f"Warning! Found anomalous shape of {tnsr.shape}")
        
        assert z_tnsr.shape == (B,), f"Error! Expected tensor of shape {(B,)}, but found one of shape {z_tnsr.shape}"
    else:
        z_tnsr = tnsr

    # break

    fig.add_trace(
        go.Heatmap(
            x=x_tnsr_detach,
            y=y_tnsr_detach,
            z = z_tnsr.detach().cpu().numpy()
        ),
        row=2,
        col=i+1
    )
fig.show()
fig.write_html("./all_metrics.html", include_mathjax="cdn")

In [None]:
target_metric_cfg: dict = {
"_target_": "train_EBM_geodesic.Method3Metric",
      "eps": 1e-6,
      "eta": 5,
      "mu": 1.0,
      "beta": 0.004,
      "alpha_fn_choice": "linear"
}


toy_metric: Method2Metric = hydra.utils.instantiate(target_metric_cfg, ebm=ebm)

In [None]:
met_score, met_energy =  toy_metric.get_score_n_nrg(pos)

In [None]:
# print(f"{score_on_pos.shape=}")
grad_outer_prod: Tensor = torch.einsum('bi,bj->bij', met_score, met_score)


eigvals, eigvecs = torch.linalg.eig(grad_outer_prod)
print(f"{eigvals.shape=}")
print(f"{met_score.shape=}")
print(f"{torch.norm(met_score, dim=1, p=2)[:10]**2=}")
print(f"{eigvals[:10]=}")
assert torch.all(eigvals.imag.abs() < 1e-6), "Complex eigenvalues found in grad_outer_prod"
assert torch.all(eigvals.real >= -1.0e-5), f"Non-positive eigenvalues found in grad_outer_prod with min {eigvals.real.min().item()}. {eigvals=}"

# print(f"{grad_outer_prod.shape=}")
alpha_val: Tensor = toy_metric.alpha_fn(met_energy).to(pos.device)

min_en: float = met_energy.real.min().item()
print(f"{min_en=}")
max_en: float = met_energy.real.max().item()
print(f"{max_en=}")

mult_factor: float = 10.0
alpha_a: float = mult_factor * (1 / (max_en - min_en + toy_metric.eps))
alpha_b: float = min_en

alpha_lb: float = 5.0

toy_metric.alpha_a = alpha_a
toy_metric.alpha_b = alpha_b
toy_metric.alpha_lb = alpha_lb

alpha_val: Tensor = toy_metric.alpha_fn(met_energy).to(pos.device)
# assert torch.max(alpha_val).item() < mult_factor, f"Unexpectedly high alpha values found with max {alpha_val.max().item()}. {alpha_val=}"

assert torch.min(alpha_val).item() >= 0, f"Non-positive alpha values found with min {alpha_val.min().item()}. {alpha_val=}"

print(f"{alpha_val.max().item()=}")
print(f"{alpha_val.min().item()=}")

alpha_val = einops.rearrange(alpha_val, 'b 1 -> b 1 1')
assert torch.all(alpha_val.real >= 0), f"Non-positive alpha values found with min {alpha_val.real.min().item()}. {alpha_val=}"

# print(f"{alpha_val.shape=}")

# give it a 'batch' dimension to add with grad_outer_prod
I_mat: Tensor = einops.rearrange(torch.eye(2).to(pos.device), 'i j -> 1 i j')
# print(f"{I_mat.shape=}")
# print(f"{grad_outer_prod.shape=}")
# print(f"{alpha_val.shape=}")

print(f"{toy_metric.mu=}")
print(f"{toy_metric.eta=}")
A_pre: Tensor = (toy_metric.mu * I_mat + toy_metric.eta * grad_outer_prod)

A_mat: Tensor =  alpha_val * A_pre 

A_mat: Tensor = A_mat 
# print(f"{A_mat.shape=}")
print(f"{A_mat.shape=}")
eigvals, eigvecs = torch.linalg.eig(A_mat)
print(f"{eigvecs.shape=}")

assert torch.all(eigvecs.imag.abs() < 1e-6), "Complex eigenvalues found in A_mat"
assert torch.all(eigvals.real >= 0), f"Non-positive eigenvalues found in A_mat with min {eigvals.real.min().item()}. {eigvals=}"

print(f"Min eigenvalue of A_mat: {eigvals.real.min().item()}")
print(f"Max eigenvalue of A_mat: {eigvals.real.max().item()}")

A_inv: Tensor = torch.linalg.inv(A_mat)

eigvals_inv, eigvecs_inv = torch.linalg.eig(A_inv)
assert torch.all(eigvecs_inv.imag.abs() < 1e-6), "Complex eigenvalues found in A_inv"
assert torch.all(eigvals_inv.real >= 0), f"Non-positive eigenvalues found in A_inv with min {eigvals_inv.real.min().item()}. {eigvals_inv=}"


# Method 3 Stuff


grad_outer_prod: Tensor = torch.einsum('bi,bj->bij', met_score, met_score)
# taking the inner product of the score with itself under A(x)
inner_prods_pre: Tensor = torch.einsum('bij, bj -> bi', A_mat, met_score)
inner_prods: Tensor = torch.einsum('bi,bi->b', inner_prods_pre, met_score)

# Computing the inverse of A(x) cheaply using the Sherman-Morrison formula
# inv_met: Tensor = (self.eta / self.mu) * I_mat - \
#     (
#         (grad_outer_prod) / \
#         ((self.mu / self.eta) + inner_prods)
#     )

# yes, I know there is an easier way of doing this but I just want this to work 
inv_met: Tensor = torch.linalg.inv(A_mat)

I_mat: Tensor = einops.rearrange(torch.eye(2), 'i j -> 1 i j')
I_mat = I_mat.to(pos.device)

eigvecs, eigvals_inv_met = torch.linalg.eig(inv_met)
assert torch.all(inner_prods > 0), "Negative values found in inner_prods, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."


print(f"{torch.isnan(inv_met).any()=}")
assert not torch.isnan(inv_met).any(), "NaN values found in inv_met, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."

inv_inner_prods_pre: Tensor = torch.einsum('bij,bj->bi', inv_met, met_score)
inv_inner_prod: Tensor = torch.einsum('bi,bi->b', inv_inner_prods_pre, met_score)

assert not torch.isnan(inv_inner_prod).any(), "NaN values found in inv_inner_prod, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."
assert  not torch.isnan(inner_prods).any(), "NaN values found in inner_prods, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."

assert torch.all(inv_inner_prod > 0), "Negative values found in inv_inner_prod, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."

print(f"{torch.isnan(inv_inner_prod).any()=}")
print(f"{torch.linalg.norm(inv_inner_prod)=}")

beta_comp: Tensor = torch.sqrt( inv_inner_prod / inner_prods)

print(f"beta should be less than {beta_comp.min().item()} to fulfill the condition for a Randers metric")
if toy_metric.beta >= beta_comp.min().item():
    print("Warning! beta is greater than the minimum value required for a Randers metric, which could lead to instability. Consider reducing beta or checking the energy landscape for very low values.")

print(f"{inv_inner_prod.max().item()=}")
print(f"{inner_prods.max().item()=}")
print(f"{beta_comp.max().item()=}")

one_ov_sq: Tensor = 1  / torch.sqrt(inv_inner_prod + toy_metric.eps)
assert not torch.any(torch.isnan(one_ov_sq)), "NaN values found in one_ov_sq, which could lead to instability. Consider increasing eps or checking the energy landscape for very low values."
# print(f"{one_ov_sq.shape=}")


# if beta is less than the above value, we should not run into the assert below...

einops.parse_shape(inv_met, 'b i j')
einops.parse_shape(met_score, 'b i')
einops.parse_shape(one_ov_sq, 'b')

one_form: Tensor = toy_metric.beta * one_ov_sq[:, None] * met_score
einops.parse_shape(one_form, 'b i')

# calculating the norm of the one-form under A(x)

pre_out: Tensor = torch.einsum('bij,bj->bi', A_mat , one_form)
norm_one_form: Tensor = torch.einsum('bi,bi->b', pre_out, one_form)
assert torch.all(norm_one_form < 1.0), f"The norm of the one-form under A(x) should be less than 1 for stability reasons. Got an average of {norm_one_form.mean().item()} and a max of {norm_one_form.max().item()}"

print(f"{norm_one_form.max().item()=}")


# Visualizing the 1-form

In [None]:
target_metric_cfg: dict = {
"_target_": "train_EBM_geodesic.Method3Metric",
    "a_num": 20.0,
    "b_num": 0.01,
    "eps": 1e-6,
    "eta": 0.1,
    "mu": 1.0,
    "beta": 0.1,
    "alpha_fn_choice": "linear"
}

method3_metric: Method3Metric = hydra.utils.instantiate(target_metric_cfg, ebm=ebm)

N_samples: int = 50
x_p_subs, y_p_subs = torch.meshgrid(torch.linspace(-10, 10, N_samples), torch.linspace(-2.5, 10, N_samples), indexing='xy')
pos_subs = torch.cat([x_p_subs.flatten().unsqueeze(1), y_p_subs.flatten().unsqueeze(1)], dim=1).to(DEVICE)

pos_subs.requires_grad_(True)
one_form_all: Tensor = method3_metric.one_form(pos_subs)
print(f"{one_form_all.shape=}")

In [None]:
# Visualizing the one-form along with the landscape of A(x) for the Method3Metric

In [None]:
fig: go.Figure = go.Figure()

one_form_dtch: Tensor = one_form_all.detach().cpu().numpy()

x_tnsr_subs: Tensor = pos_subs[:, 0].detach().cpu().numpy()
y_tnsr_subs: Tensor = pos_subs[:, 1].detach().cpu().numpy()

num_to_subsample: int = 1000
one_form_dtch: Tensor = one_form_dtch

scale_factor: float = 5.0
fig_quiver = ff.create_quiver(
    x=x_tnsr_subs,
    y=y_tnsr_subs,
    u=one_form_dtch[:, 0],
    v=one_form_dtch[:, 1],
    scale=scale_factor,
)

# set equal aspect ratio
fig_quiver.update_layout(
    title=f"One-form Visualization for Method3Metric (Scale: {scale_factor})",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1),
)
fig_quiver.show()

# fig.add_trace(


# )