In [1]:
import xdsl

from compiler import (parse_toy, optimise_toy, lower_from_toy, optimise_vir, lower_to_riscv, print_riscv_ssa, print_op)

program = """
def main() {
  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
  # The shape is inferred from the supplied literal.
  var a = [[1, 2, 3], [4, 5, 6]];

  # b is identical to a, the literal tensor is implicitly reshaped: defining new
  # variables is the way to reshape tensors (element count must match).
  var b<3, 2> = [1, 2, 3, 4, 5, 6];

  # There is a built-in print instruction to display the contents of the tensor
  print(b);

  # Reshapes are implicit on assignment
  var c<2, 3> = b;

  # There are + and * operators for pointwise addition and multiplication
  var d = a + c;

  print(d);
}
"""

toy_0 = parse_toy(program)
toy_1 = optimise_toy(toy_0)
vir_0 = lower_from_toy(toy_1)
vir_1 = optimise_vir(vir_0)
risc_0 = lower_to_riscv(vir_1)
code = print_riscv_ssa(risc_0)

In [2]:
print_op(toy_0)

"builtin.module"() ({
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %1 = "toy.constant"() {"value" = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>} : () -> tensor<6xi32>
    %2 = "toy.reshape"(%1) : (tensor<6xi32>) -> tensor<3x2xi32>
    "toy.print"(%2) : (tensor<3x2xi32>) -> ()
    %3 = "toy.reshape"(%2) : (tensor<3x2xi32>) -> tensor<2x3xi32>
    %4 = "toy.add"(%0, %3) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
    "toy.print"(%4) : (tensor<2x3xi32>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}) : () -> ()


In [3]:
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, TypeVar
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import SSAValue, Operation

import toy.dialect as toy

class EmulationError(Exception): pass

_OperationCovT = TypeVar('_OperationCovT', bound=Operation, covariant=True)
_OperationInvT = TypeVar('_OperationInvT', bound=Operation)


class BaseOperationEmulator(ABC):

    @classmethod
    @abstractmethod
    def run_generic(cls, emulator: Emulator, op: Operation):
        ...


class OperationEmulator(Generic[_OperationInvT], BaseOperationEmulator, ABC):

    op_type: type[_OperationInvT]

    @classmethod
    def run_generic(cls, emulator: Emulator, op: Operation):
        if isinstance(op, cls.op_type):
            inputs = ()
            results = cls.run(emulator, op, inputs)
        else:
            raise EmulationError(f'Unexpected operation type {type(op)}, expected {cls.op_type} for {type(cls)}')



    @classmethod
    @abstractmethod
    def run(cls, emulator: Emulator, op: _OperationInvT, args: tuple[Any, ...]) -> tuple[Any, ...]:
        ...

@dataclass
class EmulationContext:
    name: str = field(default="unknown")
    parent: EmulationContext | None = None
    env: dict[SSAValue, Any] = field(default_factory=dict)
    
    def __getitem__(self, key: SSAValue) -> SSAValue:
        if key in self.env:
            return self.env[key]
        if self.parent is not None:
            return self.parent[key]
        raise EmulationError(f'Could not find value for {key}')

    def __setitem__(self, key: SSAValue, value: Any):
        if key in self.env:
            raise EmulationError(
                f'Attempting to register SSAValue {value} for name {key}'
                ', but value with that name already exists')
        self.env[key] = value


@dataclass
class Emulator:

    module: ModuleOp
    _registered_ops: dict[type[Operation], OperationEmulator[Operation]] = field(default_factory=dict)
    context: EmulationContext = field(default_factory=lambda: EmulationContext(name='root'))

    def run(self, op: Operation):
        op_type = type(op)
        if op_type not in self._registered_ops:
            raise EmulationError(f'Could not find OperationEmulator for op {op}')
        
        emulator = self._registered_ops[op_type]

        type(emulator).run_generic(self, op)


def emulate_toy(module: ModuleOp):
    emulator = Emulator(module)
    
    for op in module.regions[0].blocks[0].ops:
        if isinstance(op, toy.FuncOp) and op.sym_name.data == 'main':
            emulator.run(op)
            return




In [4]:

@dataclass
class FunctionTable:
    functions: dict[type[Operation], Callable[[Emulator, Operation, tuple[Any, ...]], tuple[Any, ...]]] = field(default_factory=dict)

    def register_op(self, op_type: type[_OperationInvT], func: Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]):
        self.functions[op_type] = func # type: ignore

    def op_types(self) -> set[type[Operation]]:
        return set(self.functions.keys())

    def register(self, op_type: type[_OperationInvT]) -> Callable[[Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]], Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]]:
        def wrapper(func: Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]):
            self.register_op(op_type, func)
            return func
        return wrapper



toy_ft = FunctionTable()

@toy_ft.register(toy.PrintOp)
def run_print(emulator: Emulator, op: toy.PrintOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    print(args[0])
    return ()


from xdsl.ir import Block
from xdsl.dialects.builtin import i32

block = Block.from_arg_types([i32])
print_op = toy.PrintOp.from_input(block.args[0])

module_op = ModuleOp.from_region_or_ops([])

emulator = Emulator(module_op)

run_print(emulator, print_op, (2, ))



2


()