In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os

# ======= Step 1: Set up environment =======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EXPORT_DIR = "exported_weights"
os.makedirs(EXPORT_DIR, exist_ok=True)

# ======= Step 2: Define data loaders =======
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# ======= Step 3: Define BNN model =======
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.sign(self.bn1(self.fc1(x)))
        x = self.fc2(x)
        return x

model = BNN().to(device)

# ======= Step 4: Helper functions =======
def binarize_and_export(weight_tensor, filename):
    with open(os.path.join(EXPORT_DIR, filename), 'w') as f:
        for row in weight_tensor:
            binary = ((row.sign() + 1) / 2).int()
            bits = ''.join(str(b.item()) for b in binary)
            f.write(bits + '\n')

def fold_bn_into_fc(fc, bn):
    w = fc.weight.data
    b = fc.bias.data if fc.bias is not None else torch.zeros_like(bn.running_mean)
    gamma = bn.weight.data
    beta = bn.bias.data
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps
    std = torch.sqrt(var + eps)
    w_folded = w * (gamma / std).unsqueeze(1)
    b_folded = (b - mean) / std * gamma + beta
    return w_folded, b_folded

import time

def measure_inference_time(model, test_loader, label=""):
    model.eval()
    correct = 0
    start_time = time.time()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
    end_time = time.time()
    total_time = end_time - start_time
    avg_time = total_time / len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset) * 100

    print(f"\n🕒 {label} Inference Results")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Total Inference Time: {total_time:.4f} seconds")
    print(f"Average Time per Image: {avg_time * 1000:.4f} ms")
    return total_time, avg_time

# ======= Step 5: Train with flipping frequency tracking =======
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
best_accuracy = 0.0

flip_count = torch.zeros_like(model.fc1.weight.data).int().to(device)
prev_sign = torch.sign(model.fc1.weight.data.clone().detach())

for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            current_sign = torch.sign(model.fc1.weight.data)
            flip = (current_sign != prev_sign).int()
            flip_count += flip
            prev_sign = current_sign.clone()

    # Evaluate accuracy
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()

    accuracy = correct / len(test_dataset)
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {accuracy*100:.2f}%")
    if epoch == 9:  # If this is the last turn
      measure_inference_time(model, test_loader, label="Before Pruning")

# ======= Step 6: Prune based on flipping frequency =======
flattened = flip_count.view(-1).float()
num_to_prune = int(flattened.numel() * 0.2)
_, prune_indices = torch.topk(flattened, num_to_prune, largest=True)
mask = torch.ones_like(flattened)
mask[prune_indices] = 0
mask = mask.view_as(model.fc1.weight.data)

with torch.no_grad():
    model.fc1.weight.data *= mask
print("✂️ Pruned unstable weights based on flipping frequency.")

# ======= Step 7: Fine-tune after pruning =======
print("🔁 Fine-tuning pruned model...")
for epoch in range(3):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Final test
model.eval()
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()

final_acc = correct / len(test_dataset)
print(f"✅ Final Accuracy After Pruning and Fine-Tuning: {final_acc*100:.2f}%")
measure_inference_time(model, test_loader, label="After Pruning")


# ======= Step 8: Export weights =======
fc1_folded_w, _ = fold_bn_into_fc(model.fc1, model.bn1)
binarize_and_export(fc1_folded_w, "fc1_weights.mem")
binarize_and_export(model.fc2.weight.data, "fc2_weights.mem")
print("📦 Exported .mem weights.")



Epoch 1, Loss: 0.4975, Accuracy: 88.05%
Epoch 2, Loss: 0.5157, Accuracy: 88.62%
Epoch 3, Loss: 0.3596, Accuracy: 89.01%
Epoch 4, Loss: 0.3083, Accuracy: 88.89%
Epoch 5, Loss: 0.3963, Accuracy: 88.73%
Epoch 6, Loss: 0.1814, Accuracy: 88.49%
Epoch 7, Loss: 0.5201, Accuracy: 88.93%
Epoch 8, Loss: 0.5741, Accuracy: 89.02%
Epoch 9, Loss: 0.5776, Accuracy: 89.18%
Epoch 10, Loss: 0.3556, Accuracy: 88.90%

🕒 Before Pruning Inference Results
Accuracy: 88.90%
Total Inference Time: 6.0309 seconds
Average Time per Image: 0.6031 ms
✂️ Pruned unstable weights based on flipping frequency.
🔁 Fine-tuning pruned model...
✅ Final Accuracy After Pruning and Fine-Tuning: 85.11%

🕒 After Pruning Inference Results
Accuracy: 85.11%
Total Inference Time: 5.9040 seconds
Average Time per Image: 0.5904 ms
📦 Exported .mem weights.
