# ESP32-friendly Binary Neural Net (BNN): **Train → Test → Export** (No Synthetic Fallback)

This notebook includes:
- **Training** (`train`) — GPU-aware, multilabel (3 outputs)
- **Testing** (`test_model`, `test_accuracy`) — macro/per-class F1, subset & hamming accuracy
- **Export** (`export_to_header`) — bit-packed `bnn_export.h` for ESP32 (XNOR+POPCOUNT inference)

**Important:** You must set `CSV_PATH` to a valid CSV file. The CSV must have **features in all columns except the last 3**, which are binary labels (0/1).


In [1]:
# Environment & compatibility
import os, math, csv, random, json
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# PyTorch 2.6+ safe-load allowlist for older checkpoints that include NumPy objects
try:
    from torch.serialization import add_safe_globals
    import numpy as _np
    add_safe_globals([_np.core.multiarray._reconstruct])
    def safe_torch_load(path, map_location='cpu'):
        return torch.load(path, map_location=map_location, weights_only=False)
except Exception:
    def safe_torch_load(path, map_location='cpu'):
        return torch.load(path, map_location=map_location)
print('safe_torch_load ready')


Device: cuda
safe_torch_load ready


## Data loader (CSV)
- CSV with header; **last 3 columns** are labels (0/1)
- Features standardized (mean/std exported for ESP32)


In [2]:
class TabularMultiLabel(Dataset):
    def __init__(self, csv_path: str):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f'CSV not found: {csv_path}')
        rows = []
        with open(csv_path, 'r', newline='') as f:
            r = csv.reader(f)
            header = next(r)
            for row in r:
                if not row: continue
                rows.append([float(x) for x in row])
        arr = np.array(rows, dtype=np.float32)
        X = arr[:, :-3]
        y = arr[:, -3:]
        self.mean = X.mean(axis=0, keepdims=True)
        self.std = X.std(axis=0, keepdims=True) + 1e-6
        self.X = (X - self.mean) / self.std
        self.y = y
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


## Binary layers & model (hidden layers binarized with STE)
- Hidden weights & activations are binary; final layer is real-valued
- Suggests hidden sizes to fit ~**20 KB** bit-packed parameters


In [3]:
def binarize(t: torch.Tensor):
    return t.sign().clamp(min=-1., max=1.)

class BinaryLinear(nn.Module):
    def __init__(self, in_f, out_f, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_f, in_f))
        self.bias = nn.Parameter(torch.zeros(out_f)) if bias else None
        nn.init.kaiming_normal_(self.weight)
    def forward(self, x):
        w = binarize(self.weight) + (self.weight - self.weight.detach())
        return F.linear(x, w, self.bias)

class BinaryAct(nn.Module):
    def forward(self, x):
        return binarize(x) + (x - x.detach())

class BNN(nn.Module):
    def __init__(self, in_f: int, hidden_sizes: Tuple[int, ...], out_f: int = 3):
        super().__init__()
        layers = []
        last = in_f
        for h in hidden_sizes:
            layers += [BinaryLinear(last, h), nn.BatchNorm1d(h), BinaryAct()]
            last = h
        self.hidden = nn.Sequential(*layers)
        self.fc_out = nn.Linear(last, out_f)
    def forward(self, x):
        h = self.hidden(x)
        return self.fc_out(h)

def packed_bytes(bits:int):
    return (bits + 7) // 8

def suggest_hidden_sizes(in_f: int, out_f: int = 3, budget_bytes: int = 20*1024):
    candidates = [256, 128, 64, 32]
    best = None
    for a in candidates:
        for b in candidates:
            for c in candidates:
                hs = [a,b,c]
                bits = in_f*a + a*b + b*c
                bits += 8*(c*out_f)
                if packed_bytes(bits) <= budget_bytes:
                    if best is None or sum(hs) > sum(best):
                        best = hs
    if best is None:
        for a in candidates[::-1]:
            bits = in_f*a + 8*(a*out_f)
            if packed_bytes(bits) <= budget_bytes:
                return [a]
        return [32]
    return best


## Train / Test functions
- `train(...)` saves `bnn_best.pt` with weights + metadata (in_f, hidden, mean, std)
- `test_model(...)` reloads and reports macro/per-class F1
- `test_accuracy(...)` adds subset & hamming accuracy, per-class accuracy


