In [1]:
import collections

MAXN = 1025

tree = collections.defaultdict(list)
centroidTree = collections.defaultdict(list)
centroidMarked = [False]*MAXN

# method to add edge between to nodes of the undirected tree 
def addEdge(u, v):
	tree[u].append(v)
	tree[v].append(u)

# method to setup subtree sizes and nodes in current tree 
def DFS(src, visited, subtree_size, n):
	# mark node visited 
	visited[src] = True

	# increase count of nodes visited 
	n[0] += 1

	# initialize subtree size for current node
	subtree_size[src] = 1

	# recur on non-visited and non-centroid neighbours 
	for it in tree[src]:
		if not visited[it] and not centroidMarked[it]:
			DFS(it, visited, subtree_size, n)
			subtree_size[src] += subtree_size[it]

def getCentroid(src, visited, subtree_size, n):
	# assume the current node to be centroid 
	is_centroid = True

	# mark it as visited 
	visited[src] = True

	# track heaviest child of node, to use in case node is 
	# not centroid 
	heaviest_child = 0

	# iterate over all adjacent nodes which are children 
	# (not visited) and not marked as centroid to some 
	# subtree
	for it in tree[src]:
		if not visited[it] and not centroidMarked[it]:
			# If any adjacent node has more than n/2 nodes,
			# current node cannot be centroid 
			if subtree_size[it] > n/2:
				is_centroid = False

			# update heaviest child 
			if heaviest_child == 0 or subtree_size[it] > subtree_size[heaviest_child]:
				heaviest_child = it

	# if current node is a centroid 
	if is_centroid and n - subtree_size[src] <= n/2:
		return src

	# else recur on heaviest child 
	return getCentroid(heaviest_child, visited, subtree_size, n)

# function to get the centroid of tree rooted at src.
# tree may be the original one or may belong to the forest 
# function to get the centroid of tree rooted at src.
# tree may be the original one or may belong to the forest 
def getCentroidTree(src):
	visited = [False]*MAXN

	subtree_size = [0]*MAXN

	# initialize auxiliary arrays
	n = [0]

	# DFS to set up subtree sizes and nodes in current tree 
	DFS(src, visited, subtree_size, n)

	visited = [False]*MAXN

	centroid = getCentroid(src, visited, subtree_size, n[0])

	centroidMarked[centroid] = True

	return centroid


# function to generate centroid tree of tree rooted at src 
def decomposeTree(root):
	# get centroid for current tree 
	cend_tree = getCentroidTree(root)

	print(cend_tree, end=" ")


	# for every node adjacent to the found centroid, 
	# decompose the tree rooted at that node
	for it in tree[cend_tree]:
		if not centroidMarked[it]:
			decomposeTree(it)

# driver code
if __name__ == "__main__":
	# number of nodes in the tree 
	n = 16

	# arguments in order: node u, node v
	# sequencing starts from 1 
	addEdge(1, 4)
	addEdge(2, 4)
	addEdge(3, 4)
	addEdge(4, 5)
	addEdge(5, 6)
	addEdge(6, 7)
	addEdge(7, 8)
	addEdge(7, 9)
	addEdge(6, 10)
	addEdge(10, 11)
	addEdge(11, 12)
	addEdge(11, 13)
	addEdge(12, 14)
	addEdge(13, 15)
	addEdge(13, 16)

	# generates centroid tree 
	decomposeTree(1)


6 4 1 2 3 5 7 8 9 11 10 12 14 13 15 16 

In [3]:
# Function to find the centroid of a tree

# centroid = the node that minimizes the size of the largest subtree formed by removing it.


# Useful variables

import collections

MAXN = 1025 # max size, can be changed

tree = collections.defaultdict(list)
centroidTree = collections.defaultdict(list)
centroidMarked = [False]*MAXN

# method to setup subtree sizes and nodes in current tree (gfg)
def DFS(src, visited, subtree_size, n):
     
    # src: The current node being considered as a centroid.
    # visited: A list or array indicating whether a node has been visited.
    # subtree_size: A list or array containing the sizes of the subtrees rooted at each node.
    # n: total number of nodes

	# mark node visited 
	visited[src] = True

	# increase count of nodes visited 
	n[0] += 1

	# initialize subtree size for current node
	subtree_size[src] = 1

	# recur on non-visited and non-centroid neighbours 
	for it in tree[src]:
		if not visited[it] and not centroidMarked[it]:
			DFS(it, visited, subtree_size, n)
			subtree_size[src] += subtree_size[it]
               

