In [37]:
import itertools

def allPairSplits(lst):
    '''     
    Generate all sets of unique pairs from a list `lst`.
    This is equivalent to all partitions of `lst` (considered as an indexed set) which have 2 elements
    in each partition.
    
    Recall how we compute the total number of such partitions. Starting with a list [1, 2, 3, 4, 5, 6]
    one takes off the first element, and chooses its pair [from any of the remaining 5]. For example, 
    we might choose our first pair to be (1, 4). Then, we take off the next element, 2, and choose 
    which element it is paired to (say, 3). So, there are 5 * 3 * 1 = 15 such partitions.
    That sounds like a lot of nested loops (i.e. recursion), because 1 could pick 2, in which case our 
    next element is 3. But, if one abstracts "what the next element is", and instead just thinks of what 
    index it is in the remaining list, our choices are static and can be aided by the product function.
    
    From selfgatoatigrado: https://stackoverflow.com/a/13020502
    '''
    N = len(lst)
    choice_indices = itertools.product(*[range(k) for k in range(N-1, 0, -2)])

    for choice in choice_indices:
        # calculate the list corresponding to the choices
        tmp = lst[:]
        result = []
        for index in choice:
            result.append((tmp.pop(0), tmp.pop(index)))
        yield result  # use yield and then turn it into a list is faster than append

In [38]:
aa = allPairSplits([0,1,2,3])
print(list(aa))

[[(0, 1), (2, 3)], [(0, 2), (1, 3)], [(0, 3), (1, 2)]]


In [15]:
import numpy as np

In [44]:
dimensions = [2,2,1,1]
removed_connections = [(0,1),(0,2)]

num_nodes = len(dimensions)
nodes = list(range(num_nodes))
print(nodes)
color_dict = {}
# we store all dimensions/coloring the nodes can have in color_nodes.
color_nodes = []
# We distinguish between nodes but not between repeated nodes.
# [(node1,color1),(node1,color2)] = [(node1,color2),(node1,color1)]
for coloring in itertools.product(*[list(range(dimensions[nn])) for nn in nodes]):
    color_nodes.append(sorted([[nodes[ii], coloring[ii]] for ii in range(len(nodes))]))

all_uncolored_pms = list(allPairSplits(nodes))
print(len(all_uncolored_pms))

for removed_connection in removed_connections:
    for ii, uncolored_pm in reversed(list(enumerate(all_uncolored_pms))):
        if removed_connection in uncolored_pm:
            del all_uncolored_pms[ii]

def color_pm(pm, coloring):
    '''
    Given a pair of nodes, `pm`, and a coloring, `coloring`, color the nodes in the pair.

    input: pm = [(node1, node2), (node3, node4)], coloring = [(node1, color1), (node2, color2), (node3, color3), (node4, color4)]

    output: [(node1, node2, color1, color2), (node3, node4, color3, color4)]
    '''
    return [(edge[0], edge[1], coloring[edge[0]], coloring[edge[1]]) for edge in pm]


state_catalog = {}
for coloring in color_nodes:
    state_catalog[tuple(coloring)] = [color_pm(pm, coloring) for pm in all_uncolored_pms]

print(len(list(all_uncolored_pms)))


[0, 1, 2, 3]
3
[(0, 1), (2, 3)]
[(0, 2), (1, 3)]
1


In [30]:
aa = [1,2,3]
aa.remove(2)
print(aa)

[1, 3]
