In [14]:
from typing import Optional, Tuple
import time

class Node:
    def __init__(self, key: int):
        self.key = key
        self.left: Optional[Node] = None
        self.right: Optional[Node] = None
        self.parent: Optional[Node] = None  # Add parent pointer

class SplayTree:
    def __init__(self):
        self.root: Optional[Node] = None
        self.step_count = 0  # Track splay steps

    def rotate(self, u: Node):
        p = u.parent
        if not p:
            return  # Cannot rotate, u is root
        g = p.parent  # Grandparent
        if p.left == u:
            # Right rotation
            p.left = u.right
            if u.right:
                u.right.parent = p
            u.right = p
        else:
            # Left rotation
            p.right = u.left
            if u.left:
                u.left.parent = p
            u.left = p
        p.parent = u
        u.parent = g
        if g:
            if g.left == p:
                g.left = u
            else:
                g.right = u
        else:
            self.root = u  # u is new root

    def splayStep(self, u: Node, visualizer=None):
        p = u.parent
        if not p:
            return  # u is root
        g = p.parent
        if not g:
            # Zig case
            if visualizer:
                self.step_count += 1
                visualizer.print_tree(self.root, f"Splay Step {self.step_count}: Zig rotation at node {u.key}")
                time.sleep(0.5)
            self.rotate(u)
        elif (g.left == p and p.left == u) or (g.right == p and p.right == u):
            # Zig-zig case
            if visualizer:
                self.step_count += 1
                visualizer.print_tree(self.root, f"Splay Step {self.step_count}: Zig-zig rotation at node {u.key}")
                time.sleep(0.5)
            self.rotate(p)
            self.rotate(u)
        else:
            # Zig-zag case
            if visualizer:
                self.step_count += 1
                visualizer.print_tree(self.root, f"Splay Step {self.step_count}: Zig-zag rotation at node {u.key}")
                time.sleep(0.5)
            self.rotate(u)
            self.rotate(u)

    def splay(self, u: Node, visualizer=None):
        while u.parent:
            self.splayStep(u, visualizer)

    def insert(self, key: int, visualizer=None):
        if visualizer:
            visualizer.print_tree(self.root, f"Before inserting {key}")
            self.step_count = 0

        if not self.root:
            self.root = Node(key)
            if visualizer:
                visualizer.print_tree(self.root, f"After inserting {key}")
            return

        # Standard BST insertion with parent pointers
        current = self.root
        while True:
            if key == current.key:
                break  # Key already exists
            elif key < current.key:
                if current.left:
                    current = current.left
                else:
                    current.left = Node(key)
                    current.left.parent = current
                    current = current.left
                    break
            else:
                if current.right:
                    current = current.right
                else:
                    current.right = Node(key)
                    current.right.parent = current
                    current = current.right
                    break

        # Splay the node
        self.splay(current, visualizer)

        if visualizer:
            visualizer.print_tree(self.root, f"After inserting {key}")

    def lookup(self, key: int, visualizer=None) -> bool:
        if visualizer:
            visualizer.print_tree(self.root, f"Before looking up {key}")
            self.step_count = 0

        current = self.root
        last = None
        while current:
            last = current
            if key == current.key:
                break
            elif key < current.key:
                current = current.left
            else:
                current = current.right

        if current:
            self.splay(current, visualizer)
            found = True
        else:
            self.splay(last, visualizer)
            found = False

        if visualizer:
            visualizer.print_tree(self.root, f"After looking up {key} (Found: {found})")
        return found

    def delete(self, key: int, visualizer=None):
        if not self.root:
            return

        if visualizer:
            visualizer.print_tree(self.root, f"Before deleting {key}")
            self.step_count = 0

        # Splay the node with key to the root
        current = self.root
        last = None
        while current:
            last = current
            if key == current.key:
                break
            elif key < current.key:
                current = current.left
            else:
                current = current.right

        if current:
            self.splay(current, visualizer)
        else:
            self.splay(last, visualizer)
            # Key not found
            if visualizer:
                visualizer.print_tree(self.root, f"After deleting {key} (Key not found)")
            return

        # Now the node with key is at root
        # Remove the root
        if not self.root.left:
            self.root = self.root.right
            if self.root:
                self.root.parent = None
        else:
            left_subtree = self.root.left
            left_subtree.parent = None
            right_subtree = self.root.right
            # Splay the largest node in left subtree
            max_node = left_subtree
            while max_node.right:
                max_node = max_node.right
            self.splay(max_node)
            # Now max_node is root, and its right child is None
            self.root.right = right_subtree
            if right_subtree:
                right_subtree.parent = self.root

        if visualizer:
            visualizer.print_tree(self.root, f"After deleting {key}")

    def split(self, key: int, visualizer=None) -> Tuple[Optional[Node], Optional[Node]]:
        if not self.root:
            return None, None

        if visualizer:
            visualizer.print_tree(self.root, f"Before splitting at {key}")
            self.step_count = 0

        # Splay the key
        current = self.root
        last = None
        while current:
            last = current
            if key == current.key:
                break
            elif key < current.key:
                current = current.left
            else:
                current = current.right

        if current:
            self.splay(current, visualizer)
        else:
            self.splay(last, visualizer)
            current = self.root

        if current.key < key:
            left_tree = self.root
            right_tree = self.root.right
            if right_tree:
                right_tree.parent = None
            left_tree.right = None
        else:
            right_tree = self.root
            left_tree = self.root.left
            if left_tree:
                left_tree.parent = None
            right_tree.left = None

        if visualizer:
            visualizer.print_tree(left_tree, "Left subtree after split")
            visualizer.print_tree(right_tree, "Right subtree after split")

        return (left_tree, right_tree)

