In [None]:
import numpy as np
import torch
from tqdm import tqdm
from nlnas.separability import gdv
from sklearn.manifold import TSNE

n_samples, n_dims = 1024, 32
n_iter, n_save = 70, 5

v = torch.concat(
    [
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
    ]
)
v.requires_grad = True
y = torch.concat(
    [
        torch.zeros(int(n_samples / 4)),
        torch.ones(int(n_samples / 4)),
        torch.ones(int(n_samples / 4)) * 2,
        torch.ones(int(n_samples / 4)) * 3,
    ]
)

opt = torch.optim.Adam([v], lr=1e-3)

r = np.linspace(0, n_iter, n_save, dtype=int)
vs: dict[int, np.ndarray] = {-1: v.detach().numpy()}

progress = tqdm(range(n_iter))
for i in progress:
    opt.zero_grad()
    loss = gdv(v, y)
    loss.backward()
    opt.step()
    progress.set_postfix(
        {
            "gdv": float(loss),
            "grad": float(torch.linalg.norm(v.grad)),
        }
    )
    if i in r:
        vs[i] = v.detach().numpy()
vs[n_iter - 1] = v.detach().numpy()

if n_dims > 2:
    for epoch in tqdm(vs.keys()):
        vs[epoch] = TSNE().fit_transform(vs[epoch])

In [None]:
import bokeh.plotting as bk
import bokeh.io

bokeh.io.output_notebook()

from nlnas.plotting import class_scatter

figures = []
for epoch, v in vs.items():
    # g = gdv(e, w)
    figure = bk.figure(title=f"Iteration {epoch+1}", toolbar_location=None)
    figure.height, figure.width = 200, 200
    figure.grid.visible, figure.axis.visible = False, False
    class_scatter(
        figure,
        v,
        y.numpy(),
        palette="viridis",
    )
    figures.append(figure)

bk.show(bk.row(figures))