In [1]:
import networkx as nx
from typing import List
from collections import deque

def solve(transmissions: List[List[int]], num_servers: int) -> int:
    # Check if the number of servers is at least 2
    if num_servers < 2:
        print("The number of servers should be >= 2")
        return -2

    # Check if the number of directed transmissions is less than the number of servers
    if len(transmissions) >= num_servers:
        print("The number of directed transmissions should be < n")
        return -2
    
    G = nx.DiGraph()
    for i in range(num_servers):
        G.add_node(i)
    for u, v in transmissions:
        G.add_edge(u, v)
    
    # if the graph is not connected -> not feasible -> directly return -1
    if not nx.is_weakly_connected(G):
        return -1
    
    # Convert the graph to an undirected graph and compute the depth of each node
    U = G.to_undirected()
    depths = nx.single_source_shortest_path_length(U, 0)

    # if the graph is connected & num of edge is n-1 -> it is a tree -> do the BFS search 
    changes = 0
    for node in range(1, num_servers):  # Exclude the central server
        try: # Try to find a path from node to the central server
            _ = nx.shortest_path(G, node, 0)  
        except nx.NetworkXNoPath: # If there's no path, change the direction of an edge from a predecessor of node & smaller depth
            predecessors = list(G.predecessors(node))
            # If there's no predecessor -> go to next node
            if predecessors: 
                # Choose the predecessor with the smallest depth (1 predecessor vs. 2 predecessors)
                predecessor = min(predecessors, key=depths.get)
                G.remove_edge(predecessor, node)
                G.add_edge(node, predecessor)
                changes += 1

    # Check if all servers can now reach the central server
    for node in range(1, num_servers):  # Exclude the central server
        try:
            _ = nx.shortest_path(G, node, 0)  # Try to find a path from node to the central server
        except nx.NetworkXNoPath:
            return -1  # If a server cannot reach the central server, return -1

    return changes

In [2]:
# test case 1.1: n>2 & connected & changes at leaf nodes 
transmissions = [[1, 0], [3, 1], [1, 2], [4,0], [4,5]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(1)...')

Passed test case 1(1)...


In [3]:
# test case 1.2: n>2 & connected & changes at inner nodes 
transmissions = [[0, 1], [3, 1], [2, 1], [4, 0], [5, 4]]
num_servers = 6
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(2)...')

Passed test case 1(2)...


In [4]:
# test case 1.3: n>2 & connected & changes at inner nodes 
transmissions = [[1, 0], [3, 1], [2, 1], [0, 4], [5, 4]]
num_servers = 6
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(3)...')

Passed test case 1(3)...


In [5]:
# test case 1.4: n>2 & connected & changes at inner nodes 
transmissions = [[0, 1], [3, 1], [2, 1], [0, 4], [5, 4]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(4)...')

Passed test case 1(4)...


In [6]:
# test case 1.5: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 1], [3, 1], [1, 2], [4, 0], [5, 4]]
num_servers = 6
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(5)...')

Passed test case 1(5)...


In [7]:
# test case 1.6: n>2 & connected & changes at both inner nodes and leaf nodes
transmissions = [[0, 1], [3, 1], [1, 2], [4, 0], [4, 5]]
num_servers = 6
answer = 3
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 1: expected {answer}, got {result}"
print('Passed test case 1(6)...')

Passed test case 1(6)...


In [8]:
# test case 2: n>2 & connected & no change
transmissions = [[1, 0], [3, 1], [2, 3]]
num_servers = 4
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 2: expected {answer}, got {result}"
print('Passed test case 2...')

Passed test case 2...


In [9]:
# Test case 3: n>2 & connected & no change
transmissions = [[1, 0], [2, 0], [3, 0]]
num_servers = 4
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 3: expected {answer}, got {result}"
print('Passed test case 3...')

Passed test case 3...


In [10]:
# Test case 4: n>2 & not connected
transmissions = [[1, 2], [2, 3], [3, 4]]
num_servers = 5
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 4: expected {answer}, got {result}"
print('Passed test case 4...')

Passed test case 4...


In [11]:
# Test case 5: n=2 & connected & changes
transmissions = [[0, 1]]
num_servers = 2
answer = 1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 5: expected {answer}, got {result}"
print('Passed test case 5...')

Passed test case 5...


In [12]:
# Test case 6: n=2 & connected & no change
transmissions = [[1, 0]]
num_servers = 2
answer = 0
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 6: expected {answer}, got {result}"
print('Passed test case 6...')

Passed test case 6...


In [13]:
# Test case 7: n=2 & not connected
transmissions = []
num_servers = 2
answer = -1
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 7: expected {answer}, got {result}"
print('Passed test case 7...')

Passed test case 7...


In [14]:
# Test case 8: n=1
transmissions = []
num_servers = 1
answer = -2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 8: expected {answer}, got {result}"
print('Passed test case 8...')

The number of servers should be >= 2
Passed test case 8...


In [15]:
# Test case 9: transmission>=n
transmissions = [[0,1],[1,0]]
num_servers = 2
answer = -2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 9: expected {answer}, got {result}"
print('Passed test case 9...')

The number of directed transmissions should be < n
Passed test case 9...


In [16]:
# Test case 10: n=4 & connected & no change
# 0 -> 3 <- 2 -> 1
transmissions = [[0, 3], [2, 3], [2, 1]]
num_servers = 4
answer = 2
result = solve(transmissions, num_servers)
assert result == answer, f"Test case 17: expected {answer}, got {result}"
print('Passed test case 17...')

AssertionError: Test case 17: expected 2, got -1