## 3 二叉树

### 3.1 一般二叉树

In [5]:
from typing import Generator


class BinaryTreeNode:
    val: str | int
    left_child: "BinaryTreeNode"
    right_child: "BinaryTreeNode"

    @classmethod
    def init(cls, tree_str: str):
        ret: "BinaryTreeNode | None" = None

        stack: list["BinaryTreeNode"] = []
        last_node = None

        is_right_child = False
        for ch in cls.element_iter(tree_str):
            if ch == '(':
                stack.append(last_node)
            elif ch == ',':
                is_right_child = True
            elif ch == ')':
                stack.pop()
            elif ch.isalpha() or ch.isdigit():
                last_node = BinaryTreeNode(ch)
                if not ret:
                    ret = last_node
                else:
                    if is_right_child:
                        stack[-1].right_child = last_node
                        is_right_child = False
                    else:
                        stack[-1].left_child = last_node
        return ret

    @staticmethod
    def element_iter(exp: str) -> Generator[str, None, None]:
        import re
        elements = re.findall(r"([a-zA-Z0-9]+|\(|,|\))", exp)
        for el in elements:
            yield el

    def __init__(self, val: str | int, left: "BinaryTreeNode" = None, right: "BinaryTreeNode" = None):
        self.val = val
        self.left_child = left
        self.right_child = right

    def preorder_traversal(self):
        print(self.val, end=" ")
        if self.left_child:
            self.left_child.preorder_traversal()
        if self.right_child:
            self.right_child.preorder_traversal()

    def inorder_traversal(self):
        if self.left_child:
            self.left_child.preorder_traversal()
        print(self.val, end=" ")
        if self.right_child:
            self.right_child.preorder_traversal()

    def postorder_traversal(self):
        if self.left_child:
            self.left_child.preorder_traversal()
        if self.right_child:
            self.right_child.preorder_traversal()
        print(self.val, end=" ")

In [None]:
bt_root = BinaryTreeNode.init("A(B(,D),C(E))")
bt_root.preorder_traversal()
print("\n")
bt_root.inorder_traversal()
print("\n")
bt_root.postorder_traversal()

In [None]:
# LeetCode 101
def is_symmetric(l: "BinaryTreeNode", r: "BinaryTreeNode"):
    """思路，反向判断即可，若是到最后，每个节点的孙节点都没出现不相等的状况，则最终会判断为true；分析时要逐层复杂化树去思考"""
    if not l and not r:
        return True

    if not l or not r:
        return False
    if l.val != r.val:
        return False
    is_left_sym = is_symmetric(l.left_child, r.right_child)
    is_right_sym = is_symmetric(l.right_child, r.left_child)
    return is_left_sym and is_right_sym


root101 = BinaryTreeNode.init("1(2(3,4),2(4,3))")
root101.preorder_traversal()
print("\n")
root101.inorder_traversal()
print("\n")
root101.postorder_traversal()
is_symmetric(root101.left_child, root101.right_child)

In [None]:
# LeetCode 103
def max_indepth(root: "BinaryTreeNode") -> int:
    def _dive(node, now_level):
        if not node.left_child and not node.right_child:
            return now_level
        left_indepth, right_indepth = -1, -1
        if node.left_child:
            left_indepth = _dive(node.left_child, now_level + 1)
        if node.right_child:
            right_indepth = _dive(node.right_child, now_level + 1)
        return max(left_indepth, right_indepth)

    if not root:
        return 0
    return _dive(root, 1)


root103 = BinaryTreeNode.init("3(9,20(15,7))")
max_indepth(root103)

In [None]:
# LeetCode 111
def min_indepth(root: "BinaryTreeNode") -> int:
    def _dive(node):
        if node is None:
            return 0

        left_min = _dive(node.left_child)
        right_min = _dive(node.right_child)
        return min(left_min, right_min) + 1

    return _dive(root)


root111 = BinaryTreeNode.init("3(9,20(15,7))")
min_indepth(root111)

In [None]:
# LeetCode 226
def invert_tree(root: "BinaryTreeNode") -> "BinaryTreeNode":
    def _invert(node):
        if not node or not node.left_child and not node.right_child:
            return None
        node.left_child, node.right_child = node.right_child, node.left_child
        _invert(node.left_child)
        _invert(node.right_child)

    _invert(root)
    return root


root226 = BinaryTreeNode.init("3(9,20(15,7))")
invert_tree(root226)

In [None]:
from typing import List, Optional


