# Binary Search Tree

In [42]:
class Node:
    def __init__(self, item):
        self.val = item
        self.left = None
        self.right = None

class BinaryTree:
    def __init__(self):
        self.head = Node(None)
        
        #test purpose lists
        self.preorder_list = []
        self.inorder_list = []
        self.postorder_list = []
        
    """Add"""
    def add(self, item):
        if self.head.val is None: # tree head 없으면
            self.head.val = item
        else: # 기존 tree 있으면 
            self.__add_node(self.head, item)
    
    def __add_node(self, cur, item):
        if cur.val >= item:
            if cur.left is not None:
                self.__add_node(cur.left, item)
            else:
                cur.left = Node(item)
        else:
            if cur.right is not None:
                self.__add_node(cur.right, item)
            else:
                cur.right = Node(item)

    """Search"""    
    def search(self, item):
        if self.head.val is None: # tree가 없으니 False 반환
            return False
        else:
            return self.__search_node(self.head, item)

    def __search_node(self, cur, item):
        if cur.val == item: # 찾았으면 True
            return True
        else:
            if cur.val >= item:
                if cur.left is not None:
                    return self.__search_node(cur.left, item)
                else:
                    return False # 못찾음.
            else:
                if cur.right is not None:
                    return self.__search_node(cur.right, item)
                else:
                    return False

    """Remove"""
    def remove(self, item):
        # Tree 가 없으면..
        if self.head.val is None: 
            print ("There is no item: in Binary Search Tree", item)
            
        # 현재 위치와 값이 같다면!! 
        if self.head.val == item: 
            # 1) Node to be removed has no children.(자식node가 없는 경우)
            if self.head.left is None and self.head.right is None:
                self.head = None
            # 2) Node to be removed has one right child.(오른쪽 자식 node가 있는 경우) 
            #    지워야 할 것을 오른쪽 자식 node로 대체
            elif self.head.left is None and self.head.right is not None:
                self.head = self.head.right
            # 3) Node to be removed has one left child.(왼쪽 자식 node 가 있는 경우)
            #    지워야 할 것을 왼쪽 자식 node로 대체
            elif self.head.left is not None and self.head.right is None:
                self.head = self.head.left
            # 4) Node to be removed has two children.(지워야 할 node의 두 자식노드 모두 있는 경우.)
            else:
                # 지워야 할 현재 위치 값을 오른쪽 자식 노드의 맨 왼쪽 자식 노드 값으로 대체
                self.head.val = self.__most_left_val_from_right_node(self.head.right).val
                # 대체 후 같은 값을 찾아서 제거!! 
                self.__removeitem(self.head, self.head.right, self.head.val)
        # 현재 위치와 값이 다르다면!!   
        else: 
            if self.head.val > item: # 작으면 왼쪽
                self.__remove(self.head, self.head.left, item)
            else:                    # 크면 왼쪽
                self.__remove(self.head, self.head.right, item)

    def __remove(self, parent, cur, item):
        # 자식 노드 값이 없으면 지워야 할 값이 없다고 return
        if cur is None:
            print ("There is no item: ", item)
            
        # 자식노드의 값과 지워야 할 값이 같으면
        if cur.val == item:
            # 1) Node to be removed has no children.(자식 노드가 없으면)
            if cur.left is None and cur.right is None:
                # 해당 자식 노드가 left인지 right 인지 확인해서 지움! 
                if parent.left == cur:
                    parent.left = None
                else:
                    parent.right = None
            # 2) Node to be removed has one right child.(오른쪽 자식 노드가 있으면)
            elif cur.left is None and cur.right is not None:
                if parent.left == cur:
                    parent.left = cur.right
                else:
                    parent.right = cur.right
            # 3) Node to be removed has one left child.(왼쪽 자식 노드가 있으면)
            elif cur.left is not None and cur.right is None:
                if parent.left == cur:
                    parent.left = cur.left
                else:
                    parent.right = cur.left
            # 4) Node to be removed has two children.(양쪽 자식 노드 모두 있으면)
            else:
                # 자식노드 값을 자식노드 오른쪽 값의 맨 왼쪽 값으로 바꾸고.
                cur.val = self.__most_left_val_from_right_node(cur.right).val
                # 자식노드 오른쪽 노드의 맨 왼쪽 값을 제거한다.
                self.__removeitem(cur, cur.right, cur.val)
                
        # else 있어야 할 거 같은데,,, 안했네!! 
        # 재귀 하면 될 듯!!??????? ㅇㅇ 맞는듯!! 
        else:
            if cur.val > item: # 작으면 왼쪽
                self.__remove(cur, cur.left, item)
            else:                    # 크면 왼쪽
                self.__remove(cur, cur.right, item)            
    
    
    def __most_left_val_from_right_node(self, cur):
        if cur.left is None:
            return cur
        else:
            return self.__most_left_val_from_right_node(cur.left)

    def __removeitem(self, parent, cur, item):
        if cur.val == item:
            if parent.left == cur:
                parent.left = None
            else:
                parent.right = None
        else:
            if cur.val > item:
                self.__removeitem(cur, cur.left, item)
            else:
                self.__removeitem(cur, cur.right, item)
                
    """Preorder traverse"""
    def preorder_traverse(self):
        if self.head is not None:
            self.__preorder(self.head)

    def __preorder(self, cur):
        self.preorder_list.append(cur.val)
        print (cur.val)
        if cur.left is not None:
            self.__preorder(cur.left)
        if cur.right is not None:
            self.__preorder(cur.right)

    """Inorder traverse"""
    def inorder_traverse(self):
        if self.head is not None:
            self.__inorder(self.head)

    def __inorder(self, cur):
        if cur.left is not None:
            self.__inorder(cur.left)

        self.inorder_list.append(cur.val)
        print (cur.val)

        if cur.right is not None:
            self.__inorder(cur.right)

    """Postorder traverse"""
    def postorder_traverse(self):
        if self.head is not None:
            self.__postorder(self.head)

    def __postorder(self, cur):
        if cur.left is not None:
            self.__postorder(cur.left)

        if cur.right is not None:
            self.__postorder(cur.right)

        self.postorder_list.append(cur.val)
        print (cur.val)

