In [1]:
from __future__ import annotations

from xdsl.dialects.builtin import ModuleOp

from toy_to_riscv.helpers import (parse_toy, print_module, optimise_toy, lower_from_toy, 
                                  optimise_vir, lower_to_riscv, emulate_riscv)

from riscv.emulator_iop import run_riscv, print_riscv_ssa

import toy.dialect as td
import riscv.riscv_ssa as rd
import toy_to_riscv.dialect as trd

### WIP

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

### WIP

toy_0 = parse_toy(example)
toy_1 = optimise_toy(toy_0)
print_module(toy_1)
print()

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %2 = "toy.add"(%0, %1) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%2) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()



In [2]:
vir_0 = lower_from_toy(toy_1)
print_module(vir_0)
print()

vir_1 = optimise_vir(vir_0)
print_module(vir_1)
print()

"builtin.module"() ({
  "toy.func"() ({
    %0 = "riscv.toy.vector_constant"() {"data" = [#int<2>, #int<3>], "label" = "tensor_shape"} : () -> #riscv_ssa.reg
    %1 = "riscv.toy.vector_constant"() {"data" = [#int<1>, #int<2>, #int<3>, #int<4>, #int<5>, #int<6>], "label" = "tensor_data"} : () -> #riscv_ssa.reg
    %2 = "riscv.toy.tensor.make"(%0, %1) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
    %3 = "riscv.toy.vector_constant"() {"data" = [#int<2>, #int<3>], "label" = "tensor_shape"} : () -> #riscv_ssa.reg
    %4 = "riscv.toy.vector_constant"() {"data" = [#int<1>, #int<2>, #int<3>, #int<4>, #int<5>, #int<6>], "label" = "tensor_data"} : () -> #riscv_ssa.reg
    %5 = "riscv.toy.tensor.make"(%3, %4) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
    %6 = "riscv.toy.tensor.shape"(%2) : (#riscv_ssa.reg) -> #riscv_ssa.reg
    %7 = "riscv.toy.tensor.data"(%2) : (#riscv_ssa.reg) -> #riscv_ssa.reg
    %8 = "riscv.toy.tensor.data"(%5) : (#riscv_ssa.reg) -> #riscv_ssa.reg
    %9 

In [3]:
riscv_0 = lower_to_riscv(vir_1)
print_module(riscv_0)
print()

"builtin.module"() ({
  "riscv.section"() ({
    "riscv_ssa.label"() {"label" = #riscv.label<heap>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".space", "value" = "1024"} : () -> ()
  }) {"directive" = ".bss"} : () -> ()
  "riscv.section"() ({
    "riscv_ssa.label"() {"label" = #riscv.label<main.tensor_shape.0>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0x2, 0x2, 0x3"} : () -> ()
    "riscv_ssa.label"() {"label" = #riscv.label<main.tensor_data.0>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
    "riscv_ssa.label"() {"label" = #riscv.label<main.tensor_shape.1>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0x2, 0x2, 0x3"} : () -> ()
    "riscv_ssa.label"() {"label" = #riscv.label<main.tensor_data.1>} : () -> ()
    "riscv_ssa.directive"() {"directive" = ".word", "value" = "0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6"} : () -> ()
  }) {"directive" = ".d

In [4]:
code = print_riscv_ssa(riscv_0)
print(code)
print()

.bss 
heap:
.space 1024
.data 
main.tensor_shape.0:
.word 0x2, 0x2, 0x3
main.tensor_data.0:
.word 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6
main.tensor_shape.1:
.word 0x2, 0x2, 0x3
main.tensor_data.1:
.word 0x6, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6
.text 
main:
	li	%0, heap
	li	%1, main.tensor_shape.0
	li	%2, main.tensor_data.0
	li	%3, main.tensor_shape.1
	li	%4, main.tensor_data.1
	lw	%5, %2, 0		# Get input count
	addi	%6, %5, 4		# Input storage int32 count
	li	%7, 4
	mul	%8, %6, %7		# Alloc count bytes
	lw	%9, %0, 0		# Old heap count
	add	%10, %9, %8		# New heap count
	sw	%10, %0, 0		# Update heap
	addi	%11, %0, 4		# Heap storage start
	add	%12, %11, %9		# Allocated memory
	sw	%5, %12, 0		# Set result count
	addi	%13, %2, 4		# lhs storage
	addi	%14, %4, 4		# lhs storage
	addi	%15, %12, 4		# destination storage
	toy.buffer.add	%5, %13, %15
	toy.buffer.add	%5, %14, %15
	li	%16, 2
	li	%17, 4
	mul	%18, %16, %17		# Alloc count bytes
	lw	%19, %0, 0		# Old heap count
	add	%20, %19, %18		# New heap count
	s

In [5]:
emulate_riscv(code)

[34m[1m[CPU] Started running from example.asm:.text at heap (0x100) + 0x450[0m
Program(name=example.asm,sections=set(),base=['.bss', '.data', '.text'])
[34m[1m   Running 0x00000550:[0m li %0, heap
[34m[1m   Running 0x00000554:[0m li %1, main.tensor_shape.0
[34m[1m   Running 0x00000558:[0m li %2, main.tensor_data.0
[34m[1m   Running 0x0000055C:[0m li %3, main.tensor_shape.1
[34m[1m   Running 0x00000560:[0m li %4, main.tensor_data.1
[34m[1m   Running 0x00000564:[0m lw %5, %2, 0
[34m[1m   Running 0x00000568:[0m addi %6, %5, 4
[34m[1m   Running 0x0000056C:[0m li %7, 4
[34m[1m   Running 0x00000570:[0m mul %8, %6, %7
[34m[1m   Running 0x00000574:[0m lw %9, %0, 0
[34m[1m   Running 0x00000578:[0m add %10, %9, %8
[34m[1m   Running 0x0000057C:[0m sw %10, %0, 0
[34m[1m   Running 0x00000580:[0m addi %11, %0, 4
[34m[1m   Running 0x00000584:[0m add %12, %11, %9
[34m[1m   Running 0x00000588:[0m sw %5, %12, 0
[34m[1m   Running 0x0000058C:[0m addi %13,