# 1. 加载Overlay，定义DMA对象

In [1]:
from pynq import Overlay, allocate
import pynq.lib.dma
overlay = Overlay('/home/xilinx/jupyter_notebooks/lstm/vivado_mnist/lstm_mnist.bit') # 加载Overlay
dma = overlay.axi_dma_0  # 定义DMA对象

# 2. 解析MNIST测试集图片和标签数据

In [3]:
import read_mnist as reader
test_images = reader.load_images('data_mnist/t10k-images.idx3-ubyte')  # 读取测试集图片
test_labels = reader.load_labels('data_mnist/t10k-labels.idx1-ubyte')  # 读取测试集标签 
# print(test_images[1].shape)
# print(test_images[1])

magic number: 2051; image number: 10000; image size: 28px*28px
done
magic number:2049; image number: 10000
done


# 3. 从测试集中随机选取并显示若干图片

In [4]:
import random
import numpy as np
import matplotlib.pyplot as plt

IMG_NUM = 10

index = [0] * IMG_NUM
img = [0] * IMG_NUM
for i in range (IMG_NUM):
    index[i] = random.randint(0, 10000)  # 在测试集中随机选取8张图片
    img_dat = np.array(test_images[index[i]]).reshape(28, 28)

# 显示上面选取的8张测试图片
plt.figure()
for i in range (IMG_NUM):
    plt.subplot(1, IMG_NUM, i + 1)
    plt.imshow(np.array(test_images[index[i]]).reshape(28, 28),cmap='gray')
print('Image labels: ', test_labels[index])
plt.show()

Image labels:  [7 1 7 6 9 2 1 7 9 0]


<matplotlib.figure.Figure at 0xaec3c8d0>

# 4. 利用IP核加速RNN推导

In [5]:
import time
hw_time = [0] * IMG_NUM
out_buf = allocate(shape=(10000, 10), dtype = np.float32)
for i in range (IMG_NUM):
    t0 = time.time()
    dma.sendchannel.transfer(test_images[index[i]])  # 调用DMA将待预测图片数据传输到IP核
    dma.recvchannel.transfer(out_buf[index[i]])  # 调用DMA从IP核获取RNN的推导结果
    dma.sendchannel.wait()  # 等待DMA发送完成
    dma.recvchannel.wait()  # 等待DMA接收完成
    t1 = time.time()
    hw_time[i] = t1 - t0

    out_list = np.array(out_buf[index[i]]).tolist()
    max_indx = out_list.index(max(out_list))  # 推导结果向量的最大分量的下标即为预测结果
    print('Result: %d,' % max_indx, 'time: {:1.6f}s'.format(hw_time[i]))
    

# 统计平均推理时间（单位：秒）
avg_hw_time_sec = sum(hw_time) / IMG_NUM
# 转换为毫秒
avg_hw_time_ms = avg_hw_time_sec * 1000
print(f"\n 平均每张处理时间: {avg_hw_time_ms:.3f} ms")


Result: 7, time: 0.006333s
Result: 1, time: 0.006293s
Result: 7, time: 0.006261s
Result: 6, time: 0.006284s
Result: 5, time: 0.006218s
Result: 2, time: 0.006238s
Result: 1, time: 0.006222s
Result: 7, time: 0.006202s
Result: 9, time: 0.006204s
Result: 0, time: 0.006194s

 平均每张处理时间: 6.245 ms
