In [1]:
from pathlib import Path

MODELS_DIR = Path().parent / "models"
EXPERIMENTS_DIR = Path().parent / "experiments"
BASE_DATA_DIR = Path().resolve().parent / ".cache" / "huggingface" / "datasets" 

In [None]:
import sys

sys.path.append(str(Path().resolve().parent))
from dataset_loading import dataset

mnist = dataset.get_dataset(name="mnist", preprocess=True, to_tensor=False, flatten=False, resize=28, class_limit=10)
train_mnist = mnist[0]["train"]
test_mnist = mnist[0]["test"]

### VAE LATENT SPACE ANALYSIS UMAP

In [None]:
import numpy as np
import pandas as pd
import torch

from utils.vae_net import VAE

# Get trained VAE
ckpt = torch.load(MODELS_DIR / "5_epochs" / "mnist_head_10.pth", map_location="cuda")
vae_mnist = VAE(w=28, h=28, latent_dim=10, channels=1)

vae_mnist.load_state_dict(ckpt["hyper_state_dict"])
vae_mnist.to("cuda")
vae_mnist.eval()

# extract samples from TEST set 10 samples from class 0,1,3, as we don't want to check memorization but generalization of gemotric shapes
test_mnist_parquet = pd.DataFrame(test_mnist)
selected_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
samples_per_class = 200  # or None for all

mus_mnist = []
labels_mnist = []

with torch.no_grad():
    for c in selected_classes:
        rows = test_mnist_parquet[test_mnist_parquet["label"] == c]
        rows = rows.sample(n=min(samples_per_class, len(rows)), random_state=42)

        imgs = np.stack(rows["image"].values)
        imgs = torch.tensor(imgs).unsqueeze(1).float().to("cuda") / 255.0

        _, mu_mnist, _ = vae_mnist(imgs)
        mus_mnist.append(mu_mnist.cpu().numpy())
        labels_mnist.extend([c] * mus_mnist[-1].shape[0])

mus_mnist = np.concatenate(mus_mnist, axis=0)
labels_mnist = np.array(labels_mnist)



In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import umap

reducer = umap.UMAP(n_neighbors=20, min_dist=0.1, random_state=42)
embedding = reducer.fit_transform(np.array(mus_mnist))
plt.figure(figsize=(10,8))
sns.scatterplot(x=embedding[:,0], y=embedding[:,1], hue=labels_mnist, palette="tab10")
plt.title("VAE MNIST Latent Space UMAP Projection")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

# Configuration
classes = np.unique(labels_mnist)
bandwidth = None  # let gaussian_kde choose (Scott's rule)
grid_res = 200

# Fix plot limits so geometry is comparable across runs
x_min, x_max = embedding[:, 0].min(), embedding[:, 0].max()
y_min, y_max = embedding[:, 1].min(), embedding[:, 1].max()

xx, yy = np.meshgrid(
    np.linspace(x_min, x_max, grid_res),
    np.linspace(y_min, y_max, grid_res)
)
grid = np.vstack([xx.ravel(), yy.ravel()])

plt.figure(figsize=(10, 8))

# Scatter points (faint, for reference)
plt.scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=labels_mnist,
    s=6,
    cmap="tab10",
    alpha=0.25
)

# KDE contours per class
for c in classes:
    mask = labels_mnist == c
    points = embedding[mask].T  # shape (2, Nc)

    if points.shape[1] < 10:
        continue  # KDE not stable with very few points

    kde = gaussian_kde(points, bw_method=bandwidth)
    density = kde(grid).reshape(xx.shape)

    plt.contour(
        xx, yy, density,
        levels=5,
        linewidths=1,
        alpha=0.9
    )

plt.title("Class-conditional KDE over VAE latent UMAP")
plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.grid(True)
plt.show()


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch

datasets_cfg = {
    "mnist": {
        "data": test_mnist,
        "vae": vae_mnist,
        "ckpt": "mnist_head_10.pth",
        "num_classes": 10,
    },
    "kmnist": {
        "data": test_kmnist,
        "vae": vae_kmnist,
        "ckpt": "kmnist_head_10.pth",
        "num_classes": 10,
    },
    "fashion_mnist": {
        "data": test_fashion_mnist,
        "vae": vae_fashion_mnist,
        "ckpt": "fashion_mnist_head_10.pth",
        "num_classes": 10,
    },
    "hebrew_chars": {
        "data": train_hebrew_chars,
        "vae": vae_hebrew_chars,
        "ckpt": "hebrew_chars_head_10.pth",
        "num_classes": 20,
    },
    "math_shapes": {
        "data": test_math_shapes,
        "vae": vae_math_shapes,
        "ckpt": "math_shapes_head_10.pth",
        "num_classes": 8,
    },
}


