In [None]:
import sys
import random as rand

class COLOR:
    RED     = "RED"
    BLACK   = "BLACK"

class Node:
    def __init__(self, val):
        self.color = None
        self.p = None
        self.key = val
        self.left = None
        self.right = None         

class RedBlackTree: 
    # follow CLRS textbook: item < the parent is on the left subtree and > on the right subtree
    def __init__(self):
        NIL = Node("NIL") 
        self.T = NIL
        self.NIL = NIL
        self.blackHeight = 0
    
    

    # A simple tree recursive find method
    def find(self, x, i):
        if x == self.NIL:
            print(i, "is not found")
            return self.NIL
        if x.key == i:
            print(i, "is found")
            return x
        elif i < x.key:
            return self.find(x.left, i)
        else:
            return self.find(x.right, i)

    # CLRS textbook RB-INSERT pseudocode page: 315
    def insert(self, i):
        z = Node(i)
        y = self.NIL
        x = self.T
        while x != self.NIL:
            y = x
            if z.key < x.key:
                x = x.left
            else:
                x = x.right
        z.p = y
        if y == self.NIL: #empty tree at the beginning
            self.T = z
        else:
            if z.key < y.key:
                y.left = z
            else:
                y.right = z
        z.left = z.right = self.NIL
        z.color = COLOR.RED
        self.insert_fixup(z)

        return i

    # CLRS textbook RB-INSERT-FixUp pseudocode
    # page: 316
    def insert_fixup(self, z):
        while z.p != self.NIL and z.p.color == COLOR.RED:
            if z.p == z.p.p.left:
                y = z.p.p.right
                if y != self.NIL and y.color == COLOR.RED:
                    z.p.color = COLOR.BLACK
                    y.color = COLOR.BLACK
                    z.p.p.color = COLOR.RED
                    z = z.p.p
                else:
                    if z == z.p.right:
                        z = z.p
                        self.left_rotate(z)
                    z.p.color = COLOR.BLACK
                    z.p.p.color = COLOR.RED
                    self.right_rotate(z.p.p)
            else:
                y = z.p.p.left
                if y.color == COLOR.RED:
                    z.p.color = COLOR.BLACK
                    y.color = COLOR.BLACK
                    z.p.p.color = COLOR.RED
                    z = z.p.p
                else:
                    if z == z.p.left:
                        z = z.p
                        self.right_rotate(z)
                    z.p.color = COLOR.BLACK
                    z.p.p.color = COLOR.RED
                    self.left_rotate(z.p.p)
        self.T.color = COLOR.BLACK

    # CLRS textbook LEFT-ROTATE pseudocode
    # page: 313
    def left_rotate(self, x):
        y = x.right                   # set y
        x.right = y.left              # turn y's left subtree into x's right subtree
        if y.left != self.NIL:
            y.left.p = x
        y.p = x.p           # link x's p to y
        if x.p == self.NIL:
            self.T = y
        elif x == x.p.left:
            x.p.left = y
        else:
            x.p.right = y
        y.left = x                   # put x on y's left
        x.p = y

    # CLRS textbook RIGHT-ROTATE pseudocode
    # page: 313 Ex. 13.2-1 "The code for RIGHT-ROTATE is symmetric to LEFT-ROTATE"
    def right_rotate(self, x):
        y = x.left
        x.left = y.right
        if y.right != self.NIL:
            y.right.p = x
        y.p = x.p
        if x.p == self.NIL:
            self.T = y
        elif x == x.p.right:
            x.p.right = y
        else:
            x.p.left = y
        y.right = x
        x.p = y

    # CLRS textbook TREE-MINIMUM pseudocode
    # page: 291
    def tree_minimum(self, x):
        while x.left != self.NIL:
            x = x.left
        return x

    # CLRS textbook TREE-SUCCESSOR pseudocode
    # page: 292
    def tree_successor(self, x): 
        if x.right != self.NIL:
            return self.tree_minimum(x.right)
        y = x.p
        while y != self.NIL and x == y.right:
            x = y
            y = y.p
        return y

        
    def findBlackHeight(self):
        x = self.T
        bh = 0
        while(x != self.NIL):
            if x.color == COLOR.BLACK:
                bh += 1
            x = x.right
        self.blackHeight = bh
        return bh
        
    
    def __repr__(self): # overload print to print tree
        lines = []
        print_tree(self.T, lines)
        return '\n'.join(lines)
        
