In [None]:
import torch
import torch
from torch import Tensor
import math
import sys
sys.path.append('../')  
import numpy as np
import plotly.graph_objects as go
from dataclasses import dataclass
import einops
from plotly.subplots import make_subplots
from model import MLP_ELU_convex
import hydra
from metrics import RiemannianMetric, DiagonalRiemannianMetric, Method2Metric, Method3Metric
from utils.toy_dataset import GaussianMixture
from train_EBM_geodesic import get_metrics_dict


This is a notebook to demonstrate the plotting of metric tensors quantities over some 2D domain, with those quantities being the eigenvalue magnitudes 

In [None]:

    # %%
DEVICE: str = "cuda:1"

NB_GAUSSIANS = 200
RADIUS = 8
mean_ = (torch.linspace(0, 180, NB_GAUSSIANS + 1)[0:-1] * math.pi / 180)
MEAN = RADIUS * torch.stack([torch.cos(mean_), torch.sin(mean_)], dim=1)
COVAR = torch.tensor([[1., 0], [0, 1.]]).unsqueeze(0).repeat(len(MEAN), 1, 1)

x_p, y_p = torch.meshgrid(torch.linspace(-10, 10, 100), torch.linspace(-2.5, 10, 62), indexing='xy')
pos = torch.cat([x_p.flatten().unsqueeze(1), y_p.flatten().unsqueeze(1)], dim=1).to(DEVICE)

## Defining the mixture
weight_1 = (torch.ones(NB_GAUSSIANS) / NB_GAUSSIANS)
mixture_1 = GaussianMixture(center_data=MEAN, covar=COVAR, weight=weight_1).to(DEVICE)
mixture_1 = mixture_1.to(DEVICE)

## compute the energy landscape
energy_landscape_1: Tensor = mixture_1.energy(pos)
print(f"{energy_landscape_1.shape=}")

In [None]:
# %%
num_samples: int = int(1000)
sample_dataset = mixture_1.sample(num_samples).to(DEVICE)
reference_samples = mixture_1.sample(num_samples)
## ebm-based metric
loaded: dict = torch.load("../tutorial/EBM_mixture1.pth", weights_only=False)

ebm = loaded['type']()
ebm.load_state_dict(loaded['weight'])
ebm.to(DEVICE)

target_metric_cfg: dict = {
"_target_": "train_EBM_geodesic.Method2Metric",
      "a_num": 20.0,
      "b_num": 0.01,
      "eps": 1e-6,
      "eta": 0,
      "mu": 1.0,
      "alpha_fn_choice": "linear"
}

metric_dict: dict[str, RiemannianMetric] = get_metrics_dict(
        mixture_1,
        pos,
        ebm,
        reference_samples,
        DEVICE)

metric_dict[target_metric_cfg["_target_"]] = hydra.utils.instantiate(target_metric_cfg, ebm=ebm)

In [None]:
metric_dict.keys()


In [None]:

pos.requires_grad_(True)


metric_gouts: dict[str, Tensor] = {}

for metric_str, metric_fn in metric_dict.items():
    print(f"{metric_str}, {metric_fn}")
    output_tnsr: Tensor = metric_fn.g(x_t=pos)
    metric_gouts[metric_str] = output_tnsr    

In [None]:
# m2_metric_is_diagonal: bool = torch.all(metric_gouts["train_EBM_geodesic.Method2Metric"][:, 1, 0] == 0) and torch.all(metric_gouts["train_EBM_geodesic.Method2Metric"][:, 0, 1] == 0)
# m2_metric_is_diagonal

In [None]:
metric_gouts["train_EBM_geodesic.Method2Metric"]

In [None]:
for metric_name, tnsr in metric_gouts.items():
    print(f"Metric: {metric_name}, Tensor shape: {tnsr.shape}")


In [None]:
@dataclass
class AnimationData:
    color_tuple: tuple
    size: int
    symbol: str
    latex_label: str

normed_m2_color: Tensor = torch.tensor([147, 112, 219]) 
normed_m3_color: Tensor = torch.tensor([128, 0, 128])
m2_color: tuple = tuple(normed_m2_color.tolist())
m3_color: tuple = tuple(normed_m3_color.tolist())

