In [1]:
from itertools import chain

In [2]:
class TreeNode:
    def __init__(self, id):
        self.id = id
        self.children = []

    def add_child(self, child_node):
        """
        Adds a child TreeNode to the current node.
        :param child_node: TreeNode instance to be added as a child.
        """
        self.children.append(child_node)

    @property
    def grandchildren(self):
        """
        Returns a list of all grandchildren of the current node.
        :return: List of TreeNode instances that are grandchildren of this node.
        """
        grandchildren = []
        for child in self.children:
            grandchildren.extend(child.children)
        return grandchildren

    def display(self):
        """
        Recursively displays the tree structure starting from this node.
        """
        self.__display(prefix="", is_last=True, is_root=True)

    def __display(self, prefix="", is_last=True, is_root=False):
        connector = ""
        new_prefix = ""
        if not is_root:
            connector = "└── " if is_last else "├── "
            new_prefix = prefix + ("    " if is_last else "│   ")

        print(prefix + connector + f"{self.id}")

        for i, child in enumerate(self.children):
            is_last_child = (i == len(self.children) - 1)
            child.__display(prefix=new_prefix, is_last=is_last_child)

In [3]:
def create_sample_tree():
    """
    Creates a sample tree for demonstration purposes.

    Tree Structure:
         1
       / | \
      2  3  4
         | | \
         5 6  7

    :return: Root TreeNode of the sample tree.
    """
    root = TreeNode(1)
    child_2 = TreeNode(2)
    child_3 = TreeNode(3)
    child_4 = TreeNode(4)

    child_5 = TreeNode(5)
    child_6 = TreeNode(6)
    child_7 = TreeNode(7)

    root.add_child(child_2)
    root.add_child(child_3)
    root.add_child(child_4)

    child_3.add_child(child_5)
    child_4.add_child(child_6)
    child_4.add_child(child_7)

    return root

In [4]:
root = create_sample_tree()
root.display()

1
├── 2
├── 3
│   └── 5
└── 4
    ├── 6
    └── 7


In [5]:
def mis_tree(node):
    c_mis = [mis_tree(c) for c in node.children]  # MIS for each child node
    c_mis = list(chain.from_iterable(c_mis))  # flatten list (2d -> 1d)

    gc_mis = [mis_tree(gc) for gc in node.grandchildren]  # MIS for each grandchild node
    gc_mis = list(chain.from_iterable(gc_mis))  # flatten list (2d -> 1d)
    gc_mis.append(node)  # add current node to MIS

    node_mis = max(c_mis, gc_mis, key=len)  # select the largest MIS
    return node_mis

In [6]:
mis = mis_tree(root)
print([n.id for n in mis])

[2, 5, 6, 7]
