In [3]:
from pathlib import Path

from xdsl.ir import MLContext, Region
from xdsl.dialects.builtin import ModuleOp
from xdsl.printer import Printer

from toy.dialect import Toy
from toy.helpers import parse as parse_toy, print_module

from riscv.riscv_ssa import LabelOp, LIOp, MULOp, AddOp, ECALLOp, RISCVSSA, DirectiveOp, LWOp, PrintOp
from riscv.emulator_iop import run_riscv, print_riscv_ssa

### WIP

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  # var b<2, 3> = [1, 2, 3, 4, 5, 6];
  # var c = a + b;
  print(a);
}
"""

### WIP

context = MLContext()

context.register_dialect(Toy)
context.register_dialect(RISCVSSA)

printer = Printer(target=Printer.Target.MLIR)

### WIP

module = parse_toy(example)
print_module(module)
print()


### WIP

module = ModuleOp.from_region_or_ops([
    DirectiveOp.get(".bss", ""), # bss stands for block starting symbol
    LabelOp.get("heap"),
    DirectiveOp.get(".space", "100"),    
    DirectiveOp.get(".data", ""),
    LabelOp.get("main.a"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x3"),
    DirectiveOp.get(".word", "0x6"),
    DirectiveOp.get(".word", "0x1"),
    DirectiveOp.get(".word", "0x2"),
    DirectiveOp.get(".word", "0x3"),
    DirectiveOp.get(".word", "0x4"),
    DirectiveOp.get(".word", "0x5"),
    DirectiveOp.get(".word", "0x6"),
    DirectiveOp.get(".text", ""),
    LabelOp.get('main'),
    a0  := LIOp.get(83),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    PrintOp.get(add), # debug instruction to print register contents
    data_ptr := LIOp.get("main.a"),
    PrintOp.get(data_ptr),
    data := LWOp.get(data_ptr, 0),
    PrintOp.get(data),
    # perform the "exit" syscall, opcode 93
    ECALLOp.get(93)
])

### WIP

print(print_riscv_ssa(module))
print()

### WIP

run_riscv(print_riscv_ssa(module), unlimited_regs=True)


"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.reshape"(%0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%1) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()

.text
.bss 
heap:
.space 100
.data 
main.a:
.word 0x2
.word 0x2
.word 0x3
.word 0x6
.word 0x1
.word 0x2
.word 0x3
.word 0x4
.word 0x5
.word 0x6
.text 
main:
	li	%0, 83
	li	%1, 5
	mul	%2, %0, %1
	li	%3, 10
	add	%4, %2, %3
	print	%4
	li	%5, main.a
	print	%5
	lw	%6, %5, 0
	print	%6
	li	a7, 93
	scall


[34m[1m[CPU] Started running from example.asm:.text at heap (0x100) + 0x8c[0m
Program(name=example.asm,sections=set(),base=['.text', '.bss', '.data', '.text'])
[34m[1m   Running 0x0000018C:[0m li %0, 83
[34m[1m   Running 0x00000190:[0m li %1, 5
[34m[1m   Running 0x00000194:[0m mul %2, %0, %1
[34m[1m   Running 0x

In [15]:
from xdsl.ir import Operation, SSAValue

def bla(ops: tuple[Operation, ...]) -> ModuleOp:
    return ModuleOp.from_region_or_ops(list(ops))

module = bla((
    LabelOp.get('main'),
    a0  := LIOp.get(83),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    PrintOp.get(add), # debug instruction to print register contents
    # perform the "exit" syscall, opcode 93
    ECALLOp.get(93)
))

risc_v_module = bla((
    DirectiveOp.get(".bss", ""), # bss is standard name for heap
    LabelOp.get("heap"),
    DirectiveOp.get(".space", "16k"),
    DirectiveOp.get(".data", ""),
    LabelOp.get("main.a.data"),
    DirectiveOp.get(".words", "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
    LabelOp.get("main.a.shape"),
    DirectiveOp.get(".words", "0x2, 0x2, 0x3"),
    LabelOp.get("main.b.data"),
    DirectiveOp.get(".words", "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"),
    LabelOp.get("main.b.shape"),
    DirectiveOp.get(".text", ""),
    LabelOp.get('main'),
    a_shape := LIOp.get("main.a.shape"),
    a_data := LIOp.get("main.a.data"),
    a_rank := LWOp.get(a_shape, 0),
    a_count := LWOp.get(a_data, 0),
    a0  := LIOp.get(82),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    ECALLOp.get(93, add)
))

printer.print(risc_v_module)

"builtin.module"() ({
  "riscv_ssa.directive"() {"label" = ".bss", "value" = ""} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<heap>} : () -> ()
  "riscv_ssa.directive"() {"label" = ".space", "value" = "16k"} : () -> ()
  "riscv_ssa.directive"() {"label" = ".data", "value" = ""} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<main.a.data>} : () -> ()
  "riscv_ssa.directive"() {"label" = ".words", "value" = "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<main.a.shape>} : () -> ()
  "riscv_ssa.directive"() {"label" = ".words", "value" = "0x2, 0x2, 0x3"} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<main.b.data>} : () -> ()
  "riscv_ssa.directive"() {"label" = ".words", "value" = "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<main.b.shape>} : () -> ()
  "riscv_ssa.directive"() {"label" = ".text", "value" = ""} : () -> ()
  "riscv_ssa.label"() {"label" = #riscv.label<