In [None]:
from ip import *
import struct
import socket

FIRST_ID = -1
LAST_ID  = 24

SEQ_LEN   = 196
PATCH_DIM = 768
EMBED_DIM = 192
CLASSES   = 1000

BLOCK_IDS = list(range(FIRST_ID, LAST_ID+1))

TIs  = {i: SEQ_LEN                                          for i in BLOCK_IDS}
TOs  = {i: 1            if i == LAST_ID     else SEQ_LEN    for i in BLOCK_IDS}
CIs  = {i: PATCH_DIM    if i == FIRST_ID    else EMBED_DIM  for i in BLOCK_IDS}
COs  = {i: CLASSES      if i == LAST_ID     else EMBED_DIM  for i in BLOCK_IDS}
TIPs = {i: 2                                                for i in BLOCK_IDS}
TOPs = {i: 1            if i == LAST_ID     else 2          for i in BLOCK_IDS}
CIPs = {i: 2            if i == FIRST_ID    else 1          for i in BLOCK_IDS}
COPs = {i: 1                                                for i in BLOCK_IDS}

TITs = {i: TIs[i] // TIPs[i] for i in BLOCK_IDS}
TOTs = {i: TOs[i] // TOPs[i] for i in BLOCK_IDS}
CITs = {i: CIs[i] // CIPs[i] for i in BLOCK_IDS}
COTs = {i: COs[i] // COPs[i] for i in BLOCK_IDS}

MODULE_NAMEs = {
    **{-1: "patch_embed", 24: "head"}, 
    **{i: f"attn_{i//2}" for i in range(0, 24, 2)},
    **{i: f"mlp_{i//2}"  for i in range(1, 24, 2)}
}

I_TYPEs = {i: np.int8  if i == FIRST_ID else np.int16 for i in BLOCK_IDS}
I_BYTEs = {i: 1        if i == FIRST_ID else 2        for i in BLOCK_IDS}
O_TYPEs = {i: np.int32 if i == LAST_ID  else np.int16 for i in BLOCK_IDS}
O_BYTEs = {i: 4        if i == LAST_ID  else 2        for i in BLOCK_IDS}

In [None]:
START_ID  = -1
CLOSE_ID  = 24

TI   = TIs[START_ID]
CI   = CIs[START_ID]
TO   = TOs[CLOSE_ID]
CO   = COs[CLOSE_ID]

CIP  = CIPs[START_ID]
TIP  = TIPs[START_ID]
TOP  = TOPs[CLOSE_ID]
COP  = COPs[CLOSE_ID]
print(f"CIP is {CIP}")
print(f"TIP is {TIP}")

I_TYPE = I_TYPEs[START_ID]
I_BYTE = I_BYTEs[START_ID]
O_TYPE = O_TYPEs[CLOSE_ID]
O_BYTE = O_BYTEs[CLOSE_ID]

lpddr0     = AXI_MEM(0x500_0000_0000, dtype=I_TYPE, length=0x1_0000_0000)
lpddr1     = AXI_MEM(0x501_0000_0000, dtype=O_TYPE, length=0x1_0000_0000)

TIT  = TI // TIP
TOT  = TO // TOP

CIT  = CI // CIP
COT  = CO // COP

I_PIXELS = TI * CI
O_PIXELS = TO * CO

LOAD_N = 1
REPEAT_N = 2000

iArray = np.fromfile(f"/home/xilinx/datasets/refs/{MODULE_NAMEs[START_ID]}_input.bin",  dtype="i8").astype(I_TYPE)
oArray = np.fromfile(f"/home/xilinx/datasets/refs/{MODULE_NAMEs[CLOSE_ID]}_output.bin", dtype="i8").astype(O_TYPE)
print(iArray[:10])
print(oArray[:10])

iArray = iArray.reshape(LOAD_N, TIT, TIP, CIT, CIP).transpose((0, 1, 3, 2, 4)).reshape(-1)
oArray = oArray.reshape(LOAD_N, TOT, TOP, COT, COP).transpose((0, 1, 3, 2, 4)).reshape(-1)
print(iArray[:10])
print(oArray[:10])
print(f"iarray size is {iArray.size}")

# put iArray into lpddr0
for i in range(REPEAT_N):
    lpddr0[iArray.size * i : iArray.size * (i+1)] = iArray
print("Put done")


# Calculate transfer length
TEST_N = LOAD_N * REPEAT_N 
PER_TRANSFER = 400 # limited by DMA: maximum (1<<26) bytes
I_TRANSFER_BYTES = PER_TRANSFER * I_PIXELS * I_BYTE
O_TRANSFER_BYTES = PER_TRANSFER * O_PIXELS * O_BYTE
print(f"I_BYTE is {I_BYTE}, I_TYPE is {I_TYPE}")
print(f"O_BYTE is {O_BYTE}, O_TYPE is {O_TYPE}")

CIP is 2
TIP is 2
[-80 -79 -76 -79 -80 -80 -80 -78 -78 -62]
[102624  15038   3417  15584   8737    788    556   -316  -9018 -18218]
[-80 -79 -69 -85 -76 -79 -81 -79 -80 -80]
[102624  15038   3417  15584   8737    788    556   -316  -9018 -18218]
iarray size is 150528
Put done
I_BYTE is 1, I_TYPE is <class 'numpy.int8'>
O_BYTE is 4, O_TYPE is <class 'numpy.int32'>


In [None]:
dma         = AXI_DMA       (0xAF00_0000)
gpio        = AXI_REGISTER  (0xAB00_0000, np.uint32)
clock       = AXI_CLOCK     (0xA400_0000)
pl_reset    = PL_RESET      ()
vit         = AXI_IP        (0xA500_0000, 
                                [
                                    ("N", 0x00, np.uint32), 
                                    ("T", 0x10, np.uint32), 
                                    ("R", 0x20, np.uint32)
                                ]
                            )

In [4]:
pl_reset.reset()
# test gpio
for i in range(4):
    gpio.write(0x0)
    sleep(0.2)
    gpio.write(0xf)
    sleep(0.2)

In [5]:
clock.refresh(MHz425)
pl_reset.reset()
dma.reset()
dma.mm2s.enable()
dma.s2mm.enable()

vit.R = 0
vit.R = 1
vit.N = PER_TRANSFER

start = time.time()
for i in range(TEST_N // PER_TRANSFER):
    vit.T = 1
    dma.mm2s.transfer(lpddr0.base_addr + i * I_TRANSFER_BYTES, I_TRANSFER_BYTES)
    dma.s2mm.transfer(lpddr1.base_addr + i * O_TRANSFER_BYTES, O_TRANSFER_BYTES)
    dma.s2mm.wait()
end = time.time()
elapsed = end - start
print(f"elapsed time: {elapsed: .4f}")
print(f"FPS: {TEST_N / elapsed}")

elapsed time:  0.2817
FPS: 7099.291476389881


In [None]:
flag = True
err_cnt = 0

try:
    for test_n in range(TEST_N):
        if(test_n % 1000==0):
            print(test_n)
        for o_pixels in range(O_PIXELS):
            golden = oArray[(test_n % LOAD_N) * O_PIXELS + o_pixels]
            actual = lpddr1[test_n * O_PIXELS + o_pixels]
            actual = actual - (1<<19) if actual >= (1<<18) else actual # convert to signed
            if golden != actual:
                print(f"test_n: {test_n}, o_pixels: {o_pixels}, golden: {golden}, actual: {actual}")
                err_cnt += 1
                if err_cnt >= 20:
                    flag = False
                    break
        if flag == True:
            if (test_n % 1000==0):
                print(f"test{test_n} pass")
        else :
            print(f"test{test_n} fail")
            break
            # continue
except KeyboardInterrupt:
    print("interrupted")
                