# Trees

## Binary Trees
- 2 nodes/children: left and right
- Perfect binary tree: all nodes have 2 children, except for the leaves
- Full binary tree: all nodes have either 0 or 2 children

## Binary Search Trees
- Left child is less than parent
- Right child is greater than parent
- No duplicate values
- Left and right subtrees are also binary search trees

### Big O
- Insert: O(log n)
- Search: O(log n)
- Delete: O(log n)

### Traversal

- In-order traversal: left, root, right
- Pre-order traversal: root, left, right

### Removal
1. Find the node to remove
2. Handle based on the number of children:
    1. No children: Simply remove the node by updating its parent's reference
    2. One child: Replace the node with its child
    3. Two children:
        - Find the inorder successor (minimum node in right subtree)
        - Replace the node's value with the successor's value
        - Remove the successor node (which will have at most one child)

In [48]:
from typing import Optional
class BinaryTreeNode:
    def __init__(self, value: any):
        self.value = value
        self.parent:Optional['BinaryTreeNode'] = None
        self.left:Optional['BinaryTreeNode'] = None
        self.right:Optional['BinaryTreeNode'] = None

    def __eq__(self, other: 'BinaryTreeNode') -> bool:
        if other is None:
            return False
        return self.value == other.value
    
    def __gt__(self, other: 'BinaryTreeNode') -> bool:
        if other is None:
            return True
        return self.value > other.value
    
    def __lt__(self, other: 'BinaryTreeNode') -> bool:
        if other is None:
            return False
        return self.value < other.value
    
    def remove_child(self, child: 'BinaryTreeNode'):
        if self.left == child:
            self.left = None

        if self.right == child:
            self.right = None

    def has_children(self) -> bool:
        return self.num_children() != 0
    
    def num_children(self) -> int:
        count = 0
        count += 1 if self.left is not None else 0
        count += 1 if self.right is not None else 0
        return count
    
    def delink(self):
        """clear all relationships"""
        self.parent = None
        self.right = None
        self.left = None

    def replace(self, target: 'BinaryTreeNode', value: Optional['BinaryTreeNode']):
        if self.left == target:
            self.left = value
            return
        
        if self.right == target:
            self.right = value
            return
        
        raise ValueError(f"{target} doesn't exist on parent {self}")
    
    def succeed(self, successor: 'BinaryTreeNode'):
        
        successor.parent = self.parent
        successor.right = self.right if self.right != successor else successor.right
        successor.left = self.left

        # Update the parent of the successor's children
        if successor.right is not None:
            successor.right.parent = successor
        if successor.left is not None:
            successor.left.parent = successor
            
        # Update parent's reference to self
        if self.parent is not None:
            self.parent.replace(self, successor)

    def __str__(self):
        return str(self.value)

