In [None]:
# %% [markdown]
# ## Unsupervised AAE – MNIST
# Stand-alone version (no inheritance) + clustering accuracy

# %%
import torch, torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt

from unsupervised import (
    UAAEConfig,
    UnsupervisedAdversarialAutoencoder,
)
from dataloader import load_mnist_data          # your helper


In [None]:
BATCH_SIZE = 100
VAL_SAMPLES = 5000
TEST_SAMPLES = 5000

# Load full train and test sets
train_loader_full, test_loader_full = load_mnist_data(batch_size=BATCH_SIZE, num_samples=-1)
full_train_ds = train_loader_full.dataset   # 60,000 samples
full_test_ds = test_loader_full.dataset     # 10,000 samples

# Keep all of training set
train_loader = torch.utils.data.DataLoader(full_train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# ── Take 5,000 from test set for validation, rest for testing ──
val_ds, test_ds = torch.utils.data.random_split(
    full_test_ds,
    [VAL_SAMPLES, TEST_SAMPLES],
    generator=torch.Generator().manual_seed(42)
)

val_loader  = torch.utils.data.DataLoader(val_ds,  batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train samples: {len(full_train_ds)} | Val samples: {len(val_ds)} | Test samples: {len(test_ds)}")


In [None]:
# %%
cfg = UAAEConfig(
    input_dim   = 784,
    ae_hidden   = 3000,
    disc_hidden = 3000,
    latent_dim_categorical = 16,
    latent_dim_style = 5,
    use_decoder_sigmoid = True,
)
print(f'Device found: {cfg.device}')
model = UnsupervisedAdversarialAutoencoder(cfg)
# model.load_weights('runs/unsup_aae/weights_epoch_50/weights')
print(model)


In [None]:
model.fit(
    train_loader  = train_loader,
    val_loader  = val_loader,     # gives val accuracy each epoch
    test_loader = test_loader,
    epochs      = 500, # 1500
    prior_std   = 1.0,
    result_folder     = Path("runs/unsup_aae"),
)

In [None]:
# %%
test_acc = model.evaluate_clustering(val_loader, test_loader)
print(f"Test clustering accuracy: {test_acc:.2%}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter

# ── 1. Read CSV ───────────────────────────────────────────────────────────────
df = pd.read_csv("runs/unsup_aae/batch_log.csv")
df["batch"] = df["batch"].astype(int)
df["epoch"] = df["epoch"].astype(int)

# ── 2. Create Iteration Axis ──────────────────────────────────────────────────
if df["epoch"].nunique() > 1:
    num_batches_per_epoch = df.loc[df["epoch"] < df["epoch"].max(), "batch"].max()
else:
    num_batches_per_epoch = df["batch"].max()

df["iteration"] = (df["epoch"] - 1) * num_batches_per_epoch + df["batch"]
df = df.sort_values("iteration").reset_index(drop=True)

# grab the final iteration value
max_iter = df["iteration"].max()

# ── 3. Define loss groups and colors ─────────────────────────────────────────
groups = {
    "Reconstruction Loss"       : ["recon_loss"],
    "Categorical Disc/Gen Loss" : ["disc_cat_loss", "gen_cat_loss"],
    "Style Disc/Gen Loss"       : ["disc_style_loss", "gen_style_loss"],
}

colors = {
    "recon_loss"      : "green",
    "disc_cat_loss"   : "blue",
    "gen_cat_loss"    : "orange",
    "disc_style_loss" : "blue",
    "gen_style_loss"  : "orange",
}

# ── 4. Plot each group in its own figure ──────────────────────────────────────
for title, cols in groups.items():
    fig, ax = plt.subplots(figsize=(10, 4))
    for col in cols:
        if col in df.columns:
            ax.plot(df["iteration"], df[col],
                    label=col,
                    color=colors[col],
                    linewidth=1.5,
                    alpha=0.8)

    # labels and title
    ax.set_title(title)
    ax.set_ylabel("Loss")
    ax.set_xlabel("Iteration")

    # single tick at the final iteration
    ax.set_xticks([max_iter])
    ax.set_xticklabels([str(max_iter)])

    # force plain formatting (no offset or sci notation)
    fmt = ScalarFormatter(useOffset=False)
    fmt.set_scientific(False)
    ax.xaxis.set_major_formatter(fmt)
    ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))

    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    plt.tight_layout()

plt.show()


In [None]:
# %%
import torch.nn.functional as F

@torch.no_grad()
def generate_heads(self, n: int, prior_std: float = 1.0):
    z_cat = F.one_hot(torch.randint(0, self.cfg.latent_dim_categorical, (n,), device=self.device), num_classes=self.cfg.latent_dim_categorical).float()
    z_style = torch.zeros(n, self.cfg.latent_dim_style, device=self.device)
    z = torch.cat([z_cat, z_style], dim=1)
    return self.decoder(z)
        
with torch.no_grad():
    samples = generate_heads(model, 16).cpu().view(-1, 28, 28)

fig, axes = plt.subplots(4, 4, figsize=(4,4))
for ax, img in zip(axes.flatten(), samples):
    ax.imshow(img, cmap="gray")
    ax.axis("off")
plt.tight_layout()


In [None]:
def print_cluster_counts(model, test_loader):
    from collections import Counter

    model.eval()
    all_preds = []

    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(model.device)
            preds = model.predict_clusters(x)
            all_preds.extend(preds.cpu().tolist())

    counts = Counter(all_preds)
    K = model.cfg.latent_dim_categorical
    print("Cluster assignment counts (on test set):")
    for cluster_id in range(K):
        print(f"  Cluster {cluster_id}: {counts[cluster_id]} samples")

print_cluster_counts(model, train_loader)

In [None]:
import torch
import matplotlib.pyplot as plt


# ------------------------------------------------------------
# 1.  map each categorical cluster → the digit it most often predicts
# ------------------------------------------------------------
@torch.no_grad()
def assign_cluster_labels(model, val_loader):
    model.eval()
    K = model.cfg.latent_dim_categorical
    # running best label & max-prob seen so far for each cluster
    best_label_for = torch.full((K,), -1, dtype=torch.long, device=model.device)
    best_prob_for  = torch.zeros(K, device=model.device)

    for x, y in val_loader:                  # y are ground-truth digits
        x, y = x.to(model.device), y.to(model.device)
        z_cat, _ = model.forward_encoder(x)
        probs, preds = z_cat.max(dim=1)      # highest softmax prob + its cluster id
        for i in range(x.size(0)):
            cid   = preds[i].item()
            prob  = probs[i].item()
            label = y[i].item()
            if prob > best_prob_for[cid]:
                best_prob_for[cid]  = prob
                best_label_for[cid] = label

    # send back a plain Python dict so it’s easy to access in plotting
    return {int(k): int(v) for k, v in enumerate(best_label_for.cpu()) if v >= 0}


# ------------------------------------------------------------
# 2.  plot a small grid of images per cluster with the inferred label
# ------------------------------------------------------------
def plot_clusters_with_labels(model, test_loader, val_loader, num_per_cluster=5):
    cluster_labels = assign_cluster_labels(model, val_loader)

    model.eval()
    K = model.cfg.latent_dim_categorical
    picked = {k: [] for k in range(K)}

    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(model.device)
            preds = model.predict_clusters(x)
            for img, cid in zip(x, preds):
                cid = cid.item()
                if len(picked[cid]) < num_per_cluster:
                    picked[cid].append(img.cpu())
            if all(len(p) >= num_per_cluster for p in picked.values()):
                break

    # keep only non-empty clusters
    clusters = [cid for cid, imgs in picked.items() if imgs]
    rows = len(clusters)
    fig, axes = plt.subplots(rows, num_per_cluster, figsize=(num_per_cluster*2, rows*2))

    if rows == 1:                      # matplotlib quirk: axes isn’t 2-D if rows==1
        axes = axes[None, :]

    for r, cid in enumerate(clusters):
        label = cluster_labels.get(cid, "?")
        for c in range(num_per_cluster):
            ax = axes[r][c]
            if c < len(picked[cid]):
                img = picked[cid][c].reshape(28, 28)
                ax.imshow(img, cmap='gray')
            ax.axis('off')
            if c == 0:
                ax.set_title(f"label {label}", fontsize=10)

    plt.tight_layout()
    plt.show()


plot_clusters_with_labels(model, train_loader, val_loader)



In [None]:
# %%
import torch.nn.functional as F

@torch.no_grad()
def generate(self, n: int, prior_std: float = 1.0):
    z_cat = F.one_hot(torch.randint(0, self.cfg.latent_dim_categorical, (n,), device=self.device), num_classes=self.cfg.latent_dim_categorical).float()
    z_style = torch.randn(n, self.cfg.latent_dim_style, device=self.device) * prior_std
    z = torch.cat([z_cat, z_style], dim=1)
    return self.decoder(z)
        

samples = generate(model, 16).cpu().view(-1, 28, 28)

fig, axes = plt.subplots(4, 4, figsize=(4,4))
for ax, img in zip(axes.flatten(), samples):
    ax.imshow(img, cmap="gray")
    ax.axis("off")
plt.tight_layout()
