In [None]:
#"Hard" algorithm
#https://www.cs.princeton.edu/~chazelle/pubs/mstapprox.pdf 
#"Easy" algorithm
#http://www.wisdom.weizmann.ac.il/~oded/PTW/sublin.pdfs

In [None]:
import networkx as nx
import random

n     = 10
max_w = 50
c     = 15
eps   = 0.05

#G = nx.generators.random_graphs.random_regular_graph(5, n)
#generator = nx.generators.cycle_graph
generator = nx.generators.complete_graph

#n = n*c
#s.random_internet_as_graph(n)
G = generator(n)

for _ in range(c - 1):
    G_ = generator(n)
    G = nx.disjoint_union(G, G_)

n = len(G)

for (u, v) in G.edges():
    G.edges[u,v]['weight'] = random.randint(1, max_w)

mst_gt = nx.algorithms.tree.mst.minimum_spanning_tree(G)

mst_gt_w = 0
#print(mst_gt.edges)
print(list(mst_gt.edges)[:10])
for (u,v) in mst_gt.edges:
    #print((u, v))
    mst_gt_w += G.edges[u,v]['weight']

print(mst_gt_w)


In [None]:

import random
import sys
import time
from queue import Queue
from math import log10, log2

DEBUG = 0

RANDOM_SAMPLE_FACTOR = 6
SOME_BIG_CONSTANT    = 1500

#Vertex to [(neighbor, edge weight), ... ]
memos = {}
sg_memos = {}


def query_node_local(n):
    nbors = G.neighbors(n)
    return [(n2, G.edges[n, n2]["weight"]) for n2 in nbors]


def query_node_in_subgraph_n(n, w):
    if (n, w) not in memos:
        u_nbors = [v for v in query_node(n) if v[0] != n and v[1] <= w]
        u_nbors = list(set(u_nbors))
        memos[(n, w)] = u_nbors
        return memos[(n, w)]
    else:
        return memos[(n, w)]


def query_node(n):

    if DEBUG:
        return query_node_local(n)

    if n in memos:
        return memos[n]

    else:
        print(n)
        sys.stdout.flush()
        inp = input().strip("\n").split()[1:]
        memos[n] = []
        if len(inp) < 2:
            assert(len(inp) == 0)
            return memos[n]

        for i in range(0, len(inp), 2):
            #print("i", i, "len inp", len(inp), "inp", inp)
            memos[n].append((int(inp[i]), int(inp[i + 1])))
        
        return memos[n]

