In [1]:
"""
Credit: 
    Min-set-cover solver: https://gist.github.com/marekyggdrasil/a8e63be8e34e000f2507bdb5e0755dda
"""

# from util import read_lines, write_lines
from nltk import FreqDist
from collections import defaultdict
from dlx import DLX
import random

In [2]:
def read_lines(filename):
    """
    Load a file line by line into a list
    """
    with open(filename, 'r') as fp:
        lines = fp.readlines()
    print("Done reading file", filename)
    
    return [line.strip() for line in lines]

def write_lines(filename, lines):
    """
    Write a list to a file line by line 
    """
    with open(filename, 'w', encoding="utf-8") as fp:
        for line in lines:
            print(line, file=fp)
    print("Done writing to file %s." % filename)

In [3]:
lines = read_lines("word_freq_list_with_roots.txt")

all_roots = []
entries = []
roots_set = set()
for line in lines:
# for line in random.sample(lines, 150):
    split = line.split('\t')
    if len(split) == 4: # only keep words that actually have roots
        roots = split[-1].split(', ')
        if tuple(roots) in roots_set:
            continue # since we're only interested in root cover
        else:
            roots_set.add(tuple(roots))
        all_roots.extend(roots)
        split[-1] = roots
        entries.append(split)
        
print('Total number of lines left:', len(entries))

dist = dict(FreqDist(all_roots).most_common())
vocab_roots = list(dist.keys())
root2idx = {root: idx for idx, root in enumerate(vocab_roots)}

root2entries = defaultdict(list)
for entry in entries:
    roots = entry[-1]
    for root in roots:
        root2entries[root].append(entry)

# for root in dist:
#     print(root)
#     print("")
#     print(root2entries[root])
#     input("wait")

Done reading file word_freq_list_with_roots.txt
Total number of lines left: 549


In [None]:
def to_str(entry):
    str_entry = '\t'.join(entry[:(-1)])
    if entry[-1] == []:
        return str_entry
    else:
        str_roots = ', '.join(entry[-1])
        return f'{str_entry}\t{str_roots}'

In [None]:
def solve(X, Y, solution=[]):
    if not X:
        yield list(solution)
    else:
        c = min(X, key=lambda c: len(X[c]))
        for r in list(X[c]):
            solution.append(r)
            cols = select(X, Y, r)
            for s in solve(X, Y, solution):
                yield s
            deselect(X, Y, r, cols)
            solution.pop()

def select(X, Y, r):
    cols = []
    for j in Y[r]:
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].remove(i)
        cols.append(X.pop(j))
    return cols

def deselect(X, Y, r, cols):
    for j in reversed(Y[r]):
        X[j] = cols.pop()
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].add(i)

In [None]:
indices = list(range(len(vocab_roots)))
# Empty lists are fine
rows = [
    [root2idx[root] 
     for root in entry[-1]]
    for entry in entries]
Y = {i: row for i, row in enumerate(rows)}
X = {j: set() for j in indices}
for i in Y:
    for j in Y[i]:
        X[j].add(i)

selected = list(solve(X, Y))[0]

In [None]:
print(f'Number of selected entries for min-set-cover: {len(selected)}/{len(entries)}')
selected_entries = [entries[i] for i in selected]
selected_entries.sort(key=lambda x: int(x[0]))
str_selected_entries = list(map(to_str, selected_entries))

write_lines("selected_entries.txt", str_selected_entries)

In [None]:
!cat selected_entries.txt

In [None]:
# def genInstance(labels, rows) :
#     columns = []
#     indices_l = {}
#     for i in range(len(labels)) :
#         label = labels[i]
#         indices_l[label] = i
#         columns.append(tuple([label,0]))
#     return labels, rows, columns, indices_l

# def solveInstance(instance) :
#     labels, rows, columns, indices_l = instance
#     instance = DLX(columns)
#     indices = {}
#     for l, i in zip(rows, range(len(rows))) :
#         h = instance.appendRow(l, 'r'+str(i))
#         indices[str(hash(tuple(sorted(l))))] = i
#     sol = instance.solve()
#     lst = list(sol)
#     selected = []
#     for i in lst[0] :
#         l = instance.getRowList(i)
#         l2 = [indices_l[label] for label in l]
#         idx = indices[str(hash(tuple(sorted(l2))))]
#         selected.append(idx)
#     return selected

# def printColumnsPerRow(instance, selected) :
#     labels, rows, columns, indices_l = instance
#     print('covered columns per selected row')
#     for s in selected :
#         A = []
#         for z in rows[s-1] :
#             c, _ = columns[z]
#             A.append(c)
#         print(s, A)

# def printInstance(instance) :
#     labels, rows, columns, indices_l = instance
#     print('columns')
#     print(labels)
#     print('rows')
#     print(rows)

In [None]:
# labels = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
# rows = [[],[0,3,6],[0,3],[3,4,6],[2,4,5],[1,2,5,6],[1,6]]
# instance = genInstance(labels, rows)
# selected = solveInstance(instance)
# printInstance(instance)
# printColumnsPerRow(instance, selected)