In [1]:
from pynq import Overlay
from pynq import MMIO
import numpy as np

# Program bitstream to FPGA
overlay = Overlay('MarginSampling512_2_10Classes.bit')

# Access to memory map of the AXI GPIO 0
GPIO0_ADDR_BASE = 0x41200000
GPIO0_ADDR_RANGE = 0x10000
gpio0_obj = MMIO(GPIO0_ADDR_BASE, GPIO0_ADDR_RANGE)

# Access to memory map of the BRAM 0
BRAM0_ADDR_BASE = 0x40000000
BRAM0_ADDR_RANGE = 0x40000
bram0_obj = MMIO(BRAM0_ADDR_BASE, BRAM0_ADDR_RANGE)

# Access to memory map of the BRAM 1
BRAM1_ADDR_BASE = 0x42000000
BRAM1_ADDR_RANGE = 0x1000
bram1_obj = MMIO(BRAM1_ADDR_BASE, BRAM1_ADDR_RANGE)

### Functions

In [2]:
def margin_sw(probs, query_batch):

    margins = np.array([_[-1]-_[-2] for _ in np.sort(probs)])
    query_idx_batch = np.argsort(margins)[:query_batch]

    query_idx = np.argsort(margins)
    margins_batch = margins[query_idx_batch]

    return margins, margins_batch, query_idx, query_idx_batch

def margin_emul(probs, query_batch, n_rbanks, n_registers):

    rbanks = [dict() for _ in range(n_rbanks)]
    margins = []
    indx = 0
    r = 0
    rb = 0

    for indx, prob in enumerate(probs):
        margin = np.sort(prob)[-1] - np.sort(prob)[-2]
        margins.append(margin)
        if indx < query_batch-1:
            rbanks[rb][r] = (indx, margin)
        elif indx == query_batch-1:
            rbanks[rb][r] = (indx, margin)
            maxs = [max(rbank.items(), key=lambda x: x[1][1]) for rbank in rbanks]
        elif indx > query_batch-1:
            if margin < maxs[rb][1][1]:
                rbanks[rb][maxs[rb][0]] = (indx, margin)
                maxs[rb] = max(rbanks[rb].items(), key=lambda x: x[1][1])
        
        rb += 1
        if rb == n_rbanks:
            rb = 0
            r += 1
        if r == n_registers:
            r = 0

    query_idx = []
    for rb in rbanks:
        query_idx += [idx for idx,margin in rb.values()]

    margins_batch = np.array(margins)[query_idx]

    return margins_batch, np.array(query_idx)

