# ADT Metaclass

In [3]:
from dataclasses import make_dataclass
from typing import Any
from copy import deepcopy


class ADTMeta(type):
    def __new__(cls, name, bases, clsdict):
        base = make_dataclass(name, fields=[], bases=(object,))

        for data_cons, fields in clsdict["__annotations__"].items():
            globals().update(
                {data_cons: make_dataclass(data_cons, fields, bases=(base,))}
            )

        return base

In [4]:
def impl_for(cls):
    def wrap(func):
        setattr(cls, func.__name__, func)
        for subcls in cls.__subclasses__():
            setattr(subcls, func.__name__, func)

        # TODO: have to mangle name so that func(...) is no longer available, only x.func(...)
        return func

    return wrap

## List

In [None]:
"""
data List a = Null | Cons a (List a)
"""


class List(metaclass=ADTMeta):
    Nil: []
    Cons: [("x", Any), ("xs", "List")]

## Tree

In [None]:
"""
data Tree a = Null | Leaf a | Node a [Tree a]
"""


class Tree(metaclass=ADTMeta):
    Null: []
    Leaf: [("val", Any)]
    Node: [("val", Any), ("children", list["Tree"])]

# Test

In [None]:
@impl_for(List)
def show(self) -> str:
    match self:
        case Nil():
            return "<END>"
        case Cons(head, tail):
            return f"{head}:{tail.show()}"


@impl_for(Tree)
def show(self, depth=0) -> str:
    match self:
        case Null():
            return ""
        case Leaf(x):
            return f"{x}|{depth}"
        case Node(x, xs):
            rest = ":".join([r.show(depth + 1) for r in xs])
            return f"({x}|{depth}):{rest}"

In [None]:
lst = Cons(1, Cons(2, Cons(3, Nil())))
lst.show()

In [None]:
tree = Node(val=1, children=[Leaf(2), Leaf(3), Leaf(4)])
tree.show()