In [None]:
# %%
from hydra import initialize, compose
from datetime import datetime
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
from sklearn.cluster import KMeans
from torch import Tensor
from tqdm import tqdm
from typing import Tuple
import einops
import hydra
import ipdb
import math
from torch import Tensor
import matplotlib.pyplot as plt
import numpy as np
import sys
from dataclasses import dataclass
import torch
import torch.nn as nn
import plotly
import plotly.graph_objects as go
sys.path.append("../")

In [None]:
t_vals: Tensor = torch.linspace(0, 2 * math.pi, steps=1000)
rand_paths: Tensor = torch.stack([torch.cos(t_vals), torch.sin(t_vals)], axis=1)

T_TIME: int = 100
timed_rand_paths: Tensor = einops.repeat(rand_paths, "n d -> t_time n d", t_time=T_TIME) 
timed_rand_paths = 0.1 * torch.randn_like(timed_rand_paths) + timed_rand_paths

In [None]:
print(f"{rand_paths.shape=}")
print(f"{timed_rand_paths.shape=}")

In [None]:
fig: go.Figure = go.Figure()
fig.add_trace(
    go.Scatter(
        x=rand_paths[:, 0].cpu().numpy(),
        y=rand_paths[:, 1].cpu().numpy(),
        mode='markers',
        marker=dict(
            size=5,
            color='blue',
            opacity=0.6
        )
    )
)

fig.update_layout(
    title="Random Paths on Unit Circle",
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1)
)

frame_data: list[go.Trace] = [] 
for t in range(T_TIME):
    frame_data.append(
        go.Scatter(
            x=timed_rand_paths[t, :, 0].cpu().numpy(),
            y=timed_rand_paths[t, :, 1].cpu().numpy(),
            mode='markers',
            marker=dict(
                size=5,
                color='blue',
                opacity=0.6
            )
        )
    )
    
fig.frames = [go.Frame(data=[frame], name=str(t)) for t, frame in enumerate(frame_data)]

# Slider steps (one per frame)
slider_steps = [
    dict(
        method="animate",
        args=[
            [str(t)],
            dict(
                mode="immediate",
                frame=dict(duration=0, redraw=True),
                transition=dict(duration=0),
            ),
        ],
        label=str(t),
    )
    for t in range(T_TIME)
]


fig.update_layout(
    title="Random Paths on Unit Circle",
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1),

    sliders=[
        dict(
            active=0,
            currentvalue=dict(prefix="t = "),
            pad=dict(t=50),
            steps=slider_steps,
        )
    ],

    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            y=1,
            x=1.1,
            xanchor="right",
            yanchor="top",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        dict(
                            frame=dict(duration=100, redraw=True),
                            fromcurrent=True,
                            transition=dict(duration=0),
                        ),
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        dict(frame=dict(duration=0, redraw=False), mode="immediate"),
                    ],
                ),
            ],
        )
    ]
)

fig.show()




# simulation of adding the timelapses from multiple metrics together


In [None]:
normed_m2_color: Tensor = torch.tensor([147,112,219]) 
normed_m3_color: Tensor = torch.tensor([128,0,128])
m2_color: tuple = tuple(normed_m2_color.tolist())
m3_color: tuple = tuple(normed_m3_color.tolist())

@dataclass
class AnimationData:
     color_tuple: tuple
     size: int
     symbol: str
     latex_label: str



