In [91]:
import random
import math

def cop_kmeans (dataset, k, ml=[], cl=[]):
    
    ml, cl = transitive_closure(ml, cl, len(dataset))
    
    centers = initialize_centers(dataset, k)
    clusters_ = []
    clusters = None

    while not clusters != clusters_:
        clusters = clusters_
        clusters_ = [-1] * len(dataset)
        for i, d in enumerate(dataset):
            found_cluster = False
            indices = closest_clusters(centers, d)
            counter = 0
            while (not found_cluster) and counter < len(indices):
                index = indices[counter]
                if not violate_constraints(i, index, clusters_, ml, cl):
                    found_cluster = True
                    clusters_[i] = index
            if not found_cluster:
                return None
            
        centers = compute_centers(clusters_, dataset, k)

        
    return clusters, centers

def euclidean_distance(point1, point2):
    return math.sqrt(sum([(i-j)**2 for (i,j) in zip(point1, point2)]))

def closest_clusters(centers, datapoint):
    distances = [euclidean_distance(center, datapoint) for 
                 center in centers]
    return sorted(range(len(distances)), key=lambda x: distances[x])

# under-specified in the paper
def initialize_centers(dataset, k):
    ids = range(len(dataset))
    random.shuffle(ids)
    return [dataset[id] for id in ids[:k]]

def violate_constraints(data_index, cluster_index, clusters, ml, cl):
    for i in ml[data_index]:
        if cluster[i] != -1 and cluster[i] == cluster_index:
            return True
    
    for i in cl[data_index]:
        if cluster[i] == cluster_index:
            return True

    return False

def compute_centers(clusters, dataset, k=None):
    if k == None:
        k = max(clusters) + 1
    dim = len(dataset[0])
    centers = [[]] * k
    for i in range(k):
        center = [0.0] * dim
        count = 1
        for j, c in enumerate(clusters):
            if c == i:
                center += dataset[j]
                count += 1
        centers[i] = [center[d]/float(count) for d in range(dim)]
    return centers

def transitive_closure(ml, cl, n):
    def add_to_dict(d, i, j):
        if not i in d:
            d[i] = set()
        d[i].add(j)
    
    def add_both(d, i, j):
        add_to_dict(d, i, j)
        add_to_dict(d, j, i)
        
    ml_graph = dict()
    for (i,j) in ml:
        add_both(ml_graph, i, j)
    
    def dfs(i, graph, visited, component):
        visited[i] = True
        for j in graph[i]:
            if not visited[j]:
                dfs(j, graph, visited, component)
        component.append(i)
        
    visited = [False] * n
    for i in range(n):
        if not visited[i]:
            component = []
            dfs(i, ml_graph, visited, component)
            for x1, x2 in zip(component, component):
                if x1 != x2:
                    add_to_dict(ml_graph, x1, x2)

    cl_graph = dict()               
    for (i,j) in cl:
        add_both(cl_graph, i, j)
        for x, y in zip(ml_graph[i], ml_graph[j]):
            add_both(cl_graph, x, y)
    
    return ml_graph, cl_graph