In [None]:
def umap_for_dataset(
    data,
    vae,
    ckpt_path,
    samples_per_class=500,
    device="cuda",
    umap_kwargs=None,
):
    df = pd.DataFrame(data)

    ckpt = torch.load(ckpt_path, map_location=device)
    vae.load_state_dict(ckpt["hyper_state_dict"])
    vae.to(device)
    vae.eval()

    mus = []
    labels = []

    with torch.no_grad():
        for c in sorted(df["label"].unique()):
            rows = df[df["label"] == c]

            if samples_per_class is not None:
                rows = rows.sample(
                    n=min(samples_per_class, len(rows)),
                    random_state=42,
                )

            imgs = np.stack(rows["image"].values)
            imgs = (
                torch.tensor(imgs)
                .unsqueeze(1)
                .float()
                .to(device) / 255.0
            )

            _, mu, _ = vae(imgs)
            mu = mu.cpu().numpy()

            mus.append(mu)
            labels.extend([c] * mu.shape[0])

    mus = np.concatenate(mus, axis=0)
    labels = np.array(labels)

    umap_model = umap.UMAP(
        n_neighbors=20,
        min_dist=0.1,
        random_state=42,
        **(umap_kwargs or {}),
    )

    embedding = umap_model.fit_transform(mus)

    return embedding, labels


In [None]:
import numpy as np
from matplotlib.lines import Line2D

fig, axes = plt.subplots(1, 5, figsize=(20, 4))
base_ckpt_dir = Path().resolve().parent / "models"

cmap = plt.get_cmap("tab20")

for ax, (name, cfg) in zip(axes, datasets_cfg.items()):
    emb, labels = umap_for_dataset(
        data=cfg["data"],
        vae=cfg["vae"],
        ckpt_path=base_ckpt_dir / cfg["ckpt"],
        samples_per_class=200,
    )

    ax.scatter(
        emb[:, 0],
        emb[:, 1],
        c=labels,
        s=5,
        cmap=cmap,
    )

    ax.set_title(name)
    ax.set_xticks([])
    ax.set_yticks([])

    # ---- per-axis legend ----
    unique_labels = np.unique(labels)

    legend_handles = [
        Line2D(
            [0], [0],
            marker="o",
            linestyle="",
            markersize=6,
            markerfacecolor=cmap(c / max(unique_labels)),
            markeredgecolor="none",
            label=str(c),
        )
        for c in unique_labels
    ]

    ax.legend(
        handles=legend_handles,
        title="Class",
        loc="best",
        fontsize=8,
        title_fontsize=9,
        frameon=False,
    )

plt.tight_layout()
plt.show()


### TARGET NETWORK EMBEDDING OVER TIME

In [None]:
data = torch.load("target_embeddings_over_time.pt", map_location="cpu")

# sort epochs
epochs = sorted(data.keys())
print(f"Loaded embeddings for {len(epochs)} epochs:")
print(epochs[:5], "...", epochs[-5:])


In [None]:

all_embs = np.concatenate(
    [data[e]["embeddings"].numpy() for e in epochs],
    axis=0
)

# compute global limits with a small margin
margin = 0.05
x_min, x_max = all_embs[:, 0].min(), all_embs[:, 0].max()
y_min, y_max = all_embs[:, 1].min(), all_embs[:, 1].max()

x_pad = (x_max - x_min) * margin
y_pad = (y_max - y_min) * margin

X_LIM = (x_min - x_pad, x_max + x_pad)
Y_LIM = (y_min - y_pad, y_max + y_pad)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

unique_labels = np.unique(data[epochs[0]]["labels"].numpy())
cmap = plt.get_cmap("tab10")

label_to_color = {lab: cmap(lab) for lab in unique_labels}

legend_elements = [
    Line2D(
        [0], [0],
        marker='o',
        color='w',
        label=f"Class {lab}",
        markerfacecolor=label_to_color[lab],
        markersize=6
    )
    for lab in unique_labels
]


In [None]:
from matplotlib import animation

fig, ax = plt.subplots(figsize=(6, 5))

def update(i):
    ax.clear()

    epoch = epochs[i]
    emb = data[epoch]["embeddings"].numpy()
    labels = data[epoch]["labels"].numpy()

    for lab in unique_labels:
        mask = labels == lab
        ax.scatter(
            emb[mask, 0],
            emb[mask, 1],
            s=8,
            color=label_to_color[lab],
            alpha=0.8
        )

    ax.set_xlim(*X_LIM)
    ax.set_ylim(*Y_LIM)
    ax.set_aspect("equal", adjustable="box")

    ax.set_title(f"Target embedding geometry â€“ epoch {epoch}")
    ax.set_xlabel("dim 1")
    ax.set_ylabel("dim 2")
    ax.legend(handles=legend_elements, loc="best", fontsize=8)
    ax.grid(True)

ani = animation.FuncAnimation(
    fig,
    update,
    frames=len(epochs),
    interval=100,
    blit=False
)

ani.save(
    "geometry_evolution_labeled.mp4",
    fps=10,
    dpi=150,
    writer="ffmpeg"
)

plt.close(fig)
