In [24]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datetime import datetime
import os
import time
import faulthandler

# Assuming the necessary components from cirkit and your project are available
from src.circuit_types import CIRCUIT_BUILDERS
from cirkit.symbolic.circuit import Circuit
from cirkit.pipeline import PipelineContext, compile as compile_circuit
import cirkit.symbolic.functional as SF
from cirkit.backend.torch.queries import IntegrateQuery
from cirkit.utils.scope import Scope
from cirkit.backend.torch.layers import TorchSumLayer
from cirkit.backend.torch.layers import TorchCategoricalLayer
from src.nystromlayer import NystromSumLayer

# --- Helper Functions ---

def compile_symbolic(circuit: Circuit, *, device: str, rank: int | None = None, opt: bool = False):
    """Compile a symbolic circuit with optional Nyström optimization."""
    ctx = PipelineContext(
        backend="torch",
        semiring="complex-lse-sum",
        fold=False,
        optimize=opt,
        nystrom_rank=rank,
    )
    compiled = compile_circuit(circuit, ctx, nystrom_rank=rank).to(device).eval()
    print(f"Compiled circuit with rank {rank} on device {device}", flush=True)
    return compiled


def sync_sumlayer_weights(
    original: nn.Module, nystrom: nn.Module, *, pivot: str = "uniform", rank: int | None = None
) -> None:
    """Copy weights from ``original`` to ``nystrom`` for matching layers."""
    # Sync for TorchSumLayer and NystromSumLayer
    orig_sum_layers = [m for m in original.modules() if isinstance(m, TorchSumLayer)]
    nys_sum_layers = [m for m in nystrom.modules() if isinstance(m, NystromSumLayer)]
    if len(orig_sum_layers) != len(nys_sum_layers):
        print(f"{len(orig_sum_layers)},{len(nys_sum_layers)}")
        raise ValueError("Sum layer count mismatch when syncing weights")

    faulthandler.enable(all_threads=True)
    total_sum = len(orig_sum_layers)
    interval_sum = max(1, total_sum // 4)

    for i, (o, n) in enumerate(zip(orig_sum_layers, nys_sum_layers), start=1):
        start = time.perf_counter()
        faulthandler.dump_traceback_later(30, repeat=False)
        try:
            if rank is not None:
                n.rank = int(rank)
                n.rank_param.data.fill_(n.rank)
            n.pivot = pivot
            n._build_factors_from(o)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        except Exception as e:
            print(f"[{datetime.now()}] Exception on sum layer {i}/{total_sum}: {e}", flush=True)
            raise
        finally:
            faulthandler.cancel_dump_traceback_later()

        if i % interval_sum == 0 or i == total_sum:
            pct = int(100 * i / total_sum)
            print(f"[{datetime.now()}] Sum Layer Weight Sync Progress: {pct}% ({i}/{total_sum})", flush=True)

    # --- Step 2: Sync Categorical Layers (NEW & IMPROVED LOGIC) ---
    print("\n--- Syncing Categorical Layer (logits/probs) by Parameter Name ---", flush=True)

    # Create a dictionary of the original model's parameters for fast lookup.
    # This is the 'a' from your working snippet.
    original_params = dict(original.named_parameters())

    copied_params_count = 0
    with torch.no_grad():
        # Iterate through the Nystrom model's named parameters.
        for name, nys_param in nystrom.named_parameters():
            # Check if the parameter is a logit or probability tensor by its name.
            if 'logits' in name or 'probs' in name:
                if name in original_params:
                    # Use the proven method: copy the data directly.
                    # .copy_() is a safe, in-place operation.
                    orig_param = original_params[name]
                    nys_param.data.copy_(orig_param.data)
                    copied_params_count += 1
                else:
                    # This warning is helpful for debugging architecture mismatches.
                    print(f"Warning: Parameter '{name}' found in Nystrom model but not in original.")
    
    print(f"--- Finished syncing. Copied {copied_params_count} categorical parameters. ---\n",)
def complex_logsumexp(z, dim):
    """Numerically stable complex log-sum-exp, corrected for broadcasting."""
    # Keep dimension for stable subtraction
    m = z.real.max(dim=dim, keepdim=True).values
    
    # Calculate the log-sum-exp part
    log_sum_exp_part = (z - m).exp().sum(dim=dim).log()
    
    # Squeeze `m` to match the shape of `log_sum_exp_part` for correct addition
    return log_sum_exp_part + m.squeeze(dim)

# --- Setup and Data Loading ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
physical_batch_size = 64
n_input = 28 * 28

# Build the symbolic circuit
builder = CIRCUIT_BUILDERS["MNIST"]
builder_kwargs = {"num_input_units": 16, "num_sum_units": 16, "region_graph": "quad-tree-4"}
symbolic = builder(**builder_kwargs)
squared = SF.multiply(symbolic, symbolic)

# Prepare the MNIST test data
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (255 * x.view(-1)).long())])
data_test = datasets.MNIST(root="./.data", train=False, download=True, transform=transform)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=physical_batch_size)
batch, _ = next(iter(test_dataloader))
batch = batch.to(device)

