In [1]:
from collections import deque
import pprint

In [2]:
class Node:
    
    def __init__(self, value):
        self.value = value
        self.right_child = None
        self.left_child = None
        
    
    def __repr__(self):
        return '<Node: %s>' % self.value

In [3]:
class BinaryTree:
    
    def __init__(self, values=None):
        self.head = None
        
        if values:
            for value in values:
                self.add(value)
        
        
    def add(self, value):
        ''' Add item and maintain complete tree form. '''
        if not self.head:
            self.head = Node(value)
            return
        
        # Optimization: make queue an instance variable
        # instead of rebuliding each time.
        queue = deque()
        queue.append(self.head)
        
        while queue:
            node = queue.popleft()
            
            if node.left_child and node.right_child:
                queue.append(node.left_child)
                queue.append(node.right_child)
            elif not node.left_child:
                node.left_child = Node(value)
                return
            else:
                node.right_child = Node(value)
                return
        
        
    def get_min_value(self):
        ''' Find and return minimum value in the tree. '''

        min_value = self._get_min_value(self.head)
        return min_value
    
    
    def _get_min_value(self, node):
        ''' Min value helper. '''
        
        if node:
            min_left = self._get_min_value(node.left_child)
            min_right = self._get_min_value(node.right_child)
            min_val = min(min_left, min_right)
        else:
            return float('inf')
        
        return min(node.value, min_val)

In [4]:
tree = BinaryTree([1,2,3,4,5,6,7,8,-1])

In [5]:
tree.add(9)

In [6]:
tree.head.left_child.left_child.right_child

<Node: -1>

In [7]:
tree.get_min_value()

-1

### Binary Search Tree

In [4]:
class BST(BinaryTree):
    
    def __init__(self):
        BinaryTree.__init__(self)
        
        self.displays = {'in': self._display_in_order,
                        'pre': self._display_pre_order,
                        'post': self._display_post_order}
    
        
    def _add_child(self, value, node):
        if value < node.value:
            if not node.left_child:
                node.left_child = Node(value)
                return
            self._add_child(value, node.left_child)
        else:
            if not node.right_child:
                node.right_child = Node(value)
                return
            self._add_child(value, node.right_child)
                
    
    def add(self, value):
        
        if not self.head:
            self.head = Node(value)
            return
        
        self._add_child(value, self.head)
        
        
    def get_min_value_iteratively(self):
        ''' Iterative approach to finding min value. '''
        
        if not self.head:
            return None
        
        current = self.head
        
        while current.left_child:
            current = current.left_child
            
        return current.value
    
        
    def get_min(self):
        ''' Find and return minimum node in the tree. '''
        
        if not self.head:
            return None
        
        return self._get_min(self.head)
    
    
    def _get_min(self, node):
        ''' Min value helper. '''
        
        if node.left_child:
            return self._get_min(node.left_child)

        return node
        
        
    def get_max_value(self):
        ''' Return max value in tree. '''
        
        if not self.head:
            return None
        
        max_value = self._get_max_value(self.head)
        return max_value
    
    
    def _get_max_value(self, node):
        ''' Get max value helper. '''
        
        if node.right_child:
            return self._get_max_value(node.right_child)
            
        return node.value
    
    
    def _get_x_min_value(self, node, x, count):
        
        if node:
            left_val = self._get_x_min_value(node.left_child, x, count)
            count += 1
            if count == x:
                return node.value
            right_val = self._get_x_min_value(node.right_child, x, count)
            return left_val or right_val
    
    def get_x_min_value(self, x):
        ''' Find a return x min value. '''
        
        count = Count()
        x_min_value = self._get_x_min_value(self.head, x, count)
        return x_min_value
    
    
    def get_height(self):
        ''' Return height of the tree. '''
        
        if not self.head:
            return None
        
        height = self._get_height(self.head)
        return height
        
    
    def _get_height(self, node):
        
        if not node:
            return -1
        
        return max(self._get_height(node.left_child), self._get_height(node.right_child)) +1
    
    
    def level_order_traversal(self):
        ''' Visit each level from left to right. '''
        
        if not self.head:
            return None
        
        level_lists = []
        node_queue = deque()
        node_queue.append((self.head, 0))
        
        while node_queue:
            node, level = node_queue.popleft()
            if len(level_lists) <= level:
                level_lists.append([])
            
            level_lists[level].append(node.value)
            
            if node.left_child:
                node_queue.append((node.left_child, level+1))
            
            if node.right_child:
                node_queue.append((node.right_child, level+1))
        
        return level_lists
        
    def _get_in_order_list(self, node, in_order):
        ''' Return contents of tree as in order list. '''
        
        if node:
            self._get_in_order_list(node.left_child, in_order)
            in_order.append(node.value)
            self._get_in_order_list(node.right_child, in_order)
        
    def get_in_order_list(self):
        in_order = []
        self._get_in_order_list(self.head, in_order)
        return in_order
        
    def display(self, display_type='in'):
        display_func = self.displays.get(display_type)
        display_func(self.head)
    
    def _display_in_order(self, node):
        if node:
            self._display_in_order(node.left_child)
            print(node.value)
            self._display_in_order(node.right_child)
            
    def _display_pre_order(self, node):
        if node:
            print(node.value)
            self._display_pre_order(node.left_child)
            self._display_pre_order(node.right_child)
            
    def _display_post_order(self, node):
        if node:
            self._display_post_order(node.left_child)
            self._display_post_order(node.right_child)
            print(node.value)
            
    def is_bst(self):
        ''' Check if tree is a binary search tree. '''
        if not self.head:
            return False
        
        return self._is_bst(self.head, float('-inf'), float('inf'))
    
    
    def _is_bst(self, node, lower, upper):
        ''' Utility funciton for binary search tree check. '''
        if not node:
            return True
        
        if (node.value > lower
            and node.value < upper
            and self._is_bst(node.left_child, lower, node.value)
            and self._is_bst(node.right_child, node.value, upper)
           ):
            return True
        
        return False
    
    def is_bst_in_order_traversal(self):
        pass
    
    
    def contains(self, value):
        ''' Search tree to see if it contains a specific value. '''
        
        if not self.head:
            return False
        
        return self._contains(self.head, value)
    
    
    def _contains(self, node, value):
        
        if not node:
            return False
        
        if node.value == value:
            return True
        elif node.value > value:
            return self._contains(node.left_child, value)
        else:
            return self._contains(node.right_child, value)
        
    def delete(self, value):
        
        if not self.head:
            return None
        
        self.head = self._delete(self.head, value)
        
    def _delete(self, node, value):
        
        if not node:
            return None
        
        if node.value > value:
            node.left_child = self._delete(node.left_child, value)
        elif node.value < value:
            node.right_child = self._delete(node.right_child, value)
        else:
            # found the element to delete
            # case 1: no childs
            if not node.left_child and not node.right_child:
                return None
            # case 2: one child
            elif not node.left_child:
                return node.right_child
            elif not node.right_child:
                return node.left_child
            else:
                # case 3: has two childs
                min_node = self._get_min(node.right_child)
                # Relpace with min of the right subtree.
                node.value = min_node.value
                # Delete min node from right subtree.
                node.right_child = self._delete(node.right_child, min_node.value)
            
        return node
    
    def _find_node(self, node, value):
        ''' Search and return node that contains specified value. '''
        
        if node:
            if value < node.value:
                return self._find_node(node.left_child, value)
            elif value > node.value:
                return self._find_node(node.right_child, value)
            else:
                return node
    
    def find_in_order_successor(self, value):
        
        if not self.head:
            return None
        
        return self._find_in_order_successor(value)
    
    
    def _find_in_order_successor(self, value):
        target_node = self._find_node(self.head, value)
        
        if not target_node:
            return None
        
        if target_node.right_child:
            return self._get_min(target_node.right_child)
        
        # Find nearest ancestor that is greater than value.
        ancestor = None
        current_node = self.head
        while current_node:
            if value < current_node.value:
                ancestor = current_node
                current_node = current_node.left_child
            elif value > current_node.value:
                current_node = current_node.right_child
            else:
                return ancestor

