# Given a binary tree, flatten it into linked list in-place and return the left most node of the flattened tree. After flattening, right of each node should contain next node in in-order.

In [1]:
class BinaryTree:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

In [2]:
a = BinaryTree(1)
a.left = BinaryTree(2)
a.right = BinaryTree(3)
a.left.left = BinaryTree(4)
a.left.right = BinaryTree(5)
a.right.left = BinaryTree(6)
a.left.right.left = BinaryTree(7)
a.left.right.right = BinaryTree(8)

In [None]:
"""
        1
       / \
      2   3
     / \ / 
    4  5 6  
      / \  
      7  8   
"""

In [3]:
# Naive sol.. O(n)T / O(n)S

def flattenBinaryTree_1(root):
    inOrderNodes = getNodesInOrder(root, [])
    
    for i in range(0, len(inOrderNodes)-1):
        leftNode = inOrderNodes[i]
        rightNode = inOrderNodes[i+1]
        
        leftNode.right = rightNode
        rightNode.left = leftNode
        
    return inOrderNodes[0].value

def getNodesInOrder(tree, array):
    if tree is not None:
        getNodesInOrder(tree.left, array)
        array.append(tree)
        getNodesInOrder(tree.right, array)
    
    return array

In [4]:
flattenBinaryTree_1(a)

4

In [5]:
class BinaryTree:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

In [6]:
a = BinaryTree(1)
a.left = BinaryTree(2)
a.right = BinaryTree(3)
a.left.left = BinaryTree(4)
a.left.right = BinaryTree(5)
a.right.left = BinaryTree(6)
a.left.right.left = BinaryTree(7)
a.left.right.right = BinaryTree(8)

In [None]:
"""
        1
       / \
      2   3
     / \ / 
    4  5 6  
      / \  
      7  8   
"""

In [7]:
# Optimal sol.. O(n)T / O(d)S - d is depth of tree

def flattenBinaryTree_2(root):
    leftMost, _ = flattenTree(root)
    
    return leftMost.value

def flattenTree(node):
    if node.left is None:
        leftMost = node
    else:
        leftSubtreeLeftMost, leftSubtreeRightMost = flattenTree(node.left)
        connectNodes(leftSubtreeRightMost, node)
        leftMost = leftSubtreeLeftMost
        
    if node.right is None:
        rightMost = node
    else:
        rightSubtreeLeftMost, rightSubtreeRightMost = flattenTree(node.right)
        connectNodes(node, rightSubtreeLeftMost)
        rightMost = rightSubtreeRightMost
    
    return [leftMost, rightMost]

def connectNodes(left, right):
    left.right = right
    right.left = left

In [8]:
flattenBinaryTree_2(a)

4