# --- 1. Original Model NLL Calculation ---

print(f"\n[{datetime.now()}] --- Benchmarking ORIGINAL model ---", flush=True)
original_circuit = compile_symbolic(squared, device=device, rank=None, opt=False)

# Load the pretrained model from cache
units = 16
cache_path = f"./model_cache/checkpoints/mnist_{units}_{units}_epoch10.pt"
if os.path.exists(cache_path):
    original_circuit.load_state_dict(torch.load(cache_path, map_location=device)["model_state_dict"])
else:
    raise FileNotFoundError(f"Checkpoint not found at {cache_path}")

# Calculate Normalizer Z
iq_orig = IntegrateQuery(original_circuit)
Z_bok_orig = iq_orig(batch, integrate_vars=Scope(original_circuit.scope))


# Get circuit output and compute NLL
circuit_output_real = original_circuit(batch).real
nll_orig = -(circuit_output_real - Z_bok_orig[0][0].real)

print("\n--- NLL Calculation for Original Model ---")
print(f"Batch shape: {batch.shape}")
print(f"Normalizer Z_bok {Z_bok_orig[0][0]}")
print(f"Circuit Output shape: {circuit_output_real.shape}")
print(f"NLL shape: {nll_orig.shape}")
print(f"Average NLL for the Original batch: {nll_orig.mean()}")

# --- 2. Nyström Approximated Model NLL Calculation ---

print(f"\n[{datetime.now()}] --- Benchmarking NYSTRÖM model ---", flush=True)
nystrom_circuit = compile_symbolic(squared, device=device, rank=256, opt=True)


# Synchronize weights from the trained original model
sync_sumlayer_weights(original_circuit, nystrom_circuit, pivot="uniform", rank=77)

# Calculate Normalizer Z for the Nystrom model
iq_nys = IntegrateQuery(nystrom_circuit)
Z_bok_nys = iq_nys(batch, integrate_vars=Scope(nystrom_circuit.scope))





# Get circuit output and compute NLL, APPLYING THE SAME FIX
nystrom_output_real = nystrom_circuit(batch).real
nll_nys = -(nystrom_output_real - Z_bok_nys[0][0].real)

print("\n--- NLL Calculation for Nyström Approximated Model ---")
print(f"Normalizer Z_bok_nys: {Z_bok_nys[0][0]}")
print(f"Nyström Output shape: {nystrom_output_real.shape}")
print(f"NLL Nyström shape: {nll_nys.shape}")
print(f"Average NLL for the Nyström batch: {nll_nys.mean()}")
orig_bpd = (nll_orig.mean() / (784 )).item()
nystrom_bpd = (nll_nys.mean() / (784)).item()
bpd_diff = abs(orig_bpd - nystrom_bpd)
print(orig_bpd,nystrom_bpd,bpd_diff)


[2025-08-12 14:52:08.531394] --- Benchmarking ORIGINAL model ---
Compiled circuit with rank None on device cuda

--- NLL Calculation for Original Model ---
Batch shape: torch.Size([64, 784])
Normalizer Z_bok tensor([-78.2435+0.j], device='cuda:0', grad_fn=<SelectBackward0>)
Circuit Output shape: torch.Size([64, 1, 1])
NLL shape: torch.Size([64, 1, 1])
Average NLL for the Original batch: 1270.25927734375

