Problem Statement. <br/>

Given the root of a binary tree, each node in the tree has a distinct value. <br/>
After deleting all nodes with a value in to_delete, we are left with a forest (a disjoint union of trees). <br/>
Return the roots of the trees in the remaining forest.  You may return the result in any order. <br/>

Example 1: <br/>
Input: root = [1,2,3,4,5,6,7], to_delete = [3,5] <br/>
Output: [[1,2,null,4],[6],[7]]

# BFS - O(N) runtime, O(N + D) space

In [2]:
from typing import List

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def delNodes(self, root: TreeNode, to_delete: List[int]) -> List[TreeNode]:
        
        res = set()
        queue = deque([root])
        to_delete_set = set(to_delete)
        if root.val in to_delete_set:
            if root.left: res.add(root.left)
            if root.right: res.add(root.right)
        else:
            res.add(root)
        
        while queue:
            node = queue.popleft()
            if node.left: 
                queue.append(node.left)
                if node.left.val in to_delete_set:
                    if node.left.left: res.add(node.left.left)
                    if node.left.right: res.add(node.left.right)
                    node.left = None
            if node.right: 
                queue.append(node.right)
                if node.right.val in to_delete_set:
                    if node.right.left: res.add(node.right.left)
                    if node.right.right: res.add(node.right.right)
                    node.right = None
                        
            if node.val in to_delete_set and node in res:
                res.remove(node)
                
        return list(res)

# DFS Stack - O(N) runtime, O(N + D) space

In [4]:
from typing import List

#Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def delNodes(self, root: TreeNode, to_delete: List[int]) -> List[TreeNode]:
        
        res = set()
        stack = deque([root])
        to_delete_set = set(to_delete)
        if root.val in to_delete_set:
            if root.left: res.add(root.left)
            if root.right: res.add(root.right)
        else:
            res.add(root)
        
        while stack:
            node = stack.pop()
            if node.left: 
                stack.append(node.left)
                if node.left.val in to_delete_set:
                    if node.left.left: res.add(node.left.left)
                    if node.left.right: res.add(node.left.right)
                    node.left = None
            if node.right: 
                stack.append(node.right)
                if node.right.val in to_delete_set:
                    if node.right.left: res.add(node.right.left)
                    if node.right.right: res.add(node.right.right)
                    node.right = None
                        
            if node.val in to_delete_set and node in res:
                res.remove(node)
                
        return list(res)

# DFS Recursive - O(N) runtime, O(N + D) space

In [6]:
from typing import List, Set

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def delNodes(self, root: TreeNode, to_delete: List[int]) -> List[TreeNode]:
        
        res = set()
        to_delete_set = set(to_delete)
        if root.val in to_delete_set:
            if root.left: res.add(root.left)
            if root.right: res.add(root.right)
        else:
            res.add(root)
            
        self.delNodesRecursive(root, res, to_delete_set)
        
        return list(res)
        
    def delNodesRecursive(self, node: TreeNode, res: Set[TreeNode], to_delete_set: Set[int]) -> List[TreeNode]:
        if node.left: 
            self.delNodesRecursive(node.left, res, to_delete_set)
            if node.left.val in to_delete_set:
                if node.left.left: res.add(node.left.left)
                if node.left.right: res.add(node.left.right)
                node.left = None         
        
        if node.right: 
            self.delNodesRecursive(node.right, res, to_delete_set)   
            if node.right.val in to_delete_set:
                if node.right.left: res.add(node.right.left)
                if node.right.right: res.add(node.right.right)
                node.right = None

        if node.val in to_delete_set and node in res:
            res.remove(node)

# Cleaner DFS Recursive - O(N) runtime, O(N + H) space where H is the height of the tree

In [7]:
from typing import List

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def delNodes(self, root: TreeNode, to_delete: List[int]) -> List[TreeNode]:
        to_delete_set = set(to_delete)
        res = []

        def helper(root, is_root):
            if not root: return None
            root_deleted = root.val in to_delete_set
            if is_root and not root_deleted:
                res.append(root)
            root.left = helper(root.left, root_deleted)
            root.right = helper(root.right, root_deleted)
            return None if root_deleted else root
        helper(root, True)
        return res