dico_color: dict[str, AnimationData]={
     "conf_ebm": AnimationData(color_tuple=(66, 133, 244), size=150, symbol='.', latex_label=r"$\mathbf{G}_{E_{\theta}}$"), ## blue google (66, 133, 244)
    "conf_ebm_invp": AnimationData(color_tuple=(52, 168, 83), size=150, symbol='.', latex_label=r"$\mathbf{G}_{1/p_{\theta}}$"), ## green google (52, 168, 83)
    "diag_rbf_invp": AnimationData(color_tuple=(234, 67, 53), size=150, symbol='.', latex_label=r"$\mathbf{G}_{RBF}$"), ## red google (234, 67, 53)  
    "diag_land_invp": AnimationData(color_tuple=(251, 188, 5), size=150, symbol='.', latex_label=r"$\mathbf{G}_{LAND}$"), ## yellow google (251, 188, 5)
    "conf_true_logp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='+', latex_label=r"$\mathbf{G}_{E_{\mathcal{M}}}$"),
    "conf_true_invp": AnimationData(color_tuple=(0, 0, 0), size=150, symbol='x', latex_label=r"$\mathbf{G}_{1/p_{\mathcal{M}}}$"),
    "train_EBM_geodesic.Method2Metric_invp":AnimationData(color_tuple=m2_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{1/M2}$"), ## yellow google (251, 188, 5)
    "train_EBM_geodesic.Method3Metric":AnimationData(color_tuple=m3_color, size=150, symbol='.', latex_label=r"$\mathbf{G}_{M3}$"), ## yellow google (251, 188, 5)
    }


In [None]:
all_metric_timed_paths: dict[str, Tensor] = {}

t_vals: Tensor = torch.linspace(0, 2 * math.pi, steps=100)
rand_paths: Tensor = torch.stack([torch.cos(t_vals), torch.sin(t_vals)], axis=1)

T_TIME: int = 100

different_fns: list[Tensor] = [
    torch.stack([torch.cos(t_vals), torch.sin(t_vals)], axis=1),
    torch.stack([torch.exp(-t_vals) * torch.cos(t_vals), torch.exp(-t_vals) * torch.sin(t_vals)], axis=1),
    torch.stack([torch.exp(-0.5 * t_vals) * torch.cos(2 * t_vals), torch.exp(-0.5 * t_vals) * torch.sin(2 * t_vals)], axis=1),
    torch.stack([torch.exp(-0.5 * t_vals) * torch.cos(t_vals), torch.exp(-0.5 * t_vals) * torch.sin(t_vals)], axis=1),
    torch.stack([torch.exp(-2 * t_vals) * torch.cos(t_vals), torch.exp(-2 * t_vals) * torch.sin(t_vals)], axis=1),  
    torch.stack([torch.exp(-0.1 * t_vals) * torch.cos(t_vals), torch.exp(-0.1 * t_vals) * torch.sin(t_vals)], axis=1),
    torch.stack([torch.exp(-0.01 * t_vals) * torch.cos(t_vals), torch.exp(-0.01 * t_vals) * torch.sin(t_vals)], axis=1),    
    torch.stack([torch.exp(-0.5 * t_vals) * torch.cos(2 * t_vals), torch.exp(-0.5 * t_vals) * torch.sin(2 * t_vals)], axis=1),
]

for metric_name in dico_color.keys():
    rand_paths = different_fns[np.random.randint(0, len(different_fns))]
    timed_rand_paths: Tensor = einops.repeat(rand_paths, "n d -> t_time n d", t_time=T_TIME) 
    # timed_rand_paths = 0.1 * torch.randn_like(timed_rand_paths) + timed_rand_paths
    # of shape (T_TIME, P, D)
    timed_paths: Tensor =    timed_rand_paths + 0.1 * torch.randn_like(timed_rand_paths)
    all_metric_timed_paths[metric_name] = timed_paths

In [None]:
fig: go.Figure = go.Figure()
# fig.add_trace(
#     go.Scatter(
#         x=rand_paths[:, 0].cpu().numpy(),
#         y=rand_paths[:, 1].cpu().numpy(),
#         mode='markers',
#         marker=dict(
#             size=5,
#             color='blue',
#             opacity=0.6
#         )
#     )
# )

fig.update_layout(
    title="Random Paths on Unit Circle",
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1)
)

all_frames_data: list[list[go.Trace]] = []
for t in range(T_TIME):
    frame_data: list[go.Trace] = [] 

    for metric in dico_color.keys():
        if t == 0:
            print(f"Adding metric {metric} with color {dico_color[metric].color_tuple} and label {dico_color[metric].latex_label}")

        frame_data.append(
            go.Scatter(
                name=dico_color[metric].latex_label,
                x=all_metric_timed_paths[metric][t, :, 0].cpu().numpy(),
                y=all_metric_timed_paths[metric][t, :, 1].cpu().numpy(),
                mode='lines+markers',
                marker=dict(
                    size=5,
                    color=plotly.colors.label_rgb(dico_color[metric].color_tuple),
                    opacity=0.6
                )
            )
        )
    all_frames_data.append(frame_data)

frame_zero_data: list[go.Trace] = all_frames_data[0]
fig.add_traces(frame_zero_data)
    
