In [None]:

from pathlib import Path
from typing import List, Optional, Callable, Sequence, Tuple

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# For visualization
from sklearn.decomposition import PCA, FastICA
import matplotlib.pyplot as plt

from torch.cuda.amp import GradScaler, autocast


print("Libraries imported successfully!")

In [None]:
from python.VQVAE_v2 import *

In [None]:
BATCH_SIZE = 8192
EPOCHS = 50
LR = 5e-4
IN_CHANNELS = 3
EMBEDDING_DIM = 128  # The dimensionality of the embeddings
NUM_EMBEDDINGS = 256  # The size of the codebook (the "dictionary")
COMMITMENT_COST = 0.25

device = "cuda"
data_dir = "C:/Users/zphrfx/Desktop/hdk/VQVDB/data/vdb_cache/npy"


model = VQVAE(
    in_channels=IN_CHANNELS,
    embedding_dim=EMBEDDING_DIM,
    num_embeddings=NUM_EMBEDDINGS,
    commitment_cost=COMMITMENT_COST,
).to(device)

npy_files = list(Path(data_dir).glob("*.npy"))
if not npy_files:
    raise ValueError(f"No .npy files found in /data/npy")

print(f"Found {len(npy_files)} .npy files")

vdb_dataset = VDBLeafDataset(npy_files=npy_files, include_origins=False, in_channels=IN_CHANNELS)

# Save the model state_dict
model_path = "C:/Users/zphrfx/Desktop/hdk/VQVDB/python/models/vqvae.pth"

# Visualize the reconstruction quality for a single example
save = torch.load(model_path, map_location=device)
model.load_state_dict(save["state_dict"])
model.eval()



In [None]:
print("Visualizing Reconstruction Quality for a Single Example")


original_block = vdb_dataset[78].unsqueeze(0).to(device)

print("Performing reconstruction...")
with torch.no_grad():
    indices = model.encode(original_block)
    reconstructed_block = model.decode(indices)

# --- 4. Prepare Data for Plotting ---
# Convert to NumPy and permute from (C, D, H, W) to (D, H, W, C) for easier slicing
original_np = original_block.squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
reconstructed_np = reconstructed_block.squeeze(0).permute(1, 2, 3, 0).detach().cpu().numpy()

# Calculate the magnitude of the error vector for heatmap visualization
error_magnitude_np = np.linalg.norm(original_np - reconstructed_np, axis=-1)

# Helper function to map vector components from [-1, 1] to a [0, 1] RGB image
def vector_to_rgb(vector_slice):
    # Data is in [-1, 1], so shift and scale to [0, 1] for RGB display
    return np.clip((vector_slice * 0.5) + 0.5, 0, 1)

# --- 5. Plotting ---
print("Generating plots...")
fig, axes = plt.subplots(3, 3, figsize=(10, 10), constrained_layout=True)
center_slice_idx = 4  # Center slice for visualization

# --- Row 1: Z-Axis Slice ---
axes[0, 0].imshow(vector_to_rgb(original_np[center_slice_idx, :, :, :]))
axes[0, 0].set_title(f'Original (Slice Z={center_slice_idx})')

im_err_z = axes[0, 2].imshow(error_magnitude_np[center_slice_idx, :, :], cmap='magma')
axes[0, 2].set_title('Error Magnitude')

axes[0, 1].imshow(vector_to_rgb(reconstructed_np[center_slice_idx, :, :, :]))
axes[0, 1].set_title(f'Reconstructed (Slice Z={center_slice_idx})')

# --- Row 2: Y-Axis Slice ---
axes[1, 0].imshow(vector_to_rgb(original_np[:, center_slice_idx, :, :]))
axes[1, 0].set_title(f'Original (Slice Y={center_slice_idx})')

im_err_y = axes[1, 2].imshow(error_magnitude_np[:, center_slice_idx, :], cmap='magma', vmin=im_err_z.get_clim()[0], vmax=im_err_z.get_clim()[1])
axes[1, 2].set_title('Error Magnitude')

axes[1, 1].imshow(vector_to_rgb(reconstructed_np[:, center_slice_idx, :, :]))
axes[1, 1].set_title(f'Reconstructed (Slice Y={center_slice_idx})')

# --- Row 3: X-Axis Slice ---
axes[2, 0].imshow(vector_to_rgb(original_np[:, :, center_slice_idx, :]))
axes[2, 0].set_title(f'Original (Slice X={center_slice_idx})')

im_err_x = axes[2, 2].imshow(error_magnitude_np[:, :, center_slice_idx], cmap='magma', vmin=im_err_z.get_clim()[0], vmax=im_err_z.get_clim()[1])
axes[2, 2].set_title('Error Magnitude')

axes[2, 1].imshow(vector_to_rgb(reconstructed_np[:, :, center_slice_idx, :]))
axes[2, 1].set_title(f'Reconstructed (Slice X={center_slice_idx})')

# --- Final Touches ---
# Turn off all axes
for ax in axes.flat:
    ax.axis('off')

