In [1]:
from __future__ import annotations

from pathlib import Path

from xdsl.ir import MLContext, Region, Dialect, SSAValue
from xdsl.dialects.builtin import ModuleOp
from xdsl.printer import Printer

from toy.dialect import Toy
from toy.helpers import parse as parse_toy, print_module

import riscv.riscv_ssa as riscv_d
from riscv.emulator_iop import run_riscv, print_riscv_ssa

### WIP

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<2, 3> = [1, 2, 3, 4, 5, 6];
  var c = a + b;
  print(c);
}
"""

### WIP

from xdsl.ir import Operation
from xdsl.irdl import irdl_op_definition, Operand, OpResult

from typing import Annotated

from riscv.riscv_ssa import RegisterType

@irdl_op_definition
class PrintTensorOp(Operation):
    name = "riscv.toy.print"
    
    rs1: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, rs1: Operation | SSAValue) -> PrintTensorOp:
        return cls.build(operands=[rs1], result_types=[])

@irdl_op_definition
class AddTensorOp(Operation):
    name = "riscv.toy.add"
    
    rd: Annotated[OpResult, RegisterType]
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, lhs_reg: Operation | SSAValue, rhs_reg: Operation | SSAValue, heap_reg: Operation | SSAValue) -> AddTensorOp:
        return cls.build(operands=[lhs_reg, rhs_reg, heap_reg], result_types=[RegisterType()])

ToyRISCV = Dialect([PrintTensorOp, AddTensorOp], [])

### WIP

context = MLContext()

context.register_dialect(Toy)
context.register_dialect(riscv_d.RISCVSSA)
context.register_dialect(ToyRISCV)

printer = Printer(target=Printer.Target.MLIR)

### WIP

from riscemu.instructions import InstructionSet, Instruction
from riscemu.MMU import MMU
from riscemu.types import Int32

from typing import cast

from functools import reduce

def tensor_description(shape: list[int], data: list[int]) -> str:
    if len(shape) == 1:
        return str(data)
    if len(shape):
        size = reduce(lambda acc, el: acc * el, shape[1:], 1)
        return f'[{", ".join(tensor_description(shape[1:], data[start:start+size]) for start in range(0, size * shape[0], size))}]'
    else:
        return '[]'

# Define a RISC-V ISA extension by subclassing InstructionSet
class ToyAccelerator(InstructionSet):
    # each method beginning with instruction_ will be available to the Emulator
    
    def ptr_read(self, ptr: int, /, offset: int = 0) -> int:
        mmu = cast(MMU, self.mmu)
        byte_array = mmu.read(ptr + offset * 4, 4)
        return int.from_bytes(byte_array, byteorder="little")
    
    def ptr_write(self, ptr: int, /, value: int, offset: int = 0):
        mmu = cast(MMU, self.mmu)
        byte_array = bytearray(value.to_bytes(4, byteorder="little"))
        mmu.write(ptr + offset * 4, 4, byte_array)

    def buffer_read(self, ptr: int, len: int, /, offset: int = 0) -> list[int]:
        return [
            self.ptr_read(ptr, offset) for offset in range(offset, offset+len)
        ]

    def buffer_write(self, ptr: int, /, data: list[int], offset: int = 0):
        for i, value in enumerate(data):
            self.ptr_write(ptr, value=value, offset=offset+i)

    def buffer_copy(self, /, source: int, destination: int, count: int):
        mmu = cast(MMU, self.mmu)
        mmu.write(destination, count * 4, mmu.read(source, count * 4))

    # Vector helpers

    # A vector is represented as an array of ints, where the first int is the count:
    # [] -> [0]
    # [1] -> [1, 1]
    # [1, 2, 3] -> [3, 1, 2, 3]

    def vector_count(self, ptr: int) -> int:
        return self.ptr_read(ptr)
            
    def vector_data(self, ptr: int) -> list[int]:
        count = self.vector_count(ptr)
        return self.buffer_read(ptr, count, offset=1)

    def vector_end(self, ptr: int) -> int:
        return ptr + (1 + self.ptr_read(ptr)) * 4

    def vector_add(self, lhs: int, rhs: int):
        '''lhs += rhs'''
        count = self.vector_count(lhs)
        data = [l + r for (l, r) in zip(self.vector_data(lhs), self.vector_data(rhs))]
        self.buffer_write(lhs, data=data, offset=1)

    # Heap helpers
    
    # The heap pointer is the address of the start of the heap, and contains the count
    # of remaining allocated elements. Defaults to 0. This means that it can
    # be used as an append-only vector.

    def alloc(self, heap_ptr: int, /, count: int) -> int:
        result = self.vector_end(heap_ptr)
        heap_size = self.vector_count(heap_ptr)
        self.ptr_write(heap_ptr, value=heap_size + count * 4)
        return result

    def vector_copy(self, ptr: int, /, heap_ptr: int) -> int:
        storage_len = self.vector_count(ptr) + 1
        new = self.alloc(heap_ptr, count=storage_len)
        self.buffer_copy(source=ptr, destination=new, count=storage_len)
        return new

    # Tensor helpers

    # The tensor is represented as a vector, containing two concatenated vectors:
    # the shape, followed by the data
    # [] -> [2, 0, 0] (rank: 0, shape: [], count: 0, data: [])
    # [1, 2] -> [5, 1, 2, 2, 1, 2] (rank: 1, shape: [2], count: 2, data: [1, 2])
    # [[1, 2, 3], [4, 5, 6]] 
    #   -> [10, 2, 2, 3, 6, 1, 2, 3, 4, 5, 6] (
    #       rank: 2, 
    #       shape: [2, 3], 
    #       count: 2, 
    #       data: [1, 2, 3, 4, 5, 6]
    #   )

    # Where rank is the length of the shape subarray, and count is the length of data.

    def tensor_shape_array(self, ptr: int) -> int:
        return ptr + 4

    def tensor_rank(self, ptr: int) -> int:
        return self.vector_count(self.tensor_shape_array(ptr))

    def tensor_shape(self, ptr: int) -> list[int]:
        return self.vector_data(self.tensor_shape_array(ptr))

    def tensor_data_array(self, ptr: int) -> int:
        return self.vector_end(self.tensor_shape_array(ptr))

    def tensor_count(self, ptr: int):
        return self.vector_count(self.tensor_data_array(ptr))

    def tensor_data(self, ptr: int) -> list[int]:
        return self.vector_data(self.tensor_data_array(ptr))

    def tensor_storage_len(self, ptr: int):
        '''
        rank + count + 2
        '''
        return self.vector_count(ptr)

    def tensor_copy(self, ptr: int, /, heap_ptr: int) -> int:
        return self.vector_copy(ptr, heap_ptr=heap_ptr)

    def tensor_add(self, lhs: int, rhs: int):
        '''lhs += rhs'''
        self.vector_add(self.tensor_data_array(lhs), self.tensor_data_array(rhs))

    # Custom instructions

    def instruction_toy_print(self, ins: Instruction):
        """
        This instruction prints a formatted tensor
        [[1, 2, 3], [4, 5, 6]]
        """
        # get the input register
        t_ptr_reg = ins.get_reg(0)
        t_ptr = int(self.regs.get(t_ptr_reg))

        shape = self.tensor_shape(t_ptr)
        data = self.tensor_data(t_ptr)
        
        print(tensor_description(shape, data))

    def instruction_toy_add(self, ins: Instruction):
        """
        This instruction allocates a tensor with the same shape as the inputs, and stores
        the pointwise sum. No checks about validity of inputs are made.
        """

        destination_ptr_reg = ins.get_reg(0)
        lhs_ptr_reg = ins.get_reg(1)
        rhs_ptr_reg = ins.get_reg(2)
        heap_ptr_reg = ins.get_reg(3)

        l_ptr = int(self.regs.get(lhs_ptr_reg))
        r_ptr = int(self.regs.get(rhs_ptr_reg))
        h_ptr = int(self.regs.get(heap_ptr_reg))
        
        l_shape = self.tensor_shape(l_ptr)
        r_shape = self.tensor_shape(r_ptr)

        assert l_shape == r_shape

        d_ptr = self.tensor_copy(l_ptr, heap_ptr=h_ptr)
        self.tensor_add(d_ptr, r_ptr)

        self.regs.set(destination_ptr_reg, Int32(d_ptr))


### WIP

module = parse_toy(example)
print_module(module)
print()

### WIP

module = ModuleOp.from_region_or_ops([
    riscv_d.DataSectionOp.from_ops([
        riscv_d.LabelOp.get("main.a"),
        riscv_d.DirectiveOp.get(".word", "0xA, 0x2, 0x2, 0x3, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
        riscv_d.LabelOp.get("main.b"),
        riscv_d.DirectiveOp.get(".word", "0xA, 0x2, 0x2, 0x3, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
    ]),
    riscv_d.FuncOp.from_ops('main', [
        heap := riscv_d.LIOp.get('heap'),
        a := riscv_d.LIOp.get('main.a'),
        b := riscv_d.LIOp.get('main.b'),
        c := AddTensorOp.get(a, b, heap),
        PrintTensorOp.get(c),
        riscv_d.ReturnOp.get(),
    ])
])

print_module(module)
print()

### WIP

code = print_riscv_ssa(module)
print(code)
print()

### WIP

import contextlib, io

f = io.StringIO()
with contextlib.redirect_stdout(f):
    run_riscv(code, extensions=[ToyAccelerator], unlimited_regs=True)
output = f.getvalue()

print(output)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.reshape"(%0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
    %2 = "toy.constant"() {"value" = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>} : () -> tensor<6xi32>
    %3 = "toy.reshape"(%2) : (tensor<6xi32>) -> tensor<2x3xi32>
    %4 = "toy.add"(%1, %3) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%4) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()

"builtin.module"() ({
  "riscv.data_section"() ({
    "riscv_ssa.label"() {"label" = #riscv.label<main.a>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0xA, 0x2, 0x2, 0x3, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
    "riscv_ssa.label"() {"label" = #riscv.label<main.b>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = 