### [Trim a binary search tree](https://leetcode.com/problems/trim-a-binary-search-tree/description/)

Given a binary search tree and the lowest and highest boundaries as L and R, trim the tree so that all its elements lies in [L, R] (R >= L). You might need to change the root of the tree, so the result should return the new root of the trimmed binary search tree.

Example 1:
```
Input: 
    1
   / \
  0   2

  L = 1
  R = 2

Output: 
    1
      \
       2
```

Example 2:
    
```
Input: 
    3
   / \
  0   4
   \
    2
   /
  1

  L = 1
  R = 3

Output: 
      3
     / 
   2   
  /
 1
```

In [2]:
# Definition for a binary tree node.
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None


In [1]:
class Solution(object):
    def trimBST(self, root, L, R):
        """
        :type root: TreeNode
        :type L: int
        :type R: int
        :rtype: TreeNode
        """
        
        # The first solution was quite naive. It passed the tests, but suffered
        # from bad complexity. O(n*log(n)) time and O(n) space (without the recursion)
        # Also, it returned a new tree without trimming the existing one. That's not
        # what is expected off of this question.
        
        # lets see if we can do better in time and space.
        # what are the base cases?
        
        # root should not be empty
        # if root.val < L
        #   // all values to the left are anyway less than the current root.val.
        #   // so current root and all nodes to its right can be trimmed. 
        #   // trimmed tree for this case would be the right subtree
        # if root.val > R
        #   // all values to the right are greater than the current root.val
        #   // so current root and all nodes to its right can be trimmed.
        #   // trimmed tree for this case would be the left subtree
        # root.val lies in range L..R
        #   // so, root is escaped from trimming. 
        #   // now, trim the left and right subtrees
        
        def trim(node):
            if not node:
                return None
            elif node.val > R:
                # drop the root and right subtree
                return trim(node.left)
            elif node.val < L:
                # drop the root and left subtree
                return trim(node.right)
            else:
                # take the root
                # check the left and right subtree whether they need to be trimmed or not.
                node.left = trim(node.left)
                node.right = trim(node.right)
                return node

        return trim(root)
    
    def trimBSTNaive(self, root, L, R):
        """
        :type root: TreeNode
        :type L: int
        :type R: int
        :rtype: TreeNode
        """
        
        # edge cases
        if not root:
            return None
        
        # pre-order traversal
        # insert each node into a new tree if it is within the range of L and R
        self.newroot = None
        self.preOrder(root, L, R)
        
        return self.newroot
    
    def preOrder(self, node, L, R):
        if not node:
            return
        
        if L <= node.val <= R:
            if not self.newroot:
                self.newroot = TreeNode(node.val)
            else:
                self.insertNode(self.newroot, node.val)

        self.preOrder(node.left, L, R)
        self.preOrder(node.right, L, R)
    
    def insertNode(self, root, val):
        
        if not root:
            return TreeNode(val)
        
        if val < root.val:
            root.left = self.insertNode(root.left, val)
        else:
            root.right = self.insertNode(root.right, val)
            
        return root

In [13]:
# Helper cell to generate test trees. 
# This will come in handy in future tree based problems

import random

def generateTestTree(numNodes):
    # pick a random root
    # all elements less than root goes to the left.
    # all elements greater than root goes to the right.

    def insertNode(root, val):
        
        if not root:
            return TreeNode(val)
        
        if val < root.val:
            root.left = insertNode(root.left, val)
        elif val > root.val:
            root.right = insertNode(root.right, val)
        
        return root
    
    
    rootVal = random.randrange(0, numNodes)
    root = TreeNode(rootVal)
    
    for val in range(numNodes):
        insertNode(root, val)

    return root

def inOrder(node):
    if node:
        yield from inOrder(node.left)
        yield node.val
        yield from inOrder(node.right)
         


In [11]:
s = Solution()

assert s.trimBST(None, 1, 10) == None, "Expected None"

minNodes = 1
maxNodes = 20
numTestRoot = 10

numNodes = [random.randrange(minNodes, maxNodes) for _ in range(numTestRoot)]

for n in numNodes:
    root = generateTestTree(n)
    L = random.randrange(minNodes, maxNodes)
    R = random.randrange(minNodes, maxNodes)
    L = min(L, R)
    R = max(L, R)
    
    newroot = s.trimBST(root, L, R)
    if newroot:
        valuesInNewTree = [val for val in inOrder(newroot)]
        assert valuesInNewTree[0] >= L, "Invalid trimming on the left"
        assert valuesInNewTree[-1] <= R, "Invalid trimming on the right"
    