In [None]:
import argparse
import os
import struct
import time
from pathlib import Path
from typing import List, Optional, Callable, Sequence

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# For visualization
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

print("Libraries imported successfully!")

In [None]:
class VDBLeafDataset(Dataset):
    def __init__(
            self,
            npy_files: Sequence[str | Path],
            transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
            *,
            include_origins: bool = False,
            origins_root: str | Path | None = None,
            origins_suffix: str = "._origins.npy",
    ) -> None:
        super().__init__()

        self.transform = transform
        self.include_origins = include_origins

        # --- mmap all data files -------------------------------------------------
        self.arrays: List[np.memmap] = []
        self.origin_arrays: List[np.memmap] | None = [] if include_origins else None

        for f in npy_files:
            arr = np.load(f, mmap_mode="r")
            if arr.shape[1:] != (8, 8, 8):
                raise ValueError(f"{f}: expected (N, 8, 8, 8), got {arr.shape}")
            self.arrays.append(arr)

            if include_origins:
                if origins_root is not None:
                    origin_path = Path(origins_root) / (Path(f).stem + origins_suffix)
                else:
                    origin_path = Path(f).with_suffix(origins_suffix)
                if not origin_path.exists():
                    raise FileNotFoundError(origin_path)

                self.origin_arrays.append(np.load(origin_path, mmap_mode="r"))

        # --- pre-compute global index mapping ------------------------------------
        lengths = np.fromiter((a.shape[0] for a in self.arrays), dtype=np.int64)
        self.file_offsets = np.concatenate(([0], np.cumsum(lengths)))
        self.total_leaves: int = int(self.file_offsets[-1])

    # ---------------------------------------------------------------------------

    def __len__(self) -> int:
        return self.total_leaves

    def __getitem__(self, idx: int):
        if not (0 <= idx < self.total_leaves):
            raise IndexError(idx)

        # locate (file, local) in O(log n) inside highly-optimised C code
        file_idx = int(np.searchsorted(self.file_offsets, idx, side="right") - 1)
        local_idx = idx - int(self.file_offsets[file_idx])

        # zero-copy view from the mmap’d array
        leaf_np = self.arrays[file_idx][local_idx].astype(np.float32, copy=True)
        leaf = torch.from_numpy(leaf_np).to(torch.float32).unsqueeze(0)

        if self.transform is not None:
            leaf = self.transform(leaf)

        if self.include_origins:
            origin_np = self.origin_arrays[file_idx][local_idx].astype(np.int32, copy=False)
            origin = torch.from_numpy(origin_np)
            return leaf, origin

        return leaf


