In [None]:
# ==== Hyperparameters (aligned with attached CIFAR notebook) ====
BATCH_SIZE       = 128
GATE_OPTIMIZER   = 'softmax'   # 'softmax' | 'gumbel_softmax' | 'sparsemax'
NETWORK_LR       = 0.01
GUMBEL_TAU       = 1
GROUP_SUM_TAU    = 100
NETWORK_LAYERS   = 8
GATES            = 128_000
EPOCHS           = 2000
NOISE_TEMP = 0.1
# Artifacts
LOG_CSV          = "cifar_run_log_new.csv"           # append-only
FINAL_STATS_JSON = "cifar_final_stats.json"
PLOT_PNG         = "cifar_acc_discrete_vs_eval.png"

In [14]:


# ==== Imports ====
import os, csv, json, time
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

from difflogic import LogicLayer, GroupSum
from difflogic.packbitstensor import PackBitsTensor


In [15]:

# ==== Device ====
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:

# Note: The attached notebook uses CIFAR-10 with Normalize(0.5, 0.5, 0.5), keep the same to avoid deviating. 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])

train_ds = datasets.CIFAR10(root='./data', train=True,  download=True, transform=transform)
test_ds  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
# ==== Model: 3x32x32 -> 3072 input, LogicLayer stack, GroupSum(10) ====
def build_model():
    layers = [
        nn.Flatten(),
        LogicLayer(3072, GATES, device='cuda', implementation='cuda',
                   gate_function=GATE_OPTIMIZER, gumbel_tau=GUMBEL_TAU)
    ]
    for _ in range(NETWORK_LAYERS - 1):
        layers.append(
            LogicLayer(GATES, GATES, device='cuda', implementation='cuda',
                       gate_function=GATE_OPTIMIZER, gumbel_tau=GUMBEL_TAU)
        )
    layers.append(GroupSum(10, tau=GROUP_SUM_TAU))
    return nn.Sequential(*layers).to(device)

model = build_model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=NETWORK_LR)


In [18]:
@torch.no_grad()
def eval_accuracy_float(model, loader, mode='eval'):
    orig = model.training
    model.train(mode == 'train')
    total, correct = 0, 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    model.train(orig)
    return correct / max(1, total)

def packbits_eval(model, loader):
    orig_mode = model.training
    start = time.perf_counter()
    with torch.no_grad():
        model.eval()
        res = np.mean([
            (
                model(
                    PackBitsTensor(
                        x.to('cuda').reshape(x.shape[0], -1).round().bool()
                    )
                ).argmax(-1) == y.to('cuda')
            ).to(torch.float32).mean().item()
            for x, y in loader
        ])
    model.train(mode=orig_mode)
    elapsed = time.perf_counter() - start
    throughput = len(loader.dataset) / elapsed
    print(f"throughput : {throughput:.1f}/s")
    return float(res)


