In [6]:
class Tree():
    def __init__(self, root=None):
        self.root = root
        self.size = 0
        
    def insert(self, node, start_node=None):
        # Create root if no root
        if start_node == None:
            start_node = self.root
        if not self.root:
            self.root = node
        else:
            if node.key < start_node.key:
                if start_node.has_left_child(): # recurse
                    self.insert(node, start_node.left_child)
                else:
                    start_node.left_child = node
                    node.parent = start_node
            else: 
                if start_node.has_right_child():
                    self.insert(node, start_node.right_child)
                else:
                    start_node.right_child = node
                    node.parent = start_node
        self.size += 1
        
    def search(self, key, start_node=None):
        if start_node == None:
            start_node = self.root
        if key == start_node.key:
            return start_node
        if key < start_node.key:
            return self.search(key, start_node.left_child)
        else:
            return self.search(key, start_node.right_child)
        
    def traverse_in_order(self, start_node = None):
        if start_node == None:
            start_node = self.root
        if start_node.has_left_child():
            self.traverse_in_order(start_node.left_child)
        print(start_node.key)
        if start_node.has_right_child():
            self.traverse_in_order(start_node.right_child)        
    
    def predecessor(self, node_key):
        start_node = self.search(node_key)
        
        # Case 1: we have a left child
        if start_node.has_left_child():
            new_node = start_node.left_child
            while new_node.has_right_child():
                new_node = new_node.right_child
            return new_node
                
        # Case 2: we do not have a left child
        else:
            parent_node = start_node.parent
            while parent_node.key > start_node.key:
                parent_node = parent_node.parent
            return parent_node

    

    def delete(self, node_key):
        node = self.search(node_key)
        
        if not node.has_left_child() and not node.has_right_child():
            if node == node.parent.left:
                node.parent.left = None
            else:
                node.parent.right = None
            del node
        elif node.has_left_child() and not node.has_right_child():
            # Node's left child takes position of node
            if node.parent:
                if node == node.parent.left_child:
                    node.parent.left_child = node.left_child
                else:
                    node.parent.right_child = node.right_child
                node.left_child.parent = node.parent
            del node
        elif not node.has_left_child() and node.has_right_child():
            # Node's right child takes position of node
            if node.parent:
                if node == node.parent.left_child:
                    node.parent.left = node.right_child
                else:
                    node.parent.right = node.right_child
                node.right_child.parent = node.parent
            del node
        # TODO: case with two keys
        else:
            predecessor_node = self.predecessor(node_key)
            predecessor_node.key, node_key = node_key, predecessor_node.key 
            del node
    
    def traverse_in_order(self, start_node = None):
        if start_node == None:
            start_node = self.root
        if start_node.has_left_child():
            self.traverse_in_order(start_node.left_child)
        print(start_node.key)
        if start_node.has_right_child():
            self.traverse_in_order(start_node.right_child)        
                        

class Node():

    def __init__(self, key, payload, parent=None, left_child=None, right_child=None):
        '''
        Keys used for sorting, value used for returning.
        '''
        self.key = key
        self.payload = payload
        self.parent = parent
        self.left_child = left_child
        self.right_child = right_child   
        
    def has_left_child(self):
        return True if self.left_child else False
        
    def has_right_child(self):
        return True if self.right_child else False

In [7]:
t = Tree()
t.insert(Node(1, 'a'))
t.insert(Node(2, 'b'))
t.insert(Node(3, 'c'))
t.insert(Node(4, 'd'))
t.insert(Node(5, 'd'))
t.traverse_in_order()
# print(t.search(3).key)
# print(t.predecessor(3).key)
t.delete(3)
t.traverse_in_order()

1
2
3
4
5
1
2
3
4
5


In [49]:
print(t.root.key)
print(t.root.left_child)
print(t.root.right_child.key)
print(t.root.right_child.parent.key)

3
<__main__.Node object at 0x10742fb38>
4
3


In [None]:
print(t.root.right_child)