In [None]:
#"Hard" algorithm
#https://www.cs.princeton.edu/~chazelle/pubs/mstapprox.pdf 
#"Easy" algorithm
#http://www.wisdom.weizmann.ac.il/~oded/PTW/sublin.pdf

import networkx as nx
import random

n = 1000
max_w = 50
eps = 0.01

G = nx.generators.random_graphs.random_regular_graph(5, n)
#G = nx.generators.random_graphs.binomial_graph(n, 2)

for (u, v) in G.edges():
    G.edges[u,v]['weight'] = random.randint(1, max_w)

mst_gt = nx.algorithms.tree.mst.minimum_spanning_tree(G)

mst_gt_w = 0
#print(mst_gt.edges)
print(list(mst_gt.edges)[:10])
for (u,v) in mst_gt.edges:
    #print((u, v))
    mst_gt_w += G.edges[u,v]['weight']

print(mst_gt_w)

In [None]:
#%%writefile mst.py

import random
import sys

DEBUG = 1

RANDOM_SAMPLE_FACTOR = 100
SOME_BIG_CONSTANT    = 5000

#Vertex to [(neighbor, edge weight), ... ]
memos = {}

def query_node_local(n):
    nbors = G.neighbors(n)
    return [(n2, G.edges[n, n2]["weight"]) for n2 in nbors]

def query_node(n):

    if DEBUG:
        return query_node_local(n)

    if n in memos:
        pass
        return memos[n]

    else:
        print(n)
        sys.stdout.flush()
        inp = sys.stdin.readline().split()[1:]
        memos[n] = []
        if len(inp) < 2:
            return memos[n]

        for i in range(len(inp), 2):
            memos[n].append((inp[i], inp[i + 1]))
        
        return memos[n]

def approx_avg_degree(eps, C, n):
    import random
    n_ = int(C/eps)
    max_deg = 0

    for i in range(n_):
        node = random.randint(0, n - 1)
        out = query_node(node)
        try:
            deg = len(out)
            max_deg = max(max_deg, deg)
        except:
            pass

    return max_deg


def approx_cc(n, gi, eps, max_w, d_bar):
    import random
    import sys

    #if DEBUG:
    #    print("gi", gi, "eps", eps, "max_w", max_w, "d bar", d_bar)

    r = int(1/eps) * RANDOM_SAMPLE_FACTOR
    #Does not check dupes
    vs = [random.randint(0, n - 1) for _ in range(r)]
    for v in vs:
        for _ in range(vs.count(v) - 1):
            vs.remove(v)
    
    betas = [0 for _ in range(int(r))]

    for i, v in enumerate(vs):
        nb_w_tups = query_node(v)
        nb_w_tups = [t for t in nb_w_tups if t[1] <= gi]
        new_n = [t for t in nb_w_tups]
        n_visit = 0
        max_degree_seen = 0
        nodes_visited = set()
        edges_visited = set()

        if len(nb_w_tups) == 0:
            singleton = True
        else:
            singleton = False
        #singleton = True if len(nb_w_tups) == 0 else False
        heads = 0
        dui = len(nb_w_tups)
        walk_limit = max_w
        
        e_visit = 0

        for j, nb_w in enumerate(nb_w_tups):
            "BFS"
            #bfs
            nodes_visited.add(nb_w[0])
            for e in [(v, tup[0]) for tup in nb_w_tups]:
                edges_visited.add(e)

            toss = random.randint(0, 1)
            heads += 1
            max_degree_seen = max(max_degree_seen, len(memos[nb_w[0]]))
            #print("toss", toss, "n_visit", n_visit, "walk limit", walk_limit, "max deg seen", max_degree_seen, "d bar", d_bar)
            if toss and len(edges_visited) <= walk_limit and max_degree_seen <= d_bar:
                for tup in nb_w_tups[1]:
                    new_n = query_node(tup[0])
                    nb_w_tups += [nn for nn in new_n if nn[1] <= gi and nn[0] not in nodes_visited and nn not in nb_w_tups]
                    #walk_limit = walk_limit*2
                #max_degree_seen = max(max_degree_seen, dui)
                n_visit+=1
                walk_limit *= 2
            else:
                break

        if singleton:
            betas[i] = 2
            continue
        elif len(nb_w_tups) == 0:
            #Finished bfs
            betas[i] = dui * 2**(heads)
   
    return n / (2 * r) * sum(betas)

pass

def approx_cc_simple(n, gi, eps, max_w, d_bar):
    import random
    import sys

    r = min(max(1, int(1/eps) * min(RANDOM_SAMPLE_FACTOR, 1 + int(n/RANDOM_SAMPLE_FACTOR))), n)
    #Does not check dupes
    vs = [random.randint(0, n - 1) for _ in range(r)]

    for v in vs:
        for i in range(vs.count(v) - 1):
            vs.remove(v)
    
    betas = [0 for _ in range(int(r))]

    X = int(1/random.random())
    #print("X", X)

    #BFS
    queue = []
    
    for i, u in enumerate(vs):
    
        betas[i] = 1
        visited_nodes = set()
        visited_nodes.add(u)

        u_nbors = [v for v in query_node(u) if v[0] != u and v[1] <= gi]

        for j, v in enumerate(u_nbors):
            visited_nodes.add(v[0])
            v_nbors = query_node(v[0])
            u_nbors += [t for t in v_nbors if t[0] not in visited_nodes and t[1] <= gi]

            if j > X or len(visited_nodes) > X:
                betas[i] = 0
                #print(u_nbors[j:])
                break
            if u_nbors[j + 1:] == []:
                betas[i] = 1
                break


    return n / r * sum(betas)


def approx_mst(n, eps, max_w):
    d_bar = approx_avg_degree(eps, SOME_BIG_CONSTANT, n)
    c_bars = []
    for i in range(1, max_w):
        cc_bar = approx_cc_simple(n, i , eps, 4*max_w/eps, d_bar)   #G^i : edges with weight at most i
        c_bars.append(cc_bar)


    return n - max_w + sum(c_bars)


#Doesn't work outside ipython
if not DEBUG:
    n = int(input())
    eps = float(input())
    max_w = int(input())


mst = approx_mst(n, eps, max_w)
#assert(mst) > n/2
print("end " +str(mst))

#main()
