In [1]:
from itertools import chain

##### Declare the TreeNode class

In [2]:
class TreeNode:
    def __init__(self, id):
        self.id = id
        self.children = []

    def add_child(self, child_node):
        """
        Adds a child TreeNode to the current node.
        :param child_node: TreeNode instance to be added as a child.
        """
        self.children.append(child_node)

    @property
    def grandchildren(self):
        """
        Returns a list of all grandchildren of the current node.
        :return: List of TreeNode instances that are grandchildren of this node.
        """
        grandchildren = []
        for child in self.children:
            grandchildren.extend(child.children)
        return grandchildren

    def display(self):
        """
        Recursively displays the tree structure starting from this node.
        """
        self.__display(prefix="", is_last=True, is_root=True)

    def __display(self, prefix="", is_last=True, is_root=False):
        connector = ""
        new_prefix = ""
        if not is_root:
            connector = "└── " if is_last else "├── "
            new_prefix = prefix + ("    " if is_last else "│   ")
        print(prefix + connector + f"{self.id}")

        for i, child in enumerate(self.children):
            is_last_child = (i == len(self.children) - 1)
            child.__display(prefix=new_prefix, is_last=is_last_child)

#### Finding Maximum Independent Set in a tree

In [3]:
def mis_tree(node):
    c_mis = [mis_tree(c) for c in node.children]  # MIS for each child node
    c_mis = list(chain.from_iterable(c_mis))  # flatten list (2d -> 1d)

    gc_mis = [mis_tree(gc) for gc in node.grandchildren]  # MIS for each grandchild node
    gc_mis = list(chain.from_iterable(gc_mis))  # flatten list (2d -> 1d)
    gc_mis.append(node)  # add current node to MIS

    node_mis = max(c_mis, gc_mis, key=len)  # select the largest MIS
    return node_mis

##### First example

In [4]:
root = TreeNode(1)

node_2 = TreeNode(2)
node_3 = TreeNode(3)
node_4 = TreeNode(4)
node_5 = TreeNode(5)
node_6 = TreeNode(6)
node_7 = TreeNode(7)

root.add_child(node_2)
root.add_child(node_3)
root.add_child(node_4)
node_3.add_child(node_5)
node_4.add_child(node_6)
node_4.add_child(node_7)

root.display()

1
├── 2
├── 3
│   └── 5
└── 4
    ├── 6
    └── 7


In [5]:
mis = sorted(mis_tree(root), key=lambda n: n.id)

print(f"MIS size: {len(mis)}")
print(f"MIS nodes: {[n.id for n in mis]}")

MIS size: 4
MIS nodes: [2, 5, 6, 7]


##### Second example

In [6]:
root = TreeNode(1)

node_2 = TreeNode(2)
node_3 = TreeNode(3)
node_4 = TreeNode(4)
node_5 = TreeNode(5)
node_6 = TreeNode(6)
node_7 = TreeNode(7)
node_8 = TreeNode(8)

root.add_child(node_2)
root.add_child(node_3)
node_2.add_child(node_4)
node_2.add_child(node_5)
node_3.add_child(node_6)
node_3.add_child(node_7)
node_6.add_child(node_8)

root.display()

1
├── 2
│   ├── 4
│   └── 5
└── 3
    ├── 6
    │   └── 8
    └── 7


In [7]:
mis = sorted(mis_tree(root), key=lambda n: n.id)

print(f"MIS size: {len(mis)}")
print(f"MIS nodes: {[n.id for n in mis]}")

MIS size: 5
MIS nodes: [1, 4, 5, 7, 8]
