In [None]:
# init
import unittest
from bst import Node
from bst import bst

# Finding the k-th Smallest Element in a BST


In [None]:
class Node:

    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


def kth_smallest(root, k):
    stack = []
    while root or stack:
        while root:
            stack.append(root)
            root = root.left
        root = stack.pop()
        k -= 1
        if k == 0:
            break
        root = root.right
    return root.val


class Solution(object):
    def kth_smallest(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        count = []
        self.helper(root, count)
        return count[k-1]

    def helper(self, node, count):
        if not node:
            return

        self.helper(node.left, count)
        count.append(node.val)
        self.helper(node.right, count)

if __name__ == '__main__':
    n1 = Node(100)
    n2 = Node(50)
    n3 = Node(150)
    n4 = Node(25)
    n5 = Node(75)
    n6 = Node(125)
    n7 = Node(175)
    n1.left, n1.right = n2, n3
    n2.left, n2.right = n4, n5
    n3.left, n3.right = n6, n7
    print(kth_smallest(n1, 2))
    print(Solution().kth_smallest(n1, 2))

# Finding the Lowest Common Ancestor (LCA) in a BST


In [None]:
"""
Given a binary search tree (BST),
find the lowest common ancestor (LCA) of two given nodes in the BST.

According to the definition of LCA on Wikipedia:
    “The lowest common ancestor is defined between two
    nodes v and w as the lowest node in T that has both v and w
    as descendants (where we allow a node to be a descendant of itself).”

        _______6______
       /              \
    ___2__          ___8__
   /      \        /      \
   0      _4       7       9
         /  \
         3   5

For example, the lowest common ancestor (LCA) of nodes 2 and 8 is 6.
Another example is LCA of nodes 2 and 4 is 2,
since a node can be a descendant of itself according to the LCA definition.
"""

In [None]:
def lowest_common_ancestor(root, p, q):
    """
    :type root: Node
    :type p: Node
    :type q: Node
    :rtype: Node
    """
    while root:
        if p.val > root.val < q.val:
            root = root.right
        elif p.val < root.val > q.val:
            root = root.left
        else:
            return root


# Counting the Number of Empty Branches in a BT


In [None]:
"""
Write a function num_empty returns returns the number of empty branches in a
tree. Function should count the total number of empty branches among the nodes
of the tree. A leaf node has two empty branches. In the case, if root is None,
it considered as a 1 empty branch
For example: the following tree has 10 empty branch (* is empty branch)

                    9 __
                 /      \___
               6            12
              / \          /   \
            3     8       10      15
          /  \   / \     /  \    /   \
         *    * 7   *   *    *  *    18
               / \                   /  \
              *   *                 *    *

    empty_branch = 10

"""

In [None]:
def num_empty(root):
    if root is None:
        return 1
    elif root.left is None and root.right:
        return 1 + num_empty(root.right)
    elif root.right is None and root.left:
        return 1 + num_empty(root.left)
    else:
        return num_empty(root.left) + num_empty(root.right)

"""
    The tree is created for testing:

                    9
                 /      \
               6         12
              / \       /   \
            3     8   10      15
                 /              \
                7                18

    num_empty = 10

"""

class TestSuite(unittest.TestCase):
    def setUp(self):
        self.tree = bst()
        self.tree.insert(9)
        self.tree.insert(6)
        self.tree.insert(12)
        self.tree.insert(3)
        self.tree.insert(8)
        self.tree.insert(10)
        self.tree.insert(15)
        self.tree.insert(7)
        self.tree.insert(18)

    def test_num_empty(self):
        self.assertEqual(10, num_empty(self.tree.root))

if __name__ == '__main__':
    unittest.main()

# Predecessor 

In [None]:
# https://en.wikipedia.org/wiki/Predecessor_problem

In [None]:
def predecessor(root, node):
    pred = None
    while root:
        if node.val > root.val:
            pred = root
            root = root.right
        else:
            root = root.left
    return pred

# Serialize and Deserialize a Binary Tree

In [None]:
class TreeNode(object):
    def __init__(self, x):
        # Initialize a tree node with a given value x
        self.val = x  # Store the value of the node
        self.left = None  # Initialize the left child as None
        self.right = None  # Initialize the right child as None

def serialize(root):
    """
    Serialize a binary tree to a string.
    
    :param root: The root node of the binary tree
    :return: A string representing the serialized binary tree
    """
    def build_string(node):
        # Helper function to build the serialized string recursively
        if node:
            # If the current node is not None, append its value to the list
            vals.append(str(node.val))
            # Recursively serialize the left subtree
            build_string(node.left)
            # Recursively serialize the right subtree
            build_string(node.right)
        else:
            # If the current node is None, append a placeholder "#" to the list
            vals.append("#")
    
    vals = []  # Initialize a list to hold serialized values
    build_string(root)  # Start building the string from the root
    return " ".join(vals)  # Join the list into a single string with spaces

def deserialize(data):
    """
    Deserialize a string back to a binary tree.
    
    :param data: A string representing the serialized binary tree
    :return: The root node of the deserialized binary tree
    """
    def build_tree():
        # Helper function to build the binary tree recursively
        val = next(vals)  # Get the next value from the iterator
        if val == "#":
            return None  # If the value is "#", return None (no node)
        node = TreeNode(int(val))  # Create a new TreeNode with the integer value
        # Recursively build the left and right children of the node
        node.left = build_tree()  
        node.right = build_tree()  
        return node  # Return the constructed node
    
    vals = iter(data.split())  # Split the input data into values and create an iterator
    return build_tree()  # Start building the tree from the first value


# Find the in-order successor of a given node in a BST

In [None]:
# Predecessor problem
# https://en.wikipedia.org/wiki/Predecessor_problem

In [None]:
# https://www.geeksforgeeks.org/inorder-successor-in-binary-search-tree/

In [None]:
# function used

def successor(root, node):
    succ = None
    while root:
        if node.val < root.val:
            succ = root
            root = root.left
        else:
            root = root.right
    return succ

## Calculate the number of structurally unique BSTs that can be formed with values from 1 to n


In [None]:
"""
Given n, how many structurally unique BST's
(binary search trees) that store values 1...n?

For example,
Given n = 3, there are a total of 5 unique BST's.

   1         3     3      2      1
    \       /     /      / \      \
     3     2     1      1   3      2
    /     /       \                 \
   2     1         2                 3
"""

In [None]:
"""
Taking 1~n as root respectively:
1 as root: # of trees = F(0) * F(n-1)  // F(0) == 1
2 as root: # of trees = F(1) * F(n-2)
3 as root: # of trees = F(2) * F(n-3)
...
n-1 as root: # of trees = F(n-2) * F(1)
n as root:   # of trees = F(n-1) * F(0)

So, the formulation is:
F(n) = F(0) * F(n-1) + F(1) * F(n-2) + F(2) * F(n-3) + ... + F(n-2) * F(1) + F(n-1) * F(0)
"""

In [None]:
def num_trees(n):
    """
    :type n: int
    :rtype: int
    """
    dp = [0] * (n+1)
    dp[0] = 1
    dp[1] = 1
    for i in range(2, n+1):
        for j in range(i+1):
            dp[i] += dp[i-j] * dp[j-1]
    return dp[-1]