# Generic Tree Class

In [17]:
class treeNode:
    def __init__(self,data):
        self.data = data
        self.children = list()

# Print Tree (Recursively)

In [18]:
def printTree(root):
    if root == None:
        return
    
    print(root.data)
    for child in root.children:
        printTree(child)

# Print Tree detailed

In [19]:
def printTreeDetailed(root):
    if root == None:
        return
    
    print(root.data,end=": ")
    n = len(root.children)
    for i in range(n):
        child = root.children[i]
        if i < n-1:
            print(child.data,end=",")
        else:
            print(child.data,end="") # Last children should not have comma
    print()
    
    for child in root.children:
        printTreeDetailed(child)

# Take Tree input (Recursively)

In [20]:
def takeTreeInput():
    print("Enter root data")
    rootData = int(input())
    if rootData == -1:
        return None
    
    root = treeNode(rootData)
    print("Enter number of children for", rootData)
    ChildrenCount = int(input())
    for i in range(ChildrenCount):
        child = takeTreeInput()
        root.children.append(child)
        
    return root

# Number Of nodes in Tree

In [21]:
def numNodes(root):
    if root == None:
        return 0
    count = 1
    for child in root.children:
        count+=numNodes(child)
    return count

# Sum of all nodes

In [22]:
def sumOfAllNodes(root):
    if root == None:
        return 0
    
    sum = root.data
    for child in root.children:
        sum+=sumOfAllNodes(child)
    return sum

# Node with largest data

In [23]:
def maxDataNode(root):
    if root == None:
        return None
    maxNode = root
    for child in root.children:
        childMaxNode = maxDataNode(child)
        if childMaxNode.data > maxNode.data:
            maxNode = childMaxNode
    return maxNode

# Height of Tree

In [24]:
def treeHeight(root):
    if root == None :
        return 0
    h = 0
    for child in root.children :
        childHeight = treeHeight(child) 
        if childHeight > h :
            h = childHeight
    return h+1

# Tree input levelwise

In [25]:
import queue

def takeTreeInputLevelWise():
    q = queue.Queue()
    print("Enter root")
    rootData = int(input())
    if rootData == -1:
        return None
    root = TreeNode(rootData)
    q.put(root)
    while not q.empty():
        current_node = q.get()
        print("Enter num of children for",current_node.data)
        numChildren = int(input())
        for i in range(numChildren):
            print("Enter next child for",current_node.data)
            childData = int(input())
            child = TreeNode(childData)
            current_node.children.append(child)
            q.put(child)
    return root

# Print levelwise

In [26]:
import queue

def printLevelWiseTree(root):
    if root == None:
        return
    
    q = queue.Queue()
    q.put(root)
    while not q.empty():
        currentNode = q.get()
        print(currentNode.data,end=':')
        n = len(currentNode.children)
        for i in range(n):
            child = currentNode.children[i]
            if i < n-1:
                print(child.data,end=',')
            else:
                print(child.data,end='')
            q.put(child)
        print()
    return

# Contains x

In [27]:
def containsX(root, x):
    if root == None:
        return False
    if root.data == x:
        return True
    for child in root.children:
        if containsX(child, x):
            return True
    return False

# Count leaf nodes

In [28]:
def leafNodeCount(root):
    if root == None:
        return 0
    leafNodes = 0
    if len(root.children) == 0:
        leafNodes = 1
    for child in root.children:
        leafNodes+=leafNodeCount(child)
    return leafNodes

# Node with maximum child sum

In [29]:
def maxSumNode(root):
    if root == None:
        return None
    
    maxNodeSum = root
    maxSum = root.data
    for child in root.children:
        maxSum+=child.data
        
    for child in root.children:
        childMaxSumNode, childMaxSum = maxSumNode(child)
        if childMaxSum > maxSum:
            maxNodeSum = childMaxSumNode
            maxSum = childMaxSum
            
    return maxNodeSum,maxSum

# Structurally identical

In [30]:
def isIdentical(tree1, tree2):
    if not tree1:
        if not tree2:
            return True
        return False
    if (tree1.data != tree2.data) or len(tree1.children) != len(tree2.children):
        return False
    for child1, child2 in zip(tree1.children, tree2.children):
        if not isIdentical(child1, child2):
            return False
    return True

# Next larger

In [31]:
def nextLargest(root, n):
    if root == None:
        return None

    nextLargerValueNode = None
    if root.data > n :
        nextLargerValueNode = root
    for child in root.children :
        childLargerValueNode = nextLargest(child, n)
        if childLargerValueNode:
            if nextLargerValueNode == None or nextLargerValueNode.data > childLargerValueNode.data :
                nextLargerValueNode = childLargerValueNode
                
    return nextLargerValueNode

# Replace with depth

In [32]:
def replacewithDepthHelper(root,depth=0):
    root.data = depth
    for child in root.children:
        replacewithDepthHelper(child,depth+1)
        
    return root

def replacewithDepth(root):
    if root == None:
        return

    return replacewithDepthHelper(root,depth=0)

# Remove leaf nodes

In [33]:
def removeLeafNodes(root):
    if root == None:
        return None
    
    if len(root.children) == 0:
        return None 

    i = 0
    while i < len(root.children):
        child = root.children[i]

        if (len(child.children) == 0):
            del root.children[i]
            i -= 1
        i += 1

    for i in range(len(root.children)):
        root.children[i] = removeLeafNodes(root.children[i])
        
    return root