# Imports

In [1]:
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Union

import pygraphviz as pgv

# Classes

In [2]:
class Expression(ABC):

    def __invert__(self) -> 'Not':
        return Not(self)

    def __and__(self, other: 'Expression') -> 'And':
        return And(self, other)

    def __or__(self, other: 'Expression') -> 'Or':
        return Or(self, other)

    def __le__(self, other: 'Expression') -> 'Implication':
        return Implication(self, other)

    def __ge__(self, other: 'Expression') -> 'RevImplication':
        return RevImplication(self, other)

    def __add__(self, other: 'Expression') -> 'Equivalence':
        """To not override ==, + will be logical equivalence"""
        return Equivalence(self, other)

    def __xor__(self, other: 'Expression') -> 'Xor':
        return Xor(self, other)

    @abstractmethod
    def simplify(self, **kwargs: Any) -> 'Expression':
        ...

    @abstractmethod
    def assume(self,
               literal: 'Literal',
               value: Union['_0', '_1']) -> 'Expression':
        ...


@dataclass(frozen=True)
class ExpressionOfZero(Expression, ABC):
    pass


@dataclass(frozen=True)
class ExpressionOfOne(Expression, ABC):
    expr: Expression

    def assume(self, literal, value):
        return type(self)(self.expr.assume(literal, value))


@dataclass(frozen=True)
class ExpressionOfTwo(Expression, ABC):
    expr1: Expression
    expr2: Expression

    def assume(self, literal, value):
        return type(self)(
            self.expr1.assume(literal, value),
            self.expr2.assume(literal, value)
        )


@dataclass(frozen=True)
class Literal(ExpressionOfZero):
    name: str

    def __str__(self):
        return self.name

    def __repr__(self):
        return f'Literal({self.name})'

    def simplify(self, **kwargs):
        return self

    def assume(self, literal, value):
        if self == literal:
            return value
        return self


_0 = Literal('0')
_1 = Literal('1')


@dataclass(frozen=True)
class Not(ExpressionOfOne):

    def __str__(self):
        template = '~%s' \
            if isinstance(self.expr, Literal | Not) \
            else '~(%s)'
        return template % self.expr

    def __repr__(self):
        return f'Not({repr(self.expr)})'

    def simplify(self, **kwargs):
        result = self.expr.simplify()
        reverse = False
        if isinstance(result, Not):
            result = result.expr
            reverse = True
        if result == _0:
            return (_1, _0)[reverse]
        if result == _1:
            return (_0, _1)[reverse]
        return result if reverse else Not(result)


