# Matrix

<br>

![image](https://user-images.githubusercontent.com/50367487/84627043-26ebf180-af21-11ea-9850-c5ee42556362.png)

In [None]:
#!/bin/python3

import math
import os
import random
import re
import sys

# Complete the minTime function below.
class Node:
    def __init__(self, data, rank=0, parent=None, is_machine=False):
        self.data = data
        self.rank = int(rank)
        self.parent = parent if parent else self
        self.is_machine = is_machine


class DisjointSet:

    def __init__(self):
        self._graph = {}

    def make_set(self, data, is_machine=False):
        node = Node(data, is_machine=is_machine)
        self._graph[data] = node

    def union(self, data1, data2):
        node1 = self._graph.get(data1)
        node2 = self._graph.get(data2)

        parent1 = self._find_set_node(node1.parent)
        parent2 = self._find_set_node(node2.parent)

        if parent1 == parent2:
            return True

        if parent1.rank >= parent2.rank:
            if parent1.rank == parent2.rank:
                parent1.rank += 1
            parent2.parent = parent1
            parent1.is_machine = parent1.is_machine or parent2.is_machine
        else:
            parent1.parent = parent2
            parent2.is_machine = parent2.is_machine or parent1.is_machine
        return True

    def contains(self, data):
        return data in self._graph

    def is_machine_set(self, data):
        node = self._find_set_node(self._graph.get(data))
        return node.is_machine if node else False

    def find_set(self, data):
        node = self._find_set_node(self._graph.get(data))
        return node.data if node else None

    def _find_set_node(self, node):
        if node is None:
            return None

        if node.parent == node:
            return node

        node.parent = self._find_set_node(node.parent)
        return node.parent


def minTime(roads, machines):
    if not roads or not machines:
        return 0

    cities = DisjointSet()
    machines = set(machines)

    for m in machines:
        cities.make_set(m, is_machine=True)

    sorted_roads = sorted(roads, reverse=True, key=lambda r: r[2])

    total_cost = 0
    for left, right, weight in sorted_roads:
        if not cities.contains(left):
            cities.make_set(left, is_machine=False)

        if not cities.contains(right):
            cities.make_set(right, is_machine=False)

        if cities.is_machine_set(left) and cities.is_machine_set(right):
            total_cost += weight
        else:
            cities.union(left, right)

    return total_cost

if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    nk = input().split()

    n = int(nk[0])

    k = int(nk[1])

    roads = []

    for _ in range(n - 1):
        try:
            roads.append(list(map(int, input().rstrip().split())))
        except:
            roads.append(list(map(int, input().rstrip().split())))

    machines = []
    for _ in range(k):
        try:
            machines_item = int(input())
            machines.append(machines_item)
        except:
            pass
    
    result = minTime(roads, machines)

    fptr.write(str(result) + '\n')

    fptr.close()