# VQ-VAE Evaluation Notebook


This notebook contains the evaluation of the trained VQ-VAE model.

In [None]:
import wandb
import numpy as np
import pandas as pd

## Validation

In [None]:
keys = [
    'val_epoch_codebook_loss',
    'val_epoch_reconstructon_loss',
    'val_epoch_lpips_loss',
]

metric_names = [
    'Codebook loss',
    'Reconstructon loss',
    'LPIPS loss',
]

api = wandb.Api(timeout=60)

run = api.run("simonluder/MSE_P9_LDM/uedmx2jh")

history = run.history(keys=keys)

In [None]:
# [metric for metric in run.summary.keys() if "val_" in metric]

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharex=True)

for ax, key, name in zip(axes, keys, metric_names):
    ax.plot(history["_step"], history[key])
    ax.set_title(name)
    ax.set_xlabel("Step")
    ax.set_ylabel("Value")

plt.tight_layout()
# plt.suptitle("Validation losses")
plt.show()



## Testing

In [None]:
df_basic_test = pd.read_csv(r"Z:\simon_luder\Data_Setup\Pollen_Datasets\data\final\poleno\basic_test.csv")
df_isolated_all = pd.read_csv(r"Z:\simon_luder\Data_Setup\Pollen_Datasets\data\final\poleno\isolated_all.csv")
print(len(df_basic_test), len(df_isolated_all))

In [None]:
scores_basic_test = r"C:\Users\simon\Documents\GitHub\LDM_for_Holographic_Images\checkpoints\vqvae_8_512\test\basic_test_20251215_181623\test_logs.json"
scores_isolated_all = r"C:\Users\simon\Documents\GitHub\LDM_for_Holographic_Images\checkpoints\vqvae_8_512\test\isolated_all_20251216_123210\test_logs.json"

scores_basic_test = pd.read_json(scores_basic_test)
scores_isolated_all = pd.read_json(scores_isolated_all)

df_basic_test = pd.merge(df_basic_test, scores_basic_test, how="inner", left_on="rec_path", right_on="filenames")
df_isolated_all = pd.merge(df_isolated_all, scores_isolated_all, how="inner", left_on="rec_path", right_on="filenames")

In [None]:
df_basic_test["group"] = "seen"
df_isolated_all["group"] = "unseen"

In [None]:
test_columns = ["test_reconstructon_loss", "test_codebook_loss", "test_lpips_loss"]

df = pd.concat([df_basic_test, df_isolated_all]).reset_index()
# df = df.groupby(["dataset_id", "group"])[test_columns].mean().reset_index()
df.head(3)

In [None]:
import seaborn as sns

metric = "test_reconstruction_loss"
x = "species"

order = df.groupby(x)[metric].median().sort_values().index

plt.figure(figsize=(16, 4))
sns.boxplot(df, x=x, y=metric, hue="group", fill=False, flierprops={"marker": "."}, fliersize=1, order=order)
# plt.ylim(0, 0.002)
plt.title("Test reconstruction loss per species")
plt.xticks(rotation=90)
plt.yscale("log")
plt.show()

In [None]:
category = "unseen"
sample = df.loc[df["group"] == category]
rest = df.loc[df["group"] != category]

In [None]:
import numpy as np

def perm_test_large_imbalance(species, others, *, stat="logmean", alternative="two-sided",
                             n_perm=100000, eps=1e-12, seed=0):
    rng = np.random.default_rng(seed)

    species = np.asarray(species, dtype=float)
    others  = np.asarray(others,  dtype=float)

    pooled = np.concatenate([species, others])
    nA = len(species)
    N = len(pooled)

    print(len(species),len(others),len(pooled))

    # statistic
    def transform(x):
        if stat == "mean":
            return x
        if stat == "median":
            return x
        if stat == "logmean":
            return np.log(x + eps)
        raise ValueError("stat must be 'mean', 'median', or 'logmean'")

    pooled_t = transform(pooled)

    def T(a_idx_mask):
        A = pooled_t[a_idx_mask]
        B = pooled_t[~a_idx_mask]
        if stat == "median":
            return np.median(A) - np.median(B)
        else:  # mean or logmean
            return A.mean() - B.mean()

    # observed
    obs_mask = np.zeros(N, dtype=bool)
    obs_mask[:nA] = True
    obs = T(obs_mask)

    # permutations via random index selection (fixed nA)
    more_extreme = 0
    idx = np.arange(N)

    for _ in range(n_perm):
        A_idx = rng.choice(idx, size=nA, replace=False)
        mask = np.zeros(N, dtype=bool)
        mask[A_idx] = True
        t = T(mask)

        if alternative == "two-sided":
            more_extreme += (abs(t) >= abs(obs))
        elif alternative == "greater":  # species > others
            more_extreme += (t >= obs)
        elif alternative == "less":     # species < others
            more_extreme += (t <= obs)
        else:
            raise ValueError("alternative must be 'two-sided', 'greater', or 'less'")

    p = (more_extreme + 1) / (n_perm + 1)  # avoids p=0
    return obs, p


mse_species = sample["test_reconstruction_loss"].values
mse_others = rest["test_reconstruction_loss"].values

# Example:
obs, p = perm_test_large_imbalance(mse_species, mse_others, stat="mean", alternative="two-sided", n_perm=1000)
print("obs:", obs.round(6), "p:", p.round(3))