In [46]:
from copy import deepcopy
from itertools import combinations
from gurobipy import Model, GRB
from random import randint
import pickle

In [36]:
g0 = {
    0: {1, 2},
    1: {0, 3},
    2: {0, 3},
    3: {1, 2, 4, 5},
    4: {3, 5},
    5: {3, 4}
}

w0 = {0:2, 1:3, 2:1, 3:4, 4:5, 5:0}

In [2]:

def add_edge(g, x, y):
    g[x].add(y)
    g[y].add(x)
    
def triangulate_graph(g):
    g1 = deepcopy(g)
    g2 = deepcopy(g)
    clusters = []
    while g2:
        #TODO improve performance
        num_added = {i: len(list(filter(lambda x: not x[1] in g2[x[0]], 
                                        combinations(j, 2)))) 
                 for i,j in g2.items()}
        v = min(num_added, key=num_added.get)
        cluster = {v} | g2[v]
        if not any(map(lambda x : cluster < x, clusters)): clusters.append(cluster)
        for x, y in combinations(g2[v], 2):
            add_edge(g1, x, y)
            add_edge(g2, x, y)
        for u in g2[v]:
            g2[u].remove(v)
        del g2[v]
    return g1, clusters
            
    

In [3]:
def create_jointree(clusters):
    n_clusters = len(clusters)
    sepsets = {(i,j): len(clusters[i] & clusters[j]) for (i,j) in combinations(range(n_clusters), 2)}
    trees = [i for i in range(n_clusters)]
    jointree = []
    while len(jointree) < n_clusters - 1:
        x, y = max(sepsets, key=sepsets.get)
        if trees[x] == trees[y]: continue
        trees = [z if z!=y else x for z in trees]
        jointree.append((x, y))
        del sepsets[(x, y)]
    return jointree

In [4]:
# TODO create a rooted jointree right from the beginning
def make_rooted(jointree, root=0):
    n = max([max(i, j) for (i, j) in jointree])
    used = [False] * len(jointree)
    tree = dict()
    nodes = [0]
    while nodes:
        k = nodes.pop(0)
        tree[k] = []
        for z, (i, j) in enumerate(jointree):
            if used[z]: continue
            if i != k and j != k: continue
            if i == k: tree[k].append(j)
            elif j == k: tree[k].append(i)
            used[z] = True
        nodes += tree[k]
    return tree

In [32]:
def check_independence(sub, graph):
    for v in sub:
        if graph[v] & sub: return False
    return True
    
def mwis(graph, clusters, jointree, weights, root):
    child_mwis = dict()
    for child in jointree[root]:
        child_mwis[child] = mwis(graph, clusters, jointree, weights, child)
    root_wmis = dict()
    vertices = clusters[root]
    for n in range(1, len(vertices)+1):
        for sub_ in combinations(vertices, n):
            sub = set(sub_)
            if not check_independence(sub, graph): continue
            w = sum([weights[i] for i in sub_])
            w_ = None
            for child in jointree[root]:
                intersection = clusters[child] & sub
                for (child_set, child_weight) in child_mwis[child].items():
                    if not child_weight: continue
                    child_set = set(child_set)
                    if (child_set & vertices) == intersection:
                        w_common = sum([weights[i] for i in sub & child_set])
                        w_total = child_weight - w_common
                        if not w_ or w_total > w_:
                            w_ = w_total
            root_wmis[sub_] = w + w_ if w_ else w
    return root_wmis
                    
    

In [37]:
def get_mwis(graph, weights):
    _, clusters = triangulate_graph(graph)
    jointree = make_rooted(create_jointree(clusters))
    res = mwis(graph, clusters, jointree, weights, 0)
    return max(res.values())
    

In [65]:
def mip_mis(graph, weights):
    m = Model()
    x = dict()
    for v in graph:
        x[v] = m.addVar(vtype=GRB.BINARY, obj=weights[v])
    m.update()
    for v in graph:
        for u in graph[v]:
            if u < v:
                m.addConstr(x[u] + x[v] <= 1)
    m.ModelSense = -1
    m.params.OutputFlag = 0
    m.optimize()
    return m.ObjVal, [i for i in x if x[i].x]

In [55]:
fname = 'data/g10.pkl'
with open(fname, 'rb') as f:
    g = pickle.load(f)
    w = {x: randint(1, 10) for x in g}
    print(get_mwis(g, w))
    print(mip_mis(g, w))

24
28.0


In [56]:
get_mwis(g, w)

24

In [66]:
mip_mis(g, w)

(28.0, [0, 4, 7])