In [4]:
def train(csv_path: str, epochs=20, batch_size=2048, lr=5e-4):
    ds = TabularMultiLabel(csv_path)
    n = len(ds)
    in_f = ds.X.shape[1]
    out_f = 3
    if n > 320000:
        train_n = 300000
        test_n = n - train_n
        train_ds, test_ds = random_split(ds, [train_n, test_n], generator=torch.Generator().manual_seed(SEED))
    else:
        train_n = int(0.8 * n)
        test_n = n - train_n
        train_ds, test_ds = random_split(ds, [train_n, test_n], generator=torch.Generator().manual_seed(SEED))

    hs = suggest_hidden_sizes(in_f, out_f, budget_bytes=20*1024)
    print(f"Input={in_f}, Hidden={hs}, Output=3 (budget ~20KB)")
    model = BNN(in_f, tuple(hs), out_f).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=(device.type=='cuda'))
    test_loader = DataLoader(test_ds, batch_size=4096, shuffle=False, num_workers=0)

    best_f1 = 0.0
    for ep in range(1, epochs+1):
        model.train(); total=0.0
        for xb, yb in train_loader:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad(); logits = model(xb)
            loss = F.binary_cross_entropy_with_logits(logits, yb)
            loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += loss.item()*xb.size(0)
        sched.step()
        # Eval
        model.eval(); tp=np.zeros(3); fp=np.zeros(3); fn=np.zeros(3)
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(device); yb = yb.to(device)
                probs = torch.sigmoid(model(xb)); preds = (probs>=0.5).float()
                y=yb.cpu().numpy(); p=preds.cpu().numpy()
                tp += (((p==1)&(y==1)).sum(axis=0)); fp += (((p==1)&(y==0)).sum(axis=0)); fn += (((p==0)&(y==1)).sum(axis=0))
        f1=[]
        for i in range(3):
            precision = tp[i]/(tp[i]+fp[i]+1e-9); recall = tp[i]/(tp[i]+fn[i]+1e-9)
            f1.append(2*precision*recall/(precision+recall+1e-9))
        macro=float(np.mean(f1))
        print(f"Epoch {ep:02d} | loss={total/max(1,train_n):.4f} | macroF1={macro:.4f} | per-class F1={f1}")
        if macro>best_f1:
            best_f1=macro
            torch.save({'state_dict':model.state_dict(),'in_f':in_f,'hidden':hs,'mean':ds.mean,'std':ds.std}, 'bnn_best.pt')
            print(f"Saved bnn_best.pt (macroF1={best_f1:.4f})")
    print('Best macroF1=', best_f1)

def test_model(ckpt_path='bnn_best.pt', csv_path=None, threshold=0.5):
    if csv_path is None:
        raise ValueError('csv_path is required')
    ckpt = safe_torch_load(ckpt_path, map_location=device)
    state=ckpt['state_dict']; in_f=ckpt['in_f']; hs=ckpt['hidden']
    model = BNN(in_f, tuple(hs), out_f=3).to(device); model.load_state_dict(state); model.eval()
    ds = TabularMultiLabel(csv_path)
    loader = DataLoader(ds, batch_size=4096, shuffle=False)
    tp=np.zeros(3); fp=np.zeros(3); fn=np.zeros(3)
    with torch.no_grad():
        for xb, yb in loader:
            xb=xb.to(device); yb=yb.to(device)
            probs=torch.sigmoid(model(xb)); preds=(probs>=threshold).float()
            y=yb.cpu().numpy(); p=preds.cpu().numpy()
            tp += (((p==1)&(y==1)).sum(axis=0)); fp += (((p==1)&(y==0)).sum(axis=0)); fn += (((p==0)&(y==1)).sum(axis=0))
    f1=[]
    for i in range(3):
        precision = tp[i]/(tp[i]+fp[i]+1e-9); recall = tp[i]/(tp[i]+fn[i]+1e-9)
        f1.append(2*precision*recall/(precision+recall+1e-9))
    macro=float(np.mean(f1))
    print(f"Test macroF1={macro:.4f} | per-class F1={f1}")
    return macro, f1

