# Binary Tree implementation

classes to implement
- BST Node
- BST

methods to implement
- search
- insert
- minimum
- floor
- ceiling
- delete

In [123]:
class Node:
    def __init__(self, key: int, value: int):
        self.key: int = key
        self.value: int = value
        self.count: int = 1
        self.right: Node = None
        self.left: Node = None

In [124]:
from collections import deque

class BST:
    def __init__(self):
        self.root: Node = None

    def get(self, key: int) -> Node:
        return self._get(self.root, key)

    def _get(self, x: Node, key: int) -> Node:
        if x is None:
            return None
        if x.key == key:
            return x
        elif x.key < key:
            return self._get(x.left, key)
        return self._get(x.right, key)
    
    def insert(self, key: int, value: int):
        self.root = self._insert(self.root, key, value)
        
    def _insert(self, x: Node, key: int, value: int) -> Node:
        if x == None:
            return Node(key, value)
        if key < x.key:
            x.left = self._insert(x.left, key, value)
        elif key > x.key: 
            x.right = self._insert(x.right, key, value)
        else:
            x.count += 1
            x.value = value
        return x
    
    def inorder(self):
        self._inorder(self.root)
        print()
    
    def _inorder(self, a: Node):
        if a.left is not None:
            self._inorder(a.left)
        print(a.key, end=" ")
        if a.right is not None:
            self._inorder(a.right)

    def isEmpty(self) -> bool:
        return self.root is None

    def min(self) -> Node:
        return self._min(self.root)

    def _min(self, a: Node) -> Node:
        if a.left is None:
            return a
        return self._min(a.left)
    
    def max(self) -> Node:
        return self._max(self.root)
    
    def _max(self, a: Node) -> Node:
        if a.right is None:
            return a
        return self._max(a.right)

    def floor(self, key: int) -> Node:
        curFloor = None
        curNode = self.root
        while True:
            if curNode.key > key:
                if curNode.left is None:
                    break
                curNode = curNode.left
                continue
            if curFloor is None:
                curFloor = curNode
            elif curNode.key > curFloor.key:
                curFloor = curNode
        
            if curNode.right is None:
                break
            curNode = curNode.right
        return curFloor 
    
    def _floor2(self, x: Node, key: int) -> Node:
        if x is None:
            return None
        if x.key == key:
            return x
        if key < x.key:
            return self._floor2(x.left, key)
        t: Node = self._floor2(x.right, key)
        if t is not None:
            return t
        return x

    def floor2(self, key: int) -> Node:
        x: Node = self._floor2(self.root, key)
        if x is None:
            raise Exception("No floor found")
        return x
    
    def ceil(self, key: int) -> Node:
        return self._ceil(self.root, key)
    
    def _ceil(self, x: Node, key: int) -> Node:
        if x is None:
            return None
        if x.key == key:
            return x
        if x.key < key:
            return self._ceil(x.right, key)
        # if x.key > key
        t: Node = self._ceil(x.left, key)
        # either something from the left, which is smaller than the current Node, or the current Node
        if t is not None:
            return t
        return x

    def deleteMin(self):
        if self.isEmpty():
            raise Exception("Tree is empty")
        self.root = self._deleteMin(self.root)

    def _deleteMin(self, x: Node) -> Node:
        if x.left is None:
            return x.right
        x.left = self._deleteMin(x.left)
        return x

    def delete(self, key: int):
        self.root = self._delete(self.root, key)

    def _delete(self, x: Node, key: int) -> Node:
        if x is None:
            return None
        if key < x.key:
            x.left = self._delete(x.left, key)
        elif key > x.key:
            x.right = self._delete(x.right, key)
        else:
            # If they are equal
            x.count -= 1
            if x.count > 0:
                return x
            if x.right is None:
                return x.left
            if x.left is None:
                return x.right
            t: Node = x
            x = self._min(t.right)
            x.right = self._deleteMin(t.right)
            x.left = t.left
        return x

    # This is going to be a BFS function to print the binary tree
    def print(self):
        print("The tree:")
        queue = deque()
        curLevel = 0
        queue.append((self.root, curLevel))
        while len(queue) > 0:
            x, level = queue.popleft()
            if x is not None:
                if level != curLevel:
                    curLevel = level
                    print()
                print(f"({x.key}, {x.count})", end=" ")
                queue.append((x.left, level + 1))
                queue.append((x.right, level + 1))

In [125]:
if __name__ == "__main__":
    bst = BST()
    bst.insert(41, 41)
    bst.insert(20, 20)
    bst.insert(65, 65)
    bst.insert(91, 91)
    bst.insert(99, 99)
    bst.insert(72, 72)
    bst.insert(50, 50)
    bst.insert(11, 11)
    bst.insert(29, 29)
    bst.insert(32, 32)
    bst.insert(35, 35)
    bst.insert(68, 68)
    bst.insert(50, 50)
    bst.delete(29)
    bst.delete(65)
    bst.delete(20)
    bst.insert(75, 75)
    bst.delete(68)
    bst.print()
    print()


The tree:
(41, 1) 
(32, 1) (72, 1) 
(11, 1) (35, 1) (50, 2) (91, 1) 
(75, 1) (99, 1) 
