## Problem: Flatten Binary Tree

Write a function that takes in a Binary Tree, flattens it, and returns its leftmost node.
A flattened Binary Tree is a structure that's nearly identical to a Doubly Linked List (except that nodes
have left and right pointers instead of prev and next pointers), where nodes follow the
original tree's left-to-right order.
Note that if the input Binary Tree happens to be a valid Binary Search Tree, the nodes in the flattened
tree will be sorted.
The flattening should be done in place, meaning that the original data structure should be mutated (no
new structure should be created).
Each BinaryTree node has an integer value , a left child node, and a right child node.
Children nodes can either be BinaryTree nodes themselves or None / null .

Example:
tree = 

             1
           /  \
          2    3
         / \  /
        4   5 6
           / \
          7   8

Output:

4 <-> 2 <-> 7 <-> 5 <-> 8 <-> 1 <-> 6 <-> 3 // the leftmost node with value 4


In [5]:
# This is the class of the input root. Do not edit it.
class BinaryTree:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

In [2]:
## Aproach 1
def flattenBinaryTree(root):
    # Write your code here.
    leftMost, _ = helper(root)
    return leftMost

def helper(root):
    if root.left is None:
        leftMost = root
    else:
        leftside, rightside = helper(root.left)
        rightside.right = root
        root.left = rightside
        leftMost = leftside
    if root.right is None:
        rightMost = root
    else:
        leftside, rightside = helper(root.right)
        leftside.left = root
        root.right = leftside
        rightMost = rightside
    return (leftMost, rightMost)

In [3]:
def printDLL(head):
    while head:
        print(head.value)
        head = head.right
        

In [4]:
root = BinaryTree(1, BinaryTree(2, BinaryTree(4), BinaryTree(5, BinaryTree(7), BinaryTree(8))), BinaryTree(3, BinaryTree(6)))
head = flattenBinaryTree(root)
printDLL(head)

4
2
7
5
8
1
6
3


In [31]:
## Approach 2
def flattenBinaryTree1(root):
    if not root:
        return None
    prev = [None] # Passed the refrence to hold the value during recursion
    head = [None] # Passed the refrence to hold the value during recursion
    flattenTree(root, prev, head)
    return head[0]

# Helper function to perform in-order traversal and flatten the tree
def flattenTree(node, prev, head):
    
    if not node:
        return None

    # Flatten the left subtree
    flattenTree(node.left, prev, head)

    # Update pointers
    if prev[0]:
        prev[0].right = node
        node.left = prev[0]
    else:
        head[0] = node
        
    prev[0] = node

    # Flatten the right subtree
    flattenTree(node.right, prev, head)


In [33]:
root = BinaryTree(1, BinaryTree(2, BinaryTree(4), BinaryTree(5, BinaryTree(7), BinaryTree(8))), BinaryTree(3, BinaryTree(6)))
head = flattenBinaryTree1(root)
printDLL(head)

4
2
7
5
8
1
6
3
