In [None]:
# task 3 – PS-PL GPIO calculator
#
# hw:
#   - ZYNQ PS + PL
#   - EMIO GPIO, 52-bit wide
#   - RTL: task3.v
#
# GPIO bit layout (LSB → MSB):
#   [23:0]   : inputs passthrough (in0-in1-in2 packed)
#   [47:24]  : output (24-bit signed)
#   [48]     : done
#   [51:49]  : opcode
#
# author: Alp Bolukbasi
#

from pynq import Overlay, MMIO
from dataclasses import dataclass
import time

# -----------------------------------------------------------------------------
# load overlay
# -----------------------------------------------------------------------------
overlay = Overlay("task3.bit")

# -----------------------------------------------------------------------------
# GPIO register definitions
# -----------------------------------------------------------------------------
GPIO_BASE = 0xe000a000

DATA2_RO  = 0x068
DATA3_RO  = 0x06c

DATA2     = 0x048
DATA3     = 0x04c

DIRM2     = 0x284
DIRM3     = 0x2c4

OEN2      = 0x288
OEN3      = 0x2c8

# -----------------------------------------------------------------------------
# protocol constants
# -----------------------------------------------------------------------------
RESULT_LSB = 24
DONE_BIT   = 48
OPCODE_LSB = 49

MASK_24 = (1 << 24) - 1
MASK_3  = (1 << 3)  - 1

# -----------------------------------------------------------------------------
# MMIO initialization
# -----------------------------------------------------------------------------
mmio = MMIO(GPIO_BASE, 0x1000)

# bank 2 (EMIO 0–31)
# bits 0–23  : PS → PL inputs
# bits 24–31 : PL → PS result
mmio.write(DIRM2, 0x00ffffff)
mmio.write(OEN2,  0x00ffffff)

# bank 3 (EMIO 32–63)
# bits 49–51 : PS → PL opcode
# bits 32–48 : PL → PS result + done
opcode_mask = (0x7 << (49 - 32))
mmio.write(DIRM3, opcode_mask)
mmio.write(OEN3,  opcode_mask)

# -----------------------------------------------------------------------------
# signed conversion helper
# -----------------------------------------------------------------------------
def to_signed(value, bits):
    sign_bit = 1 << (bits - 1)
    return (value ^ sign_bit) - sign_bit

# -----------------------------------------------------------------------------
# structured return type
# -----------------------------------------------------------------------------
@dataclass
class CalcResult:
    opcode: int
    result: int
    done: bool

# -----------------------------------------------------------------------------
# main execution function
# -----------------------------------------------------------------------------
def execute_calculation(opcode, in0, in1, in2):
    
    # python reference compute
    a = in2
    b = in1
    c = in0

    if opcode == 0:
        expected = b + c
    elif opcode == 1:
        expected = c - b
    elif opcode == 2:
        expected = c * b
    elif opcode == 3:
        expected = c >> b 
    elif opcode == 4:
        expected = b * b
    elif opcode == 5:
        expected = c * c * c
    elif opcode == 6:
        expected = a + b + c
    elif opcode == 7:
        expected = (5*a*a) + (8*a) - (4*b*b) + (3*b) + (6*c*c) - (2*c) + 13
    else:
        expected = 0

    # pack inputs
    inputs_packed = ((in2 & 0xff) << 16) | \
                    ((in1 & 0xff) << 8)  | \
                    (in0 & 0xff)

    word_out = (inputs_packed & MASK_24) | \
               ((opcode & MASK_3) << OPCODE_LSB)

    # write to hw
    mmio.write(DATA2, word_out & 0xffffffff)
    mmio.write(DATA3, (word_out >> 32) & 0xffffffff)

    # poll for completion
    start_time = time.time()
    while (time.time() - start_time) < 2.0:
        low_bits  = mmio.read(DATA2_RO)
        high_bits = mmio.read(DATA3_RO)

        full_word = (high_bits << 32) | (low_bits & 0xffffffff)
        done = (full_word >> DONE_BIT) & 0x1

        if done:
            raw_result = (full_word >> RESULT_LSB) & MASK_24
            actual = to_signed(raw_result, 24)
            delta = actual - expected

            # clean debug output
            print(f"--- opcode {opcode} check ---")
            print(f"inputs   : in0(c)={in0:<3} in1(b)={in1:<3} in2(a)={in2:<3}")
            print(f"expected : {expected}")
            print(f"actual   : {actual}")
            print(f"delta    : {delta}")
            if delta != 0:
                print(">> ERROR: MISMATCH DETECTED <<")
            print("-" * 30)

            return CalcResult(opcode, actual, True)

    print(f"--- opcode {opcode} check ---")
    print(">> TIMEOUT ERROR <<\n")
    return CalcResult(opcode, 0, False)


# -----------------------------------------------------------------------------
# smoke tests
# -----------------------------------------------------------------------------
test_cases = [
    (0,  3,  5,  7),
    (1, 10,  4,  5),
    (2,  8,  1,  6),
    (3, 15,  3,  1),
    (4,  9,  0,  0),
    (5,  6,  2,  4),
    (6, 12,  3,  2),
    (7,  3,  2,  1),
    (7,  5, 10,  3)
]

print("starting task 3 verification...\n" + "="*30)

for op, i0, i1, i2 in test_cases:
    execute_calculation(op, i0, i1, i2)