# Function to find the centroid (gfg)
def getCentroid(src, visited, subtree_size, n): 
     
    # src: The current node being considered as a centroid.
    # visited: A list or array indicating whether a node has been visited.
    # subtree_size: A list or array containing the sizes of the subtrees rooted at each node.
    # n: total number of nodes

	# assume the current node to be centroid 
	is_centroid = True

	# mark it as visited 
	visited[src] = True

	# track heaviest child of node, to use in case node is 
	# not centroid 
	heaviest_child = 0

	# iterate over all adjacent nodes which are children 
	# (not visited) and not marked as centroid to some 
	# subtree
	for it in tree[src]:
		if not visited[it] and not centroidMarked[it]:
			# If any adjacent node has more than n/2 nodes,
			# current node cannot be centroid 
			if subtree_size[it] > n/2:
				is_centroid = False

			# update heaviest child 
			if heaviest_child == 0 or subtree_size[it] > subtree_size[heaviest_child]:
				heaviest_child = it

	# if current node is a centroid 
	if is_centroid and n - subtree_size[src] <= n/2:
		return src

	# else recur on heaviest child 
	return getCentroid(heaviest_child, visited, subtree_size, n) 


# We also need to consider the subtrees resulting from this centroid

# Firstly, find the nodes:
def collectSubtreeNodes(node, parent, subtree_nodes, visited):
    visited[node] = True
    subtree_nodes.append(node)
    for it in tree[node]:
        if it != parent and not visited[it] and not centroidMarked[it]:
            collectSubtreeNodes(it, node, subtree_nodes, visited)


# Function to get the subtrees resulting from the centroid
def getSubtrees(src):
    subtrees = []
    visited = [False] * MAXN

    for it in tree[src]:
        if not centroidMarked[it]:
            subtree_nodes = []
            collectSubtreeNodes(it, src, subtree_nodes, visited)
            subtrees.append(subtree_nodes)
    return subtrees

# Main function to get centroid and its subtrees
def findCentroidAndSubtrees(src):
    visited = [False] * MAXN
    subtree_size = [0] * MAXN
    n = [0]  # using a list to simulate pass by reference
    DFS(src, visited, subtree_size, n)
    
    visited = [False] * MAXN
    centroid = getCentroid(src, visited, subtree_size, n[0])
    
    centroidMarked[centroid] = True
    subtrees = getSubtrees(centroid)
    
    return centroid, subtrees

#### Balanced Tree example

In [4]:
# Resetting tree and centroidMarked for a fresh start
tree = collections.defaultdict(list)
centroidMarked = [False]*MAXN

# Constructing a balanced tree
#       1
#      / \
#     2   3
#    /|   |\
#   4 5   6 7

tree[1].extend([2, 3])
tree[2].extend([1, 4, 5])
tree[3].extend([1, 6, 7])
tree[4].extend([2])
tree[5].extend([2])
tree[6].extend([3])
tree[7].extend([3])

centroid, subtrees = findCentroidAndSubtrees(1)
print("Example 1 - Balanced Tree:")
print("Centroid:", centroid)
print("Subtrees:", subtrees)
print("\n")


Example 1 - Balanced Tree:
Centroid: 1
Subtrees: [[2, 4, 5], [3, 6, 7]]




#### Skewed Tree

In [5]:
# Resetting tree and centroidMarked for a fresh start
tree = collections.defaultdict(list)
centroidMarked = [False]*MAXN

# Constructing a skewed tree
#   1
#    \
#     2
#      \
#       3
#        \
#         4
#          \
#           5

tree[1].extend([2])
tree[2].extend([1, 3])
tree[3].extend([2, 4])
tree[4].extend([3, 5])
tree[5].extend([4])

centroid, subtrees = findCentroidAndSubtrees(1)
print("Example 2 - Skewed Tree:")
print("Centroid:", centroid)
print("Subtrees:", subtrees)
print("\n")


Example 2 - Skewed Tree:
Centroid: 3
Subtrees: [[2, 1], [4, 5]]




#### Unbalanced Tree

In [6]:
# Resetting tree and centroidMarked for a fresh start
tree = collections.defaultdict(list)
centroidMarked = [False]*MAXN

# Constructing an unbalanced tree
#       1
#      / \
#     2   3
#    /   / \
#   4   5   6
#      / \
#     7   8
#    /
#   9

tree[1].extend([2, 3])
tree[2].extend([1, 4])
tree[3].extend([1, 5, 6])
tree[4].extend([2])
tree[5].extend([3, 7, 8])
tree[6].extend([3])
tree[7].extend([5, 9])
tree[8].extend([5])
tree[9].extend([7])

centroid, subtrees = findCentroidAndSubtrees(1)
print("Example 3 - Unbalanced Tree:")
print("Centroid:", centroid)
print("Subtrees:", subtrees)
print("\n")


Example 3 - Unbalanced Tree:
Centroid: 3
Subtrees: [[1, 2, 4], [5, 7, 9, 8], [6]]


