# Binary Neural Network MNIST Testbench for PYNQ

## 1. Import Libraries and Load Bitstream

In [None]:
from pynq import Overlay
from pynq import MMIO
from pynq import allocate
import numpy as np
import time
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

In [None]:
# Load the BNN bitstream - UPDATE THIS PATH!
bitstream_path = './bnn.bit'  # Change to your actual .bit file path

ol = Overlay(bitstream_path)

In [None]:
# Initialize BNN IP
BNN_BASE_ADDR = 0x40000000
bnn_ip = MMIO(BNN_BASE_ADDR, 0x10000)

## 2. Load Golden Test Data

In [None]:
# Golden input data
golden_inputs = np.array([
    # Sample 0 - Label: 7
    [4294967295,4294967295,4294967295,4294967295,4294967295,4294967295,4290838527,4227859455,3221241855,3758358527,4232052735,2281701368,2147483407,4294963455,4294844415,4293132287,4232052735,2281701360,2147483407,4294959359,4294844415,4290904063,4229955583,3288334335,4294967295],
    # Sample 1 - Label: 2
    [4294967295,4294967295,4294967292,134217600,1073737731,4294837311,4292985855,4282138623,4286840831,4034920447,268435424,4294966303,4294951935,4294459391,4278714367,4034920447,268435440,67239680,2093056,33521671,4294967295,4294967295,4294967295,4294967295,4294967295],
    # Sample 2 - Label: 1
    [4294967295,4294967295,4294967295,4294967295,536870897,4294967103,4294960127,4294852607,4293132287,4236247039,3355443196,2147483527,4294965503,4294938623,4294049791,4280287231,4060086271,536870881,4294966815,4294951935,4294721535,4294967295,4294967295,4294967295,4294967295],
    # Sample 3 - Label: 0
    [4294967295,4294967295,4294967295,4294967280,2147483399,4294959231,4294705663,4290777087,4160815103,524256,404749831,3254771964,536756161,4293131295,4265542143,3824173054,1040711648,16776704,268427264,4294901887,4294707199,4294967295,4294967295,4294967295,4294967295],
    # Sample 4 - Label: 4
    [4294967295,4294967295,4294967295,4294967295,4294966526,2147471335,4294508095,4287620095,4043194367,532938723,4232052287,2281687032,4294737807,4291031295,4261416959,3758161919,2199912447,4043309055,536870897,4294967071,4294963711,4294918143,4294967295,4294967295,4294967295],
    # Sample 5 - Label: 1
    [4294967295,4294967295,4294967295,4294967295,4294967280,4294966799,4294959359,4294713343,4291035135,4232052735,3288334328,2147483527,4294965375,4294905855,4293984255,4279238655,3791650814,536870881,4294966815,4294951935,4294721535,4293132287,4294967295,4294967295,4294967295],
    # Sample 6 - Label: 4
    [4294967295,4294967295,4294967295,4294967295,4294965375,2415857648,4293000991,4263502335,3288088572,1070071747,4043308032,268419072,4294836255,4294959615,4294852607,4291035135,4236247039,3355443192,4294967177,4294965279,4294935039,4294459391,4294967295,4294967295,4294967295],
    # Sample 7 - Label: 9
    [4294967295,4294967295,4294967295,4294967295,4294967295,4294959615,4294713343,4286644223,4161011711,4194289,33554204,536867267,4294901791,4293919231,4286648319,4231069695,4287102975,4169138175,3288334334,536870881,4294967071,4294965503,4294938623,4294967295,4294967295],
    # Sample 8 - Label: 5
    [4294967295,4294967295,4294967295,4294967295,2155872224,134213632,2147418119,4279240703,3785883644,536870787,4294965375,4294938623,4294459391,4286579711,4160757759,2147549183,2148532222,16777184,268434944,4294963263,4294952959,4294967295,4294967295,4294967295,4294967295],
    # Sample 9 - Label: 9
    [4294967295,4294967295,4294967295,4294967295,4294967295,4294967295,4294934783,4292870655,4160757759,25231296,2014313502,16760832,536608771,4293918847,4294709247,4286840831,4030726142,134217696,4294966303,4294935039,4294459391,4287102975,4169138175,2415919103,4294967295]
], dtype=np.uint32)

