In [None]:
import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.figure_factory as ff
from torch import Tensor
import numpy as np
import torch
import torch
import torch.nn as nn
import sys
sys.path.append("../")
from model import MLP_ELU_convex
import matplotlib.pyplot as plt
from utils.toy_dataset import GaussianMixture
import math
import numpy as np
from sklearn.cluster import KMeans
import einops

In [None]:
figure: go.Figure = sp.make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Plot 1", "Plot 2", "Plot 3", "Plot 4"),
)
figure.update_layout(height=800, width=800, title_text="Multiple Subplots Example")

In [None]:
import numpy as np

xx, yy = np.mgrid[-3:3:0.5, -3:3:0.5]
x: np.ndarray = xx.flatten()
y: np.ndarray = yy.flatten()
u: np.ndarray = np.exp(-(x**2 + y**2))

u_qvr: np.ndarray = np.cos(x) 
v_qvr: np.ndarray = np.sin(y)

subfig: go.Figure = ff.create_quiver(
    x,
    y,
    u_qvr,  
    v_qvr,
    scale=0.5,
)

# subfig.show()


figure.add_trace(go.Heatmap(z=u, x=x, y=y, ), row=1, col=1)
figure.update_xaxes(title_text="$X Axis", row=1, col=1)
figure.update_yaxes(title_text="Y Axis", row=1, col=1)

for trace in subfig.data:
    figure.add_trace(trace, row=1, col=1)

figure.show()


In [None]:
import plotly.express as px

fig = px.line(x=[1, 2, 3, 4], y=[1, 4, 9, 16], title=r'$\alpha_{1c} = 352 \pm 11 \text{ km s}^{-1}$')
fig.update_layout(
    xaxis_title=r'$\sqrt{(n_\text{c}(t|{T_\text{early}}))}$',
    yaxis_title=r'$d, r \text{ (solar radius)}$'
)
fig.write_html("latex_labels.html", include_mathjax='cdn')

In [None]:
N_steps: int = 100

subsmpl_steps: int = 10

xx, yy = torch.meshgrid(
    torch.linspace(-3, 3, steps=N_steps),
    torch.linspace(-3, 3, steps=N_steps)
)

xx_subsmpl, yy_subsmpl = torch.meshgrid(
    torch.linspace(-3, 3, steps=subsmpl_steps),
    torch.linspace(-3, 3, steps=subsmpl_steps)
)

x: Tensor = xx.flatten()
y: Tensor = yy.flatten()

x_subsmpl: Tensor = xx_subsmpl.flatten()
y_subsmpl: Tensor = yy_subsmpl.flatten()


inputs: Tensor = torch.stack([x, y], dim=-1)
inputs.requires_grad_(True)
shape: tuple = einops.parse_shape(inputs, "b dim")

energy_output: Tensor = einops.einsum(torch.exp(-(inputs ** 2)), "b dim -> b")
assert energy_output.shape == (shape["b"], )

In [None]:
grad_x = torch.autograd.grad(
    outputs=energy_output, 
    inputs=inputs,
    grad_outputs=torch.ones_like(energy_output),  # same shape as energy_output
    create_graph=True,
    retain_graph=True
    )[0]
grad_x.shape

In [None]:

figure: go.Figure = sp.make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Plot 1", "Plot 2", "Plot 3", "Plot 4"),
)
figure.update_layout(height=800, width=800, title_text="Multiple Subplots Example")


u_qvr: np.ndarray = grad_x.detach().numpy()[:, 0]
v_qvr: np.ndarray = grad_x.detach().numpy()[:, 1]


subsample_amt: int = 10
subsampled_x: np.ndarray = x.detach().numpy()[::subsample_amt]
subsampled_y: np.ndarray = y.detach().numpy()[::subsample_amt]
u_qvr: np.ndarray = u_qvr[::subsample_amt]
v_qvr: np.ndarray = v_qvr[::subsample_amt]
subfig: go.Figure = ff.create_quiver(
    subsampled_x,
    subsampled_y,
    u_qvr,
    v_qvr,
    scale=0.5,
)

# subfig.show()


figure.add_trace(go.Heatmap(z=energy_output.detach().numpy(), x=x, y=y, ), row=1, col=1)
figure.update_xaxes(title_text="$X Axis", row=1, col=1)
figure.update_yaxes(title_text="Y Axis", row=1, col=1)

for trace in subfig.data:
    figure.add_trace(trace, row=1, col=1)

figure.show()

In [None]:
NB_GAUSSIANS = 200
RADIUS = 8
DEVICE = "cuda:1"
mean_ = (torch.linspace(0, 180, NB_GAUSSIANS + 1)[0:-1] * math.pi / 180)
MEAN = RADIUS * torch.stack([torch.cos(mean_), torch.sin(mean_)], dim=1)
COVAR = torch.tensor([[1., 0], [0, 1.]]).unsqueeze(0).repeat(len(MEAN), 1, 1)