[2025-08-12 14:52:14.713304] --- Benchmarking NYSTRÖM model ---
Compiled circuit with rank 256 on device cuda
[2025-08-12 14:52:35.460575] Sum Layer Weight Sync Progress: 24% (262/1049)
[2025-08-12 14:52:39.892844] Sum Layer Weight Sync Progress: 49% (524/1049)
[2025-08-12 14:52:44.311997] Sum Layer Weight Sync Progress: 74% (786/1049)
[2025-08-12 14:52:48.841489] Sum Layer Weight Sync Progress: 99% (1048/1049)
[2025-08-12 14:52:48.844181] Sum Layer Weight Sync Progress: 100% (1049/1049)

--- Syncing Categorical Layer (logits/probs) by Parameter Name ---
--- Finished syncing. Copied 7

In [19]:
def verify_full_model_output(
    original_model: nn.Module,
    nystrom_model: nn.Module,
    verification_batch: torch.Tensor,
    *,
    use_allclose: bool = True,
    rtol: float = 1e-5,
    atol: float = 1e-6
) -> bool:
    """
    Verifies that two models produce the same output for a given input batch.

    This is the most reliable way to check if weight synchronization was successful.

    Args:
        original_model: The original, trained model.
        nystrom_model: The Nystrom-approximated model.
        verification_batch: A batch of real data for testing.
        use_allclose: If True, use torch.allclose for float comparison.
        rtol: Relative tolerance for torch.allclose.
        atol: Absolute tolerance for torch.allclose.

    Returns:
        True if the outputs are identical within the given tolerance.
    """
    print("\n--- Verifying Full Model Output Post-Sync ---")
    original_model.eval()
    nystrom_model.eval()

    with torch.no_grad():
        original_output = original_model(verification_batch)
        nystrom_output = nystrom_model(verification_batch)

    # The output is complex, so torch.allclose is the correct tool
    if use_allclose:
        are_outputs_equal = torch.allclose(original_output, nystrom_output, rtol=rtol, atol=atol)
    else:
        are_outputs_equal = torch.equal(original_output, nystrom_output)

    if not are_outputs_equal:
        print("Verification FAILED: Full model outputs do not match.")
        # Provide debug info if they don't match
        real_diff = torch.abs(original_output.real - nystrom_output.real).sum()
        imag_diff = torch.abs(original_output.imag - nystrom_output.imag).sum()
        print(f"Sum of absolute difference (real part): {real_diff.item()}")
        print(f"Sum of absolute difference (imag part): {imag_diff.item()}")
    else:
        print("Verification PASSED: Full model outputs are identical (within tolerance).")

    return are_outputs_equal

In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datetime import datetime
import os
import time
import faulthandler
from typing import Type

# Assuming the necessary components from cirkit and your project are available
from src.circuit_types import CIRCUIT_BUILDERS
from cirkit.symbolic.circuit import Circuit
from cirkit.pipeline import PipelineContext, compile as compile_circuit
import cirkit.symbolic.functional as SF
from cirkit.backend.torch.queries import IntegrateQuery
from cirkit.utils.scope import Scope
from cirkit.backend.torch.layers import TorchSumLayer
from cirkit.backend.torch.layers.input import TorchCategoricalLayer # Corrected import path
from src.nystromlayer import NystromSumLayer

# --- Helper Functions (Your existing functions) ---

def compile_symbolic(circuit: Circuit, *, device: str, rank: int | None = None, opt: bool = False):
    """Compile a symbolic circuit with optional Nyström optimization."""
    ctx = PipelineContext(
        backend="torch",
        semiring="complex-lse-sum",
        fold=False,
        optimize=opt,
        nystrom_rank=rank,
    )
    compiled = compile_circuit(circuit, ctx, nystrom_rank=rank).to(device).eval()
    print(f"Compiled circuit with rank {rank} on device {device}", flush=True)
    return compiled