# Golden output data
golden_outputs = np.array([
    [-2,4,-8,14,-4,0,-42,48,-4,2],
    [2,0,48,10,-20,4,2,-16,4,-14],
    [-16,42,2,4,-2,-14,8,2,14,-8],
    [50,-8,-8,-10,-12,16,-2,0,-8,-2],
    [-6,-16,-12,-10,40,-12,10,8,-12,22],
    [-24,30,-2,4,-2,-18,0,6,14,12],
    [-12,-10,-26,0,38,6,-4,-6,14,8],
    [-20,-2,-18,4,10,2,-8,6,2,44],
    [-2,0,4,-22,4,16,26,-28,12,2],
    [-6,-12,-20,6,12,-24,-10,28,-4,34]
], dtype=np.int32)

# Expected labels (0-9)
golden_labels = np.array([7,2,1,0,4,1,4,9,5,9], dtype=np.int32)

NUM_SAMPLES = 10

## 3. Visualize the MNIST Test Images

In [None]:
def bits_to_image(binary_data):
    """Convert 25x32-bit binary data to 28x28 image"""
    # Each uint32 contains 32 bits, total 800 bits, we need 784 (28x28)
    image = np.zeros(784, dtype=np.uint8)
    
    bit_idx = 0
    for word_idx in range(25):
        word = binary_data[word_idx]
        # Extract bits from MSB to LSB
        for bit_pos in range(31, -1, -1):
            if bit_idx < 784:
                bit_value = (word >> bit_pos) & 1
                # Inverted: bit=1 means background (white), bit=0 means digit (black)
                image[bit_idx] = 255 if bit_value == 1 else 0
                bit_idx += 1
    
    return image.reshape(28, 28)

# Convert all samples to images
mnist_images = [bits_to_image(golden_inputs[i]) for i in range(NUM_SAMPLES)]

# Display all 10 images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle('MNIST Test Images (10 Samples)', fontsize=16, fontweight='bold')

for i in range(NUM_SAMPLES):
    row = i // 5
    col = i % 5
    axes[row, col].imshow(mnist_images[i], cmap='gray', interpolation='nearest')
    axes[row, col].set_title(f'Sample {i+1}: Label={golden_labels[i]}', 
                              fontsize=12, fontweight='bold')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## 4. DMA Transfer and FPGA Control

In [None]:
# Allocate physically contiguous memory for input/output
input_buffer = allocate(shape=(25,), dtype=np.uint32)
output_buffer = allocate(shape=(10,), dtype=np.int32)

def start_bnn(input_addr, output_addr):
    """Configure and start BNN IP"""
    # Write input pointer (offset 0x10)
    bnn_ip.write(0x10, input_addr & 0xFFFFFFFF)  # Low 32 bits
    bnn_ip.write(0x14, (input_addr >> 32) & 0xFFFFFFFF)  # High 32 bits
    
    # Write output pointer (offset 0x1C)
    bnn_ip.write(0x1C, output_addr & 0xFFFFFFFF)
    bnn_ip.write(0x20, (output_addr >> 32) & 0xFFFFFFFF)
    
    # Start the IP (AP_START = bit 0)
    bnn_ip.write(0x00, 0x01)

def wait_for_completion(timeout_ms=1000):
    """Wait for BNN to complete"""
    start_time = time.time()
    while (time.time() - start_time) < (timeout_ms / 1000.0):
        status = bnn_ip.read(0x00)
        if status & 0x02:  # AP_DONE bit
            return True
        time.sleep(0.0001)  # 100 us
    return False

## 5. Run BNN Inference Tests

In [None]:
# Initialize result arrays
predicted_outputs = np.zeros((NUM_SAMPLES, 10), dtype=np.int32)
predicted_labels = np.zeros(NUM_SAMPLES, dtype=np.int32)
execution_times = np.zeros(NUM_SAMPLES)
exact_matches = np.zeros(NUM_SAMPLES, dtype=bool)

print("="*80)
print("BNN Multi-Layer Testbench")
print(f"Testing {NUM_SAMPLES} samples")
print("="*80)
print()

total_passed = 0
total_failed = 0