def approx_avg_degree(eps, C, n):
    import random
    n_ = 50 + int(C/eps)
    n = min(n_, 1+ n//1000)
    max_deg = 0

    for i in range(n_):
        node = random.randint(0, n - 1)
        out = query_node(node)

        deg = len(out)
        max_deg = max(max_deg, deg)

    return max_deg



def approx_cc_simple(n, gi, eps, max_w, d_bar, time_limit=None, start_time=None, read_new_verts=True, sample_size=0.5):

    #If eps > 1, still sample some constant or fraction
    r = 5*min(1000, log10(n)) + int(RANDOM_SAMPLE_FACTOR/eps**2) #RANDOM_SAMPLE_FACTOR)

    betas = 0
    
    i = 0
    betas = [0 for _ in range(max_w)]

    checkpoint_score_1 = 0
    checkpoint_score_2 = 1

    sampled = False

    while i < r:
        if (time.time() - start_time) > time_limit:
            break

        if i > 0 and i % r*sample_size == 0 and sampled == False:
            sampled = True
            r *= max(1, log2(betas[-1]))

        """
        if i > 0 and i % int(r * window) == 0 and int(r * window) > 0 and i > 500:
            if i == r*window:
                #First time
                #checkpoint_score_1 = [n / i * betas[j] for j in range(max_w)]
                pass
            else:
                checkpoint_score_2 = sum([n / i * betas[j] for j in range(max_w)])
                #if max(checkpoint_score_1 / checkpoint_score_2, checkpoint_score_2/checkpoint_score_1) < min(eps**2, 0.5**2)/100:
                if checkpoint_score_1 == checkpoint_score_2:
                    pass
                elif max(checkpoint_score_1 / checkpoint_score_2, checkpoint_score_2 / checkpoint_score_1) - 1 < min(eps**2, 0.5**2)/10:
                    break
                else:
                    pass
                checkpoint_score_1 = checkpoint_score_2
        """

        i += 1

        if read_new_verts:
            u = random.randint(0, n - 1)
        else:
            u = random.choice(list(memos.keys()))


        X = int(1/random.random())

        #betas[i] = 0
        for w in range(max_w, 0, -1):
            visited_nodes = set()
            visited_nodes.add(u)

            u_nbors_queue = Queue()
            u_nbors = [v for v in query_node(u) if v[0] != u and v[1] <= w]
            u_nbors = list(set(u_nbors))
            for nb in u_nbors:
                u_nbors_queue.put(nb)
            
            if DEBUG:
                for v in u_nbors:
                    assert(u_nbors.count(v) == 1)

            if u_nbors == []:
                #betas[i] = 1
                betas[w - 1] += 1
                continue

            j = -1
            heads = 0
            
            #Constant lookup for processed nodes
            added_to_queue = set([u])
            
            while not u_nbors_queue.empty():
                if (time.time() - start_time) > time_limit:
                    break

                #v = u_nbors.pop(0)
                v = u_nbors_queue.get()
                j += 1
                visited_nodes.add(v[0])
                v_nbors = query_node(v[0])
                v_nbors = sorted(v_nbors, key=lambda x: x[0] in memos, reverse=True)
                added = False
                for t in v_nbors:
                    if t[0] not in visited_nodes and t[0] not in added_to_queue and t[1] <= w:
                        u_nbors_queue.put(t)
                        added_to_queue.add(t[0])


                if j > X or len(visited_nodes) > X or len(added_to_queue) > X:
                    break            

                if u_nbors_queue.empty():
                    betas[w - 1] += 2**heads

                    toss = random.randint(0, 1)
                    if toss == 0:
                        break
                    else:
                        X *= 2
                        heads += 1

    if i <= 1:
        return [0 for _ in range(max_w)]

    return [n / i * betas[j] for j in range(max_w)]


def approx_mst(n, eps, max_w):

    d_bar = 0
    c_bars = []

    eps = eps / (1 + 0.5 * max_w)
    if max_w > 1:
        c_bars = approx_cc_simple(n, max_w, eps/2.5, max_w, d_bar, time_limit=2.9, start_time=time.time(), read_new_verts=True)
    else:
        c_bars = approx_cc_simple(n, max_w, eps/300, max_w, d_bar, time_limit=2.9, start_time=time.time(), read_new_verts=True)

    c_bars = [min(n, max(1, round(bar))) for bar in c_bars]
    c_bars[-1] = max(c_bars[-1], 1)

    if DEBUG == 1:
        print("est components", c_bars[-1])
        print("actual components", c)
        print("raw score", n  + sum(c_bars[1:]))

    return n + sum(c_bars[:-1]) - c_bars[-1] * max_w 


if not DEBUG:
    n = int(input())
    eps = float(input())
    max_w = int(input())


if n < 2:
    print("end " + str(0))

if max_w < 1:
    print("end " + str(0))


#Group 6 has max_w > 50
#if max_w > 50:
#    1/0


mst = approx_mst(n, eps, max_w)
mst = max(n/2, mst)
mst = min((n - 1) * max_w, mst)

print("end " + str(mst))
sys.stdout.flush()


if DEBUG == 1:
    print("gt end", mst_gt_w)

In [None]:

import random
import sys
import time
from queue import Queue
from math import log10, log2

DEBUG = 0

RANDOM_SAMPLE_FACTOR = 125
SOME_BIG_CONSTANT    = 10

#Vertex to [(neighbor, edge weight), ... ]
memos = {}
sg_memos = {}


def query_node_local(n):
    nbors = G.neighbors(n)
    return [(n2, G.edges[n, n2]["weight"]) for n2 in nbors]


def query_node_in_subgraph_n(n, w):
    if (n, w) not in memos:
        u_nbors = [v for v in query_node(n) if v[0] != n and v[1] <= w]
        u_nbors = list(set(u_nbors))
        memos[(n, w)] = u_nbors
        return memos[(n, w)]
    else:
        return memos[(n, w)]


def query_node(n):

    if DEBUG:
        return query_node_local(n)

    if n in memos:
        return memos[n]

    else:
        print(n)
        sys.stdout.flush()
        inp = input().strip("\n").split()[1:]
        memos[n] = []
        if len(inp) < 2:
            assert(len(inp) == 0)
            return memos[n]

        for i in range(0, len(inp), 2):
            memos[n].append((int(inp[i]), int(inp[i + 1])))
        
        return memos[n]

def approx_avg_degree(eps, C, n):
    import random
    n_ = 50 + int(C/eps)
    #n = min(n_, 1+ n//1000)
    max_deg = 0

    for i in range(n_):
        node = random.randint(0, n - 1)
        out = query_node(node)

        deg = len(out)
        max_deg = max(max_deg, deg)

    return max_deg



def approx_cc_simple(n, gi, eps, max_w, d_bar, time_limit=None, start_time=None, read_new_verts=True, sample_size=0.5, retry=False):

    #If eps > 1, still sample some constant or fraction
    r = 5*min(1000, log10(n)) + int(RANDOM_SAMPLE_FACTOR/eps**2) #RANDOM_SAMPLE_FACTOR)

    betas = 0
    
    max_e = 4*max_w / eps
    
    i = 0
    betas = [0 for _ in range(max_w)]

    checkpoint_score_1 = 0
    checkpoint_score_2 = 1

    sampled = False

    while i < r:
        if (time.time() - start_time) > time_limit:
            break
            
        if i == r - 1 and retry:
            assert(sampled)
            full_tally = [n / i * betas[j] / 2 for j in range(max_w)]
            full_score = n + sum(full_tally[:-1]) - full_tally[-1] * max_w
    
            if abs(full_score - first_half_score) > 0.1*n*eps:
                r *= 2
                first_half_score = full_score
            
        if i > 0 and i % int((r * sample_size)) == 0 and not sampled:
            sampled = True
            first_half_tally = [n / i * betas[j] / 2 for j in range(max_w)]
            first_half_score = n + sum(first_half_tally[:-1]) - first_half_tally[-1] * max_w 
        """
        if i > 0 and i % r*sample_size == 0 and sampled == False:
            sampled = True
            r *= max(1, log2(betas[-1]))
        """

        """
        if i > 0 and i % int(r * window) == 0 and int(r * window) > 0 and i > 500:
            if i == r*window:
                #First time
                #checkpoint_score_1 = [n / i * betas[j] for j in range(max_w)]
                pass
            else:
                checkpoint_score_2 = sum([n / i * betas[j] for j in range(max_w)])
                #if max(checkpoint_score_1 / checkpoint_score_2, checkpoint_score_2/checkpoint_score_1) < min(eps**2, 0.5**2)/100:
                if checkpoint_score_1 == checkpoint_score_2:
                    pass
                elif max(checkpoint_score_1 / checkpoint_score_2, checkpoint_score_2 / checkpoint_score_1) - 1 < min(eps**2, 0.5**2)/10:
                    break
                else:
                    pass
                checkpoint_score_1 = checkpoint_score_2
        """

        i += 1

        if read_new_verts:
            u = random.randint(0, n - 1)
        else:
            u = random.choice(list(memos.keys()))


        #X = int(1/random.random())


        #betas[i] = 0
        for w in range(max_w, 0, -1):
            visited_nodes = set()
            visited_nodes.add(u)

            u_nbors_queue = Queue()
            u_nbors = [v for v in query_node(u) if v[0] != u and v[1] <= w]
            u_nbors = list(set(u_nbors))
            for nb in u_nbors:
                u_nbors_queue.put(nb)
            
            if DEBUG:
                for v in u_nbors:
                    assert(u_nbors.count(v) == 1)

            if u_nbors == []:
                #betas[i] = 1
                betas[w - 1] += 2
                continue
            #X = max_e  #X = len(u_nbors)?
            j = -1
            heads = 0
            du = len(u_nbors)
            X = du
            #Constant lookup for processed nodes
            added_to_queue = set([u])
            seen_edges = set([(min(u, e[0]), max(u, e[0])) for e in u_nbors])
            
            while not u_nbors_queue.empty():
                if (time.time() - start_time) > time_limit:
                    break

                v = u_nbors_queue.get()
                j += 1

                visited_nodes.add(v[0])
                v_nbors = query_node(v[0])
                v_nbors = list(set(v_nbors))
                v_nbors = sorted(v_nbors, key=lambda x: x[0] in memos, reverse=True)
                v_nbors = [v_ for v_ in v_nbors if v_[1] <= w and v_[0] != v]

                #for tup in v_nbors:
                #seen_edges.add((min(v[0], tup[0]), max(v[0], tup[0])))

                if len(v_nbors) > d_bar:
                    break
                
                if len(seen_edges) > max_e:
                    break
                for t in v_nbors:

                    if t[1] <= w:
                        seen_edges.add((min(v[0], t[0]), max(v[0], t[0])))

                    if t[0] not in visited_nodes and t[0] not in added_to_queue and t[1] <= w:
                        u_nbors_queue.put(t)
                        added_to_queue.add(t[0])

                if u_nbors_queue.empty():
                    assert(du <= len(seen_edges))
                    betas[w - 1] += 2**heads*du/len(seen_edges)
                    #assert(len(seen_edges) > du)
                    if DEBUG:
                        print("heads", heads, "len seen", len(seen_edges), "w", w)
                        print("seen edges", seen_edges)
                        print("est avg degree", d_bar)
                        print("du / seen dg", du, "/", len(seen_edges))
                    break

                if len(seen_edges) > X:
                    toss = random.randint(0, 1)
                    if toss == 0:
                        break
                    else:
                        X *= 2
                        heads += 1
                    #continue    

    if i <= 1:
        return [0 for _ in range(max_w)]
 
    if DEBUG:
        print([n / i * betas[j] for j in range(max_w)])


    return [n / i * betas[j] / 2 for j in range(max_w)]


#n + sum(c_bars[:-1]) - c_bars[-1] * max_w 


def approx_mst(n, eps, max_w):

    d_bar = 0
    d_bar = approx_avg_degree(eps, SOME_BIG_CONSTANT, n)
    c_bars = []

    if max_w > 1:
        c_bars = approx_cc_simple(n, max_w, eps, max_w, d_bar, time_limit=2.9, start_time=time.time(), read_new_verts=True, retry=True)
    else:
        c_bars = approx_cc_simple(n, max_w, eps, max_w, d_bar, time_limit=2.9, start_time=time.time(), read_new_verts=True, retry=True)

    c_bars = [min(n, max(1, round(bar))) for bar in c_bars]
    c_bars[-1] = max(c_bars[-1], 1)

    if DEBUG == 1:
        print("est components", c_bars[-1])
        print("actual components", c)
        print("raw score", n  + sum(c_bars[1:]))

    return n + sum(c_bars[:-1]) - c_bars[-1] * max_w 


if not DEBUG:
    n = int(input())
    eps = float(input())
    max_w = int(input())


if n < 2:
    print("end " + str(0))

if max_w < 1:
    print("end " + str(0))


#Group 6 has max_w > 50
#if max_w > 50:
#    1/0


mst = approx_mst(n, eps, max_w)
mst = max(n/2, mst)
mst = min((n - 1) * max_w, mst)

print("end " + str(mst))
sys.stdout.flush()

if DEBUG == 1:
    print("gt end", mst_gt_w)

