In [1]:
import xdsl

from compiler import (parse_toy, optimise_toy, print_op)

program = """

def multiply_transpose(a, b) {
  return transpose(a) * transpose(b);
}

def main() {
  var a = [[1, 2, 3], [4, 5, 6]];
  var b<3, 2> = [1, 2, 3, 4, 5, 6];
  print(b);
  var c<2, 3> = b;
  var d = multiply_transpose(a, c);
  print(d);
}
"""

toy_0 = parse_toy(program)
toy_1 = optimise_toy(toy_0)


In [2]:
print_op(toy_0)

"builtin.module"() ({
  "toy.func"() ({
  ^0(%0 : tensor<*xi32>, %1 : tensor<*xi32>):
    %2 = "toy.transpose"(%0) : (tensor<*xi32>) -> tensor<*xi32>
    %3 = "toy.transpose"(%1) : (tensor<*xi32>) -> tensor<*xi32>
    %4 = "toy.mul"(%2, %3) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
    "toy.return"(%4) : (tensor<*xi32>) -> ()
  }) {"sym_name" = "multiply_transpose", "function_type" = (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>, "sym_visibility" = "private"} : () -> ()
  "toy.func"() ({
    %5 = "toy.constant"() {"value" = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
    %6 = "toy.constant"() {"value" = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>} : () -> tensor<6xi32>
    %7 = "toy.reshape"(%6) : (tensor<6xi32>) -> tensor<3x2xi32>
    "toy.print"(%7) : (tensor<3x2xi32>) -> ()
    %8 = "toy.reshape"(%7) : (tensor<3x2xi32>) -> tensor<2x3xi32>
    %9 = "toy.generic_call"(%5, %8) {"callee" = @multiply_transpose} : (tensor<2x3xi32>, tensor<2x3xi32>) -> 

In [3]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Generator, TypeVar
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import SSAValue, Operation

import toy.dialect as toy

_OperationInvT = TypeVar('_OperationInvT', bound=Operation)

class EmulationError(Exception): pass

@dataclass
class FunctionTable:
    functions: dict[type[Operation], Callable[[Emulator, Operation, tuple[Any, ...]], tuple[Any, ...]]] = field(default_factory=dict)

    def register_op(self, op_type: type[_OperationInvT], func: Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]):
        self.functions[op_type] = func # type: ignore

    def op_types(self) -> set[type[Operation]]:
        return set(self.functions.keys())

    def register(self, op_type: type[_OperationInvT]) -> Callable[[Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]], Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]]:
        def wrapper(func: Callable[[Emulator, _OperationInvT, tuple[Any, ...]], tuple[Any, ...]]):
            self.register_op(op_type, func)
            return func
        return wrapper

    def run(self, emulator: Emulator, op: Operation, args: tuple[Any, ...]) -> tuple[Any, ...]:
        return self.functions[type(op)](emulator, op, args)

    def register_from(self, other: FunctionTable):
        '''If there are duplicate definitions, the `other` will override `self`'''
        self.functions.update(other.functions)


@dataclass
class EmulationContext:
    name: str = field(default="unknown")
    parent: EmulationContext | None = None
    env: dict[SSAValue, Any] = field(default_factory=dict)
    
    def __getitem__(self, key: SSAValue) -> SSAValue:
        if key in self.env:
            return self.env[key]
        if self.parent is not None:
            return self.parent[key]
        raise EmulationError(f'Could not find value for {key}')

    def __setitem__(self, key: SSAValue, value: Any):
        if key in self.env:
            raise EmulationError(
                f'Attempting to register SSAValue {value} for name {key}'
                ', but value with that name already exists')
        self.env[key] = value

    def stack(self) -> Generator[EmulationContext, None, None]:
        if self.parent is not None:
            yield from self.parent.stack()
        yield self

    def stack_dscription(self) -> str:
        return '/'.join(c.name for c in self.stack())
        


@dataclass
class Emulator:

    module: ModuleOp
    _function_table: FunctionTable = field(default_factory=FunctionTable)
    _context: EmulationContext = field(default_factory=lambda: EmulationContext(name='root'))

    def get_values(self, values: tuple[SSAValue, ...]) -> tuple[Any, ...]:
        return tuple(self._context[value] for value in values)

    def set_values(self, ssa_values: tuple[SSAValue, ...], result_values: tuple[Any, ...]):
        self._assert(len(ssa_values) == len(result_values), f'{[f"{ssa_value}" for ssa_value in ssa_values]}, {result_values}')
        for ssa_value, result_value in zip(ssa_values, result_values):
            self._context[ssa_value] = result_value

    
    def push_context(self, name: str='child') -> None:
        self._context = EmulationContext(name, self._context)

    
    def pop_context(self) -> None:
        if self._context.parent is None:
            raise EmulationError('Attempting to pop root env')
        
        self._context = self._context.parent


    def register_functions(self, funcs: FunctionTable) -> None:
        self._function_table.register_from(funcs)

    def run(self, op: Operation):
        op_type = type(op)
        if op_type not in self._function_table.functions:
            raise EmulationError(f'Could not find OperationEmulator for op {op.name}')
        
        inputs = self.get_values(op.operands)
        results = self._function_table.run(self, op, inputs)
        self.set_values(tuple(op.results), results)
    
    def _assert(self, condition: bool, message: str | None = None):
        assert condition, f'({self._context.stack_dscription()})({message})'



In [4]:
from xdsl.dialects.builtin import TensorType, VectorType

def run_toy_func(emulator: Emulator, name: str, args: tuple[Any, ...]) -> tuple[Any, ...]:
    for op in emulator.module.regions[0].blocks[0].ops:
        if isinstance(op, toy.FuncOp) and op.sym_name.data == name:
            return run_func(emulator, op, args)
    
    raise EmulationError(f'Could not find toy function with name: {name}')


toy_ft = FunctionTable()

@dataclass
class Tensor:
    data: list[int]
    shape: list[int]

@toy_ft.register(toy.PrintOp)
def run_print(emulator: Emulator, op: toy.PrintOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    print(args[0])
    return ()


@toy_ft.register(toy.FuncOp)
def run_func(emulator: Emulator, op: toy.FuncOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    emulator.push_context(f'ctx_{op.sym_name.data}')
    block = op.body.blocks[0]
    emulator.set_values(block.args, args)
    for body_op in block.ops:
        emulator.run(body_op)
    assert isinstance(block.ops[-1], toy.ReturnOp)
    results = emulator.get_values(tuple(block.ops[-1].operands))
    emulator.pop_context()
    return results

@toy_ft.register(toy.ConstantOp)
def run_const(emulator: Emulator, op: toy.ConstantOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    assert not len(args)
    data = op.get_data()
    shape = op.get_shape()
    result = Tensor(data, shape)
    return result, 

@toy_ft.register(toy.ReshapeOp)
def run_reshape(emulator: Emulator, op: toy.ReshapeOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    arg, = args
    assert isinstance(arg, Tensor)
    result_typ = op.results[0].typ
    assert isinstance(result_typ, VectorType | TensorType)
    new_shape = result_typ.get_shape()

    return Tensor(arg.data, new_shape), 

@toy_ft.register(toy.AddOp)
def run_add(emulator: Emulator, op: toy.AddOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    lhs, rhs = args
    assert isinstance(lhs, Tensor)
    assert isinstance(rhs, Tensor)
    assert lhs.shape == rhs.shape

    return Tensor([l + r for l, r in zip(lhs.data, rhs.data)], lhs.shape),

@toy_ft.register(toy.MulOp)
def run_mul(emulator: Emulator, op: toy.MulOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    lhs, rhs = args
    assert isinstance(lhs, Tensor)
    assert isinstance(rhs, Tensor)
    assert lhs.shape == rhs.shape

    return Tensor([l * r for l, r in zip(lhs.data, rhs.data)], lhs.shape),

@toy_ft.register(toy.ReturnOp)
def run_return(emulator: Emulator, op: toy.ReturnOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    assert len(args) < 2
    return ()

@toy_ft.register(toy.GenericCallOp)
def run_generic_call(emulator: Emulator, op: toy.GenericCallOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    return run_toy_func(emulator, op.callee.data.data, args)
    

@toy_ft.register(toy.TransposeOp)
def run_transpose(emulator: Emulator, op: toy.TransposeOp, args: tuple[Any, ...]) -> tuple[Any, ...]:
    arg, = args
    assert isinstance(arg, Tensor)
    assert len(arg.shape) == 2

    cols = arg.shape[0]
    rows = arg.shape[1]

    new_data = [
        arg.data[row * cols + col]
        for col in range(cols)
        for row in range(rows)
    ]

    result = Tensor(new_data, arg.shape[::-1])

    return result,



def emulate_toy(module: ModuleOp):
    emulator = Emulator(module)
    emulator.register_functions(toy_ft)
    run_toy_func(emulator, 'main', ())


print()

emulate_toy(toy_0)

print()

# emulate_toy(toy_1)

# print()



Tensor(data=[1, 2, 3, 4, 5, 6], shape=[3, 2])
Tensor(data=[1, 9, 25, 4, 16, 36], shape=[3, 2])