In [5]:
bst = BST()
bst.add(5)
bst.add(6)
bst.add(7)
bst.add(2)
bst.add(3)
bst.add(4)

In [6]:
bst.display()

2
3
4
5
6
7


In [135]:
# Find in order ancestor
bst.find_in_order_successor(3)

<Node: 4>

In [92]:
bst.delete(5)

In [93]:
bst.display()

2
3
4
6
7


In [221]:
bst.head.right_child.right_child.right_child.value

8

In [206]:
bst.get_in_order_list()

[2, 3, 4, 5, 6, 7]

In [17]:
node = bst.get_min()

<__main__.Node object at 0x1084fdfd0>


In [20]:
node.value

2

In [208]:
bst.get_max_value()

7

#### Get second smallest element

In [209]:
class Count:
    
    def __init__(self):
        self.count = 0
    
    def __add__(self, value):
        self.count += value
        return self
        
    def __eq__(self, other):
        return self.count == other

In [210]:
bst.get_x_min_value(2)

3

In [211]:
bst.get_x_min_value(3)

4

In [212]:
bst.get_min_value_iteratively()

2

In [213]:
bst.get_height()

3

In [214]:
bst.add(8)
bst.add(9)

In [224]:
bst.get_height()

4

In [266]:
levels = bst.level_order_traversal()

In [270]:
for l in levels:
    print(l)

[5]
[2, 6]
[3, 7]
[4]


#### Is BST check

In [311]:
bst.is_bst()

True

In [312]:
not_bst = BST()
not_bst.add(10)
not_bst.add(11)
not_bst.add(9)
not_bst.add(12)

In [313]:
not_bst.head.right_child.right_child.value = 8

In [314]:
not_bst.display()

9
10
11
8


In [315]:
not_bst.is_bst()

False

In [343]:
bst.contains(2)

True

In [346]:
bst.contains(100)

False

### Create balanced BST tree from sorted array
https://leetcode.com/explore/featured/card/top-interview-questions-easy/94/trees/631/

In [27]:
def sorted_array_to_bst(nums):
    ''' Iterative approach. '''
    queue = deque()
    # high, low
    queue.append((len(nums), 0))
    bst = BST()

    while queue:
        high, low = queue.popleft()
        
        if low >= high:
            continue
        
        mid = (high+low) // 2
        bst.add(nums[mid])
        queue.append((high, mid+1))
        queue.append((mid, low))

    return bst

In [30]:
bst = sorted_array_to_bst([-10, -3, 0, 3, 5, 6])

In [31]:
bst.display()

-10
-3
0
3
5
6


In [18]:
def sorted_array_to_bst(bst, nums, high, low):
    if low >= high:
        return
    
    mid = (high+low) // 2
    bst.add(nums[mid])
    sorted_array_to_bst(bst, nums, mid, low)
    sorted_array_to_bst(bst, nums, high, mid+1)

In [23]:
bst = BST()
nums = list(range(9))
sorted_array_to_bst(bst, nums, len(nums), 0)

In [24]:
bst.display()

0
1
2
3
4
5
6
7
8
