In [138]:
import networkx as nx
from typing import List
from collections import deque

def solve(transmissions: List[List[int]], num_servers: int) -> int:
    # Check if the number of servers is at least 2
    if num_servers < 2:
        print("The number of servers should be >= 2")
        return -2

    # Check if the number of directed transmissions is less than the number of servers
    if len(transmissions) >= num_servers:
        print("The number of directed transmissions should be < n")
        return -2
    
    G = nx.DiGraph()
    for i in range(num_servers):
        G.add_node(i)
    for u, v in transmissions:
        G.add_edge(u, v)
    
    # if the graph is not connected -> not feasible -> directly return -1
    if not nx.is_weakly_connected(G):
        return -1
    
    # Compute the nearest predecessors for each node
    nearest_predecessors = {}
    for node in range(num_servers):
        predecessors = list(G.predecessors(node))
        if predecessors:
            # 1. Compute the shortest path from each predecessor to the node
            predecessor_distances = [(predecessor, nx.shortest_path_length(G, source=predecessor, target=node)) 
                                     for predecessor in predecessors]
            # 2. Sort the predecessors by the path length
            predecessor_distances.sort(key=lambda x: x[1])
            # 3. Keep only the (max) two nearest predecessors (0, 1, 2)
            nearest_predecessors[node] = [predecessor for predecessor, distance in predecessor_distances 
                                          if distance == predecessor_distances[0][1]]
        else:
            nearest_predecessors[node] = []
    print(nearest_predecessors)
            
    # Convert the graph to an undirected graph and compute the depth of each node
    U = G.to_undirected()
    depths = nx.single_source_shortest_path_length(U, 0)

    # Create a list of nodes sorted by depth
    nodes_sorted_by_depth = sorted(range(1, num_servers), key=depths.get)  # Exclude the central server

    # if the graph is connected & num of edge is n-1 -> it is a tree -> do the BFS search 
    changes = 0
    for node in nodes_sorted_by_depth:  # Exclude the central server, starting from smaller depth to larger depth
        try: # Try to find a path from node to the central server
            _ = nx.shortest_path(G, node, 0)  
        except nx.NetworkXNoPath: # If there's no path, change the direction of an edge from a predecessor of node & smaller depth
            predecessors = nearest_predecessors[node]
            if len(predecessors)!=0: # If there's no predecessor -> go to next node
                # 1. Choose the exact "last" predecessors list of current node
                # 2. Choose the predecessor with the smallest depth 
                predecessor = min(predecessors, key=depths.get)
                print("node: "+str(node)+", pre:"+str(predecessor))
                G.remove_edge(predecessor, node)
                G.add_edge(node, predecessor)
                changes += 1
        
    # Final Check if all servers can now reach the central server
    for node in range(1, num_servers):  # Exclude the central server
        try:
            _ = nx.shortest_path(G, node, 0)  # Try to find a path from node to the central server
        except nx.NetworkXNoPath:
            return -1  # If a server cannot reach the central server, return -1

    return changes