In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int,
                 commitment_cost: float, decay: float = 0.99, eps: float = 1e-5):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.eps = eps

        # Initialize embeddings
        embed = torch.randn(num_embeddings, embedding_dim)
        embed = F.normalize(embed, dim=1)  # Normalize initial embeddings
        
        self.register_buffer('embedding', embed)
        self.register_buffer('cluster_size', torch.zeros(num_embeddings))
        self.register_buffer('embed_avg', embed.clone().detach())

    def forward(self, x):
        B, D, *spatial = x.shape
        flat = x.permute(0, *range(2, 2+len(spatial)), 1).contiguous().view(-1, D)

        # Compute distances
        distances = (
            flat.pow(2).sum(1, keepdim=True)
            + self.embedding.pow(2).sum(1)
            - 2 * flat @ self.embedding.t()
        )

        # Get nearest codes
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).type(flat.dtype)

        # Quantize
        quantized = encodings @ self.embedding
        quantized = quantized.view(B, *spatial, D).permute(0, -1, *range(1, 1+len(spatial)))

        # EMA updates (simplified)
        if self.training:
            with torch.no_grad():
                encodings_sum = encodings.sum(0)
                self.cluster_size.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay)
                
                dw = encodings.t() @ flat.detach()
                self.embed_avg.mul_(self.decay).add_(dw, alpha=1-self.decay)
                
                # Normalize
                n = self.cluster_size + self.eps
                self.embedding.copy_(self.embed_avg / n.unsqueeze(1))

        # Losses
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = x + (quantized - x).detach()

        # Perplexity
        avg_probs = encodings.mean(0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, loss, perplexity


# --- Encoder ---
class Encoder(nn.Module):
    def __init__(self, in_channels, embedding_dim):
        super(Encoder, self).__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            
            nn.Conv3d(32, embedding_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(embedding_dim),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.net(x)

# --- Decoder ---
class Decoder(nn.Module):
    def __init__(self, embedding_dim, out_channels):
        super(Decoder, self).__init__()
        self.net = nn.Sequential(
            
            # 4³ → 8³
            nn.ConvTranspose3d(embedding_dim, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            
            # refine at 8³
            nn.Conv3d(64, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        return self.net(x)

# --- Full VQ-VAE Model ---
class VQVAE(nn.Module):
    def __init__(self, in_channels, embedding_dim, num_embeddings, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, embedding_dim)
        self.quantizer = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost=0.25)
        self.decoder = Decoder(embedding_dim, in_channels)

    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss, perplexity = self.quantizer(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss, perplexity

    def encode(self, x):
        z = self.encoder(x)
        flat_z = z.permute(0, *range(2, z.ndim), 1).contiguous().view(-1, self.quantizer.embedding_dim)
        distances = (torch.sum(flat_z**2, dim=1, keepdim=True) 
                     + torch.sum(self.quantizer.embedding**2, dim=1)
                     - 2 * torch.matmul(flat_z, self.quantizer.embedding.t()))
        indices = torch.argmin(distances, dim=1)
        return indices.view(z.shape[0], *z.shape[2:])

    def decode(self, indices):
        quantized_vectors = F.embedding(indices, self.quantizer.embedding)
        quantized_for_decoder = quantized_vectors.permute(0, quantized_vectors.ndim - 1, *range(1, quantized_vectors.ndim - 1))
        x_recon = self.decoder(quantized_for_decoder)
        return x_recon
        

In [None]:
# Hyperparameters
BATCH_SIZE = 8192
EPOCHS = 50
LR = 1e-3
IN_CHANNELS = 1
EMBEDDING_DIM = 64 # The dimensionality of the embeddings
NUM_EMBEDDINGS = 512 # The size of the codebook (the "dictionary")
COMMITMENT_COST = 0.25

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

npy_files = list(Path("data/npy").glob("*.npy"))
if not npy_files:
    raise ValueError(f"No .npy files found in /data/npy")

print(f"Found {len(npy_files)} .npy files")

vdb_dataset = VDBLeafDataset(npy_files=npy_files, include_origins=False)
print(f"Dataset created with {len(vdb_dataset)} total blocks.")

# keep 10% of the dataset for validation
split_idx = int(len(vdb_dataset) * 0.1)
vdb_dataset_train = torch.utils.data.Subset(vdb_dataset, range(split_idx))
vdb_dataset_val = torch.utils.data.Subset(vdb_dataset, range(split_idx, len(vdb_dataset)))
print(f"Training dataset size: {len(vdb_dataset_train)}")
print(f"Validation dataset size: {len(vdb_dataset_val)}")


train_loader = DataLoader(
    vdb_dataset_train, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=0
)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = VQVAE(IN_CHANNELS, EMBEDDING_DIM, NUM_EMBEDDINGS, COMMITMENT_COST).to(device)
optimizer = Adam(model.parameters(), lr=LR)

torch.backends.cudnn.benchmark = True

scaler = torch.amp.GradScaler()

print("Starting training with data from DataLoader...")
for epoch in range(EPOCHS):
    
    total_recon_loss = 0.0
    total_vq_loss = 0
    
    for i, data_batch in enumerate(train_loader):
        leaves = data_batch.to(device, non_blocking=True)
        
        x_recon, vq_loss, perplexity = model(leaves)
        recon_error = F.mse_loss(x_recon, leaves)
        loss = recon_error + vq_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_recon_loss += recon_error.item()
        total_vq_loss += vq_loss.item()

    # Log progress at the end of each epoch
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_vq_loss = total_vq_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}] | "
            f"Avg Recon Loss: {avg_recon_loss:.5f} | "
            f"Avg VQ Loss: {avg_vq_loss:.5f} | "
            f"Last Perplexity: {perplexity.item():.4f}") # Perplexity from last batch

print("Training finished.")

In [None]:
print("Visualizing Reconstruction Quality for a Single Example")
model.eval()

# Get a random block from the dataset
original_block = vdb_dataset[76].unsqueeze(0).to(device)

# Perform the full compression/decompression cycle
indices = model.encode(original_block)
reconstructed_block = model.decode(indices)

# Detach from GPU and convert to numpy for plotting
original_np = original_block.squeeze().cpu().numpy()
reconstructed_np = reconstructed_block.squeeze().detach().cpu().numpy()
error_np = np.abs(original_np - reconstructed_np)

# Get consistent color limits for fair comparison
vmin = min(original_np.min(), reconstructed_np.min())
vmax = max(original_np.max(), reconstructed_np.max())

# --- Plot 1: Slice-by-Slice Comparison ---
fig, axes = plt.subplots(3, 3, figsize=(13, 10))
slices_to_show = [1, 4, 7] # Show slices from the Z-axis of the 8x8x8 cube

for i, slice_idx in enumerate(slices_to_show):
    # Original
    im1 = axes[i, 0].imshow(original_np[slice_idx, :, :], vmin=vmin, vmax=vmax, cmap='viridis')
    axes[i, 0].set_title(f'Original (Slice Z={slice_idx})')
    axes[i, 0].axis('off')

    # Reconstructed
    im2 = axes[i, 1].imshow(reconstructed_np[slice_idx, :, :], vmin=vmin, vmax=vmax, cmap='viridis')
    axes[i, 1].set_title(f'Reconstructed (Slice Z={slice_idx})')
    axes[i, 1].axis('off')
    
    # Error Map
    im3 = axes[i, 2].imshow(error_np[slice_idx, :, :], cmap='magma')
    axes[i, 2].set_title('Absolute Error')
    axes[i, 2].axis('off')

fig.colorbar(im1, ax=axes[:,:2], orientation='vertical', fraction=.1)
fig.colorbar(im3, ax=axes[:,2], orientation='vertical', fraction=.1)
plt.suptitle('Qualitative Reconstruction Analysis', fontsize=16)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(original_np.flatten(), bins=50, alpha=0.7, label='Original')
plt.hist(reconstructed_np.flatten(), bins=50, alpha=0.7, label='Reconstructed')
plt.title('Histogram of Voxel Values')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
print("PCA of the learned codebook vectors:")
codebook = model.quantizer.embedding.data.cpu()
pca = PCA(n_components=2)
codebook_2d = pca.fit_transform(codebook)
plt.figure(figsize=(8, 8))
plt.scatter(codebook_2d[:, 0], codebook_2d[:, 1], s=15, alpha=0.7)
plt.title('VQ-VAE Codebook (PCA Projection)')
plt.grid(True)
plt.show()

# --- Plot 2: Codebook Usage Histogram ---
# This is a powerful diagnostic. It requires running the encoder on the whole dataset.
print("\nCalculating codebook usage across the entire dataset...")
model.eval()
all_indices = []
# Create a dataloader without shuffling to iterate through the dataset
full_loader = DataLoader(vdb_dataset, batch_size=BATCH_SIZE, shuffle=False)

with torch.no_grad():
    for data_batch in full_loader:
        data_batch = data_batch.to(device) # Move data to the same device as the model
        indices = model.encode(data_batch)
        all_indices.append(indices.cpu().numpy().flatten())

all_indices = np.concatenate(all_indices)

plt.figure(figsize=(12, 6))
plt.hist(all_indices, bins=NUM_EMBEDDINGS, range=(0, NUM_EMBEDDINGS-1))
plt.title('Codebook Usage Frequency')
plt.xlabel('Codebook Index')
plt.ylabel('Number of Times Used')
plt.show()

num_dead_codes = NUM_EMBEDDINGS - len(np.unique(all_indices))
print(f"Number of 'dead' (unused) codes: {num_dead_codes} out of {NUM_EMBEDDINGS}")