In [1]:
from __future__ import annotations

from oqd_compiler_infrastructure import TypeReflectBaseModel

In [2]:
class MyProgram(TypeReflectBaseModel):
    expr: MyExpr

In [None]:
class MyExpr(TypeReflectBaseModel):

    def __add__(self, other):
        return MyAdd(left=self, right=other)

    def __mul__(self, other):
        return MyMul(left=self, right=other)

    def __pow__(self, other):
        return MyPow(left=self, right=other)
    
    def __eq__(self, other):
        # Only create MyEq if other is also a MyExpr instance
        if isinstance(other, MyExpr):
            return MyEq(left=self, right=other)
        return NotImplemented

In [4]:
class MyInt(MyExpr):
    value: int

class MyBool(MyExpr):
    value: bool

In [5]:
class MyAdd(MyExpr):
    left: MyExpr
    right: MyExpr

In [6]:
class MyMul(MyExpr):
    left: MyExpr
    right: MyExpr

In [7]:
class MyPow(MyExpr):
    left: MyExpr
    right: MyExpr


class MyVar(MyExpr):
    name: str


class MyAssign(MyExpr):
    name: str
    value: MyExpr


class MyEq(MyExpr):
    left: MyExpr
    right: MyExpr

In [8]:
prog = MyInt(value=1) + MyInt(value=2)

prog

MyAdd(class_='MyAdd', left=MyInt(class_='MyInt', value=1), right=MyInt(class_='MyInt', value=2))

In [9]:
from oqd_compiler_infrastructure import (
    Chain,
    ConversionRule,
    FixedPoint,
    Post,
    PrettyPrint,
    RewriteRule,
)

In [10]:
class Associativity(RewriteRule):
    def map_MyAdd(self, model):
        if isinstance(model.right, MyAdd):
            return MyAdd(
                left=MyAdd(left=model.left, right=model.right.left),
                right=model.right.right,
            )

    def map_MyMul(self, model):
        if isinstance(model.right, MyMul):
            return MyMul(
                left=MyMul(left=model.left, right=model.right.left),
                right=model.right.right,
            )


class Distribution(RewriteRule):
    def map_MyMul(self, model):
        if isinstance(model.left, MyAdd):
            return MyAdd(
                left=MyMul(left=model.left.left, right=model.right),
                right=MyMul(left=model.left.right, right=model.right),
            )
        if isinstance(model.right, MyAdd):
            return MyAdd(
                left=MyMul(left=model.left, right=model.right.left),
                right=MyMul(left=model.left, right=model.right.right),
            )


canonicalization_pass = Chain(
    FixedPoint(Post(Associativity())),
    FixedPoint(Post(Distribution())),
)

In [11]:
class Execution(ConversionRule):
    def __init__(self):
        super().__init__()
        self.variables = {}
    
    def map_MyInt(self, model, operands):
        return model.value
    
    def map_MyBool(self, model, operands):
        return model.value

    def map_MyAdd(self, model, operands):
        return operands["left"] + operands["right"]

    def map_MyMul(self, model, operands):
        return operands["left"] * operands["right"]

    def map_MyPow(self, model, operands):
        if operands["right"] < 0:
            raise ValueError("Negative exponents are not supported")

        return operands["left"] ** operands["right"]
    
    def map_MyVar(self, model, operands):
        if model.name not in self.variables:
            raise ValueError(f"Variable '{model.name}' is not defined")
        return self.variables[model.name]
    
    def map_MyAssign(self, model, operands):
        value = operands["value"]
        self.variables[model.name] = value
        return value
    
    def map_MyEq(self, model, operands):
        return operands["left"] == operands["right"]

    def map_MyProgram(self, model, operands):
        return operands["expr"]


execution_pass = Post(Execution())

In [12]:
interpreter = Chain(canonicalization_pass, execution_pass)

In [13]:
printer = Post(PrettyPrint())

program = MyProgram(
    expr=MyInt(value=1) + MyInt(value=2) * MyInt(value=3) ** MyInt(value=4)
)

result = interpreter(program)

print(printer(result))

int(163)


In [14]:
# Example: Using bool type and equality operator
prog2 = MyProgram(
    expr=MyEq(
        left=MyInt(value=5),
        right=MyInt(value=5)
    )
)

result2 = interpreter(prog2)
print(printer(result2))

bool(True)


In [15]:
# Example: Using assignment operator
prog3 = MyProgram(
    expr=MyAssign(
        name="a",
        value=MyInt(value=5)
    )
)

result3 = interpreter(prog3)
print(printer(result3))

int(5)


In [16]:
# Example: Assigning variables and checking equality
prog4 = MyProgram(
    expr=MyEq(
        left=MyAssign(name="a", value=MyInt(value=5)),
        right=MyAssign(name="b", value=MyInt(value=10))
    )
)

interpreter2 = Chain(canonicalization_pass, Post(Execution()))
result4 = interpreter2(prog4)
print(f"Equality check (5 == 10): {printer(result4)}")

Equality check (5 == 10): bool(False)


In [17]:
# Example: Using bool literals
prog5 = MyProgram(
    expr=MyBool(value=True)
)

result5 = interpreter(prog5)
print(f"Bool literal: {printer(result5)}")

Bool literal: bool(True)


In [None]:
# Example: Assigning variables and using them in expressions
# First assign a and b
interpreter3 = Chain(canonicalization_pass, Post(Execution()))

# Assign a=5
prog6 = MyProgram(expr=MyAssign(name="a", value=MyInt(value=5)))
result6 = interpreter3(prog6)
print(f"Assigned a=5, result: {printer(result6)}")

# Assign b=10  
prog7 = MyProgram(expr=MyAssign(name="b", value=MyInt(value=10)))
result7 = interpreter3(prog7)
print(f"Assigned b=10, result: {printer(result7)}")

prog8 = MyProgram(expr=MyVar(name="a") + MyVar(name="b"))
result8 = interpreter3(prog8)
print(f"a + b = {printer(result8)}")

# Check equality
prog9 = MyProgram(expr=MyVar(name="a") == MyVar(name="b"))
result9 = interpreter3(prog9)
print(f"a == b: {printer(result9)}")

Assigned a=5, result: int(5)
Assigned b=10, result: int(10)
a + b = int(15)
a == b: bool(False)