In [31]:
bt = BinaryTree()
bt.add(5)
bt.add(3)
bt.add(4)
bt.add(1)
bt.add(7)
bt.add(6)
bt.add(2)

- 그려보면 이해된다!! 

In [32]:
print("pre order")
bt.preorder_traverse()
bt.preorder_list

pre order
5
3
1
2
4
7
6


[5, 3, 1, 2, 4, 7, 6]

In [33]:
print("in order")
bt.inorder_traverse()
bt.inorder_list

in order
1
2
3
4
5
6
7


[1, 2, 3, 4, 5, 6, 7]

In [34]:
print("post order")
bt.postorder_traverse()
bt.postorder_list

post order
2
1
4
3
6
7
5


[2, 1, 4, 3, 6, 7, 5]

In [43]:
print ("bt root remove")
bt_root_remove_test = BinaryTree()
bt_root_remove_test.add(5)
bt_root_remove_test.add(3)
bt_root_remove_test.add(4)
bt_root_remove_test.add(1)
bt_root_remove_test.add(7)
bt_root_remove_test.add(6)
bt_root_remove_test.add(2)
bt_root_remove_test.add(10)
bt_root_remove_test.add(8)
bt_root_remove_test.add(9)
bt_root_remove_test.add(0)
bt_root_remove_test.remove(8)
bt_root_remove_test.preorder_traverse()
bt_root_remove_test.preorder_list

bt root remove
5
3
1
0
2
4
7
6
10
9


[5, 3, 1, 0, 2, 4, 7, 6, 10, 9]

#### 이런식으로 test 할 수도 있다!! 

In [6]:
import unittest

