In [1]:
from riscv.riscv_ssa import *
from riscv.emulator_iop import *

In [2]:
from riscemu.IO import IOModule
from riscemu.types import T_RelativeAddress, T_AbsoluteAddress, UInt32
from riscemu.MMU import MMU

In [3]:
import numpy as np
from dataclasses import dataclass

In [8]:
class MatrixMulAcc(IOModule):
    """
    This represents a MatriMultiplication Accelerator interface that is present in shared
    memory. 
    
    You use it by writing:
        - pointer to matrix1, matrix2
        - rows, cols in matrix1
        - cols of matrix 2
        - a pointer to dest_matrix (address space can overlap with matrix1 and 2)
    
    In this order as words (4 byte segments), 
    beginning at the base address of the Module (we'll get to that).
    
    Then you write a nonzero value to the next address, and wait for a zero to be written
    """
    data = [
        0, # ptr_mtrx_1
        0, # ptr_mtrx_2
        0, # rows
        0, # cols
        0, # cols of m2
        0, # ptr_dest_mtrx
        0, # result_bit
    ]
    
    # make sure we're using numpys 32 bit little endian integers
    dtype = '<i32'
    
    def __init__(self, base: T_AbsoluteAddress, mmu: MMU):
        super(MatrixMulAcc, self).__init__(
            'MatrixMulAcc', 
            MemoryFlags(False, False), 
            size=7 * 4, 
            base=base
        )
        # reset backing buffer
        self.data = [0,0,0,0,0,0,0]
        self.mmu = mmu
        
        
    
    def write(addr: T_RelativeAddress, data: bytearray, size: int):
        assert size == 4
        assert addr % 4 == 0
        
        data_index = addr // 4
        
        if data_index == 7:
            # this is our "go" address, where we run the calculation
            self._run()
            self.data[7] == 1
            return
        
        # interpret bytes as uint32 and then get the pyhton int value
        data_as_uint32 = UInt32(data).value
        # save in data array
        self.data[data_index] = data_as_uint32
    
    def read(self, addr: T_RelativeAddress, size: int):
        assert size == 4
        index = addr // 4
        return self.data[index]
    
    def _run(self):
        # get all data from our backing array
        ptr_mtrx_1, ptr_mtrx_2, rows, cols, cols_m2, ptr_dest_mtrx = self.data

        mtrx1 = self._load_matrix_from_memory(ptr_mtrx_1, rows, cols)
        mtrx2 = self._load_matrix_from_memory(ptr_mtrx_2, cols_m2, rows)
        
        res = mtrx1 * mtrx2
        
        print(res)
        
        res_bytes = res.tobytes()
        assert len(res_bytes) == rows * cols * 4
        
        self.mmu.write(ptr_dest_mtrx, rows * cols * 4, bytearray(res_bytes))
        print("done!")
        
        
    def _load_matrix_from_memory(self, addr: Int32, rows: int, cols: int):
        buff = self.mmu.read(addr, rows * cols * 4)
        return np.frombuffer(buff, dtype=self.dtype).reshape((rows, cols))
        
        

TypeError: module() takes at most 2 arguments (3 given)