In [None]:
# ==== Append-safe CSV header ====
fieldnames = [
    "epoch", "train_loss", "train_acc",
    "float_eval_acc", "float_trainmode_acc", "discrete_acc",
    "BATCH_SIZE", "GATE_OPTIMIZER", "NETWORK_LR", "GUMBEL_TAU",
    "GROUP_SUM_TAU", "NETWORK_LAYERS", "GATES", "EPOCHS", "CONNECTIONS", "NOISE_TEMP"
]
header_needed = (not os.path.exists(LOG_CSV)) or (os.path.getsize(LOG_CSV) == 0)
with open(LOG_CSV, "a", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    if header_needed:
        writer.writeheader()


In [20]:
def change_gate_optimizer(model, new_gate_optimizer):
    for module in model.modules():
        if isinstance(module, LogicLayer):
            module.gate_function = new_gate_optimizer


In [None]:
GATE_OPTIMIZER = "sp"
change_gate_optimizer(model, GATE_OPTIMIZER)


NameError: name 'new_opt' is not defined

In [None]:
# ==== Training loop with logging and evaluations ====
eval_acc_hist = []
disc_acc_hist = []

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    seen = 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        seen += x.size(0)

    train_loss = running_loss / max(1, seen)
    train_acc  = correct / max(1, seen)

    # Float eval vs train-mode
    float_eval_acc      = eval_accuracy_float(model, test_loader, mode='eval')
    float_trainmode_acc = eval_accuracy_float(model, test_loader, mode='train')

    # Discrete-style evaluation (PackBitsTensor)
    discrete_acc = float_eval_acc

    print(f"Epoch {epoch}: "
          f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
          f"float_eval_acc={float_eval_acc:.4f} float_train_acc={float_trainmode_acc:.4f} "
          f"discrete_acc={discrete_acc:.4f}")

    eval_acc_hist.append(float_eval_acc)
    disc_acc_hist.append(discrete_acc)

    # Append log row
    with open(LOG_CSV, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(dict(
            epoch=epoch, train_loss=train_loss, train_acc=train_acc,
            float_eval_acc=float_eval_acc, float_trainmode_acc=float_trainmode_acc,
            discrete_acc=discrete_acc, 
            BATCH_SIZE=BATCH_SIZE, GATE_OPTIMIZER=GATE_OPTIMIZER, NETWORK_LR=NETWORK_LR,
            GUMBEL_TAU=GUMBEL_TAU, GROUP_SUM_TAU=GROUP_SUM_TAU,
            NETWORK_LAYERS=NETWORK_LAYERS, GATES=GATES, EPOCHS=EPOCHS
        ))


Epoch 1: train_loss=2.2751 train_acc=0.1337 float_eval_acc=0.1573 float_train_acc=0.2316 discrete_acc=0.1573
Epoch 2: train_loss=2.0051 train_acc=0.2858 float_eval_acc=0.1600 float_train_acc=0.3261 discrete_acc=0.1600
Epoch 3: train_loss=1.8439 train_acc=0.3515 float_eval_acc=0.1813 float_train_acc=0.3719 discrete_acc=0.1813
Epoch 4: train_loss=1.7475 train_acc=0.3913 float_eval_acc=0.1852 float_train_acc=0.3931 discrete_acc=0.1852
Epoch 5: train_loss=1.6810 train_acc=0.4174 float_eval_acc=0.1830 float_train_acc=0.4061 discrete_acc=0.1830
Epoch 6: train_loss=1.6326 train_acc=0.4337 float_eval_acc=0.1928 float_train_acc=0.4191 discrete_acc=0.1928
Epoch 7: train_loss=1.5908 train_acc=0.4522 float_eval_acc=0.2034 float_train_acc=0.4364 discrete_acc=0.2034
Epoch 8: train_loss=1.5544 train_acc=0.4654 float_eval_acc=0.2039 float_train_acc=0.4445 discrete_acc=0.2039
Epoch 9: train_loss=1.5224 train_acc=0.4784 float_eval_acc=0.1995 float_train_acc=0.4485 discrete_acc=0.1995
Epoch 10: train_los

In [None]:
# ==== Plot: discrete vs float-eval accuracy ====
plt.figure(figsize=(7, 4))
steps = np.arange(1, len(eval_acc_hist) + 1)
plt.plot(steps, eval_acc_hist, marker='o', linewidth=1.8, markersize=4, label='float eval')
plt.plot(steps, disc_acc_hist, marker='s', linewidth=1.8, markersize=4, label='discrete (PackBits)')
plt.title('CIFAR-10 Accuracy per epoch (discrete vs eval)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True, linestyle='--', alpha=0.4)
plt.legend()
plt.tight_layout()
plt.savefig(PLOT_PNG, dpi=150)
plt.show()


In [None]:


# ==== Data: CIFAR-10 (as in the notebook) ====






# ==== Save final stats ====
final_stats = dict(
    final_epoch=EPOCHS,
    final_train_loss=train_loss,
    final_train_acc=train_acc,
    final_float_eval_acc=eval_acc_hist[-1],
    final_discrete_acc=disc_acc_hist[-1],
    hyperparams=dict(
        BATCH_SIZE=BATCH_SIZE,
        GATE_OPTIMIZER=GATE_OPTIMIZER,
        NETWORK_LR=NETWORK_LR,
        GUMBEL_TAU=GUMBEL_TAU,
        GROUP_SUM_TAU=GROUP_SUM_TAU,
        NETWORK_LAYERS=NETWORK_LAYERS,
        GATES=GATES,
        EPOCHS=EPOCHS,
    ),
    log_csv=LOG_CSV,
    acc_plot=PLOT_PNG,
)
with open(FINAL_STATS_JSON, "w") as f:
    json.dump(final_stats, f, indent=2)
print(f"Saved final stats to {FINAL_STATS_JSON}")