def test_accuracy(ckpt_path='bnn_best.pt', csv_path=None, threshold=0.5):
    if csv_path is None:
        raise ValueError('csv_path is required')
    ckpt = safe_torch_load(ckpt_path, map_location=device)
    state, in_f, hs = ckpt['state_dict'], ckpt['in_f'], ckpt['hidden']
    model = BNN(in_f, tuple(hs), out_f=3).to(device)
    model.load_state_dict(state); model.eval()
    ds = TabularMultiLabel(csv_path)
    loader = DataLoader(ds, batch_size=4096, shuffle=False)
    import numpy as np
    tp = np.zeros(3); fp = np.zeros(3); fn = np.zeros(3); tn = np.zeros(3)
    total_samples = 0
    subset_correct = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device); yb = yb.to(device)
            probs = torch.sigmoid(model(xb))
            preds = (probs >= threshold).float()
            y = yb.cpu().numpy()
            p = preds.cpu().numpy()
            tp += (((p==1)&(y==1)).sum(axis=0))
            fp += (((p==1)&(y==0)).sum(axis=0))
            fn += (((p==0)&(y==1)).sum(axis=0))
            tn += (((p==0)&(y==0)).sum(axis=0))
            subset_correct += (p == y).all(axis=1).sum()
            total_samples += y.shape[0]
    f1 = []
    for i in range(3):
        precision = tp[i] / (tp[i] + fp[i] + 1e-9)
        recall    = tp[i] / (tp[i] + fn[i] + 1e-9)
        f1.append(2*precision*recall/(precision+recall+1e-9))
    macro_f1 = float(np.mean(f1))
    per_class_acc   = (tp + tn) / (tp + tn + fp + fn + 1e-9)
    hamming_acc     = float((tp + tn).sum() / (total_samples * 3.0))
    subset_accuracy = float(subset_correct / max(1, total_samples))
    print(f"Test @ thr={threshold} | subset_acc={subset_accuracy:.4f} | hamming_acc={hamming_acc:.4f} | per-class_acc={per_class_acc.tolist()} | macroF1={macro_f1:.4f}")
    return {
        'subset_accuracy': subset_accuracy,
        'hamming_accuracy': hamming_acc,
        'per_class_accuracy': per_class_acc.tolist(),
        'macro_f1': macro_f1,
        'per_class_f1': [float(x) for x in f1]
    }


## Export to ESP32 header (`bnn_export.h`) — bit-packed binary weights
Packs only **2-D Linear weights** (skips 1-D BatchNorm parameters). Includes a safety check on the number of hidden layers found.


In [5]:
def pack_bits(W: np.ndarray):
    signs = (np.sign(W) > 0).astype(np.uint8)
    if signs.ndim != 2:
        raise ValueError(f'pack_bits expects 2D, got shape {signs.shape}')
    out_f, in_f = signs.shape
    packed = []
    for o in range(out_f):
        row = signs[o]
        byte = 0
        for i in range(in_f):
            bit = int(row[i])
            byte |= (bit << (i % 8))
            if (i % 8) == 7:
                packed.append(byte); byte = 0
        if (in_f % 8) != 0:
            packed.append(byte)
    return np.frombuffer(bytearray(packed), dtype=np.uint8)

