In [None]:
import math
import pandas as pd
from tqdm import tqdm

class Node:
    def __init__(self, r=False):
        self.keys = []
        self.pointers = []
        self.data = []
        self.r = r  # This variable indicates whether the node is root or not.


class BTree:
    def __init__(self, t, r=True):  # t is maximum degree.
        self.root = Node(r)
        self.t = t
        self.min = math.ceil(t / 2) - 1  # This variable indicates minimum number of keys in the node.
        self.max = t - 1  # This variable indicates maximum number of keys in the node.

    # search function use the binary search.
    def search(self, key, curr = None):
        if not curr:
            curr = self.root
        while True:
            left = 0
            right = len(curr.keys) - 1
            while left <= right:
                mid = (left + right) // 2
                if key < curr.keys[mid]:
                    right = mid - 1
                elif key > curr.keys[mid]:
                    left = mid + 1
                else:
                    return curr.data[mid]  # return value.

            if curr.pointers:
                if right == len(curr.keys) - 1:
                    curr = curr.pointers[-1]
                elif left < len(curr.pointers):
                    curr = curr.pointers[left]
            else:
                return None  # This key is not included in this tree.

    def insert_key(self, key, value, curr=None):
        if not curr:
            curr = self.root
        # This part finds insertion position recursively using binary search.
        left = 0
        right = len(curr.keys) - 1
        while left <= right:
            mid = (left + right) // 2
            if key < curr.keys[mid]:
                right = mid - 1
            elif key > curr.keys[mid]:
                left = mid + 1
            else:
                return None  # This indicates that this tree already has insertion key.

        if curr.pointers:
            if right == len(curr.keys) - 1:
                result = self.insert_key(key, value, curr.pointers[-1])
            else:
                result = self.insert_key(key, value, curr.pointers[left])

        # This part is insertion process
        if not curr.pointers: # if leaf node,
            if len(curr.keys) > 0:
                curr.keys.insert(left, key)
                curr.data.insert(left, value)
            else:
                curr.keys.append(key)
                curr.data.append(value)
        elif result: # split case
            curr.keys.insert(left, result[0])
            curr.data.insert(left, result[3])
            curr.pointers[left] = result[1]
            curr.pointers.insert(left + 1, result[2])
        else:
            return None  # This indicates that there is insertion key.

        if len(curr.keys) > self.max:
            if curr.r == False:
                return self.split(curr)
            else:  # case : this node is root
                curr.r = False
                mid, left, right, mid_data = self.split(curr)
                temp = Node(True)
                temp.keys = [mid]
                temp.data = [mid_data]
                temp.pointers = [left, right]
                self.root = temp

        return None

    def split(self, curr):
        left = Node()  # create new node
        right = Node()  # create new node
        left.keys = curr.keys[:self.min]
        left.data = curr.data[:self.min]
        right.keys = curr.keys[self.min + 1:]
        right.data = curr.data[self.min + 1:]
        mid = curr.keys[self.min]
        mid_data = curr.data[self.min]
        if curr.pointers:  # move the pointers
            left.pointers = curr.pointers[:self.min + 1]
            right.pointers = curr.pointers[self.min + 1:]

        return mid, left, right, mid_data
    
    
    def delete_key(self, key, curr = None):
        if not curr:
            curr = self.root
            
        left = 0
        right = len(curr.keys) - 1
        varfind = 0 # This variable indicates whether the key was found or not
        while left <= right:
            mid = (left+right)//2
            if key == curr.keys[mid]:
                varfind = 1
                break
            elif key < curr.keys[mid]:
                right = mid - 1
            else:
                left = mid + 1
                
        if varfind == 0: # Not found case
            if curr.pointers: # recursively works.
                if right == len(curr.keys) - 1:
                    result = self.delete_key(key, curr.pointers[-1])
                else:
                    result = self.delete_key(key, curr.pointers[left])
            else:
                return None # This indicates that there is not deletion key.
        else: # Found case
            if curr.pointers: # Found in the internal node
                idx = curr.keys.index(key)

                # Case I : left tree's biggest key move
                left_tree = curr.pointers[idx]
                while left_tree.pointers:
                    left_tree = left_tree.pointers[-1]
                if len(left_tree.keys) > self.min:
                    temp = left_tree.keys.pop()
                    curr.keys[idx] = temp
                    temp = left_tree.data.pop()
                    curr.data[idx] = temp
                    return None # It means that no more rotation or merging is required.

                # Case II : right tree's smallest key move
                right_tree = curr.pointers[idx+1]
                while right_tree.pointers:
                    right_tree = right_tree.pointers[0]
                if len(right_tree.keys) > self.min:
                    temp = right_tree.keys.pop(0)
                    curr.keys[idx] = temp
                    temp = right_tree.data.pop(0)
                    curr.data[idx] = temp
                    return None # It means that no more rotation or merging is required.

                # Case III : Not Case I and Not Case II --> left tree is modified.
                temp = left_tree.keys.pop()
                curr.keys[idx] = temp
                temp = left_tree.data.pop()
                curr.data[idx] = temp
                temp = left_tree.keys[0]
                siblings = self.find_siblings(temp, curr)
                if siblings[0] and len(siblings[0].keys) > self.min: # left -> right rotation case
                    self.rotate(siblings[2], True, siblings[3])
                    return None
                elif siblings[1] and len(siblings[1].keys) > self.min: # right -> left rotation case
                    self.rotate(siblings[2], False, siblings[3])
                    return None
                elif siblings[0]: # merging case with left sibling
                    self.merge(siblings[2], True, siblings[3])
                    if len(siblings[3].keys) < self.min:
                        return siblings[3].keys[0]
                    else:
                        return None
                elif siblings[1]: # merging case with right sibling
                    self.merge(siblings[2], False, siblings[3])
                    if len(siblings[3].keys) < self.min:
                        return siblings[3].keys[0]
                    else:
                        return None
            else: # deletion key is in the leaf node.
                idx = curr.keys.index(key)
                curr.keys.pop(idx)
                curr.data.pop(idx)
                if curr.r: # if that is root node
                    return None
                else:
                    if len(curr.keys) >= self.min:
                        return None
                    else:
                        return curr.keys[0]
                
        if not curr.r: # If this node is not root
            if result:
                siblings = self.find_siblings(result, curr)
                
                if siblings[0] and len(siblings[0].keys) > self.min: # left -> right rotation casea
                    self.rotate(siblings[2], True, siblings[3])
                    return None
                elif siblings[1] and len(siblings[1].keys) > self.min: # right -> left rotation case
                    self.rotate(siblings[2], False, siblings[3])
                    return None
                elif siblings[0]: # merging case with left sibling
                    self.merge(siblings[2], True, siblings[3])
                    if len(siblings[3].keys) < self.min:
                        return siblings[3].keys[0]
                    else:
                        return None
                elif siblings[1]: # merging case with right sibling
                    self.merge(siblings[2], False, siblings[3])
                    if len(siblings[3].keys) < self.min:
                        return siblings[3].keys[0]
                    else:
                        return None
            else:
                return None
        else: # If this node is root,
            if result:
                siblings = self.find_siblings(result, curr)
                while True:
                    if siblings[0] and len(siblings[0].keys) > self.min:
                        self.rotate(siblings[2], True, siblings[3])
                        return None
                    elif siblings[1] and len(siblings[1].keys) > self.min:
                        self.rotate(siblings[2], False, siblings[3])
                        return None
                    elif siblings[0]:
                        self.merge(siblings[2], True, siblings[3])
                        if len(siblings[3].keys) < self.min:
                            if len(siblings[3].keys) == 0: # If siblings[3] is root node and the node is empty, root node is modified
                                curr.pointers[0].r = True
                                self.root = curr.pointers[0]
                                return None
                            else:
                                result = siblings[3].keys[0]
                        else:
                            return None
                    elif siblings[1]:
                        self.merge(siblings[2], False, siblings[3])
                        if len(siblings[3].keys) < self.min:
                            if len(siblings[3].keys) == 0: # If siblings[3] is root node and the node is empty, root node is modified
                                curr.pointers[0].r = True
                                self.root = curr.pointers[0]
                                siblings[3] = self.root
                            else:
                                result = siblings[3].keys[0]
                        else:
                            return None
                    elif not siblings[0] and not siblings[1]: # If siblings[3] is root node and result is in the root node
                        return None
                    siblings = self.find_siblings(result, siblings[3])
            else:
                return None
    
    def find_siblings(self, key, curr):
        left = 0
        right = len(curr.keys) - 1
        while left <= right:
            mid = (left + right)//2
            if key < curr.keys[mid]:
                right = mid - 1
            elif key > curr.keys[mid]:
                left = mid + 1
            else: # If key is in the this node, it can't have the siblings.
                return [None, None, left, curr]
        
        if curr.pointers[0].pointers:
            if right == len(curr.keys) - 1:
                if self.search(key, curr.pointers[-1]):
                    return [curr.pointers[-2], None, -1, curr]
                else:
                    return self.find_siblings(key, curr.pointers[-1])
            else:
                if self.search(key, curr.pointers[left]): # If next node has the key
                    if left - 1 >= 0 and left + 1 < len(curr.pointers): # Case I : it has left and right siblings
                        return [curr.pointers[left-1], curr.pointers[left+1], left, curr]
                    elif left - 1 >= 0: # Case II : it has only left siblings
                        return [curr.pointers[left-1], None, left, curr]
                    elif left + 1 < len(curr.pointers): # Case III : it has only right siblings
                        return [None, curr.pointers[left+1], left, curr]
                else:
                    return self.find_siblings(key, curr.pointers[left])
        else:
            if self.search(key, curr.pointers[left]): # If next node has the key
                if left - 1 >= 0 and left + 1 < len(curr.pointers):
                    return [curr.pointers[left-1], curr.pointers[left+1], left, curr]
                elif left - 1 >= 0:
                    return [curr.pointers[left-1], None, left, curr]
                elif left + 1 < len(curr.pointers):
                    return [None, curr.pointers[left+1], left, curr]
            else: # The key is not in the tree.
                return [None, None, left, curr]
            
    def rotate(self, idx, leftside, curr):
        if leftside: # If left -> right rotation
            if idx != -1:
                mid = curr.keys[idx-1]
                mid_data = curr.data[idx-1]
                left = curr.pointers[idx-1]
                right = curr.pointers[idx]
                right.keys.insert(0, mid)
                right.data.insert(0, mid_data)
                if left.pointers: # move the pointer
                    move_pointer = left.pointers.pop()
                    right.pointers.insert(0, move_pointer)
                mid = left.keys.pop()
                mid_data = left.data.pop()
                curr.keys[idx-1] = mid # modify parent node's key
                curr.data[idx-1] = mid_data # modify parent node's data
            else:
                mid = curr.keys[idx]
                mid_data = curr.data[idx]
                left = curr.pointers[idx-1]
                right = curr.pointers[idx]
                right.keys.insert(0, mid)
                right.data.insert(0, mid_data)
                if left.pointers: # move the pointer
                    move_pointer = left.pointers.pop()
                    right.pointers.insert(0, move_pointer)
                mid = left.keys.pop()
                mid_data = left.data.pop()
                curr.keys[idx] = mid # modify parent node's key
                curr.data[idx] = mid_data # modify parent node's data
        else: # If right -> left rotation
            mid = curr.keys[idx]
            mid_data = curr.data[idx]
            left = curr.pointers[idx]
            right = curr.pointers[idx+1]
            if right.pointers: # move the pointer
                move_pointer = right.pointers.pop(0)
                left.pointers.append(move_pointer)
            left.keys.append(mid)
            left.data.append(mid_data)
            mid = right.keys.pop(0)
            mid_data = right.data.pop(0)
            curr.keys[idx] = mid # modify parent node's key
            curr.data[idx] = mid_data # modify parent node's data
            
    def merge(self, idx, leftside, curr):
        if leftside: # merging case with left sibling
            if idx != -1:
                mid = curr.keys[idx-1]
                mid_data = curr.data[idx-1]
                left = curr.pointers[idx-1]
                right = curr.pointers[idx]
                total = Node() # total is merged result node.
                # node's keys merging step
                total.keys = left.keys
                total.keys.append(mid)
                total.keys.extend(right.keys)
                # data merging step
                total.data = left.data
                total.data.append(mid_data)
                total.data.extend(right.data)
                if left.pointers: # pointers merging step
                    total.pointers = left.pointers
                    total.pointers.extend(right.pointers)
                curr.keys.remove(mid)
                curr.data.remove(mid_data)
                curr.pointers.insert(idx-1, total)
                curr.pointers.remove(left)
                curr.pointers.remove(right)
            else: # merging case with left sibling
                mid = curr.keys[idx]
                mid_data = curr.data[idx]
                left = curr.pointers[idx-1]
                right = curr.pointers[idx]
                total = Node() # total is merged result node.
                # node's keys merging step
                total.keys = left.keys
                total.keys.append(mid)
                total.keys.extend(right.keys)
                # data merging step
                total.data = left.data
                total.data.append(mid_data)
                total.data.extend(right.data)
                # pointers merging step
                if left.pointers:
                    total.pointers = left.pointers
                    total.pointers.extend(right.pointers)
                curr.keys.remove(mid)
                curr.data.remove(mid_data)
                curr.pointers.insert(idx, total)
                curr.pointers.remove(left)
                curr.pointers.remove(right)
        else: # merging case with right sibling
            mid = curr.keys[idx]
            mid_data = curr.data[idx]
            left = curr.pointers[idx]
            right = curr.pointers[idx+1]
            total = Node() # total is merged result node.
            # node's keys merging step
            total.keys = left.keys
            total.keys.append(mid)
            total.keys.extend(right.keys)
            # data merging step
            total.data = left.data
            total.data.append(mid_data)
            total.data.extend(right.data)
            # pointers merging step
            if left.pointers:
                total.pointers = left.pointers
                total.pointers.extend(right.pointers)
            curr.keys.remove(mid)
            curr.data.remove(mid_data)
            curr.pointers.insert(idx, total)
            curr.pointers.remove(left)
            curr.pointers.remove(right)
            


