In [1]:
from __future__ import annotations

from oqd_compiler_infrastructure import TypeReflectBaseModel

In [2]:
class MyProgram(TypeReflectBaseModel):
    expr: MyExpr

In [3]:
class MyExpr(TypeReflectBaseModel):
    def __add__(self, other):
        return MyAdd(left=self, right=other)

    def __mul__(self, other):
        return MyMul(left=self, right=other)

    def __pow__(self, other):
        return MyPow(left=self, right=other)

In [4]:
class MyInt(MyExpr):
    value: int

In [5]:
class MyAdd(MyExpr):
    left: MyExpr
    right: MyExpr

In [6]:
class MyMul(MyExpr):
    left: MyExpr
    right: MyExpr

In [7]:
class MyPow(MyExpr):
    left: MyExpr
    right: MyExpr

In [8]:
prog = MyInt(value=1) + MyInt(value=2)

prog

MyAdd(class_='MyAdd', left=MyInt(class_='MyInt', value=1), right=MyInt(class_='MyInt', value=2))

In [9]:
from oqd_compiler_infrastructure import (
    Chain,
    ConversionRule,
    FixedPoint,
    Post,
    PrettyPrint,
    RewriteRule,
)

In [10]:
class Associativity(RewriteRule):
    def map_MyAdd(self, model):
        if isinstance(model.right, MyAdd):
            return MyAdd(
                left=MyAdd(left=model.left, right=model.right.left),
                right=model.right.right,
            )

    def map_MyMul(self, model):
        if isinstance(model.right, MyMul):
            return MyMul(
                left=MyMul(left=model.left, right=model.right.left),
                right=model.right.right,
            )


class Distribution(RewriteRule):
    def map_MyMul(self, model):
        if isinstance(model.left, MyAdd):
            return MyAdd(
                left=MyMul(left=model.left.left, right=model.right),
                right=MyMul(left=model.left.right, right=model.right),
            )
        if isinstance(model.right, MyAdd):
            return MyAdd(
                left=MyMul(left=model.left, right=model.right.left),
                right=MyMul(left=model.left, right=model.right.right),
            )


canonicalization_pass = Chain(
    FixedPoint(Post(Associativity())),
    FixedPoint(Post(Distribution())),
)

In [11]:
class Execution(ConversionRule):
    def map_MyInt(self, model, operands):
        return model.value

    def map_MyAdd(self, model, operands):
        return operands["left"] + operands["right"]

    def map_MyMul(self, model, operands):
        return operands["left"] * operands["right"]

    def map_MyPow(self, model, operands):
        if operands["right"] < 0:
            raise ValueError("Negative exponents are not supported")

        return operands["left"] ** operands["right"]

    def map_MyProgram(self, model, operands):
        return operands["expr"]


execution_pass = Post(Execution())

In [12]:
interpreter = Chain(canonicalization_pass, execution_pass)

In [13]:
printer = Post(PrettyPrint())

program = MyProgram(
    expr=MyInt(value=1) + MyInt(value=2) * MyInt(value=3) ** MyInt(value=4)
)

result = interpreter(program)

print(printer(result))

int(163)