class BinarySearchTree:
    def __init__(self, root: Optional[BinaryTreeNode] = None):
        self.root: Optional[BinaryTreeNode] = root

    def insert(self, value: any):
        """Insert a unique value into the tree, left is less than current node's value, right is greater than current node's value"""
        # Need to check for duplicates first
        if self.search(value) is not None:
            return
        
        node = BinaryTreeNode(value)

        if self.root is None:
            self.root = node
        else:
            # Find the position of the new node in the tree
            next = None
            parent = self.root
            next = parent.left if node < parent else parent.right
            while next is not None:
                parent = next
                next = parent.left if node < parent else parent.right

            # Insert the node
            if node < parent:
                parent.left = node
            else:
                parent.right = node
            node.parent = parent

    def remove(self, value: any) -> Optional[BinaryTreeNode]:
        """Alg:
            1. Find the node to remove
            2. Handle based on the number of children:
                1. No children: Simply remove the node by updating its parent's reference
                2. One child: Replace the node with its child
                3. Two children:
                    - Find the inorder successor (minimum node in right subtree)
                    - Replace the node's value with the successor's value
                    - Remove the successor node (which will have at most one child)

            Output: returns the removed node
        """
        # 1. Get the node to remove and return None if it doesn't exist
        node_to_remove = self.search(value)
        if node_to_remove is None:
            raise ValueError(f"{value} not found in tree")
        
        # 2. Case 1: If no children, just remove the node
        if not node_to_remove.has_children():
            node_to_remove.parent.remove_child(node_to_remove)
            node_to_remove.delink()
            return node_to_remove

        # 3. Case 2: If only 1 child, remove node and replace it with that node
        if node_to_remove.num_children() == 1:
            parent = node_to_remove.parent
            child = node_to_remove.left if node_to_remove.left is not None else node_to_remove.right
            if node_to_remove == parent.left:
                parent.left = child
            else:
                parent.right = child
            child.parent = parent

            if node_to_remove == self.root:
                self.root = child

            node_to_remove.delink()
            return node_to_remove
        
        # 4. Case 3: If 2 children, find the successor node
        if node_to_remove.num_children() == 2:
            successor = node_to_remove.right
            while successor.left is not None:
                successor = successor.left

            # We now need to replace the successor node with in the parent with a child if it has one
            sparent = successor.parent

            if sparent is not None:
                left_side = (successor == sparent.left)

                # Handle the gap left by moving the successor node
                if successor.right is not None:
                    if left_side:
                        sparent.left = successor.right
                    else:
                        sparent.right = successor.right
                    successor.right.parent = sparent
                else:
                    if left_side:
                        sparent.left = None
                    else:
                        sparent.right = None

            # Delink the successor and update it to the node_to_remove's relationships
            successor.delink()
            node_to_remove.succeed(successor)
            
            # Update the root if necessary
            if node_to_remove == self.root:
                self.root = successor

            node_to_remove.delink()
            return node_to_remove

    def search(self, value: any) -> Optional[BinaryTreeNode]:
        cur_node = self.root

        while cur_node is not None:
            if cur_node.value == value:
                return cur_node

            cur_node = cur_node.left if value < cur_node.value else cur_node.right

        
        return None
    
    def __repr__(self):
        def DrawBTSTraverse(node) -> list[str]:
            if node is None:
                return []
            
            output = [
                f"{node.value}",
                f"{"/" if node.left is not None else ""}\t {"\\" if node.right is not None else ""}",
                ]
            left_output = DrawBTSTraverse(node.left)
            right_output = DrawBTSTraverse(node.right)
            for i in range(max(len(left_output), len(right_output))):
                # merge them
                output.append(
                    f"{left_output[i] if len(left_output) > i else ""} {'\t'} {right_output[i] if len(right_output) > i else ""}"
                )
            return output


        node = self.root
        output = DrawBTSTraverse(node)
        output = ['\t'*((len(output)-i)//2)+x for i, x in enumerate(output)]
        "\n".join(output)
        return "\n".join(output)

In [49]:
# Desired tree
#        9
#      /   \
#     4     15
#    / \   /  \
#   1   6 11  20

tree = BinarySearchTree()
tree.insert(9)
tree.insert(4)
tree.insert(6)
tree.insert(15)
tree.insert(20)
tree.insert(11)
tree.insert(1)
tree.insert(21)

print(tree)

# Test removal of leaf node (1)
tree.remove(1)
print("\nAfter removing 1:")
print(tree)
assert tree.root.left.left is None

# Test removal of node with one child (15)
tree.remove(15)
print("\nAfter removing 15:")
print(tree)
assert tree.root.right.value == 20
assert tree.root.right.left.value == 11

# Test removal of node with two children (4)
tree.remove(4)
print("\nAfter removing 4:")
print(tree)
assert tree.root.left.value == 6

# Test removing root node (9)
tree.remove(9)
print("\nAfter removing root 9:")
print(tree)
assert tree.root.value == 11
assert tree.root.left.value == 6
assert tree.root.right.value == 20

# Test removing non-existent node
try:
    tree.remove(100)
    assert False, "Should have raised ValueError"
except ValueError:
    print("\nCorrectly raised ValueError when removing non-existent node")

# Test empty tree
empty_tree = BinarySearchTree()
try:
    empty_tree.remove(1)
    assert False, "Should have raised ValueError"
except ValueError:
    print("Correctly raised ValueError when removing from empty tree")


				9
			/	 \
			4 	 15
		/	 \ 	 /	 \
		1 	 6 	 11 	 20
		  	 	  	 	  	 	 \
	 	  	  	 21
 	  	  	 	 

After removing 1:
				9
			/	 \
			4 	 15
			 \ 	 /	 \
		 	 6 	 11 	 20
	 	 	  	 	  	 	 \
	 	  	  	 21
 	  	  	 	 

After removing 15:
			9
		/	 \
		4 	 20
		 \ 	 /	 \
	 	 6 	 11 	 21
 	 	  	 	  	 	 

After removing 4:
			9
		/	 \
		6 	 20
		  	 /	 \
	 	 11 	 21
 	 	  	 	 

After removing root 9:
			11
		/	 \
		6 	 20
		  	 	 \
	 	  	 21
 	  	 	 

Correctly raised ValueError when removing non-existent node
Correctly raised ValueError when removing from empty tree
