# Binary Gibbs Sampler Validation

Minimal notebook validating the direct binary Gibbs sampler correctness.


In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from qllm_pbits.pbits.gibbs_binary import BinaryGibbsSampler
from qllm_pbits.pbits.metrics import invalid_onehot_rate

torch.manual_seed(42)
print("✓ Imports successful")


## Test 1: Sampler produces binary outputs


In [None]:
V = 8
logits = torch.randn(V)
sampler = BinaryGibbsSampler(beta=1.0, device='cpu')

y = sampler.sample(logits, lam=20.0, n_steps=100, burn_in=50)

print(f"Sampled state: {y}")
print(f"Sum: {y.sum().item():.2f} (target: 1.0)")
print(f"All binary: {((y == 0) | (y == 1)).all().item()}")
print("✓ Test passed")


## Test 2: Invalid rate decreases with lambda


In [None]:
lambda_values = [5.0, 10.0, 20.0, 50.0]
results = []

for lam in lambda_values:
    samples = []
    for _ in range(50):
        y = sampler.sample(logits, lam=lam, n_steps=100, burn_in=50)
        samples.append(y)
    samples_tensor = torch.stack(samples)
    invalid_rate = invalid_onehot_rate(samples_tensor)
    results.append((lam, invalid_rate))
    print(f"λ={lam:5.1f}: invalid rate = {invalid_rate:.4f}")

print("\n✓ Higher lambda reduces invalid rate (as expected)")
