In [None]:
# init from graph1.ipynb

## Minimum Spanning Tree (MST) Implementation Using Kruskal's Algorithm

In [None]:
import sys

In [None]:
# pylint: disable=too-few-public-methods

class Edge:
    """
    An edge of an undirected graph
    """

    def __init__(self, source, target, weight):
        self.source = source  # The source vertex of the edge
        self.target = target  # The target vertex of the edge
        self.weight = weight  # The weight (cost) of the edge


class DisjointSet:
    """
    The disjoint set is represented with a list <n> of integers where
    <n[i]> is the parent of the node at position <i>.
    If <n[i]> = <i>, <i> is a root, or a head, of a set.
    """

    def __init__(self, size):
        """
        Args:
            size (int): Number of vertices in the graph
        """
        self.parent = [None] * size  # Contains which node is the parent of the node at position <i>
        self.size = [1] * size  # Contains size of node at index <i>, used to optimize merge
        for i in range(size):
            self.parent[i] = i  # Make all nodes their own parent, creating n sets.

    def merge_set(self, node1, node2):
        """
        Merge the sets containing node1 and node2.
        
        Args:
            node1, node2 (int): Indexes of nodes whose sets will be merged.
        """
        # Get the set of nodes at position <node1> and <node2>
        node1 = self.find_set(node1)  # Find root of set for node1
        node2 = self.find_set(node2)  # Find root of set for node2

        # Join the smaller set to the larger one, minimizing tree size (faster find)
        if self.size[node1] < self.size[node2]:
            self.parent[node1] = node2  # Merge set(node1) and set(node2)
            self.size[node2] += self.size[node1]  # Update size of merged set
        else:
            self.parent[node2] = node1  # Merge set(node2) and set(node1)
            self.size[node1] += self.size[node2]  # Update size of merged set

    def find_set(self, node):
        """
        Get the root element of the set containing <node>.
        
        Args:
            node (int): The index of the node to find the root for.
        """
        if self.parent[node] != node:
            # Path compression: memoize the result for faster future queries
            self.parent[node] = self.find_set(self.parent[node])

        # Return the root of the set
        return self.parent[node]


def kruskal(vertex_count, edges, forest):
    """
    Compute the minimum spanning tree (MST) using Kruskal's algorithm.
    
    Args:
        vertex_count (int): Number of vertices in the graph
        edges (list of Edge): Edges of the graph
        forest (DisjointSet): DisjointSet of the vertices
    
    Returns:
        int: Sum of weights of the minimum spanning tree
    
    Procedure:
        Sort the edges by weight.
        Add edges to the MST if they connect vertices from different sets.
        Stop when n-1 edges have been added to the MST.
    """
    edges.sort(key=lambda edge: edge.weight)  # Sort edges by weight

    mst = []  # List to store edges included in the MST

    for edge in edges:
        set_u = forest.find_set(edge.source)  # Find the set of the source vertex
        set_v = forest.find_set(edge.target)  # Find the set of the target vertex
        if set_u != set_v:  # If they are in different sets
            forest.merge_set(set_u, set_v)  # Merge the sets
            mst.append(edge)  # Include this edge in the MST
            if len(mst) == vertex_count - 1:  # If we have enough edges
                break  # Stop the process

    return sum(edge.weight for edge in mst)  # Return the sum of weights of the MST


def main():
    """
    Test the program. Input format:
    First line: integers n (number of vertices) and m (number of edges)
    Next m lines: edges in the format -> node index u, node index v, integer weight

    Sample input:
    5 6
    1 2 3
    1 3 8
    2 4 5
    3 4 2
    3 5 4
    4 5 6

    Output: Sum of weights of the minimum spanning tree.
    """
    for size in sys.stdin:
        vertex_count, edge_count = map(int, size.split())
        forest = DisjointSet(vertex_count)  # Initialize disjoint set for vertices
        edges = [None] * edge_count  # Create list for edges

        # Read <m> edges from input
        for i in range(edge_count):
            source, target, weight = map(int, input().split())
            source -= 1  # Convert from 1-indexed to 0-indexed
            target -= 1  # Convert from 1-indexed to 0-indexed
            edges[i] = Edge(source, target, weight)  # Create Edge object

        # Compute the MST using Kruskal's algorithm
        print("MST weights sum:", kruskal(vertex_count, edges, forest))


if __name__ == "__main__":
    main()  # Run the main function
