In [None]:
from math import sqrt

import numpy as np
import torch
from sklearn.manifold import TSNE
from tqdm import tqdm

from nlnas.separability import gdv, gr_dist, label_variation, mean_gr_dist

n_samples, n_dims = 512, 1024
n_iter, n_save = 70, 5

w = torch.concat(
    [
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
        torch.normal(0, 1, (int(n_samples / 4), n_dims)),
    ]
)
w.requires_grad = True
y = torch.concat(
    [
        torch.zeros(int(n_samples / 4)),
        torch.ones(int(n_samples / 4)),
        torch.ones(int(n_samples / 4)) * 2,
        torch.ones(int(n_samples / 4)) * 3,
    ]
)

opt = torch.optim.Adam([w], lr=1e-2)

r = np.linspace(0, n_iter, n_save, dtype=int)
vs: dict[int, np.ndarray] = {-1: w.detach().clone().numpy()}

progress = tqdm(range(n_iter))
for i in progress:
    opt.zero_grad()
    # loss = gdv(v, y)
    loss = -mean_gr_dist(w, y)
    loss.backward()
    opt.step()
    progress.set_postfix(
        {
            "sep": float(loss),
            "grad": float(torch.linalg.norm(w.grad)),
        }
    )
    if i in r:
        vs[i] = w.detach().clone().numpy()
vs[n_iter - 1] = w.detach().clone().numpy()

if n_dims > 2:
    ws = {}
    for epoch in tqdm(vs.keys()):
        ws[epoch] = TSNE().fit_transform(vs[epoch])
else:
    ws = vs

In [None]:
import bokeh.io
import bokeh.plotting as bk

bokeh.io.output_notebook()

from nlnas.plotting import class_scatter

figures = []
for epoch, w in ws.items():
    figure = bk.figure(title=f"Iteration {epoch+1}", toolbar_location=None)
    figure.height, figure.width = 200, 200
    figure.grid.visible, figure.axis.visible = False, False
    class_scatter(
        figure,
        w,
        y.numpy(),
        palette="viridis",
    )
    figures.append(figure)

bk.show(bk.row(figures))