for i in range(NUM_SAMPLES):
    print("-" * 80)
    print(f"Sample {i+1} - Label: {golden_labels[i]}")
    
    # Copy input to buffer
    input_buffer[:] = golden_inputs[i]
    input_buffer.flush()
    
    # Start BNN and measure time
    start_time = time.time()
    start_bnn(input_buffer.physical_address, output_buffer.physical_address)
    
    if not wait_for_completion():
        print(f"  TIMEOUT!")
        total_failed += 1
        continue
    
    execution_times[i] = (time.time() - start_time) * 1e6  # Convert to microseconds
    
    # Read output
    output_buffer.invalidate()
    predicted_outputs[i] = output_buffer[:]
    predicted_labels[i] = np.argmax(predicted_outputs[i])
    
    # Verify output (Layer 3)
    exact_match = np.array_equal(predicted_outputs[i], golden_outputs[i])
    exact_matches[i] = exact_match
    
    if exact_match:
        print(f"  Layer 3 (Output): PASSED")
        print(f"Result: ALL LAYERS PASSED")
        total_passed += 1
    else:
        print(f"  Layer 3 (Output): FAILED")
        # Show mismatches
        mismatch_count = 0
        for digit in range(10):
            if predicted_outputs[i][digit] != golden_outputs[i][digit]:
                mismatch_count += 1
                if mismatch_count <= 5:
                    print(f"    Mismatch at [{digit}]: Expected {golden_outputs[i][digit]}, "
                          f"Got {predicted_outputs[i][digit]}")
        print(f"  Layer 3 (Output): FAILED ({mismatch_count}/10 errors)")
        print(f"Result: FAILED")
        total_failed += 1
    
print()
print("="*80)
print("SUMMARY:")
print(f"  Total samples:    {NUM_SAMPLES}")
print(f"  All layers pass:  {total_passed}")
print(f"  Any layer fail:   {total_failed}")
print(f"  Overall accuracy: {100.0 * total_passed / NUM_SAMPLES:.1f}%")
print("="*80)

## 6. Performance Summary

In [None]:
# Calculate metrics
correct_predictions = np.sum(predicted_labels == golden_labels)
accuracy = 100.0 * correct_predictions / NUM_SAMPLES
exact_match_count = np.sum(exact_matches)
exact_match_rate = 100.0 * exact_match_count / NUM_SAMPLES

avg_time = np.mean(execution_times)
min_time = np.min(execution_times)
max_time = np.max(execution_times)
std_time = np.std(execution_times)
throughput = 1e6 / avg_time  # inferences/second

print("\n" + "="*80)
print("PERFORMANCE METRICS")
print("="*80)
print(f"Classification Accuracy:  {correct_predictions}/{NUM_SAMPLES} ({accuracy:.1f}%)")
print(f"Exact Output Matches:     {exact_match_count}/{NUM_SAMPLES} ({exact_match_rate:.1f}%)")
print()
print("Execution Time Statistics:")
print(f"  Average:       {avg_time:.2f} µs")
print(f"  Min:           {min_time:.2f} µs")
print(f"  Max:           {max_time:.2f} µs")
print(f"  Std Dev:       {std_time:.2f} µs")
print(f"  Throughput:    {throughput:.0f} inferences/sec")
print("="*80)

# Overall verdict
if exact_match_rate == 100.0:
    print("\nALL TESTS PASSED!")
elif accuracy == 100.0:
    print("\nAccuracy surpassed golden reference")
else:
    print(f"\nTESTS FAILED - {NUM_SAMPLES - correct_predictions} incorrect predictions")

## 7. Execution Time Chart

In [None]:
plt.figure(figsize=(12, 4))

# Color bars by correctness
colors = ['green' if exact_matches[i] else 'red' for i in range(NUM_SAMPLES)]

plt.bar(range(0, NUM_SAMPLES), execution_times, color=colors, 
        edgecolor='black', alpha=0.7)
plt.axhline(y=avg_time, color='blue', linestyle='--', linewidth=2, 
            label=f'Average: {avg_time:.2f} µs')
plt.title('BNN Execution Time per Sample (Green=Pass, Red=Fail)', 
          fontsize=14, fontweight='bold')
plt.xlabel('Sample Number', fontsize=12)
plt.ylabel('Time (µs)', fontsize=12)
plt.xticks(range(0, NUM_SAMPLES))
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3, linestyle=':', axis='y')
plt.tight_layout()
plt.show()

print(f"Average: {avg_time:.2f} µs | Std Dev: {std_time:.2f} µs")