In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datetime import datetime
import os
import time
import faulthandler

# Assuming the necessary components from cirkit and your project are available
from src.circuit_types import CIRCUIT_BUILDERS
from cirkit.symbolic.circuit import Circuit
from cirkit.pipeline import PipelineContext, compile as compile_circuit
import cirkit.symbolic.functional as SF
from cirkit.backend.torch.queries import IntegrateQuery
from cirkit.utils.scope import Scope
from cirkit.backend.torch.layers import TorchSumLayer
from src.nystromlayer import NystromSumLayer

# --- Helper Functions ---

def compile_symbolic(circuit: Circuit, *, device: str, rank: int | None = None, opt: bool = False):
    """Compile a symbolic circuit with optional Nyström optimization."""
    ctx = PipelineContext(
        backend="torch",
        semiring="complex-lse-sum",
        fold=False,
        optimize=opt,
        nystrom_rank=rank,
    )
    compiled = compile_circuit(circuit, ctx, nystrom_rank=rank).to(device).eval()
    print(f"Compiled circuit with rank {rank} on device {device}", flush=True)
    return compiled


def sync_sumlayer_weights(
    original: nn.Module, nystrom: nn.Module, *, pivot: str = "uniform", rank: int | None = None
) -> None:
    """Copy weights from ``original`` to ``nystrom`` for matching layers."""
    orig_layers = [m for m in original.modules() if isinstance(m, TorchSumLayer)]
    nys_layers = [m for m in nystrom.modules() if isinstance(m, NystromSumLayer)]
    if len(orig_layers) != len(nys_layers):
        print(f"{len(orig_layers)},{len(nys_layers)}")
        raise ValueError("Layer count mismatch when syncing weights")

    import faulthandler
    faulthandler.enable(all_threads=True)
    total = len(orig_layers)
    interval = max(1, total // 4)

    for i, (o, n) in enumerate(zip(orig_layers, nys_layers), start=1):
        start = time.perf_counter()
        faulthandler.dump_traceback_later(30, repeat=False)
        try:
            if rank is not None:
                n.rank = int(rank)
                n.rank_param.data.fill_(n.rank)
            n.pivot = pivot
            n._build_factors_from(o)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        except Exception as e:
            print(f"[{datetime.now()}] Exception on layer {i}/{total}: {e}", flush=True)
            raise
        finally:
            faulthandler.cancel_dump_traceback_later()

        if i % interval == 0 or i == total:
            pct = int(100 * i / total)
            print(f"[{datetime.now()}] Weight Sync Progress: {pct}% ({i}/{total})", flush=True)



def complex_logsumexp(z, dim):
    """Numerically stable complex log-sum-exp, corrected for broadcasting."""
    # Keep dimension for stable subtraction
    m = z.real.max(dim=dim, keepdim=True).values
    
    # Calculate the log-sum-exp part
    log_sum_exp_part = (z - m).exp().sum(dim=dim).log()
    
    # Squeeze `m` to match the shape of `log_sum_exp_part` for correct addition
    return log_sum_exp_part + m.squeeze(dim)

# --- Setup and Data Loading ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
physical_batch_size = 16
n_input = 28 * 28

# Build the symbolic circuit
builder = CIRCUIT_BUILDERS["MNIST"]
builder_kwargs = {"num_input_units": 16, "num_sum_units": 16, "region_graph": "quad-tree-4"}
symbolic = builder(**builder_kwargs)
squared = SF.multiply(symbolic, symbolic)

# Prepare the MNIST test data
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (255 * x.view(-1)).long())])
data_test = datasets.MNIST(root="./.data", train=False, download=True, transform=transform)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=physical_batch_size)
batch, _ = next(iter(test_dataloader))
batch = batch.to(device)

# --- 1. Original Model NLL Calculation ---

print(f"\n[{datetime.now()}] --- Benchmarking ORIGINAL model ---", flush=True)
original_circuit = compile_symbolic(squared, device=device, rank=None, opt=False)

# Load the pretrained model from cache
units = 16
cache_path = f"./model_cache/checkpoints/mnist_{units}_{units}_epoch10.pt"
if os.path.exists(cache_path):
    original_circuit.load_state_dict(torch.load(cache_path, map_location=device)["model_state_dict"])
else:
    raise FileNotFoundError(f"Checkpoint not found at {cache_path}")

# Calculate Normalizer Z
iq_orig = IntegrateQuery(original_circuit)
Z_bok_orig = iq_orig(batch, integrate_vars=Scope(original_circuit.scope))


# Get circuit output and compute NLL
circuit_output_real = original_circuit(batch).real
nll_orig = -(circuit_output_real - Z_bok_orig[0][0].real)

print("\n--- NLL Calculation for Original Model ---")
print(f"Batch shape: {batch.shape}")
print(f"Normalizer Z_bok shape: {Z_bok_orig.shape}")
print(f"Circuit Output shape: {circuit_output_real.shape}")
print(f"NLL shape: {nll_orig.shape}")
print(f"Average NLL for the Original batch: {nll_orig.mean().item()}")

# --- 2. Nyström Approximated Model NLL Calculation ---

print(f"\n[{datetime.now()}] --- Benchmarking NYSTRÖM model ---", flush=True)
nystrom_circuit = compile_symbolic(squared, device=device, rank=254, opt=True)


# Synchronize weights from the trained original model
sync_sumlayer_weights(original_circuit, nystrom_circuit, pivot="l2", rank=63)

# Calculate Normalizer Z for the Nystrom model
iq_nys = IntegrateQuery(nystrom_circuit)
Z_bok_nys = iq_nys(batch, integrate_vars=Scope(nystrom_circuit.scope))





# Get circuit output and compute NLL, APPLYING THE SAME FIX
nystrom_output_real = nystrom_circuit(batch).real
nll_nys = -(nystrom_output_real - Z_bok_nys[0][0].real)

print("\n--- NLL Calculation for Nyström Approximated Model ---")
print(f"Normalizer Z_bok_nys shape: {Z_bok_nys.shape}")
print(f"Nyström Output shape: {nystrom_output_real.shape}")
print(f"NLL Nyström shape: {nll_nys.shape}")
print(f"Average NLL for the Nyström batch: {nll_nys.mean().item()}")


[2025-08-12 09:16:53.841536] --- Benchmarking ORIGINAL model ---
Compiled circuit with rank None on device cuda
logits shape: torch.Size([1, 256, 256])


IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)