@dataclass(frozen=True)
class And(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' & ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return f'And({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if isinstance(result1, Not) and isinstance(result2, Not):
            return Not(Or(result1.expr, result2.expr))
        if _0 in (result1, result2):
            return _0
        if result1 == _1:
            return result2
        if result2 == _1:
            return result1
        return And(result1, result2)


@dataclass(frozen=True)
class Or(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' | ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return f'Or({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if isinstance(result1, Not) and isinstance(result2, Not):
            return Not(And(result1.expr, result2.expr))
        if _1 in (result1, result2):
            return _1
        if result1 == _0:
            return result2
        if result2 == _0:
            return result1
        return Or(result1, result2)


@dataclass(frozen=True)
class Implication(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' <= ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return f'Implication({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if result1 == _1:
            return result2
        if result2 == _0:
            return Not(result1)
        if result1 == _0 or result2 == _1:
            return _1
        return Implication(result1, result2)


@dataclass(frozen=True)
class RevImplication(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' >= ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return \
            f'RevImplication({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if result2 == _0:
            return _1
        if result2 == _1 and result1 == _0:
            return _0
        return RevImplication(result1, result2)


@dataclass(frozen=True)
class Equivalence(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' + ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return f'Equivalence({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if result1 == result2:
            return _1
        return self


@dataclass(frozen=True)
class Xor(ExpressionOfTwo):

    def __str__(self):
        left_template = '%s' \
            if isinstance(self.expr1, Literal | Not) \
            else '(%s)'
        right_template = '%s' \
            if isinstance(self.expr2, Literal | Not) \
            else '(%s)'
        template = left_template + ' ^ ' + right_template
        return template % (self.expr1, self.expr2)

    def __repr__(self):
        return f'Xor({repr(self.expr1)}, {repr(self.expr2)})'

    def simplify(self, **kwargs):
        result1 = self.expr1.simplify()
        result2 = self.expr2.simplify()
        if result1 == _0:
            return result2
        if result1 == _1:
            return ~result2
        if result2 == _0:
            return result1
        if result2 == _1:
            return ~result1
        if result1 == result2:
            return _0
        return self

In [3]:
class Infix:

    def __init__(self, func: Callable,
                 value: Union['Expression', None] = None):
        self.func = func
        self.value = value

    def __gt__(self, other: 'Expression'):
        if self.value is None:
            self.value = other
            return self
        assert self.value is not None
        value = self.value
        self.value = None
        return self.func(value, other)


equiv = Infix(lambda x, y: Equivalence(x, y), None)

In [4]:
class TreeMethod:
    ...

# Tests

In [5]:
import unittest


class SimplyficationTests(unittest.TestCase):

    # Constants

    def test_not_0_is_1(self):
        a = (~_0).simplify()
        self.assertEqual(a, _1)

    def test_not_1_is_0(self):
        a = (~_1).simplify()
        self.assertEqual(a, _0)

    def test_0_or_0_is_0(self):
        a = (_0 | _0).simplify()
        self.assertEqual(a, _0)

    def test_0_or_1_is_1(self):
        a = (_0 | _1).simplify()
        self.assertEqual(a, _1)

    def test_1_or_0_is_1(self):
        a = (_1 | _0).simplify()
        self.assertEqual(a, _1)

    def test_1_or_1_is_1(self):
        a = (_1 | _1).simplify()
        self.assertEqual(a, _1)

    def test_0_and_0_is_0(self):
        a = (_0 & _0).simplify()
        self.assertEqual(a, _0)

    def test_0_and_1_is_0(self):
        a = (_0 & _1).simplify()
        self.assertEqual(a, _0)

    def test_1_and_0_is_0(self):
        a = (_1 & _0).simplify()
        self.assertEqual(a, _0)

    def test_1_and_1_is_1(self):
        a = (_1 & _1).simplify()
        self.assertEqual(a, _1)

    # One variable

    def test_1_or_A_is_1(self):
        A = Literal('A')
        a = (_1 | A).simplify()
        self.assertEqual(a, _1)

    def test_0_or_A_is_A(self):
        A = Literal('A')
        a = (_0 | A).simplify()
        self.assertEqual(a, A)

    def test_A_or_1_is_1(self):
        A = Literal('A')
        a = (A | _1).simplify()
        self.assertEqual(a, _1)

    def test_A_or_0_is_A(self):
        A = Literal('A')
        a = (A | _0).simplify()
        self.assertEqual(a, A)

    def test_1_and_A_is_A(self):
        A = Literal('A')
        a = (_1 & A).simplify()
        self.assertEqual(a, A)

    def test_0_and_A_is_0(self):
        A = Literal('A')
        a = (_0 & A).simplify()
        self.assertEqual(a, _0)

    def test_A_and_1_is_A(self):
        A = Literal('A')
        a = (A & _1).simplify()
        self.assertEqual(a, A)

    def test_A_and_0_is_0(self):
        A = Literal('A')
        a = (A & _0).simplify()
        self.assertEqual(a, _0)


unittest.main(argv=[''], verbosity=3, exit=False)

test_0_and_0_is_0 (__main__.SimplyficationTests.test_0_and_0_is_0) ... ok
test_0_and_1_is_0 (__main__.SimplyficationTests.test_0_and_1_is_0) ... ok
test_0_and_A_is_0 (__main__.SimplyficationTests.test_0_and_A_is_0) ... ok
test_0_or_0_is_0 (__main__.SimplyficationTests.test_0_or_0_is_0) ... ok
test_0_or_1_is_1 (__main__.SimplyficationTests.test_0_or_1_is_1) ... ok
test_0_or_A_is_A (__main__.SimplyficationTests.test_0_or_A_is_A) ... ok
test_1_and_0_is_0 (__main__.SimplyficationTests.test_1_and_0_is_0) ... ok
test_1_and_1_is_1 (__main__.SimplyficationTests.test_1_and_1_is_1) ... ok
test_1_and_A_is_A (__main__.SimplyficationTests.test_1_and_A_is_A) ... ok
test_1_or_0_is_1 (__main__.SimplyficationTests.test_1_or_0_is_1) ... ok
test_1_or_1_is_1 (__main__.SimplyficationTests.test_1_or_1_is_1) ... ok
test_1_or_A_is_1 (__main__.SimplyficationTests.test_1_or_A_is_1) ... ok
test_A_and_0_is_0 (__main__.SimplyficationTests.test_A_and_0_is_0) ... ok
test_A_and_1_is_A (__main__.SimplyficationTests.te

<unittest.main.TestProgram at 0x1bd45bbff50>