x_ranges: tuple = (-10, 10)
y_ranges: tuple = (-2.5, 10)
xx, yy = torch.meshgrid(torch.linspace(x_ranges[0], x_ranges[1], 100), torch.linspace(y_ranges[0], y_ranges[1], 62), indexing='xy')
pos = torch.cat([xx.flatten().unsqueeze(1), yy.flatten().unsqueeze(1)], dim=1).to(DEVICE)

## Defining the mixture
weight_1 = (torch.ones(NB_GAUSSIANS) / NB_GAUSSIANS)
mixture_1 = GaussianMixture(center_data=MEAN, covar=COVAR, weight=weight_1).to(DEVICE)

## compute the energy landscape
energy_landscape_1 = mixture_1.energy(pos)

T_STEPS=50
dt = 1.0/(T_STEPS-1)

In [None]:

fig, ax = plt.subplots(1, 1, figsize=(8,6), dpi=100)
im = ax.contourf(xx, yy, energy_landscape_1.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
color = ['orange', 'black', 'grey', 'red', 'green','purple','pink','yellow']

    #ax.scatter(z_t[1:-1,0], z_t[1:-1,1], s=10, color=dico_color[metric][0], alpha=1, label=str(metric))
    #ax.scatter(z_t[0,0], z_t[0,1], s=10, color='red', alpha=1)
    #ax.scatter(z_t[-1,0], z_t[-1,1], s=10, color='green', alpha=1)
# Add red and green dots
ax.scatter(MEAN[10][0].cpu().detach(), MEAN[10][1].cpu().detach(), s=70, color='red',alpha=1)
ax.scatter(MEAN[-10][0].cpu().detach(), MEAN[-10][1].cpu().detach(), s=70, color='green',alpha=1)
# ax.set_axis_off()
plt.legend()
plt.show()

In [None]:
sample_dataset = mixture_1.sample(1000).to(DEVICE)
reference_samples = mixture_1.sample(1000)
## ebm-based metric
loaded = torch.load("./tutorial/EBM_mixture1.pth", weights_only=False)
ebm = loaded['type']()
ebm.load_state_dict(loaded['weight'])
ebm.to(DEVICE)

In [None]:
sample_steps: int = 10
xx_subs, yy_subs = torch.meshgrid(torch.linspace(x_ranges[0], x_ranges[1], sample_steps), torch.linspace(y_ranges[0], y_ranges[1], sample_steps), indexing='xy')
pos_subs, shape = einops.pack([xx_subs.flatten(), yy_subs.flatten()], 'b *')
print(f"{pos_subs.shape}")

pos_subs = pos_subs.to(DEVICE)
pos_subs.requires_grad_(True)
print(f"{pos_subs.shape=}")

In [None]:
energy_out = ebm.forward(pos_subs)
energy_out.requires_grad_(True)
score: Tensor = torch.autograd.grad(
    outputs=energy_out,
    inputs=pos_subs,
    grad_outputs=torch.ones_like(energy_out),
    create_graph=True,
    retain_graph=True
    )[0]
print(f"{score.shape=}")


In [None]:
pos = pos.to(DEVICE)
pos.requires_grad_(True)
energy_on_pos: Tensor = ebm.forward(pos)
energy_on_pos.requires_grad_(True)
print(f"{pos.shape=}")
print(f"{energy_on_pos.shape=}")


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8,6), dpi=100)
im = ax[0].contourf(xx, yy, energy_landscape_1.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
ax[0].set_title("Energy Landscape from True Mixture")

im2 = ax[1].contourf(xx, yy, energy_on_pos.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
ax[1].set_title("Energy Landscape from Learned EBM")

pos_subs_dtch: np.ndarray = pos_subs.detach().cpu().numpy()
score_dtch: np.ndarray = score.detach().cpu().numpy()

ax[1].quiver(
    pos_subs_dtch[:,0],
    pos_subs_dtch[:,1],
    -score_dtch[:,0],
    -score_dtch[:,1],
    color='red',

)
# scale both figures equall


color = ['orange', 'black', 'grey', 'red', 'green','purple','pink','yellow']

    #ax.scatter(z_t[1:-1,0], z_t[1:-1,1], s=10, color=dico_color[metric][0], alpha=1, label=str(metric))
    #ax.scatter(z_t[0,0], z_t[0,1], s=10, color='red', alpha=1)
    #ax.scatter(z_t[-1,0], z_t[-1,1], s=10, color='green', alpha=1)
# Add red and green dots
# ax.scatter(MEAN[10][0].cpu().detach(), MEAN[10][1].cpu().detach(), s=70, color='red',alpha=1)
# ax.scatter(MEAN[-10][0].cpu().detach(), MEAN[-10][1].cpu().detach(), s=70, color='green',alpha=1)
# ax.set_axis_off()
plt.legend()
plt.show()

In [None]:
a: float = 1.0
b: float = 1.0
eps: float = 1e-6
alpha_fn_1: Tensor = a + b * energy_on_pos
alpha_fn_2: Tensor = 1/(a + b * energy_on_pos + eps)
print(f"{alpha_fn_1.shape=}")



score_on_pos: Tensor = torch.autograd.grad(
    outputs=energy_on_pos,
    inputs=pos,
    grad_outputs=torch.ones_like(energy_on_pos),
    create_graph=True,
    retain_graph=True
    )[0]

print(f"{score_on_pos.shape=}")
grad_outer_prod: Tensor = torch.einsum('bi,bj->bij', score_on_pos, score_on_pos)
print(f"{grad_outer_prod.shape=}")
alpha_fn: Tensor = alpha_fn_1

# give it a 'batch' dimension to add with grad_outer_prod
I_mat: Tensor = einops.rearrange(torch.eye(2).to(DEVICE), 'i j -> 1 i j')
print(f"{I_mat.shape=}")
print(f"{grad_outer_prod.shape=}")
print(f"{alpha_fn.shape=}")
alpha_fn = einops.rearrange(alpha_fn, 'b 1 -> b 1 1')


In [None]:
# TODO: check that it is symmetric
grad_outer_prod[:10]

In [None]:

eigvals, eigvecs = torch.linalg.eig(grad_outer_prod)
eigvals_alt, eigvecs_alt = torch.linalg.eigh(grad_outer_prod)
print(f"{eigvals.shape=}")
print(f"{eigvecs.shape=}")
print(f"{eigvals_alt.shape=}")




print(f"{eigvals[:5]=}")
print(f"{eigvals_alt[:5]=}")
print(f"{torch.abs(eigvals[:5]).shape=}")
print(f"{torch.abs(eigvals[:5]) >= 0=}")

is_pos_semidef: bool = torch.all(torch.abs(eigvals)) >= 0
is_pos_def: bool = torch.all(torch.abs(eigvals)) > 0

print(f"{is_pos_semidef=}")
print(f"{is_pos_def=}")



In [None]:
eta: float = 0.01
mu: float = 1
A_pre: Tensor = ( mu * I_mat - eta * grad_outer_prod)

A_mat: Tensor =  alpha_fn * A_pre
print(f"{A_mat.shape=}")

A_mat_norms: Tensor = torch.linalg.norm(A_mat, dim=(1,2))
print(f"{A_mat_norms.shape=}")

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15,10), dpi=100)
im = ax[0].contourf(xx, yy, energy_landscape_1.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
ax[0].set_title("Energy Landscape from True Mixture")

im2 = ax[1].contourf(xx, yy, energy_on_pos.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
ax[1].set_title("Energy Landscape from Learned EBM")

im3 = ax[2].contourf(xx, yy, A_mat_norms.view(62, 100).detach().cpu(), 20,
                            cmap='Blues_r',
                            alpha=0.8,
                            zorder=0,
                            levels=20)
ax[2].set_title("$A(x)$ Norm Landscape")

pos_subs_dtch: np.ndarray = pos_subs.detach().cpu().numpy()
score_dtch: np.ndarray = score.detach().cpu().numpy()

ax[1].quiver(
    pos_subs_dtch[:,0],
    pos_subs_dtch[:,1],
    -score_dtch[:,0],
    -score_dtch[:,1],
    color='red',

)
# scale both figures equall


color = ['orange', 'black', 'grey', 'red', 'green','purple','pink','yellow']

    #ax.scatter(z_t[1:-1,0], z_t[1:-1,1], s=10, color=dico_color[metric][0], alpha=1, label=str(metric))
    #ax.scatter(z_t[0,0], z_t[0,1], s=10, color='red', alpha=1)
    #ax.scatter(z_t[-1,0], z_t[-1,1], s=10, color='green', alpha=1)
# Add red and green dots
# ax.scatter(MEAN[10][0].cpu().detach(), MEAN[10][1].cpu().detach(), s=70, color='red',alpha=1)
# ax.scatter(MEAN[-10][0].cpu().detach(), MEAN[-10][1].cpu().detach(), s=70, color='green',alpha=1)
# ax.set_axis_off()
plt.legend()
plt.show()

# Defining classes for both 

In [None]:
class M2_ScoreStructureTensorMetric(nn.Module)