def insertion_test(file, n):
    df = pd.read_csv(file, sep='\t', header=None, index_col = 0, names = ['value'])
    A = BTree(128, True)

    for k, v in tqdm(zip(df.index, df.value)):
        A.insert_key(k, v)

    print('Finish the insertion step')
    
    result = pd.DataFrame([A.search(i) for i in tqdm(df.index)], index = df.index, columns=['output'])
    
    result.to_csv('insertion_output{}.csv'.format(n))

    print('Finish the search step')

    if sum(df.value == result.output) == len(df.index):
        print('The result is the same as the input')

def deletion_test(input_file, delete_file, output_file, n):
    idf = pd.read_csv(input_file, sep='\t', header=None, index_col = 0, names = ['value'])
    ddf = pd.read_csv(delete_file, sep='\t', header=None, index_col = 0, names = ['value'])
    odf = pd.read_csv(output_file, sep='\t', header=None, index_col = 0, names = ['value'])
    
    A = BTree(128, True)

    for k, v in tqdm(zip(idf.index, idf.value)):
        A.insert_key(k, v)

    print('Finish the insertion step')

    for i in tqdm(ddf.index):
        A.delete_key(i)
    
    print('Finish the deletion step')

    result = pd.DataFrame([A.search(i) for i in tqdm(odf.index)], index = odf.index, columns = ['output'])
    
    result.to_csv('deletion_output{}.csv'.format(n), na_rep='N/A')

    # NaN이 아니면서 값이 같은 record 개수 + 양쪽 다 NaN 값인 record 개수
    if sum(result[(result.isna() == False).output].output == odf[(odf.isna() == False).value].value) + sum(result[result.isna().output].index == odf[odf.isna().value].index) == len(odf.index):
        print('The result is the same as the output.')

if __name__ == "__main__" :
    insertion_test('input.csv', 1)
    insertion_test('input2.csv', 2)
    deletion_test('input.csv', 'delete.csv', 'delete_result.csv', 1)
    deletion_test('input2.csv', 'delete2.csv', 'delete_result2.csv', 2)