def write_bram_pynq_format(din, word_width):
    dout = []
    for word in din:
        word = ([format(_, '04x') for _ in word])
        while len(word) < word_width//16:
            word = np.append(word, format(0, '04x'))
        word = [word[2*_]+word[2*_+1] for _ in range(word_width//16//2)]
        dout.append(word)

    dout_flattened = [prob for word in dout for prob in word]

    return np.array(dout_flattened)

def random_data(n_classes, data_length=10240, data_width=16):

    data = [[np.random.randint(2**data_width) for i in range(n_classes)] for j in range(data_length)]
    
    return np.array(data, dtype=np.uint16)

### Tests on random generated data

In [3]:
N = 50
for n in range(N):
    
    #random data
    raw_data = random_data(n_classes=10, data_length=5120, data_width=16)
    
    #pynq bram format
    input_data = write_bram_pynq_format(raw_data, 256)
    
    #write array to bram0
    j = 0;
    for i in range(0,input_data.shape[0]*4,4):
        bram0_obj.write(i, int(input_data[j], 16))
        j += 1
        
    #start processing
    gpio0_obj.write(0, 1)
    gpio0_obj.write(0, 0)
    while (gpio0_obj.read(8) == 0):
        pass
    
    #output array
    output_data = np.zeros((512,), dtype=np.uint32)
    
    #read bram1 to output array
    j = 0
    for i in range(0,output_data.shape[0]*4,4):
        output_data[j] = bram1_obj.read(i)
        j += 1
    
    #emulation results
    margin_hw_indx = margin_emul(raw_data, 512, 8, 64)[1]
    
    #check mismatches
    mismatches = []
    for _ in margin_hw_indx:
        if _ not in output_data:
            mismatches.append(_)

    print(f"{n}: {mismatches}") 
    
    del input_data, output_data, raw_data

0: []
1: []
2: []
3: []
4: []
5: []
6: []
7: []
8: []
9: []
10: []
11: []
12: []
13: []
14: []
15: []
16: []
17: []
18: []
19: []
20: []
21: []
22: []
23: []
24: []
25: []
26: []
27: []
28: []
29: []
30: []
31: []
32: []
33: []
34: []
35: []
36: []
37: []
38: []
39: []
40: []
41: []
42: []
43: []
44: []
45: []
46: []
47: []
48: []
49: []


### Tests on predicted data

In [3]:
#load data
predicted_probs = dict()
FashionMNIST_MobileNetV1 = np.load("../predicted_probs/predicted_probs_FashionMNISTSubset512_MobileNetV1.npy")
predicted_probs["FashionMNIST_MobileNetV1"] = FashionMNIST_MobileNetV1
FashionMNIST_EfficientNetB0 = np.load("../predicted_probs/predicted_probs_FashionMNISTSubset512_EfficientNetB0.npy")
predicted_probs["FashionMNIST_EfficientNetB0"] = FashionMNIST_EfficientNetB0
FashionMNIST_ResNet50 = np.load("../predicted_probs/predicted_probs_FashionMNISTSubset512_ResNet50.npy")
predicted_probs["FashionMNIST_ResNet50"] = FashionMNIST_ResNet50 
CIFAR10_MobileNetV1 = np.load("../predicted_probs/predicted_probs_CIFAR10Subset512_MobileNetV1.npy")
predicted_probs["CIFAR10_MobileNetV1"] = CIFAR10_MobileNetV1
CIFAR10_EfficientNetB0 = np.load("../predicted_probs/predicted_probs_CIFAR10Subset512_EfficientNetB0.npy")
predicted_probs["CIFAR10_EfficientNetB0"] = CIFAR10_EfficientNetB0
CIFAR10_ResNet50 = np.load("../predicted_probs/predicted_probs_CIFAR10Subset512_ResNet50.npy")
predicted_probs["CIFAR10_ResNet50"] = CIFAR10_ResNet50

print("FashionMNIST_MobileNetV1 shape:", FashionMNIST_MobileNetV1.shape)
print("FashionMNIST_EfficientNetB0 shape:", FashionMNIST_EfficientNetB0.shape)
print("FashionMNIST_ResNet50 shape:", FashionMNIST_ResNet50.shape)
print("CIFAR10_MobileNetV1 shape:", CIFAR10_MobileNetV1.shape)
print("CIFAR10_EfficientNetB0 shape:", CIFAR10_EfficientNetB0.shape)
print("CIFAR10_ResNet50 shape:", CIFAR10_ResNet50.shape)

FashionMNIST_MobileNetV1 shape: (5120, 10)
FashionMNIST_EfficientNetB0 shape: (5120, 10)
FashionMNIST_ResNet50 shape: (5120, 10)
CIFAR10_MobileNetV1 shape: (5120, 10)
CIFAR10_EfficientNetB0 shape: (5120, 10)
CIFAR10_ResNet50 shape: (5120, 10)


In [4]:
for k, raw_data in predicted_probs.items():
    
    #pynq bram format
    raw_data = np.array([prob*(2**16) for prob in raw_data], np.uint16)
    input_data = write_bram_pynq_format(raw_data, 256)
    
    #write array to bram0
    j = 0;
    for i in range(0,input_data.shape[0]*4,4):
        bram0_obj.write(i, int(input_data[j], 16))
        j += 1
        
    #start processing
    gpio0_obj.write(0, 1)
    gpio0_obj.write(0, 0)
    while (gpio0_obj.read(8) == 0):
        pass
    
    #output array
    output_data = np.zeros((512,), dtype=np.uint32)
    
    #read bram1 to output array
    j = 0
    for i in range(0,output_data.shape[0]*4,4):
        output_data[j] = bram1_obj.read(i)
        j += 1
    
    #emulation results
    margin_hw_indx = margin_emul(raw_data, 512, 8, 64)[1]
    
    #check mismatches
    mismatches = []
    for _ in margin_hw_indx:
        if _ not in output_data:
            mismatches.append(_)

    print(f"{k} Mismatches: {mismatches}") 
    
    del input_data, output_data, raw_data

FashionMNIST_MobileNetV1 Mismatches: []
FashionMNIST_EfficientNetB0 Mismatches: []
FashionMNIST_ResNet50 Mismatches: []
CIFAR10_MobileNetV1 Mismatches: []
CIFAR10_EfficientNetB0 Mismatches: []
CIFAR10_ResNet50 Mismatches: []
