In [1]:
from typing import List

import oqd_compiler_infrastructure as ci

## Language Definition


In [2]:
class MyMath(ci.TypeReflectBaseModel):
    pass


class MyInteger(MyMath):
    value: int


class MyAdd(MyMath):
    operands: List[MyMath]


class MyMul(MyMath):
    operands: List[MyMath]

In [3]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)


printer = ci.Post(ci.PrettyPrint())

print(printer(expr))

MyMul
  - operands: list
    - 0: MyAdd
      - operands: list
        - 0: MyInteger
          - value: int(1)
        - 1: MyInteger
          - value: int(2)
        - 2: MyInteger
          - value: int(3)
    - 1: MyInteger
      - value: int(4)
    - 2: MyMul
      - operands: list
        - 0: MyInteger
          - value: int(5)
        - 1: MyInteger
          - value: int(6)


## Compiler Infrastructure

### Rules


#### Rewrite Rule


In [4]:
class SimplifyMyMath(ci.RewriteRule):
    def map_MyAdd(self, model):
        o = []
        for x in model.operands:
            if isinstance(x, MyAdd):
                o.extend(x.operands)
            else:
                o.append(x)
        return model.__class__(operands=o)

    def map_MyMul(self, model):
        o = []
        for x in model.operands:
            if isinstance(x, MyMul):
                o.extend(x.operands)
            else:
                o.append(x)
        return model.__class__(operands=o)

#### Conversion Rule


In [5]:
class PrintMyMath(ci.ConversionRule):
    def __init__(self, *, verbose=False):
        super().__init__()
        self.verbose = verbose
        pass

    def map_int(self, model, operands):
        return str(model)

    def map_MyInteger(self, model, operands):
        return operands["value"]

    def map_MyAdd(self, model, operands):
        if self.verbose:
            return "(" + " + ".join(operands["operands"]) + ")"

        return " + ".join(operands["operands"])

    def map_MyMul(self, model, operands):
        if self.verbose:
            return "(" + " * ".join(operands["operands"]) + ")"

        return " * ".join(
            [
                f"({y})" if isinstance(x, MyAdd) else y
                for x, y in zip(model.operands, operands["operands"])
            ]
        )

In [6]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)

printer = ci.Post(PrintMyMath(verbose=True))

print(printer(expr))

((1 + 2 + 3) * 4 * (5 * 6))


### Walks


In [7]:
class PrintWalkOrder(ci.RewriteRule):
    def __init__(self):
        self.current_index = 0
        self.string = ""

    def generic_map(self, model):
        self.string += f"\n{self.current_index}: {model}"
        self.current_index += 1
        pass

#### Post


In [8]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)

printer = ci.Post(PrintWalkOrder())

printer(expr)

print(printer.children[0].string)


0: 1
1: class_='MyInteger' value=1
2: 2
3: class_='MyInteger' value=2
4: 3
5: class_='MyInteger' value=3
6: [MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
7: class_='MyAdd' operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
8: 4
9: class_='MyInteger' value=4
10: 5
11: class_='MyInteger' value=5
12: 6
13: class_='MyInteger' value=6
14: [MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)]
15: class_='MyMul' operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)]
16: [MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
17: cl

#### Pre


In [9]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)

printer = ci.Pre(PrintWalkOrder())

printer(expr)

print(printer.children[0].string)


0: class_='MyMul' operands=[MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
1: [MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
2: class_='MyAdd' operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
3: [MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
4: class_='MyInteger' value=1
5: 1
6: class_='MyInteger' value=2
7: 2
8: class_

#### In


In [10]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)

printer = ci.In(PrintWalkOrder())

printer(expr)

print(printer.children[0].string)


0: class_='MyMul' operands=[MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
1: class_='MyAdd' operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
2: class_='MyInteger' value=1
3: 1
4: class_='MyInteger' value=2
5: 2
6: [MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
7: class_='MyInteger' value=3
8: 3
9: class_='MyInteger' value=4
10: 4
11: [MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='M

#### Level


In [11]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(operands=[MyInteger(value=5), MyInteger(value=6)]),
    ]
)

printer = ci.Level(PrintWalkOrder())

printer(expr)

print(printer.children[0].string)


0: class_='MyMul' operands=[MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
1: [MyAdd(class_='MyAdd', operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]), MyInteger(class_='MyInteger', value=4), MyMul(class_='MyMul', operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)])]
2: class_='MyAdd' operands=[MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyInteger', value=2), MyInteger(class_='MyInteger', value=3)]
3: class_='MyInteger' value=4
4: class_='MyMul' operands=[MyInteger(class_='MyInteger', value=5), MyInteger(class_='MyInteger', value=6)]
5: [MyInteger(class_='MyInteger', value=1), MyInteger(class_='MyI

### Rewriters


#### Chain


In [12]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(
            operands=[
                MyInteger(value=5),
                MyInteger(value=6),
                MyMul(operands=[MyInteger(value=7), MyInteger(value=8)]),
            ]
        ),
    ]
)

printer = ci.Chain(
    ci.Post(SimplifyMyMath()),
    ci.Post(PrintMyMath(verbose=True)),
)

print(printer(expr))

((1 + 2 + 3) * 4 * 5 * 6 * 7 * 8)


#### FixedPoint


In [13]:
expr = MyMul(
    operands=[
        MyAdd(operands=[MyInteger(value=1), MyInteger(value=2), MyInteger(value=3)]),
        MyInteger(value=4),
        MyMul(
            operands=[
                MyInteger(value=5),
                MyInteger(value=6),
                MyMul(operands=[MyInteger(value=7), MyInteger(value=8)]),
            ]
        ),
    ]
)

simplifier = ci.FixedPoint(ci.Pre(SimplifyMyMath()))

printer = ci.Post(PrintMyMath(verbose=True))

print(printer(simplifier(expr)))

((1 + 2 + 3) * 4 * 5 * 6 * 7 * 8)
