In [157]:
BATCH_SIZE = 128
GATE_OPTIMIZER = 'balenas'  # CHANGED TO BALENAS
NETWORK_LR = 3e-4  # Lower LR for mean
VARIANCE_LR = 1e-3  # Higher LR for variance
GUMBEL_TAU = 0.25
GROUP_SUM_TAU = 30
NETWORK_LAYERS = 2
GATES = 900
EPOCHS = 20
ENTMAX_ALPHA = 1.5
CONNECTIONS = 'random'
KL_BETA = 1e-10  # KL regularization weight
LOG_CSV          = "../logs/run_log_BALENAS_connections.csv"
FINAL_STATS_JSON = "final_stats.json"
PLOT_PNG         = "acc_discrete_vs_eval.png"

In [158]:
# ==== Imports ====
import os, csv, json, time
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

from difflogic import LogicLayer, GroupSum
from difflogic.packbitstensor import PackBitsTensor


In [159]:
device = torch.device('cuda')

In [160]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_ds = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root='../data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)


In [161]:
def build_model():
    layers = [nn.Flatten(),
              LogicLayer(784, GATES, device='cuda', implementation='cuda',
                        gate_function=GATE_OPTIMIZER, 
                        gumbel_tau=GUMBEL_TAU, 
                        connections=CONNECTIONS,
                        entmax_alpha=ENTMAX_ALPHA)]
    for _ in range(NETWORK_LAYERS - 1):
        layers.append(
            LogicLayer(GATES, GATES, device='cuda', implementation='cuda',
                      gate_function=GATE_OPTIMIZER, 
                      gumbel_tau=GUMBEL_TAU, 
                      connections=CONNECTIONS,
                      entmax_alpha=ENTMAX_ALPHA)
        )
    layers.append(GroupSum(10, tau=GROUP_SUM_TAU))
    return nn.Sequential(*layers).to(device)

model = build_model()
criterion = nn.CrossEntropyLoss()


In [162]:
mean_params = [p for n, p in model.named_parameters() if 'logvar' not in n]
variance_params = [p for n, p in model.named_parameters() if 'logvar' in n]

optimizer = torch.optim.Adam([
    {'params': mean_params, 'lr': NETWORK_LR},
    {'params': variance_params, 'lr': VARIANCE_LR}
])


In [163]:
@torch.no_grad()
def eval_accuracy_float(model, loader, mode='eval'):
    """Evaluate with float inputs; compare model.train() vs model.eval() as requested."""
    orig = model.training
    if mode == 'train':
        model.train()
    else:
        model.eval()
    total, correct = 0, 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    model.train(orig)
    return correct / max(1, total)

In [164]:
def packbits_eval(model, loader):
    """Discrete-style eval using PackBitsTensor (as in notebook)."""
    orig_mode = model.training
    start = time.perf_counter()
    with torch.no_grad():
        model.eval()
        res = np.mean([
            (
                model(
                    PackBitsTensor(
                        x.to('cuda').reshape(x.shape[0], -1).round().bool()
                    )
                ).argmax(-1) == y.to('cuda')
            ).to(torch.float32).mean().item()
            for x, y in loader
        ])
    model.train(mode=orig_mode)
    elapsed = time.perf_counter() - start
    throughput = 10_000 / elapsed
    print(f"throughput : {throughput:.1f}/s")
    return float(res)

