In [5]:
g0 = {
    0: {1, 2},
    1: {0, 3},
    2: {0, 3},
    3: {1, 2, 4, 5},
    4: {3, 5},
    5: {3, 4}
}

In [6]:
from copy import deepcopy
from itertools import combinations
def add_edge(g, x, y):
    g[x].add(y)
    g[y].add(x)
    
def triangulate_graph(g):
    g1 = deepcopy(g)
    g2 = deepcopy(g)
    clusters = []
    while g2:
        #TODO improve performance
        num_added = {i: len(list(filter(lambda x: not x[1] in g2[x[0]], 
                                        combinations(j, 2)))) 
                 for i,j in g2.items()}
        v = min(num_added, key=num_added.get)
        cluster = {v} | g2[v]
        if not any(map(lambda x : cluster < x, clusters)): clusters.append(cluster)
        for x, y in combinations(g2[v], 2):
            add_edge(g1, x, y)
            add_edge(g2, x, y)
        for u in g2[v]:
            g2[u].remove(v)
        del g2[v]
    return g1, clusters
            
    

In [7]:
def create_jointree(clusters):
    n_clusters = len(clusters)
    sepsets = {(i,j): len(clusters[i] & clusters[j]) for (i,j) in combinations(range(n_clusters), 2)}
    trees = [i for i in range(n_clusters)]
    jointree = []
    while len(jointree) < n_clusters - 1:
        x, y = max(sepsets, key=sepsets.get)
        if trees[x] == trees[y]: continue
        trees = [z if z!=y else x for z in trees]
        jointree.append((x, y))
        del sepsets[(x, y)]
    return jointree

In [8]:
# TODO create a rooted jointree right from the beginning
def make_rooted(jointree, root=0):
    n = max([max(i, j) for (i, j) in jointree])
    used = [False] * len(jointree)
    tree = dict()
    nodes = [0]
    while nodes:
        k = nodes.pop(0)
        tree[k] = []
        for z, (i, j) in enumerate(jointree):
            if used[z]: continue
            if i != k and j != k: continue
            if i == k: tree[k].append(j)
            elif j == k: tree[k].append(i)
            used[z] = True
        nodes += tree[k]
    return tree

In [9]:
g1, clusters = triangulate_graph(g0)
jointree = make_rooted(create_jointree(clusters))
weights = [1] * len(g0)
print(clusters)
print(jointree)

[{3, 4, 5}, {0, 1, 2}, {1, 2, 3}]
{0: [2], 1: [], 2: [1]}


In [26]:
def check_independence(sub, graph):
    vertices = list(sub)
    for v in vertices:
        if graph[v] & sub - {v}: return False
    return True
    
def mwis(graph, clusters, jointree, weights, root):
    child_mwis = dict()
    for child in jointree[root]:
        child_mwis[child] = mwis(graph, clusters, jointree, weights, child)
    root_wmis = dict()
    vertices = clusters[root]
    for n in range(1, len(vertices)+1):
        for sub_ in combinations(vertices, n):
            sub = set(sub_)
            if not check_independence(sub, graph): continue
            w = sum([weights[i] for i in sub_])
            w_ = 0
            for child in jointree[root]:
                intersection = clusters[child] & sub
                for (child_set, child_weight) in child_mwis[child].items():
                    child_set = set(child_set)
                    if (child_set & vertices) == intersection:
                        w_common = sum([weights[i] for i in sub & child_set])
                        w_total = child_weight - w_common
                        if w_total > w_:
                            w_ = w_total
            root_wmis[sub_] = w_
    return root_wmis
                    
    

In [27]:
mwis(g0, clusters, jointree, weights, 0)

{(3,): 0, (4,): 0, (5,): 0}