def sync_sumlayer_weights(
    original: nn.Module, nystrom: nn.Module, *, pivot: str = "uniform", rank: int | None = None
) -> None:
    """Copy weights from ``original`` to ``nystrom`` for matching layers."""
    # Sync for TorchSumLayer and NystromSumLayer
    orig_sum_layers = [m for m in original.modules() if isinstance(m, TorchSumLayer)]
    nys_sum_layers = [m for m in nystrom.modules() if isinstance(m, NystromSumLayer)]
    if len(orig_sum_layers) != len(nys_sum_layers):
        print(f"{len(orig_sum_layers)},{len(nys_sum_layers)}")
        raise ValueError("Sum layer count mismatch when syncing weights")

    faulthandler.enable(all_threads=True)
    total_sum = len(orig_sum_layers)
    interval_sum = max(1, total_sum // 4)

    for i, (o, n) in enumerate(zip(orig_sum_layers, nys_sum_layers), start=1):
        start = time.perf_counter()
        faulthandler.dump_traceback_later(30, repeat=False)
        try:
            if rank is not None:
                n.rank = int(rank)
                n.rank_param.data.fill_(n.rank)
            n.pivot = pivot
            n._build_factors_from(o)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        except Exception as e:
            print(f"[{datetime.now()}] Exception on sum layer {i}/{total_sum}: {e}", flush=True)
            raise
        finally:
            faulthandler.cancel_dump_traceback_later()

        if i % interval_sum == 0 or i == total_sum:
            pct = int(100 * i / total_sum)
            print(f"[{datetime.now()}] Sum Layer Weight Sync Progress: {pct}% ({i}/{total_sum})", flush=True)

    # Sync for TorchCategoricalLayer
    orig_cat_layers = [m for m in original.modules() if isinstance(m, TorchCategoricalLayer)]
    nys_cat_layers = [m for m in nystrom.modules() if isinstance(m, TorchCategoricalLayer)]

    if len(orig_cat_layers) != len(nys_cat_layers):
        raise ValueError("Categorical layer count mismatch when syncing weights")

    total_cat = len(orig_cat_layers)
    interval_cat = max(1, total_cat // 4)

    for i, (o, n) in enumerate(zip(orig_cat_layers, nys_cat_layers), start=1):
        # In TorchCategoricalLayer, exactly one of logits or probs is not None.
        # The TorchParameter object must be called to return the underlying tensor.
        if o.logits is not None and n.logits is not None:
            n.logits().data.copy_(o.logits().data)
        elif o.probs is not None and n.probs is not None:
            n.probs().data.copy_(o.probs().data)

        if i % interval_cat == 0 or i == total_cat:
            pct = int(100 * i / total_cat)
            print(f"[{datetime.now()}] Categorical Layer Sync Progress: {pct}% ({i}/{total_cat})", flush=True)

# --- NEW: VERIFICATION FUNCTION ---
# (Paste the verify_first_layer_output function definition from above here)


# --- Setup and Data Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
physical_batch_size = 64
n_input = 28 * 28
builder_kwargs = {"num_input_units": 4, "num_sum_units": 4, "region_graph": "quad-tree-4"}

# --- 1. Original Model Setup ---
print(f"\n[{datetime.now()}] --- Step 1: Setting up ORIGINAL model ---", flush=True)
builder = CIRCUIT_BUILDERS["MNIST"]
symbolic = builder(**builder_kwargs)
squared = SF.multiply(symbolic, symbolic)
original_circuit = compile_symbolic(squared, device=device, rank=None, opt=False)

# Load pretrained weights
units = builder_kwargs["num_input_units"]
cache_path = f"./model_cache/checkpoints/mnist_{units}_{units}_epoch10.pt"
if os.path.exists(cache_path):
    original_circuit.load_state_dict(torch.load(cache_path, map_location=device)["model_state_dict"])
else:
    raise FileNotFoundError(f"Checkpoint not found at {cache_path}. Cannot proceed.")

# --- 2. Nyström Model Setup and Weight Sync ---
print(f"\n[{datetime.now()}] --- Step 2: Setting up NYSTRÖM model and syncing weights ---", flush=True)
nystrom_circuit = compile_symbolic(squared, device=device, rank=256, opt=True)
sync_sumlayer_weights(original_circuit, nystrom_circuit, pivot="uniform", rank=256)


# --- 3. Verification of Layer Outputs (NEW STEP) ---
print(f"\n[{datetime.now()}] --- Step 3: Verifying layer outputs post-sync ---", flush=True)




[2025-08-12 14:20:09.670681] --- Step 1: Setting up ORIGINAL model ---
Compiled circuit with rank None on device cuda

[2025-08-12 14:20:14.016409] --- Step 2: Setting up NYSTRÖM model and syncing weights ---
Compiled circuit with rank 256 on device cuda
[2025-08-12 14:20:17.542480] Sum Layer Weight Sync Progress: 24% (262/1049)
[2025-08-12 14:20:18.194550] Sum Layer Weight Sync Progress: 49% (524/1049)
[2025-08-12 14:20:18.843847] Sum Layer Weight Sync Progress: 74% (786/1049)
[2025-08-12 14:20:19.496502] Sum Layer Weight Sync Progress: 99% (1048/1049)
[2025-08-12 14:20:19.498726] Sum Layer Weight Sync Progress: 100% (1049/1049)
[2025-08-12 14:20:19.589956] Categorical Layer Sync Progress: 25% (196/784)
[2025-08-12 14:20:19.651380] Categorical Layer Sync Progress: 50% (392/784)
[2025-08-12 14:20:19.712486] Categorical Layer Sync Progress: 75% (588/784)
[2025-08-12 14:20:19.773617] Categorical Layer Sync Progress: 100% (784/784)

[2025-08-12 14:20:19.773983] --- Step 3: Verifying laye

In [20]:
# --- 3. Load Real Data and Verify Model Outputs ---
print(f"\n[{datetime.now()}] --- Step 3: Loading data for verification ---", flush=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (255 * x.view(-1)).long())])
data_test = datasets.MNIST(root="./.data", train=False, download=True, transform=transform)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=physical_batch_size)
verification_batch, _ = next(iter(test_dataloader))
verification_batch = verification_batch.to(device)