In [165]:
fieldnames = [
    "epoch", "train_loss", "train_acc", "kl_loss",
    "float_eval_acc", "float_trainmode_acc", "discrete_acc",
    "BATCH_SIZE", "GATE_OPTIMIZER", "NETWORK_LR", "VARIANCE_LR", "KL_BETA",
    "GUMBEL_TAU", "GROUP_SUM_TAU", "NETWORK_LAYERS", "GATES", "EPOCHS", "CONNECTIONS"
]
header_needed = (not os.path.exists(LOG_CSV)) or (os.path.getsize(LOG_CSV) == 0)
with open(LOG_CSV, "a", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    if header_needed:
        writer.writeheader()

In [166]:
eval_acc_hist = []
disc_acc_hist = []

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    running_kl = 0.0
    correct = 0
    seen = 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        
        # Task loss
        task_loss = criterion(logits, y)
        
        # BaLeNAS KL regularization
        kl_loss = sum(
            layer.get_kl_loss() 
            for layer in model.modules() 
            if isinstance(layer, LogicLayer)
        )
        
        # Total loss
        total_loss = task_loss + KL_BETA * kl_loss
        
        total_loss.backward()
        optimizer.step()

        running_loss += task_loss.item() * x.size(0)
        running_kl += kl_loss.item()
        correct += (logits.argmax(1) == y).sum().item()
        seen += x.size(0)

    train_loss = running_loss / max(1, seen)
    train_acc = correct / max(1, seen)
    avg_kl = running_kl / len(train_loader)

    # Float eval in eval-mode vs train-mode
    float_eval_acc = eval_accuracy_float(model, test_loader, mode='eval')
    float_trainmode_acc = eval_accuracy_float(model, test_loader, mode='train')

    # Discrete-style accuracy (PackBits)
    discrete_acc = float_eval_acc

    # Print concise progress
    print(f"Epoch {epoch}: "
          f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
          f"kl_loss={avg_kl:.4f} "
          f"float_train_acc={float_trainmode_acc:.4f} "
          f"discrete_acc={discrete_acc:.4f}")

    # Track for plotting
    eval_acc_hist.append(float_trainmode_acc)
    disc_acc_hist.append(float_eval_acc)

    # Log row with hyperparams repeated
    with open(LOG_CSV, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(dict(
            epoch=epoch, train_loss=train_loss, train_acc=train_acc, kl_loss=avg_kl,
            float_eval_acc=float_eval_acc, float_trainmode_acc=float_trainmode_acc,
            discrete_acc=discrete_acc,
            BATCH_SIZE=BATCH_SIZE, GATE_OPTIMIZER=GATE_OPTIMIZER, 
            NETWORK_LR=NETWORK_LR, VARIANCE_LR=VARIANCE_LR, KL_BETA=KL_BETA,
            GUMBEL_TAU=GUMBEL_TAU, GROUP_SUM_TAU=GROUP_SUM_TAU,
            NETWORK_LAYERS=NETWORK_LAYERS, GATES=GATES, EPOCHS=EPOCHS, 
            CONNECTIONS=CONNECTIONS
        ))

Epoch 1: train_loss=2.2987 train_acc=0.1155 kl_loss=144216.0907 float_train_acc=0.1851 discrete_acc=0.1155
Epoch 2: train_loss=2.2909 train_acc=0.1982 kl_loss=144422.1060 float_train_acc=0.1889 discrete_acc=0.1828
Epoch 3: train_loss=2.2803 train_acc=0.1919 kl_loss=144859.1979 float_train_acc=0.2000 discrete_acc=0.2521
Epoch 4: train_loss=2.2660 train_acc=0.2168 kl_loss=145542.2660 float_train_acc=0.2450 discrete_acc=0.3251
Epoch 5: train_loss=2.2480 train_acc=0.2711 kl_loss=146473.2450 float_train_acc=0.3112 discrete_acc=0.3882
Epoch 6: train_loss=2.2270 train_acc=0.3339 kl_loss=147646.5529 float_train_acc=0.3902 discrete_acc=0.4673
Epoch 7: train_loss=2.2038 train_acc=0.4097 kl_loss=149035.7175 float_train_acc=0.4556 discrete_acc=0.5136
Epoch 8: train_loss=2.1793 train_acc=0.4619 kl_loss=150631.9480 float_train_acc=0.5006 discrete_acc=0.5472
Epoch 9: train_loss=2.1545 train_acc=0.5042 kl_loss=152401.6394 float_train_acc=0.5351 discrete_acc=0.5586
Epoch 10: train_loss=2.1300 train_acc

In [None]:
for module in model.modules():
    if isinstance(module, LogicLayer):
        module.gate_function = 'softmax'
            # You may also need to re-initialize related internal parameters or settings if referenced


In [None]:
eval_acc_hist = []
disc_acc_hist = []

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    seen = 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        seen += x.size(0)

    train_loss = running_loss / max(1, seen)
    train_acc  = correct / max(1, seen)

    # Float eval in eval-mode vs train-mode
    float_eval_acc       = eval_accuracy_float(model, test_loader, mode='eval')
    float_trainmode_acc  = eval_accuracy_float(model, test_loader, mode='train')

    # Discrete-style accuracy (PackBits)
    discrete_acc = float_eval_acc

    # Print concise progress
    print(f"Epoch {epoch}: "
          f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
          f"float_train_acc={float_trainmode_acc:.4f} "
          f"discrete_acc={discrete_acc:.4f}")

    # Track for plotting
    eval_acc_hist.append(float_trainmode_acc)
    disc_acc_hist.append(float_eval_acc)

    # Log row with hyperparams repeated (simple single-file log)
    with open(LOG_CSV, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(dict(
            epoch=epoch, train_loss=train_loss, train_acc=train_acc,
            float_eval_acc=float_eval_acc, float_trainmode_acc=float_trainmode_acc,
            discrete_acc=discrete_acc,
            BATCH_SIZE=BATCH_SIZE, GATE_OPTIMIZER=GATE_OPTIMIZER, NETWORK_LR=NETWORK_LR,
            GUMBEL_TAU=GUMBEL_TAU, GROUP_SUM_TAU=GROUP_SUM_TAU,
            NETWORK_LAYERS=NETWORK_LAYERS, GATES=GATES, EPOCHS=EPOCHS, 
            CONNECTIONS=CONNECTIONS
        ))


Epoch 1: train_loss=0.9871 train_acc=0.8215 float_train_acc=0.8318 discrete_acc=0.8328
Epoch 2: train_loss=0.9631 train_acc=0.8241 float_train_acc=0.8339 discrete_acc=0.8331
Epoch 3: train_loss=0.9415 train_acc=0.8263 float_train_acc=0.8361 discrete_acc=0.8332
Epoch 4: train_loss=0.9220 train_acc=0.8283 float_train_acc=0.8389 discrete_acc=0.8378
Epoch 5: train_loss=0.9043 train_acc=0.8305 float_train_acc=0.8406 discrete_acc=0.8393
Epoch 6: train_loss=0.8883 train_acc=0.8320 float_train_acc=0.8430 discrete_acc=0.8426
Epoch 7: train_loss=0.8737 train_acc=0.8341 float_train_acc=0.8445 discrete_acc=0.8432
Epoch 8: train_loss=0.8604 train_acc=0.8354 float_train_acc=0.8455 discrete_acc=0.8444
Epoch 9: train_loss=0.8483 train_acc=0.8366 float_train_acc=0.8467 discrete_acc=0.8444
Epoch 10: train_loss=0.8372 train_acc=0.8379 float_train_acc=0.8486 discrete_acc=0.8483
Epoch 11: train_loss=0.8271 train_acc=0.8398 float_train_acc=0.8496 discrete_acc=0.8507
Epoch 12: train_loss=0.8177 train_acc=0.8

KeyboardInterrupt: 