In [6]:
import heapq
from collections import defaultdict

def find_minimum_spanning_tree(total_locations, connections):
    # Create an adjacency list to represent the map of all locations and their connections
    city_map = defaultdict(list)
    for location_a, location_b, cost in connections:
        city_map[location_a].append((cost, location_b))
        city_map[location_b].append((cost, location_a))

    # Boolean array to keep track of which locations are already connected
    connected = [False] * (total_locations + 1)

    # Min-heap to always choose the lowest-cost available connection
    route_options = [(0, 0, -1)]  # (cost, current_location, from_location)

    # Initialize total cost and list to store selected routes
    overall_cost = 0
    selected_routes = []

    # Loop until all reachable locations are connected
    while route_options:
        # Get the cheapest connection available
        cost, current_location, from_location = heapq.heappop(route_options)

        # Skip if this location has already been connected
        if connected[current_location]:
            continue

        # Mark the location as connected
        connected[current_location] = True

        # Add the cost of this connection
        overall_cost += cost

        # Save the connection if it's not the starting node
        if from_location != -1:
            selected_routes.append((from_location, current_location, cost))

        # Add all neighboring connections to the heap
        for next_cost, neighbor_location in city_map[current_location]:
            if not connected[neighbor_location]:
                heapq.heappush(route_options, (next_cost, neighbor_location, current_location))

    # Return the total cost and the selected routes in the MST
    return overall_cost, selected_routes


In [8]:
# Define number of locations (nodes)
total_locations = 5

# Define available connections with their respective costs
connections = [
    (0, 1, 2),
    (0, 3, 6),
    (1, 2, 3),
    (1, 3, 8),
    (1, 4, 5),
    (2, 4, 7),
    (3, 4, 9)
]

# Run Prim's algorithm
cost, routes = find_minimum_spanning_tree(total_locations, connections)

# Print the total cost
print("Minimum cost to connect all locations:", cost)

# Print each route in the MST
print("Selected connections:")
for from_location, to_location, cost in routes:
    print(f"{from_location} -> {to_location} (cost: {cost})")


Minimum cost to connect all locations: 16
Selected connections:
0 -> 1 (cost: 2)
1 -> 2 (cost: 3)
1 -> 4 (cost: 5)
0 -> 3 (cost: 6)