In [139]:
# test case 1.1: n>2 & connected & changes at leaf nodes 
transmissions = [[1, 0], [3, 1], [1, 2], [4,0], [4,5]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(1)...')

{0: [1, 4], 1: [3], 2: [1], 3: [], 4: [], 5: [4]}
node: 2, pre:1
node: 5, pre:4
Passed test case 1(1)...


In [140]:
# test case 1.2: n>2 & connected & changes at inner nodes 
transmissions = [[0, 1], [3, 1], [2, 1], [4, 0], [5, 4]]
num_servers = 6
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(2)...')

{0: [4], 1: [0, 3, 2], 2: [], 3: [], 4: [5], 5: []}
node: 1, pre:0
Passed test case 1(2)...


In [141]:
# test case 1.3: n>2 & connected & changes at inner nodes 
transmissions = [[1, 0], [3, 1], [2, 1], [0, 4], [5, 4]]
num_servers = 6
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(3)...')

{0: [1], 1: [3, 2], 2: [], 3: [], 4: [0, 5], 5: []}
node: 4, pre:0
Passed test case 1(3)...


In [142]:
# test case 1.4: n>2 & connected & changes at inner nodes 
transmissions = [[0, 1], [3, 1], [2, 1], [0, 4], [5, 4]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(4)...')

{0: [], 1: [0, 3, 2], 2: [], 3: [], 4: [0, 5], 5: []}
node: 1, pre:0
node: 4, pre:0
Passed test case 1(4)...


In [143]:
# test case 1.5: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 1], [3, 1], [1, 2], [4, 0], [5, 4]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(5)...')

{0: [4], 1: [0, 3], 2: [1], 3: [], 4: [5], 5: []}
node: 1, pre:0
node: 2, pre:1
Passed test case 1(5)...


In [144]:
# test case 1.6: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 1], [3, 1], [1, 2], [4, 0], [4, 5]]
num_servers = 6
answer = 3
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(6)...')

{0: [4], 1: [0, 3], 2: [1], 3: [], 4: [], 5: [4]}
node: 1, pre:0
node: 2, pre:1
node: 5, pre:4
Passed test case 1(6)...


In [145]:
# test case 1.7: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 3], [2, 3], [2, 1]]
num_servers = 4
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(7)...')

{0: [], 1: [2], 2: [], 3: [0, 2]}
node: 3, pre:0
node: 1, pre:2
Passed test case 1(7)...


In [146]:
# test case 1.8: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 3], [2, 3], [1, 2]]
num_servers = 4
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(8)...')

{0: [], 1: [], 2: [1], 3: [0, 2]}
node: 3, pre:0
Passed test case 1(8)...


In [147]:
# test case 1.9: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 3], [3, 2], [2, 1]]
num_servers = 4
answer = 3
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(9)...')

{0: [], 1: [2], 2: [3], 3: [0]}
node: 3, pre:0
node: 2, pre:3
node: 1, pre:2
Passed test case 1(9)...


In [148]:
# test case 1.10: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 3], [3, 2], [1, 2]]
num_servers = 4
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(10)...')

{0: [], 1: [], 2: [3, 1], 3: [0]}
node: 3, pre:0
node: 2, pre:3
Passed test case 1(10)...


In [149]:
# test case 1.11: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[3, 0], [3, 2], [1, 2]]
num_servers = 4
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(11)...')

{0: [3], 1: [], 2: [3, 1], 3: []}
node: 2, pre:3
Passed test case 1(11)...


In [150]:
# test case 1.12: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[3, 0], [2, 3], [2, 1]]
num_servers = 4
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(12)...')

{0: [3], 1: [2], 2: [], 3: [2]}
node: 1, pre:2
Passed test case 1(12)...


In [151]:
# test case 1.13: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[3, 0], [3, 2], [2, 1]]
num_servers = 4
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(13)...')

{0: [3], 1: [2], 2: [3], 3: []}
node: 2, pre:3
node: 1, pre:2
Passed test case 1(13)...


In [152]:
# test case 1.14: n>2 & connected & no change
transmissions = [[3, 0], [2, 3], [1, 2]]
num_servers = 4
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(14)...')

{0: [3], 1: [], 2: [1], 3: [2]}
Passed test case 1(14)...


In [153]:
# test case 2: n>2 & connected & no change
transmissions = [[1, 0], [3, 1], [2, 3]]
num_servers = 4
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 2: expected {answer}, got {result}"
print('Passed test case 2...')

{0: [1], 1: [3], 2: [], 3: [2]}
Passed test case 2...


In [154]:
# Test case 3: n>2 & connected & no change
transmissions = [[1, 0], [2, 0], [3, 0]]
num_servers = 4
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 3: expected {answer}, got {result}"
print('Passed test case 3...')

{0: [1, 2, 3], 1: [], 2: [], 3: []}
Passed test case 3...


In [155]:
# Test case 4: n>2 & not connected
transmissions = [[1, 2], [2, 3], [3, 4]]
num_servers = 5
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 4: expected {answer}, got {result}"
print('Passed test case 4...')

Passed test case 4...


In [156]:
# Test case 5: n=2 & connected & changes
transmissions = [[0, 1]]
num_servers = 2
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 5: expected {answer}, got {result}"
print('Passed test case 5...')

{0: [], 1: [0]}
node: 1, pre:0
Passed test case 5...


In [157]:
# Test case 6: n=2 & connected & no change
transmissions = [[1, 0]]
num_servers = 2
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 6: expected {answer}, got {result}"
print('Passed test case 6...')

{0: [1], 1: []}
Passed test case 6...


In [158]:
# Test case 7: n=2 & not connected
transmissions = []
num_servers = 2
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 7: expected {answer}, got {result}"
print('Passed test case 7...')

Passed test case 7...


In [159]:
# Test case 8: n=1
transmissions = []
num_servers = 1
answer = -2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 8: expected {answer}, got {result}"
print('Passed test case 8...')

The number of servers should be >= 2
Passed test case 8...


In [160]:
# Test case 9: transmission>=n
transmissions = [[0,1],[1,0]]
num_servers = 2
answer = -2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 9: expected {answer}, got {result}"
print('Passed test case 9...')

The number of directed transmissions should be < n
Passed test case 9...


In [161]:
# Test case 10: n=3 & connected & no change
transmissions = [[0, 2], [2, 1]]
num_servers = 3
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 10: expected {answer}, got {result}"
print('Passed test case 10...')

{0: [], 1: [2], 2: [0]}
node: 2, pre:0
node: 1, pre:2
Passed test case 10...


In [162]:
# Test case 11: n=4 & connected & no change
transmissions = [[0, 2], [2, 1], [1, 3], [2, 3]]
num_servers = 4
answer = -2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 11: expected {answer}, got {result}"
print('Passed test case 11...')

The number of directed transmissions should be < n
Passed test case 11...


In [163]:
# Test case 12: n=4 & connected & no change
transmissions = [[1, 0], [2, 1], [2, 3]]
num_servers = 4
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 12: expected {answer}, got {result}"
print('Passed test case 12...')

{0: [1], 1: [2], 2: [], 3: [2]}
node: 3, pre:2
Passed test case 12...


In [164]:
# Test case 13: n=7 & connected & no change
transmissions = [[0, 1], [1, 2], [2, 3],[3, 4], [1,4]]
num_servers = 7
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 13: expected {answer}, got {result}"
print('Passed test case 13...')

Passed test case 13...


In [165]:
# Test case 14: n=7 & connected & no change
transmissions = [[0, 1], [1, 0]]
num_servers = 3
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 14: expected {answer}, got {result}"
print('Passed test case 14...')

Passed test case 14...