class TreeVisualizer:
    # ANSI color codes
    COLORS = {
        'reset': '\033[0m',
        'bold': '\033[1m',
        'red': '\033[91m',
        'green': '\033[92m',
        'yellow': '\033[93m',
        'blue': '\033[94m',
        'magenta': '\033[95m',
        'cyan': '\033[96m',
    }

    @staticmethod
    def get_tree_structure(root: Optional[Node], prefix: str = "", is_right: bool = False) -> list:
        if not root:
            return []

        lines = []
        new_prefix = prefix + ("    " if not is_right else "│   ")

        # Process right child first for better visualization
        if root.right:
            lines.extend(TreeVisualizer.get_tree_structure(root.right, new_prefix, True))

        # Current node
        lines.append(prefix + ("└── " if not is_right else "┌── ") + str(root.key))

        # Process left child
        if root.left:
            lines.extend(TreeVisualizer.get_tree_structure(root.left, new_prefix, False))

        return lines

    @staticmethod
    def print_tree(root: Optional[Node], title: str = "Tree Structure"):
        print("\n" + "="*60)
        print(f"{TreeVisualizer.COLORS['bold']}{TreeVisualizer.COLORS['blue']}{title}{TreeVisualizer.COLORS['reset']}")
        print("="*60)

        if not root:
            print(f"{TreeVisualizer.COLORS['red']}Empty Tree{TreeVisualizer.COLORS['reset']}")
            return

        lines = TreeVisualizer.get_tree_structure(root)
        max_value_length = max(len(str(line.split('── ')[-1])) for line in lines)
        
        for line in lines:
            prefix, value = line.rsplit('── ', 1)
            colored_value = TreeVisualizer.COLORS['cyan'] + value.ljust(max_value_length) + TreeVisualizer.COLORS['reset']
            print(prefix + '── ' + colored_value)

def demonstrate_operations():
    tree = SplayTree()
    visualizer = TreeVisualizer()

    # Helper function to pause between operations
    def pause():
        time.sleep(1)

    operations = [
        ("Initial empty tree", lambda: None),
        ("Insert 3", lambda: tree.insert(3, visualizer)),
        ("Insert 2", lambda: tree.insert(2, visualizer)),
        ("Insert 1", lambda: tree.insert(1, visualizer)),
        ("Insert 6", lambda: tree.insert(6, visualizer)),
        ("Insert 5", lambda: tree.insert(5, visualizer)),
        ("Insert 4", lambda: tree.insert(4, visualizer)),
        ("Insert 9", lambda: tree.insert(9, visualizer)),
        ("Insert 8", lambda: tree.insert(8, visualizer)),
        ("Insert 7", lambda: tree.insert(7, visualizer)),
        ("Lookup 3", lambda: tree.lookup(3, visualizer)),
        ("Delete 7", lambda: tree.delete(7, visualizer)),
        ("Insert 12", lambda: tree.insert(12, visualizer)),
        ("Insert 15", lambda: tree.insert(15, visualizer)),
        ("Insert 14", lambda: tree.insert(14, visualizer)),
        ("Insert 13", lambda: tree.insert(13, visualizer))
    ]

    # Perform all operations except split
    for description, operation in operations:
        operation()
        pause()

    # Handle split operation separately
    print(f"\n{TreeVisualizer.COLORS['bold']}{TreeVisualizer.COLORS['magenta']}Performing Split Operation at 8{TreeVisualizer.COLORS['reset']}")
    left, right = tree.split(8, visualizer)

if __name__ == "__main__":
    demonstrate_operations()



[1m[94mBefore inserting 3[0m
[91mEmpty Tree[0m

[1m[94mAfter inserting 3[0m
└── [96m3[0m

[1m[94mBefore inserting 2[0m
└── [96m3[0m

[1m[94mSplay Step 1: Zig rotation at node 2[0m
└── [96m3[0m
    └── [96m2[0m

[1m[94mAfter inserting 2[0m
    ┌── [96m3[0m
└── [96m2[0m

[1m[94mBefore inserting 1[0m
    ┌── [96m3[0m
└── [96m2[0m

[1m[94mSplay Step 1: Zig rotation at node 1[0m
    ┌── [96m3[0m
└── [96m2[0m
    └── [96m1[0m

[1m[94mAfter inserting 1[0m
    │   ┌── [96m3[0m
    ┌── [96m2[0m
└── [96m1[0m

[1m[94mBefore inserting 6[0m
    │   ┌── [96m3[0m
    ┌── [96m2[0m
└── [96m1[0m

[1m[94mSplay Step 1: Zig-zig rotation at node 6[0m
    │   │   ┌── [96m6[0m
    │   ┌── [96m3[0m
    ┌── [96m2[0m
└── [96m1[0m

[1m[94mSplay Step 2: Zig rotation at node 6[0m
    ┌── [96m6[0m
    │   └── [96m3[0m
    │       └── [96m2[0m
└── [96m1[0m

[1m[94mAfter inserting 6[0m
└── [96m6[0m
        ┌── [96m3[0m
       