def print_tree(node, lines, level=0):
    
    if node.key != "NIL":
        print_tree(node.left, lines, level + 1)
        lines.append('-' * 2 * level + '> ' +
                     str(node.key) + ' ' + ('r' if node.color == COLOR.RED else 'b'))
        print_tree(node.right, lines, level + 1)

In [None]:
def joinRight(T1, x, T2):
    t1bh = T1.findBlackHeight()
    t2bh = T2.findBlackHeight()
    t1 = t1.T
    t2 = t2.T
    if (t1.color == COLOR.BLACK) and (t1bh == t2bh):
        node = Node(x)
        Node.left = t1
        Node.right = t2
        Node.color = COLOR.RED
    _T = RedBlackTree()

    node = Node(T1.key)
    node = t1.left
    node = t1.color
    node = joinRight(t1.right, x, t2)
    
    _T.T = node
    
    if (t1.color == COLOR.BLACK) and (_T.T.right.color == _T.T.right.right.color == COLOR.RED):
        _T.right.right.color = COLOR.BLACK
        return _T.left_rotate(_T.T)
    return _T

In [None]:
def joinLeft(t1, x, t2):
    t1bh = t1.findBlackHeight()
    t2bh = t2.findBlackHeight()
    t1 = t1.T
    t2 = t2.T
    if (T1.color == COLOR.BLACK) and (t1bh == t2bh):
        node = Node(x)
        Node.right = t1
        Node.left = t2
        Node.color = COLOR.RED
    _T = RedBlackTree()
    
    node = Node(t1.key)
    node.right = t1.right
    node.color = t1.color
    node.left = joinLeft(t1.left, x, t2)
    
    _T.T = node
    
    if (t1.color == COLOR.BLACK) and (_T.T.left.color == _T.T.left.left.color == COLOR.RED):
        _T.left.left.color = COLOR.BLACK
        return _T.right_rotate(_T.T)
    return _T

In [None]:
def join(T1, x, T2):
    t1bh = t1.findBlackHeight()
    t2bh = t2.findBlackHeight()
    print("Black height of given trees is", t1bh, "and", t2bh)
    t1 = t1.T
    t2 = t2.T
    if t1bh > t2bh:
        _T = joinRight(t1, x, t2)
        if (_T.T.color == COLOR.RED) and (_T.T.right.color == COLOR.RED):
            _T.T.color = COLOR.BLACK
        return _T
    if T2blackHeight > T1blackHeight:
        _T = joinLeft(t1, x, t2)
        if (_T.T.color == COLOR.RED) and (_T.T.left.color == COLOR.RED):
            _T.T.color = COLOR.BLACK
        return _T
    if (t1.color == COLOR.BLACK) and (t2.color == COLOR.BLACK):
        node = Node(x)
        node.left = t1
        node.color = COLOR.RED
        node.right = t2
        return Node
    
    node = Node(x)
    node.left = t1
    node.color = COLOR.BLACK
    node.right = t2
    
    return Node

In [None]:
def inorder(node): # traversing joint tree
    NIL = Node("NIL")
    if node != NIL:
        inorder(node.left)
        print(f"{node.key} ")
        self.inorder(node.right)

In [16]:
t1 = RedBlackTree()

for _ in range(10):
    t1.insert(rand.randrange(0, 30))
print("Printing T1 ------ ")
t1.inorder()
print("\n------------------")
t2 = RedBlackTree()
for i in range(10):
    t2.insert(rand.randrange(50, 80))
print("Printing T2 ------ ")
t2.inorder()
print("\n------------------")

x = 15 

t3 = join(t1, x, t2) # destroys t1 and t2 and returns t1 U {x} U t2

Printing T1 ------
0 1 2 5 7 8 15 17 23 29
------------------
Printing T2 ------ 
51 54 60 61 62 64 68 71 75 76
------------------
Black height of given trees is 3 and 3