# Add a single, shared colorbar for the error magnitude plots
fig.colorbar(im_err_z, ax=axes[:, 2], orientation='vertical', label='Error Vector Magnitude', shrink=0.8)

plt.suptitle('Vector Field Reconstruction (XYZ -> RGB)', fontsize=20)
plt.show()

In [None]:

# plot losses and perplexity
rl = save['recon_loss_l']
vq_loss_l = save['vq_loss_l']
perplexity_l = save['perplexity_l']

plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(rl, label='Reconstruction Loss', color='blue')
plt.plot(vq_loss_l, label='VQ Loss', color='orange')
plt.title('Training Losses')
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.legend()
plt.subplot(2, 1, 2)
plt.plot(perplexity_l, label='Perplexity', color='green')
plt.title('Perplexity Over Training Steps')
plt.xlabel('Training Steps')
plt.ylabel('Perplexity')
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(10, 6))
plt.hist(original_np.flatten(), bins=50, alpha=0.7, label='Original')
plt.hist(reconstructed_np.flatten(), bins=50, alpha=0.7, label='Reconstructed')
plt.title('Histogram of Voxel Values')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

from scipy.stats import entropy
def kl_divergence(p, q):
    """Compute KL divergence between two distributions."""
    p = p.flatten()
    q = q.flatten()
    p = p / np.sum(p)  # Normalize
    q = q / np.sum(q)  # Normalize
    return entropy(p, q)
kl_div = kl_divergence(original_np, reconstructed_np)
print(f"KL Divergence between original and reconstructed blocks: {kl_div:.4f}")

In [None]:
print("PCA of the learned codebook vectors:")
codebook = model.quantizer.embedding.data.cpu()
pca = PCA(n_components=2)
codebook_2d = pca.fit_transform(codebook)
plt.figure(figsize=(8, 8))
plt.scatter(codebook_2d[:, 0], codebook_2d[:, 1], s=15, alpha=0.7)
plt.title('VQ-VAE Codebook (PCA Projection)')
plt.grid(True)
plt.show()

# --- Plot 2: Codebook Usage Histogram ---
# This is a powerful diagnostic. It requires running the encoder on the whole dataset.
print("\nCalculating codebook usage across the entire dataset...")
model.eval()
all_indices = []
# Create a dataloader without shuffling to iterate through the dataset
full_loader = DataLoader(vdb_dataset, batch_size=BATCH_SIZE, shuffle=False)

with torch.no_grad():
    for data_batch in full_loader:
        data_batch = data_batch.to(device) # Move data to the same device as the model
        indices = model.encode(data_batch)
        all_indices.append(indices.cpu().numpy().flatten())

all_indices = np.concatenate(all_indices)

plt.figure(figsize=(12, 6))
plt.hist(all_indices, bins=NUM_EMBEDDINGS, range=(0, NUM_EMBEDDINGS-1))
plt.title('Codebook Usage Frequency')
plt.xlabel('Codebook Index')
plt.ylabel('Number of Times Used')
plt.show()

num_dead_codes = NUM_EMBEDDINGS - len(np.unique(all_indices))
print(f"Number of 'dead' (unused) codes: {num_dead_codes} out of {NUM_EMBEDDINGS}")

In [None]:
# --- Codebook Perplexity + Active-Code Ratio ---
counts = np.bincount(all_indices, minlength=NUM_EMBEDDINGS).astype(np.float64)
probs = counts / counts.sum()
nonzero = probs > 0
perplexity = np.exp(-(probs[nonzero] * np.log(probs[nonzero])).sum())
active_ratio = nonzero.mean()

print(f"Codebook perplexity: {perplexity:.2f}")
print(f"Active-code ratio  : {active_ratio*100:.1f}%")


In [None]:
from math import log10

def psnr(x, y, vmax=1.0):
    mse = torch.mean((x - y) ** 2).item()
    return 20 * log10(vmax) - 10 * log10(mse + 1e-12)

model.eval()
psnr_list, mse_list = [], []

with torch.no_grad():
    for batch in DataLoader(vdb_dataset, batch_size=BATCH_SIZE, shuffle=False):
        batch = batch.to(device)
        rec = model.decode(model.encode(batch))
        mse = ((batch - rec) ** 2).view(len(batch), -1).mean(dim=1)
        mse_list.extend(mse.cpu().numpy())
        psnr_list.extend([psnr(b, r) for b, r in zip(batch, rec)])

avg_psnr = np.mean(psnr_list)
avg_mse = np.mean(mse_list)

# Create publication-ready plots
plt.style.use('default')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# PSNR Distribution
ax1.hist(psnr_list, bins=40, alpha=0.7, color='steelblue', edgecolor='black', linewidth=0.5)
ax1.axvline(avg_psnr, color='crimson', linestyle='--', linewidth=2, 
           label=f'Mean: {avg_psnr:.1f} dB')