In [None]:
dico_color: dict[str, AnimationData] = {
    "conf_ebm_logp": AnimationData(color_tuple=(66, 133, 244), size=150, symbol='.', latex_label=r"$\mathbf{G}_{E_{\theta}}$"),
    "conf_ebm_invp": AnimationData(color_tuple=(52, 168, 83), size=150, symbol='.', latex_label=r"$\mathbf{G}_{1/p_{\theta}}$"),
    "diag_rbf_invp": AnimationData(color_tuple=(234, 67, 53), size=150, symbol='.', latex_label=r"$\mathbf{G}_{RBF}$"),
    "diag_land_invp": AnimationData(color_tuple=(251, 188, 5), size=150, symbol='.', latex_label=r"$\mathbf{G}_{LAND}$"),
    "conf_true_logp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='+', latex_label=r"$\mathbf{G}_{E_{\mathcal{M}}}$"),
    "conf_true_invp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='x', latex_label=r"$\mathbf{G}_{1/p_{\mathcal{M}}}$"),
    "train_EBM_geodesic.Method2Metric": AnimationData(color_tuple=m2_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{M2}$"),
    "train_EBM_geodesic.Method3Metric": AnimationData(color_tuple=m3_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{M3}$"),
}

In [None]:
raw_value_titles: list[str] = [f"{dico_color[metric].latex_label} Norm" for metric in dico_color.keys()]
eigvalue_titles: list[str] = [f"Metric: {dico_color[metric].latex_label} - Eigenvalues" for metric in dico_color.keys()]
all_subplot_titles: list[str] = []
all_subplot_titles.extend(raw_value_titles) 
all_subplot_titles.extend(eigvalue_titles)

fig: go.Figure = make_subplots(rows=2, cols=len(metric_gouts), subplot_titles=all_subplot_titles)


In [None]:
# adding raw value heatmaps
x_tnsr_detach: Tensor = pos[:, 0].detach().cpu().numpy()
y_tnsr_detach: Tensor = pos[:, 1].detach().cpu().numpy()
# TODO: add aserts here 
B: int = int(x_tnsr_detach.shape[0])


# Adding frobenius norm 
for i, (metric_str, tnsr) in enumerate(metric_gouts.items()):
    print(f"{i=}")
    print(f"{metric_str=}")
    print(f"{tnsr.shape=}")

    if tnsr.shape != (B,):
        dims: int = len(tnsr.shape)

        z_tnsr: Tensor = None
        if dims == 1:
            z_tnsr = tnsr**2
        if dims == 2:
            z_tnsr = torch.linalg.norm(tnsr, dim=1)
        elif dims==3:
            z_tnsr = torch.linalg.norm(tnsr, dim=(1,2))
        else:
            raise ValueError(f"Warning! Found anomalous shape of {tnsr.shape}")
        
        assert z_tnsr.shape == (B,), f"Error! Expected tensor of shape {(B,)}, but found one of shape {z_tnsr.shape}"
    else:
        z_tnsr = tnsr

    # break

    fig.add_trace(
        go.Heatmap(
            x=x_tnsr_detach,
            y=y_tnsr_detach,
            z = z_tnsr.detach().cpu().numpy()
        ),
        row=1,
        col=i+1
    )

# Adding Aggregation of matrices
for i, (metric_str, tnsr) in enumerate(metric_gouts.items()):
    print(f"{i=}")
    print(f"{metric_str=}")
    print(f"{tnsr.shape=}")

    if tnsr.shape != (B,):
        dims: int = len(tnsr.shape)
        z_tnsr: Tensor = None
        if dims == 1:
            z_tnsr = tnsr
        if dims == 2:
            z_tnsr = tnsr.mean(dim=1)
        elif dims == 3:
            eigvals, eigvecs = torch.linalg.eig(tnsr)
            print(f"{eigvals.shape=}")
            z_tnsr: Tensor = torch.abs(eigvals.mean(dim=1))
        else:
            raise ValueError(f"Warning! Found anomalous shape of {tnsr.shape}")
        
        assert z_tnsr.shape == (B,), f"Error! Expected tensor of shape {(B,)}, but found one of shape {z_tnsr.shape}"
    else:
        z_tnsr = tnsr

    # break

    fig.add_trace(
        go.Heatmap(
            x=x_tnsr_detach,
            y=y_tnsr_detach,
            z = z_tnsr.detach().cpu().numpy()
        ),
        row=2,
        col=i+1
    )
fig.show()
fig.write_html("./all_metrics.html", include_mathjax="cdn")

# Visualizing the 1-form

In [None]:
target_metric_cfg: dict = {
"_target_": "train_EBM_geodesic.Method3Metric",
    "a_num": 20.0,
    "b_num": 0.01,
    "eps": 1e-6,
    "eta": 0.1,
    "mu": 1.0,
    "beta": 0.1,
    "alpha_fn_choice": "linear"
}

method3_metric: Method3Metric = hydra.utils.instantiate(target_metric_cfg, ebm=ebm)

pos.requires_grad_(True)
one_form_all: Tensor = method3_metric.one_form(pos)
print(f"{one_form_all.shape=}")

In [None]:
# Visualizing the one-form along with the landscape of A(x) for the Method3Metric

In [None]:
fig: go.Figure = go.Figure()

