# README

Tree 구현을 위한 노트북

아래는 트리를 구현하기 전에 간단하게 시각화 한 것을 표현한다.


In [1]:

tree = {
    'root': {
        'value': 5,
        'left': {
            'value': 3,
            'left': {
                'value': 1,
                'left': None,
                'right': None,
            },
            'right': {
                'value': 4,
                'left': None,
                'right': None,
            }
        },
        'right': {
            'value': 8,
            'left': {
                'value': 6,
                'left': None,
                'right': None,
            }
        }
    }
}

print(tree['root']['left']['right']['value'])


4


이번엔 클래스를 사용하여 트리를 구현해보자.


In [2]:
from dataclasses import dataclass
from typing import TypeVar, Self
from enum import Enum, auto
from collections import deque


class Cmp(Enum):
    Less = auto()
    Greater = auto()
    Equal = auto()


def compare(lhs, rhs) -> Cmp:
    if lhs < rhs:
        return Cmp.Less
    if rhs < lhs:
        return Cmp.Greater
    return Cmp.Equal


T = TypeVar('T')


@dataclass
class Node:
    data: T
    left: Self | None = None
    right: Self | None = None


class Tree:
    """binary search tree
    """

    def __init__(self, data):
        init = Node(data)
        self.root = init
        self.count = 1

    def __len__(self):
        return self.count

    def insert(self, data):
        new_node = Node(data)
        current_node = self.root

        while current_node:
            match (compare(data, current_node.data),
                   current_node.left,
                   current_node.right):
                case (Cmp.Equal, _, _):
                    # 같은 값이면 추가하지 않음.
                    return

                case (Cmp.Less, None, _):
                    # append left
                    current_node.left = new_node
                    self.count += 1
                    return

                case (Cmp.Greater, _, None):
                    # append right
                    current_node.right = new_node
                    self.count += 1
                    return

                case (Cmp.Less, _, _):
                    # move left
                    current_node = current_node.left

                case (Cmp.Greater, _, _):
                    # move right
                    current_node = current_node.right

                case _:
                    raise RuntimeError("cannot happen")

            # end match

    def dfs(self):
        """depth first search
        """
        ret = []
        stack = [self.root]

        while len(stack) > 0:
            cur = stack.pop()
            if cur.right is not None:
                stack.append(cur.right)
            if cur.left is not None:
                stack.append(cur.left)
            ret.append(cur.data)

        return ret

    def bfs(self):
        """breadth first search
        """
        ret = []
        queue = deque([self.root])
        while len(queue) > 0:
            cur = queue.popleft()
            if cur.left:
                queue.append(cur.left)
            if cur.right:
                queue.append(cur.right)
            ret.append(cur.data)

        return ret

    def contains(self, data) -> bool:
        """data를 가지고 있는지 여부를 판별한다.
        """
        cur = self.root
        while cur:
            match compare(data, cur.data):
                case Cmp.Less:
                    cur = cur.left

                case Cmp.Greater:
                    cur = cur.right

                case Cmp.Equal:
                    return True

        return False

    def depth(self) -> int:
        """return current tree's max depth. root is depth 0"""

        # (depth, node)
        stack = [(0, self.root)]
        max_depth = 0

        while len(stack) > 0:

            depth, cur = stack.pop()
            max_depth = max(max_depth, depth)

            if cur.right is not None:
                stack.append((depth + 1, cur.right))
            if cur.left is not None:
                stack.append((depth + 1, cur.left))

        return max_depth


ls = [5, 3, 8, 1, 4, 6, 9]

t = Tree(ls[0])
for e in ls[1:]:
    t.insert(e)

print(t.dfs())
print(t.bfs())

for e in ls:
    assert t.contains(e)
assert not t.contains(100)

# for polymorphic usage

my_string = "hello, world!"
t = Tree('\0')
for e in my_string:
    t.insert(e)

print(t.dfs())
print(t.bfs())

for e in my_string:
    assert t.contains(e)
assert not t.contains('k')

# for counting max_depth in a skewed tree

ls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
t = Tree(ls[0])
for e in ls[1:]:
    t.insert(e)
assert 8 == t.depth()


[5, 3, 1, 4, 8, 6, 9]
[5, 3, 8, 1, 4, 6, 9]
['\x00', 'h', 'e', ',', ' ', '!', 'd', 'l', 'o', 'w', 'r']
['\x00', 'h', 'e', 'l', ',', 'o', ' ', 'd', 'w', '!', 'r']
