In [21]:
from __future__ import annotations

from pathlib import Path

from xdsl.ir import MLContext, Region, Dialect
from xdsl.dialects.builtin import ModuleOp
from xdsl.printer import Printer

from toy.dialect import Toy
from toy.helpers import parse as parse_toy, print_module

from riscv.riscv_ssa import LabelOp, LIOp, MULOp, AddOp, ECALLOp, RISCVSSA, DirectiveOp, LWOp, PrintOp
from riscv.emulator_iop import run_riscv, print_riscv_ssa

### WIP

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  # var b<2, 3> = [1, 2, 3, 4, 5, 6];
  # var c = a + b;
  print(a);
}
"""

### WIP

from xdsl.ir import Operation
from xdsl.irdl import irdl_op_definition, Operand

from typing import Annotated

from riscv.riscv_ssa import RegisterType

@irdl_op_definition
class PrintTensorOp(Operation):
    name = "riscv.toy.print"
    
    rs1: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, rs1: RegisterType) -> PrintTensorOp:
        """
        This is a little helper function, to help us construct an fmadd operation
        """
        return cls.build(operands=[rs1], result_types=[])

ToyRISCV = Dialect([PrintTensorOp], [])

### WIP

context = MLContext()

context.register_dialect(Toy)
context.register_dialect(RISCVSSA)
context.register_dialect(ToyRISCV)

printer = Printer(target=Printer.Target.MLIR)

### WIP

from riscemu.instructions import InstructionSet, Instruction
from riscemu.MMU import MMU
from riscemu.types import Int32

from typing import cast

from functools import reduce

def tensor_description(shape: list[int], data: list[int]) -> str:
    if len(shape) == 1:
        return str(data)
    if len(shape):
        size = reduce(lambda acc, el: acc * el, shape[1:], 1)
        return f'[{", ".join(tensor_description(shape[1:], data[start:start+size]) for start in range(0, size * shape[0], size))}]'
    else:
        return '[]'

print(tensor_description([2, 3], [1,2,3,4,5,6]))

# Define a RISC-V ISA extension by subclassing InstructionSet
class ToyAccelerator(InstructionSet):
    # each method beginning with instruction_ will be available to the Emulator
    
    def read_ptr(self, ptr: int, /, offset: int = 0) -> int:
        mmu = cast(MMU, self.mmu)
        byte_array = mmu.read(ptr + offset * 4, 4)
        return int.from_bytes(byte_array, byteorder="little")

    def read_buffer(self, ptr: int, len: int, /, offset: int = 0) -> list[int]:
        return [
            self.read_ptr(ptr, offset) for offset in range(offset, offset+len)
        ]

    def instruction_toy_print(self, ins: Instruction):
        """
        The tensor is represented as a pointer to an array with the following layout
        [ 2,      2, 3,       6,   1, 2, 3, 4, 5, 6]
        [rank, ...shape..., count,    ...data...   ]

        Where rank is the length of the shape subarray, and count is the length of data.

        This instruction prints a formatted tensor
        [[1, 2, 3], [4, 5, 6]]
        """
        # get the input register
        t_ptr_reg = ins.get_reg(0)
        t_ptr = int(self.regs.get(t_ptr_reg))

        t_rank = self.read_ptr(t_ptr)
        t_shape = self.read_buffer(t_ptr, t_rank, offset=1)
        t_len = self.read_ptr(t_ptr, offset=1 + t_rank)
        t_data = self.read_buffer(t_ptr, t_len, offset=1 + t_rank + 1)
        
        # print(t_rank, t_shape, t_len, t_data)
        print(tensor_description(t_shape, t_data))

        

### WIP

module = parse_toy(example)
print_module(module)
print()

### WIP

module = ModuleOp.from_region_or_ops([
    DirectiveOp.get(".bss", ""), # bss stands for block starting symbol
    LabelOp.get("heap"),
    DirectiveOp.get(".space", "100"),    
    DirectiveOp.get(".data", ""),
    LabelOp.get("main.a"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x3"),
    DirectiveOp.get(".word", "0x6"),
    DirectiveOp.get(".word", "0x1"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x3"),
    DirectiveOp.get(".word", "0x4"),
    DirectiveOp.get(".word", "0x5"),
    DirectiveOp.get(".word", "0x6"),
    DirectiveOp.get(".text", ""),
    LabelOp.get('main'),
    a0  := LIOp.get(83),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    PrintOp.get(add), # debug instruction to print register contents
    data_ptr := LIOp.get("main.a"),
    PrintTensorOp.get(data_ptr),
    PrintOp.get(data_ptr),
    data := LWOp.get(data_ptr, 0),
    PrintOp.get(data),
    # perform the "exit" syscall, opcode 93
    ECALLOp.get(93)
])

### WIP

code = print_riscv_ssa(module)
print(code)
print()

# assert code == """.text
# .bss 
# heap:
# .space 100
# .data 
# main.a:
# .word 0x2
# .word 0x2
# .word 0x3
# .word 0x6
# .word 0x1
# .word 0x2
# .word 0x3
# .word 0x4
# .word 0x5
# .word 0x6
# .text 
# main:
# 	li	%0, 83
# 	li	%1, 5
# 	mul	%2, %0, %1
# 	li	%3, 10
# 	add	%4, %2, %3
# 	print	%4
# 	li	%5, main.a
# 	print_tensor %5
# 	print	%5
# 	lw	%6, %5, 0
# 	print	%6
# 	li	a7, 93
# 	scall
# """

### WIP

run_riscv(print_riscv_ssa(module), extensions=[ToyAccelerator], unlimited_regs=True)

[[1, 2, 3], [4, 5, 6]]
"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.reshape"(%0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%1) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()

.bss 
heap:
.space 100
.data 
main.a:
.word 0x2
.word 0x2
.word 0x3
.word 0x6
.word 0x1
.word 0x2
.word 0x3
.word 0x4
.word 0x5
.word 0x6
.text 
main:
	li	%0, 83
	li	%1, 5
	mul	%2, %0, %1
	li	%3, 10
	add	%4, %2, %3
	print	%4
	li	%5, main.a
	toy.print	%5
	print	%5
	lw	%6, %5, 0
	print	%6
	li	a7, 93
	scall


[34m[1m[CPU] Started running from example.asm:.text at heap (0x100) + 0x8c[0m
Program(name=example.asm,sections=set(),base=['.bss', '.data', '.text'])
[34m[1m   Running 0x0000018C:[0m li %0, 83
[34m[1m   Running 0x00000190:[0m li %1, 5
[34m[1m   Running 0x00000194:[0m mul %2, %0, %1


In [7]:
code

'.bss \nheap:\n.space 100\n.data \nmain.a:\n.word 0x2\n.word 0x2\n.word 0x3\n.word 0x6\n.word 0x1\n.word 0x2\n.word 0x3\n.word 0x4\n.word 0x5\n.word 0x6\n.text \nmain:\n\tli\t%0, 83\n\tli\t%1, 5\n\tmul\t%2, %0, %1\n\tli\t%3, 10\n\tadd\t%4, %2, %3\n\tprint\t%4\n\tli\t%5, main.a\n\ttoy.print\t%5\n\tprint\t%5\n\tlw\t%6, %5, 0\n\tprint\t%6\n\tli\ta7, 93\n\tscall\n'