In [1]:
from __future__ import annotations

from pathlib import Path

from xdsl.ir import MLContext, Region, Dialect, SSAValue
from xdsl.dialects.builtin import ModuleOp
from xdsl.printer import Printer

from toy.dialect import Toy
from toy.helpers import parse as parse_toy, print_module

import riscv.riscv_ssa as riscv_d
from riscv.emulator_iop import run_riscv, print_riscv_ssa

### WIP

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

### WIP

from xdsl.ir import Operation
from xdsl.irdl import irdl_op_definition, Operand, OpResult

from typing import Annotated

from riscv.riscv_ssa import RegisterType

@irdl_op_definition
class PrintTensorOp(Operation):
    name = "riscv.toy.print"
    
    rs1: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, rs1: Operation | SSAValue) -> PrintTensorOp:
        return cls.build(operands=[rs1], result_types=[])

@irdl_op_definition
class AddTensorOp(Operation):
    name = "riscv.toy.add"
    
    rd: Annotated[OpResult, RegisterType]
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, lhs_reg: Operation | SSAValue, rhs_reg: Operation | SSAValue, heap_reg: Operation | SSAValue) -> AddTensorOp:
        return cls.build(operands=[lhs_reg, rhs_reg, heap_reg], result_types=[RegisterType()])

@irdl_op_definition
class ReshapeTensorOp(Operation):
    name = "riscv.toy.reshape"
    
    rd: Annotated[OpResult, RegisterType]
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, input_reg: Operation | SSAValue, shape_reg: Operation | SSAValue, heap_reg: Operation | SSAValue) -> ReshapeTensorOp:
        return cls.build(operands=[input_reg, shape_reg, heap_reg], result_types=[RegisterType()])

ToyRISCV = Dialect([PrintTensorOp, AddTensorOp, ReshapeTensorOp], [])

### WIP

context = MLContext()

context.register_dialect(Toy)
context.register_dialect(riscv_d.RISCVSSA)
context.register_dialect(ToyRISCV)

printer = Printer(target=Printer.Target.MLIR)

### WIP

from toy_to_riscv.accelerator import ToyAccelerator

### WIP

module = parse_toy(example)
print_module(module)
print()

### WIP

module = ModuleOp.from_region_or_ops([
    riscv_d.DataSectionOp.from_ops([
        riscv_d.LabelOp.get("main.a"),
        riscv_d.DirectiveOp.get(".word", "0xA, 0x2, 0x2, 0x3, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
        riscv_d.LabelOp.get("main.b"),
        riscv_d.DirectiveOp.get(".word", "0x9, 0x1, 0x6, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
        riscv_d.LabelOp.get("main.0"),
        riscv_d.DirectiveOp.get(".word", "0x2, 0x2, 0x3"),
    ]),
    riscv_d.FuncOp.from_ops('main', [
        heap := riscv_d.LIOp.get('heap'),
        a := riscv_d.LIOp.get('main.a'),
        b := riscv_d.LIOp.get('main.b'),
        c_shape := riscv_d.LIOp.get('main.0'),
        c := ReshapeTensorOp.get(b, c_shape, heap),
        d := AddTensorOp.get(a, c, heap),
        PrintTensorOp.get(d),
        riscv_d.ReturnOp.get(),
    ])
])

print_module(module)
print()

### WIP

code = print_riscv_ssa(module)
print(code)
print()

### WIP

import contextlib, io

f = io.StringIO()
with contextlib.redirect_stdout(f):
    run_riscv(code, extensions=[ToyAccelerator], unlimited_regs=True)
output = f.getvalue()

print(output)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.reshape"(%0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
    %2 = "toy.constant"() {"value" = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>} : () -> tensor<6xi32>
    %3 = "toy.reshape"(%2) : (tensor<6xi32>) -> tensor<6xi32>
    %4 = "toy.reshape"(%3) : (tensor<6xi32>) -> tensor<2x3xi32>
    %5 = "toy.add"(%1, %4) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%5) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()

"builtin.module"() ({
  "riscv.data_section"() ({
    "riscv_ssa.label"() {"label" = #riscv.label<main.a>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0xA, 0x2, 0x2, 0x3, 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
    "riscv_ssa.label"() {"label" = #riscv.label<main.b>} : () -> ()