# LeetCode 105 根据前序、中序获得树
def buildTree(self, preorder: List[int], inorder: List[int]) -> Optional["BinaryTreeNode"]:
    """找到当前序列中最根的根节点，找到左右子树并传给嵌套的函数，将当前的根节点返回给上一级"""
    if len(preorder) == 0:
        return None

    # 1 取出 preorder 的第一个为当前（子）树的根节点
    sub = BinaryTreeNode(preorder[0])
    # 2 寻找自己在 inorder 的位置
    sub_inorder_index = inorder.index(preorder[0])
    # 3 填充左右孩子
    sub.left_child = buildTree(self, preorder[1:sub_inorder_index + 1], inorder[:sub_inorder_index])
    sub.right_child = buildTree(self, preorder[sub_inorder_index + 1:], inorder[sub_inorder_index + 1:])
    return sub


root105 = buildTree(None, [3, 9], [9, 3])

In [None]:
from typing import List, Optional


# LeetCode 106 根据中序、后序获得树
def buildTree(self, inorder: List[int], postorder: List[int]) -> Optional["BinaryTreeNode"]:
    """找到当前序列中最根的根节点，找到左右子树并传给嵌套的函数，将当前的根节点返回给上一级"""
    if len(inorder) == 0:
        return None

    # 1 找到当前序列最顶层的根节点
    sub = BinaryTreeNode(postorder[-1])
    # 2 确定左子树的序列
    sub_inorder_index = inorder.index(postorder[-1])
    sub.left_child = buildTree(self, inorder[:sub_inorder_index], postorder[:sub_inorder_index])
    # 3 确定右子树的序列
    sub.right_child = buildTree(self, inorder[sub_inorder_index + 1:], postorder[sub_inorder_index:len(postorder) - 1])
    return sub


root106 = buildTree(None, [9, 3, 15, 20, 7], [9, 15, 7, 20, 3])

In [None]:
# LeetCode 889 根据前序、后序获得任意一棵树
def constructFromPrePost(self, preorder: List[int], postorder: List[int]) -> Optional["BinaryTreeNode"]:
    """找到当前序列中最根的根节点，找到左右子树并传给嵌套的函数，将当前的根节点返回给上一级"""
    if len(preorder) == 1:
        return BinaryTreeNode(preorder[0])
    elif len(preorder) == 0:
        return None

    # 1 找到当前序列最顶层的根节点
    sub = BinaryTreeNode(preorder[0])
    # 2 确定左子树在后续序列中的序列
    sub_postorder_index = postorder.index(preorder[1])
    sub.left_child = constructFromPrePost(
        self, preorder[1:1 + sub_postorder_index + 1], postorder[:sub_postorder_index + 1]
    )
    # 3 确定右子树的序列
    sub.right_child = constructFromPrePost(
        self, preorder[1 + sub_postorder_index + 1:], postorder[sub_postorder_index + 1:len(postorder) - 1]
    )
    return sub


root889 = constructFromPrePost(None, [3, 9, 10, 20, 15, 7], [10, 9, 15, 7, 20, 3])

In [8]:
from typing import Generic, Optional, List, Protocol, TypeVar


class Comparable(Protocol):  # 可比较类接口
    def __lt__(self, other) -> bool: ...

    def __gt__(self, other) -> bool: ...

    def __eq__(self, other) -> bool: ...


_ComparableType = TypeVar("_ComparableType", bound=Comparable)


class BSTNode(Generic[_ComparableType]):  # noqa
    val: _ComparableType
    left: Optional["BSTNode[_ComparableType]"]
    right: Optional["BSTNode[_ComparableType]"]

    def __init__(
            self,
            val: Comparable,
            left: Optional["BSTNode[_ComparableType]"] = None,
            right: Optional["BSTNode[_ComparableType]"] = None
    ):
        self.val = val
        self.left = left
        self.right = right

    def insert(self, node: "BSTNode[_ComparableType]"):
        if node.val < self.val:
            if self.left:
                self.left.insert(node)
            else:
                self.left = node
        elif node.val > self.val:
            if self.right:
                self.right.insert(node)
            else:
                self.right = node
        else:
            self.val = node.val

    def __lt__(self, other):
        return self.val < other.val

    def __gt__(self, other):
        return self.val > other.val

    def __eq__(self, other):
        return self.val == other.val