all_frames: list[go.Frame] = []

for t, frame_data in enumerate(all_frames_data):
    #  print(f"{t=}, {len(frame_data)=}")
     all_frames.append(
         go.Frame(
             data=frame_data,
             name=str(t)
            )
        )
fig.frames = all_frames

# Slider steps (one per frame)
slider_steps = [
    dict(
        method="animate",
        args=[
            [str(t)],
            dict(
                mode="immediate",
                frame=dict(duration=0, redraw=True),
                transition=dict(duration=0),
            ),
        ],
        label=str(t),
    )
    for t in range(T_TIME)
]


fig.update_layout(
    title="Random Paths on Unit Circle",
    showlegend=True,
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1),

    sliders=[
        dict(
            active=0,
            currentvalue=dict(prefix="t = "),
            pad=dict(t=50),
            steps=slider_steps,
        )
    ],

    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            y=1,
            x=1.1,
            xanchor="right",
            yanchor="top",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        dict(
                            frame=dict(duration=100, redraw=True),
                            fromcurrent=True,
                            transition=dict(duration=0),
                        ),
                    ],
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[
                        [None],
                        dict(frame=dict(duration=0, redraw=False), mode="immediate"),
                    ],
                ),
            ],
        )
    ]
)

fig.show()




In [None]:
import torch.nn as nn
import sys
sys.path.append("../")
import matplotlib.pyplot as plt
from utils.toy_dataset import GaussianMixture
import math
import numpy as np
from sklearn.cluster import KMeans

In [None]:
NB_GAUSSIANS = 200
RADIUS = 8
DEVICE = "cuda:1"
mean_ = (torch.linspace(0, 180, NB_GAUSSIANS + 1)[0:-1] * math.pi / 180)
MEAN = RADIUS * torch.stack([torch.cos(mean_), torch.sin(mean_)], dim=1) - torch.tensor([0.0,0.5])
COVAR = torch.tensor([[1.0, 0], [0, 1.0]]).unsqueeze(0).repeat(len(MEAN), 1, 1)

x_p, y_p = torch.meshgrid(torch.linspace(-10, 10, 100), torch.linspace(-2.5, 10, 62), indexing='xy')
pos = torch.cat([x_p.flatten().unsqueeze(1), y_p.flatten().unsqueeze(1)], dim=1).to(DEVICE)

## Gaussian Mixture Uniformly distributed
weight_1 = (torch.ones(NB_GAUSSIANS) / NB_GAUSSIANS)
mixture_1 = GaussianMixture(center_data=MEAN, covar=COVAR, weight=weight_1).to(DEVICE)


sample_1 = mixture_1.sample(1000).cpu().detach()
#offset_1 = torch.tensor([0.0, 4.0])
#mult_1  = torch.tensor([10.0, 6.0])

energy_landscape_1 = mixture_1.energy(pos)

In [None]:

en_landscape: Tensor = energy_landscape_1.cpu().numpy()
x_p_np: np.ndarray = x_p.cpu().numpy()
y_p_np: np.ndarray = y_p.cpu().numpy()
en_landscape = einops.rearrange(en_landscape, "(x y) -> x y", x=100)    
en_landscape =en_landscape.T



In [None]:
en_landscape.shape, x_p_np.shape, y_p_np.shape  

In [None]:

assert en_landscape.shape == x_p_np.shape == y_p_np.shape, "Shapes of z, x, and y must match for contour plot."
fig: go.Figure= go.Figure(
    data=go.Contour(
        z=en_landscape,
        x=x_p_np,
        y=y_p_np,
        colorscale='Viridis',
        contours=dict(
            coloring='heatmap',
            showlabels=True,
            labelfont=dict(size=12, color='white')
        )

    )
)
fig.update_layout(
    title="Energy Landscape of Gaussian Mixture",
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1)
)
fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure(data =
    go.Contour(
        z=[[10, 10.625, 12.5, 15.625, 20],
           [5.625, 6.25, 8.125, 11.25, 15.625],
           [2.5, 3.125, 5., 8.125, 12.5],
           [0.625, 1.25, 3.125, 6.25, 10.625],
           [0, 0.625, 2.5, 5.625, 10]],
        x=[-9, -6, -5 , -3, -1], # horizontal axis
        y=[0, 1, 4, 5, 7] # vertical axis
    ))
fig.show()