# Use the loaded batch to verify the full models are equivalent
verify_full_model_output(original_circuit, nystrom_circuit, verification_batch)

# --- 4. NLL Calculation and Comparison ---
print(f"\n[{datetime.now()}] --- Step 4: Calculating NLL using the same batch ---", flush=True)
# NLL for Original Model
iq_orig = IntegrateQuery(original_circuit)
Z_bok_orig = iq_orig(verification_batch, integrate_vars=Scope(original_circuit.scope))
circuit_output_real = original_circuit(verification_batch).real
nll_orig = -(circuit_output_real - Z_bok_orig[0][0].real)

# NLL for Nystrom Model
iq_nys = IntegrateQuery(nystrom_circuit)
Z_bok_nys = iq_nys(verification_batch, integrate_vars=Scope(nystrom_circuit.scope))
nystrom_output_real = nystrom_circuit(verification_batch).real
nll_nys = -(nystrom_output_real - Z_bok_nys[0][0].real)

# --- Final Results ---
print("\n--- Final NLL and BPD Results ---")
print(f"Average NLL for the Original batch: {nll_orig.mean():.4f}")
print(f"Average NLL for the Nyström batch: {nll_nys.mean():.4f}")

orig_bpd = (nll_orig.mean() / (n_input * 2)).item() # Squared model has 2x input vars
nystrom_bpd = (nll_nys.mean() / (n_input * 2)).item()
bpd_diff = abs(orig_bpd - nystrom_bpd)

print(f"Original BPD: {orig_bpd:.4f}")
print(f"Nystrom BPD:  {nystrom_bpd:.4f}")
print(f"BPD Difference: {bpd_diff:.6f}")


[2025-08-12 14:25:39.262236] --- Step 3: Loading data for verification ---

--- Verifying Full Model Output Post-Sync ---
Verification FAILED: Full model outputs do not match.
Sum of absolute difference (real part): 501823.9375
Sum of absolute difference (imag part): 3.6545700493606503e-11

[2025-08-12 14:25:40.570812] --- Step 4: Calculating NLL using the same batch ---

--- Final NLL and BPD Results ---
Average NLL for the Original batch: 1347.3545
Average NLL for the Nyström batch: 5656.5347
Original BPD: 0.8593
Nystrom BPD:  3.6075
BPD Difference: 2.748202


In [31]:
img, label = next(iter(DataLoader(data_test, batch_size=1)))
img.shape

torch.Size([1, 784])