ax1.set_xlabel('PSNR (dB)', fontsize=12)
ax1.set_ylabel('Number of Blocks', fontsize=12)
ax1.set_title('PSNR Distribution', fontsize=13, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.tick_params(labelsize=10)

# MSE Distribution
ax2.hist(mse_list, bins=40, alpha=0.7, color='forestgreen', edgecolor='black', linewidth=0.5)
ax2.axvline(avg_mse, color='crimson', linestyle='--', linewidth=2,
           label=f'Mean: {avg_mse:.2e}')
ax2.set_xlabel('MSE', fontsize=12)
ax2.set_ylabel('Number of Blocks (log scale)', fontsize=12)
ax2.set_title('MSE Distribution', fontsize=13, fontweight='bold')
ax2.set_yscale('log')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.tick_params(labelsize=10)

plt.tight_layout()
plt.show()

print(f"Reconstruction Quality Metrics:")
print(f"Average PSNR: {avg_psnr:.2f} dB")
print(f"Average MSE: {avg_mse:.2e}")
print(f"PSNR std: {np.std(psnr_list):.2f} dB")
print(f"MSE std: {np.std(mse_list):.2e}")


In [None]:
n_points = 100000
orig_sample = original_np.flatten()
recon_sample = reconstructed_np.flatten()
if len(orig_sample) > n_points:
    idx = np.random.choice(len(orig_sample), n_points, replace=False)
    orig_sample = orig_sample[idx]; recon_sample = recon_sample[idx]

plt.figure(figsize=(8,8))
plt.scatter(orig_sample, recon_sample, s=2, alpha=.5)
lims = [min(orig_sample.min(), recon_sample.min()),
        max(orig_sample.max(), recon_sample.max())]
plt.plot(lims, lims, 'k--', linewidth=1)
plt.xlabel('Original voxel'); plt.ylabel('Reconstructed voxel')
plt.title('Voxel-wise Scatter (diag = perfect)')
plt.grid(True, alpha=.3); plt.show()


In [None]:
# --- L2 norm of each embedding vector ---
embed_norm = torch.linalg.norm(model.quantizer.embedding.data, dim=1).cpu().numpy()
plt.figure(figsize=(10,2))
plt.bar(range(NUM_EMBEDDINGS), embed_norm, width=1.0)
plt.title('Codebook Embedding L2 Norms'); plt.xlabel('Code Index'); plt.ylabel('Norm')
plt.tight_layout(); plt.show()


In [None]:
def mip(vol, axis):
    """Maximum-intensity projection along a single axis."""
    return vol.max(axis=axis)

fig, axes = plt.subplots(2, 3, figsize=(15, 7))
views = [(0, 'XY MIP'),   # collapse Z
         (1, 'XZ MIP'),   # collapse Y
         (2, 'YZ MIP')]   # collapse X

for col, (axis_to_collapse, title) in enumerate(views):
    axes[0, col].imshow(mip(vector_to_rgb(original_np), axis=axis_to_collapse), cmap='viridis')
    axes[0, col].set_title(f'Original {title}')
    axes[0, col].axis('off')

    axes[1, col].imshow(mip(vector_to_rgb(reconstructed_np), axis=axis_to_collapse), cmap='viridis')
    axes[1, col].set_title(f'Reconstructed {title}')
    axes[1, col].axis('off')

plt.suptitle('Maximum-Intensity Projections (3-view)', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# ---------- 1. Build a per-block latent vector ----------
model.eval()
latents, errs = [], []          # errs = optional colouring

with torch.no_grad():
    for batch in DataLoader(vdb_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False):
        batch = batch.to(device)
        
        idx = model.encode(batch).long()               # (B, Z, Y, X) indices
        emb = model.quantizer.embedding[idx.view(-1)]  # (B*Z*Y*X, C)
        emb = emb.view(*idx.shape, -1)                 # (B, Z, Y, X, C)
        mean_emb = emb.mean(dim=(1, 2, 3))             # (B, C)
        latents.append(mean_emb.cpu())
        
        # Optional: per-block MSE for coloured scatter
        rec = model.decode(idx)
        errs.append(((batch - rec) ** 2)
                    .view(len(batch), -1)
                    .mean(dim=1)
                    .cpu())

latents = torch.cat(latents, dim=0).numpy()   # (N, C)
errs    = torch.cat(errs, dim=0).numpy()      # (N,)

# ---------- 2. PCA to 2-D ----------
pca2 = FastICA(n_components=2, random_state=0)
latents_2d = pca2.fit_transform(latents)      # (N, 2)


sc = plt.scatter(latents_2d[:, 0],
                 latents_2d[:, 1],
                 c=errs,                 # <- set to None for uniform colour
                 cmap='viridis',
                 s=4,
                 alpha=0.8)
if sc.get_array() is not None:           # only if colouring by a value
    plt.colorbar(sc, label='Block MSE')

plt.title('Latent Space Sampling (PCA-2D, viridis)')
plt.xlabel('PC-1'); plt.ylabel('PC-2')
plt.grid(True, alpha=.3)
plt.tight_layout()
plt.show()