class BSTree(Generic[_ComparableType]):  # noqa
    root: Optional[BSTNode[_ComparableType]]

    def __init__(self, root: Optional[BSTNode[_ComparableType]]):
        self.root = root

    @classmethod
    def init(cls, arr: List[_ComparableType]):
        root = BSTNode(arr.pop(len(arr) // 2))
        for n in arr:
            root.insert(BSTNode(n))
        return cls(root)

    def get(self, get_val: _ComparableType):
        return self.search(get_val)

    # LeetCode 1008
    @classmethod
    def from_preorder(cls, arr: List[_ComparableType]):
        def _recursive(arr_):
            if not arr_:
                return None
            root_val = arr_[0]
            first_big_index = -1
            for index, item in enumerate(arr_):
                if item > root_val:
                    first_big_index = index
                    break
            return BSTNode(
                root_val,
                _recursive(arr_[1: first_big_index] if first_big_index != -1 else arr_[1:]),
                _recursive(arr_[first_big_index:] if first_big_index != -1 else [])
            )

        def _form_preorder(lower, upper):
            if not arr or arr[0] < lower or arr[0] > upper:
                return None
            x = arr.pop(0)
            node = BSTNode(x)
            node.left = _form_preorder(lower, x)
            node.right = _form_preorder(x, upper)
            return node

        # return cls(_form_preorder(float('-inf'),float('inf')))
        return cls(_recursive(arr))

    # LeetCode 701
    def put(self, put_val: _ComparableType):
        parent, node = self.search_with_parent(put_val)
        if not parent:
            self.root = BSTNode(put_val)
            return
        if node:
            node.val = put_val
        else:
            if parent.val < put_val:
                parent.right = BSTNode(put_val)
            else:
                parent.left = BSTNode(put_val)

    # LeetCode 450
    def delete(self, del_val: _ComparableType):
        parent, node = self.search_with_parent(del_val)

        # 情况1：被删除节点只有左孩子/只有右孩子/无孩子
        if not node.left:
            self._shift(parent, node, node.right)
        elif not node.right:
            self._shift(parent, node, node.left)
        # 情况2：有左孩子和右孩子
        else:
            # 记录replaced的值，先删replaced
            replaced = self.find_biggest(node.left) or self.find_smallest(node.right)
            replacer_val = replaced.val
            self.delete(replaced.val)
            # 直接将值换上去
            node.val = replacer_val

    def _shift(self, parent, deleted, orphan):
        if not parent:
            self.root = orphan
        elif parent.left is deleted:
            parent.left = orphan
        else:
            parent.right = orphan

    def predecessor(self, aim: _ComparableType):
        node: BSTNode[_ComparableType] = self.root
        ancestor_from_left: Optional[BSTNode[_ComparableType]] = None
        while node is not None:
            if node.val < aim:
                ancestor_from_left = node
                node = node.right
            elif aim < node.val:
                node = node.left
            else:
                break
        # 情况1：没有找到节点
        if not node:
            return None
        # 情况2：找到节点且节点有左子树
        if node.left:
            return self.find_biggest(node.left)
        # 情况3：没有左子树，则从先序序列中找（可以通过查找最近一个调用node.right/自左而来的节点）
        return ancestor_from_left

    def successor(self, aim: _ComparableType):
        node: BSTNode[_ComparableType] = self.root
        ancestor_from_right: Optional[BSTNode[_ComparableType]] = None
        while node is not None:
            if node.val < aim:
                node = node.right
            elif aim < node.val:
                ancestor_from_right = node
                node = node.left
            else:
                break
        # 情况1：没有找到节点
        if not node:
            return None
        # 情况2：找到节点且节点有右子树
        if node.right:
            return self.find_smallest(node.right)
        # 情况3：没有右子树，则从先序序列中找（可以通过查找最近一个调用node.left/自右而来的节点）
        return ancestor_from_right

    # LeetCode 700
    def search(self, aim: _ComparableType):

        def _recursive(root_: BSTNode[_ComparableType]):
            if root_ is None:
                return None
            if root_.val < aim:
                return _recursive(root_.right)
            elif aim < root_.val:
                return _recursive(root_.left)
            return root_

        return _recursive(self.root)

    def search_with_parent(self, aim: _ComparableType):
        parent = None
        node = self.root
        while node is not None:
            if node.val < aim:
                parent = node
                node = node.right
            elif aim < node.val:
                parent = node
                node = node.left
            else:
                break

        return parent, node

    def find_biggest(self, start: Optional[BSTNode[_ComparableType]] = None):
        if not start:
            start = self.root

        def _recursive(root_: BSTNode[_ComparableType]) -> Optional[BSTNode[_ComparableType]]:
            if not root_:
                return None
            elif not root_.right:
                return root_

            return _recursive(root_.right)

        return _recursive(start)

    def find_biggest_with_parent(self, start: Optional[BSTNode[_ComparableType]] = None):
        def _recursive(root_: BSTNode[_ComparableType]) -> tuple[
            Optional[BSTNode[_ComparableType]], Optional[BSTNode[_ComparableType]]]:
            if not root_:
                return None, None
            elif not root_.right:
                return None, root_
            elif not root_.right.right:
                return root_, root_.right

            return _recursive(root_.right)

        return _recursive(start)

    def find_smallest(self, start: Optional[BSTNode[_ComparableType]] = None):
        def _recursive(root_: BSTNode[_ComparableType]) -> Optional[BSTNode[_ComparableType]]:
            if not root_:
                return None
            elif not root_.left:
                return root_

            return _recursive(root_.left)

        return _recursive(start)

    def find_smallest_with_parent(self, start: Optional[BSTNode[_ComparableType]] = None):
        def _recursive(root_: BSTNode[_ComparableType]) -> tuple[
            Optional[BSTNode[_ComparableType]], Optional[BSTNode[_ComparableType]]]:
            if not root_:
                return None, None
            elif not root_.left:
                return None, root_
            elif not root_.left.left:
                return root_, root_.left

            return _recursive(root_.left)

        return _recursive(start)

    def get_lt(self, aim: _ComparableType):
        """LNR，正序中序遍历"""
        stack = []
        ret = []
        node = self.root
        while node or len(stack) > 0:
            if node:
                # 左
                stack.append(node)
                node = node.left
            else:
                # 值
                popped = stack.pop()
                if popped.val < aim:
                    ret.append(popped.val)
                else:
                    break
                # 右
                node = popped.right
        return ret

    def get_gt(self, aim: _ComparableType):
        """RNL，倒序中序遍历"""
        stack = []
        ret = []
        node = self.root
        while node or len(stack) > 0:
            if node:
                # 直接到最右侧
                stack.append(node)
                node = node.right
            else:
                # 开始中序遍历，右 值 左
                popped = stack.pop()
                if popped.val > aim:
                    ret.append(popped.val)
                else:
                    break
                node = popped.left
        return ret

    def get_between(self, gt: _ComparableType, lt: _ComparableType):
        """LNR，正序中序遍历"""
        stack = []
        ret = []
        node = self.root
        while node or len(stack) > 0:
            if node:
                # 直接到最左侧
                stack.append(node)
                node = node.left
            else:
                # 开始中序遍历，左 值 右
                popped = stack.pop()
                if gt < popped.val < lt:
                    ret.append(popped.val)
                node = popped.right
        return ret

    # LeetCode 98
    def is_valid(self):
        def _recursive(root_):
            if root_ is None or (not root_.left and not root_.right):
                return True
            if root_.left and not root_.right:
                return root_.left.val < root_.val and _recursive(root_.left)
            elif not root_.left and root_.right:
                return root_.val < root_.right.val and _recursive(root_.right)
            else:
                return root_.left.val < root_.val < root_.right.val and _recursive(root_.left) and _recursive(
                    root_.right)

        return _recursive(self.root)

    # LeetCode 235
    def common_ancestor(self, p, q):
        node = self.root
        # 直至p和q被逼到一个在左侧（或当前）、一个在右侧（或当前）为止
        while (node.val < p.val and node.val < q.val) or (node.val > p.val and node.val > q.val):
            if node.val < p.val:
                node = node.right
            else:
                node = node.left
        return node


### 3.2 二叉排序树

In [11]:
bst = BSTree.init([1, 5, 8, 2, 8, 2, 8, 9, 3, 3, 7, 5])

In [None]:
bst_search1 = bst.search(3)
bst_search2 = bst.search(2)
bst_search3 = bst.search_with_parent(4)
bst_search4 = bst.search_with_parent(5)

In [None]:
bst_biggest1 = bst.find_biggest()
bst_biggest2 = bst.find_biggest(bst.root.left)
bst_smallest1 = bst.find_biggest_with_parent()
bst_smallest2 = bst.find_biggest_with_parent(bst.root.left)

In [None]:
bst.delete(0)
bst

In [None]:
bst.delete(3)
bst

In [None]:
bst.get_lt(5)

In [None]:
bst.get_gt(1)

In [None]:
bst.get_between(1, 3)

In [14]:
bst.is_valid()

True

In [10]:
bst_preorder = BSTree.from_preorder([2, 4])

In [18]:
print(float('-inf') < 1)

True