class binary_tree_test(unittest.TestCase):
    def test(self):
        bt = BinaryTree()
        bt.add(5)
        bt.add(3)
        bt.add(4)
        bt.add(1)
        bt.add(7)
        print("pre order")
        bt.preorder_traverse()
        self.assertEqual(bt.preorder_list, [5,3,1,4,7])

        print("in order")
        bt.inorder_traverse()
        self.assertEqual(bt.inorder_list, [1,3,4,5,7])

        print("post order")
        bt.postorder_traverse()
        self.assertEqual(bt.postorder_list, [1,4,3,7,5])

        print ("bt root remove")
        bt_root_remove_test = BinaryTree()
        bt_root_remove_test.add(60)
        bt_root_remove_test.add(50)
        bt_root_remove_test.add(70)
        bt_root_remove_test.remove(60)
        bt_root_remove_test.preorder_traverse()
        self.assertEqual(bt_root_remove_test.preorder_list, [70,50])

        print ("bt root remove2")
        bt_root_remove_test = BinaryTree()
        bt_root_remove_test.add(60)
        bt_root_remove_test.add(50)
        bt_root_remove_test.remove(60)
        bt_root_remove_test.preorder_traverse()
        self.assertEqual(bt_root_remove_test.preorder_list, [50])

        print ("bt root remove3")
        bt_root_remove_test = BinaryTree()
        bt_root_remove_test.add(60)
        bt_root_remove_test.add(70)
        bt_root_remove_test.remove(60)
        bt_root_remove_test.preorder_traverse()
        self.assertEqual(bt_root_remove_test.preorder_list, [70])

        print ("bt left remove 1")
        bt_left_remove_test_1 = BinaryTree()
        bt_left_remove_test_1.add(60)
        bt_left_remove_test_1.add(50)
        bt_left_remove_test_1.add(70)
        bt_left_remove_test_1.remove(50)
        bt_left_remove_test_1.preorder_traverse()
        self.assertEqual(bt_left_remove_test_1.preorder_list, [60,70])

        print ("bt left remove 2")
        bt_left_remove_test_2_left = BinaryTree()
        bt_left_remove_test_2_left.add(60)
        bt_left_remove_test_2_left.add(50)
        bt_left_remove_test_2_left.add(70)
        bt_left_remove_test_2_left.add(40)
        bt_left_remove_test_2_left.remove(50)
        bt_left_remove_test_2_left.preorder_traverse()
        self.assertEqual(bt_left_remove_test_2_left.preorder_list, [60,40,70])

        print ("bt right remove 2")
        bt_left_remove_test_2_right = BinaryTree()
        bt_left_remove_test_2_right.add(60)
        bt_left_remove_test_2_right.add(50)
        bt_left_remove_test_2_right.add(70)
        bt_left_remove_test_2_right.add(55)
        bt_left_remove_test_2_right.remove(50)
        bt_left_remove_test_2_right.preorder_traverse()
        self.assertEqual(bt_left_remove_test_2_right.preorder_list, [60,55,70])

        print ("bt right remove 1")
        bt_right_remove_test_1 = BinaryTree()
        bt_right_remove_test_1.add(60)
        bt_right_remove_test_1.add(50)
        bt_right_remove_test_1.add(70)
        bt_right_remove_test_1.remove(70)
        bt_right_remove_test_1.preorder_traverse()
        self.assertEqual(bt_right_remove_test_1.preorder_list, [60,50])

        print ("bt right remove 2")
        bt_right_remove_test_2_left = BinaryTree()
        bt_right_remove_test_2_left.add(60)
        bt_right_remove_test_2_left.add(50)
        bt_right_remove_test_2_left.add(70)
        bt_right_remove_test_2_left.add(65)
        bt_right_remove_test_2_left.remove(70)
        bt_right_remove_test_2_left.preorder_traverse()
        self.assertEqual(bt_right_remove_test_2_left.preorder_list, [60,50,65])

        print ("bt right remove 2")
        bt_right_remove_test_2_right = BinaryTree()
        bt_right_remove_test_2_right.add(60)
        bt_right_remove_test_2_right.add(50)
        bt_right_remove_test_2_right.add(70)
        bt_right_remove_test_2_right.add(75)
        bt_right_remove_test_2_right.remove(70)
        bt_right_remove_test_2_right.preorder_traverse()
        self.assertEqual(bt_right_remove_test_2_right.preorder_list, [60,50,75])

        print ("bt left remove 3")
        bt_left_remove_test_3 = BinaryTree()
        bt_left_remove_test_3.add(60)
        bt_left_remove_test_3.add(50)
        bt_left_remove_test_3.add(70)
        bt_left_remove_test_3.add(40)
        bt_left_remove_test_3.add(55)
        bt_left_remove_test_3.add(52)
        bt_left_remove_test_3.remove(50)
        bt_left_remove_test_3.preorder_traverse()
        self.assertEqual(bt_left_remove_test_3.preorder_list, [60,52,40,55,70])

        print("BST search test")
        bt_search_test = BinaryTree()
        bt_search_test.add(60)
        bt_search_test.add(50)
        bt_search_test.add(70)
        bt_search_test.add(40)
        bt_search_test.add(55)
        bt_search_test.add(52)
        self.assertTrue(bt_search_test.search(60))
        self.assertTrue(bt_search_test.search(50))
        self.assertTrue(bt_search_test.search(70))
        self.assertTrue(bt_search_test.search(40))
        self.assertTrue(bt_search_test.search(55))
        self.assertTrue(bt_search_test.search(52))
        self.assertFalse(bt_search_test.search(61))
        self.assertFalse(bt_search_test.search(81))