<a href="https://colab.research.google.com/github/Yangtze-flowing/hardware4ai/blob/main/BNN%20training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ✅ Step 1: Import Libraries
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import matplotlib.pyplot as plt

In [2]:
# ✅ Step 2: Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 15.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 474kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.81MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.32MB/s]


In [3]:
# ✅ Step 3: Define BNN Model (with BN)
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)  # ✅ 保留BN层
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.sign(self.bn1(self.fc1(x)))  # ✅ sign 作用于 BN 后
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BNN().to(device)


In [4]:
import torch.optim as optim
import time
import os

# ========== 目录设置 ==========
EXPORT_DIR = "exported_weights"
os.makedirs(EXPORT_DIR, exist_ok=True)

# ========== 工具函数 ==========
def binarize_and_export(weight_tensor, filename):
    with open(os.path.join(EXPORT_DIR, filename), 'w') as f:
        for row in weight_tensor:
            binary = ((row.sign() + 1) / 2).int()  # +1 → 1, -1 → 0
            bits = ''.join(str(b.item()) for b in binary)
            f.write(bits + '\n')

def fold_bn_into_fc(fc, bn):
    # 融合 FC + BN
    w = fc.weight.data
    b = fc.bias.data if fc.bias is not None else torch.zeros_like(bn.running_mean)
    gamma = bn.weight.data
    beta = bn.bias.data
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps
    std = torch.sqrt(var + eps)
    w_folded = w * (gamma / std).unsqueeze(1)
    b_folded = (b - mean) / std * gamma + beta
    return w_folded, b_folded

# ========== 训练 ==========
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
best_accuracy = 0.0

for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # 验证阶段
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()

    accuracy = correct / len(test_dataset)
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {accuracy*100:.2f}%")

    # 如果是最优，保存模型并导出权重
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), "best_bnn_model.pth")
        print(f"✅ New best model saved with accuracy: {best_accuracy*100:.2f}%")

        # ✅ 融合 BN1 到 FC1
        fc1_folded_w, _ = fold_bn_into_fc(model.fc1, model.bn1)

        # ✅ 导出权重（注意：fc1 已融合）
        binarize_and_export(fc1_folded_w, "fc1_weights.mem")
        binarize_and_export(model.fc2.weight.data, "fc2_weights.mem")
        print("📦 Exported folded weights to .mem files.")


Epoch 1, Loss: 0.5188, Accuracy: 87.89%
✅ New best model saved with accuracy: 87.89%
📦 Exported folded weights to .mem files.
Epoch 2, Loss: 0.1509, Accuracy: 88.34%
✅ New best model saved with accuracy: 88.34%
📦 Exported folded weights to .mem files.
Epoch 3, Loss: 0.4888, Accuracy: 88.58%
✅ New best model saved with accuracy: 88.58%
📦 Exported folded weights to .mem files.
Epoch 4, Loss: 0.1509, Accuracy: 88.72%
✅ New best model saved with accuracy: 88.72%
📦 Exported folded weights to .mem files.
Epoch 5, Loss: 0.3994, Accuracy: 88.51%
Epoch 6, Loss: 0.4070, Accuracy: 88.57%
Epoch 7, Loss: 0.5393, Accuracy: 88.58%
Epoch 8, Loss: 0.2217, Accuracy: 88.72%
Epoch 9, Loss: 0.1700, Accuracy: 88.74%
✅ New best model saved with accuracy: 88.74%
📦 Exported folded weights to .mem files.
Epoch 10, Loss: 0.5949, Accuracy: 88.56%


In [5]:
import time  # ✅ 添加时间模块

# ✅ Step 5: Test Accuracy with Timing
model.eval()
correct = 0

start_time = time.time()  # ✅ 记录开始时间

for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    pred = outputs.argmax(dim=1)
    correct += (pred == labels).sum().item()

end_time = time.time()  # ✅ 记录结束时间
total_time = end_time - start_time
avg_time_per_image = total_time / len(test_dataset)

print(f"Test Accuracy: {correct / len(test_dataset) * 100:.2f}%")
print(f"Total Inference Time: {total_time:.4f} seconds")
print(f"Average Time per Image: {avg_time_per_image * 1000:.4f} ms")


Test Accuracy: 88.56%
Total Inference Time: 6.9825 seconds
Average Time per Image: 0.6982 ms