def export_to_header(ckpt_path='bnn_best.pt', header_path='bnn_export.h'):
    ckpt = safe_torch_load(ckpt_path, map_location='cpu')
    state = ckpt['state_dict']
    in_f  = ckpt['in_f']
    hs    = ckpt['hidden']
    mean  = ckpt['mean'].astype(np.float32)
    std   = ckpt['std'].astype(np.float32)
    # Collect ONLY 2D weights (Linear layers); skip BatchNorm (1D)
    bin_ws = []
    for k, v in state.items():
        if k.startswith('hidden.') and k.endswith('.weight') and v.dim() == 2:
            bin_ws.append(v.numpy())
    if len(bin_ws) != len(hs):
        keys = [k for k,v in state.items() if k.startswith('hidden.') and k.endswith('.weight')]
        raise RuntimeError(f'Expected {len(hs)} binary layers, found {len(bin_ws)}. Keys: {keys}')
    Wout = state['fc_out.weight'].numpy().astype(np.float32)
    Bout = state['fc_out.bias'].numpy().astype(np.float32)
    packed_ws = [pack_bits(W) for W in bin_ws]
    with open(header_path, 'w') as f:
        f.write('// Auto-generated: ESP32 BNN export\n#pragma once\n\n')
        f.write(f'#define IN_F {in_f}\n#define OUT_F 3\n')
        if len(hs)>0: f.write(f'#define H1 {hs[0]}\n')
        if len(hs)>1: f.write(f'#define H2 {hs[1]}\n')
        if len(hs)>2: f.write(f'#define H3 {hs[2]}\n')
        f.write('const float FEAT_MEAN[IN_F] = {'+','.join(str(float(x)) for x in mean.flatten())+'};\n')
        f.write('const float FEAT_STD[IN_F]  = {'+','.join(str(float(x)) for x in std.flatten())+'};\n\n')
        for li, pw in enumerate(packed_ws):
            f.write('const uint8_t Wb_'+str(li)+'['+str(len(pw))+'] = {'+','.join(str(int(x)) for x in pw)+'};\n')
            inf = in_f if li==0 else hs[li-1]
            outf = hs[li]
            f.write('const int Wb_'+str(li)+'_IN = '+str(inf)+';\n')
            f.write('const int Wb_'+str(li)+'_OUT = '+str(outf)+';\n\n')
        f.write('const float Wout['+str(Wout.size)+'] = {'+','.join(str(float(x)) for x in Wout.flatten())+'};\n')
        f.write('const float Bout[3] = {'+','.join(str(float(x)) for x in Bout.flatten())+'};\n')
    print('Wrote', header_path)


## Configuration & End-to-End Run (No Fallback)
**Set `CSV_PATH`** to your dataset path. The cell asserts that the file exists and then runs **train → export → test**.


In [7]:
CSV_PATH = 'processed-data.csv'  # e.g., '/content/your_dataset.csv' (labels must be last 3 columns)
assert CSV_PATH and os.path.exists(CSV_PATH), 'Please set CSV_PATH to your dataset (last 3 cols = labels).'
print('Using dataset at', CSV_PATH)

# === Run end-to-end ===
train(CSV_PATH, epochs=20, batch_size=2048, lr=5e-4)
export_to_header('bnn_best.pt', 'bnn_export.h')
test_model('bnn_best.pt', CSV_PATH, threshold=0.5)
test_accuracy('bnn_best.pt', CSV_PATH, threshold=0.5)


Using dataset at processed-data.csv
Input=16, Hidden=[256, 256, 256], Output=3 (budget ~20KB)
Epoch 01 | loss=0.6696 | macroF1=0.1626 | per-class F1=[0.1490328902124288, 0.16828132040234176, 0.1705541153083121]
Saved bnn_best.pt (macroF1=0.1626)
Epoch 02 | loss=0.5938 | macroF1=0.0207 | per-class F1=[0.014700432309280979, 0.024028012472431675, 0.023403623303903227]
Epoch 03 | loss=0.5774 | macroF1=0.0045 | per-class F1=[0.004156694778129246, 0.006476119281688316, 0.002867830412657532]
Epoch 04 | loss=0.5739 | macroF1=0.0000 | per-class F1=[0.0, 0.0, 0.0]
Epoch 05 | loss=0.5712 | macroF1=0.0003 | per-class F1=[0.00025414575017760753, 0.0, 0.0006262525026330138]
Epoch 06 | loss=0.5696 | macroF1=0.0000 | per-class F1=[0.0, 0.0, 0.0]
Epoch 07 | loss=0.5688 | macroF1=0.0000 | per-class F1=[0.0, 0.0, 0.0]
Epoch 08 | loss=0.5680 | macroF1=0.0000 | per-class F1=[0.0, 0.0, 0.0]
Epoch 09 | loss=0.5674 | macroF1=0.0000 | per-class F1=[0.0, 0.0, 0.0]
Epoch 10 | loss=0.5671 | macroF1=0.0000 | per-c

{'subset_accuracy': 0.22642045454545454,
 'hamming_accuracy': 0.688510101010101,
 'per_class_accuracy': [0.6967803030303008,
  0.6844696969696948,
  0.6842803030303009],
 'macro_f1': 0.16452379247323165,
 'per_class_f1': [0.14930924506463453,
  0.17196819040357963,
  0.1722939419514807]}