In [None]:
import numpy as np
from numpy import linalg as LA
from collections import defaultdict

class KMeans(object):
  def __init__(self, K, init):
    self.K = K
    self.init = np.array(init)

  def get_distance(self, point_1, point_2):
    #return ((point_1[0] - point_2[0])**2 + (point_1[1] - point_2[1])**2)**0.5
    vector1 = LA.norm(point_1)
    vector2 = LA.norm(point_2)
    return abs(vector1 - vector2)

  def step(self, X, centroids):
    clusters_dict = {}
    clusters = []
    for i in range(len(X)):
      distances = [self.get_distance(X[i], centroids[k]) for k in centroids.keys()]
      cluster = np.argmin(distances)
      clusters.append(cluster)
      if cluster not in clusters_dict:
        clusters_dict[cluster] = []
      clusters_dict[cluster].append(X[i])
    return clusters, clusters_dict

  def update_centroids(self, X, clusters_dict):
      new_centroids = {}
      for k in clusters_dict.keys():
        new_centroids[k] = [np.mean([p[0] for p in clusters_dict[k]]), np.mean([p[1] for p in clusters_dict[k]])]
      return new_centroids

  def should_stop(self, old_centroids, new_centroids, th = 0.001):
      max_dif = 0
      for key in old_centroids:
        if key in new_centroids:
          dif = self.get_distance(old_centroids[key], new_centroids[key])
          max_dif = max(dif, max_dif)
      return max_dif <= th

  def fit(self, X):
    centroids = {}
    for i in range(len(self.init)):
      centroids[i] = self.init[i]
    while True:
      clusters, clusters_dict = self.step(X, centroids)
      old_centroids = {}
      for key in centroids:
        old_centroids[key] = centroids[key]
      centroids = self.update_centroids(X, clusters_dict)
      if self.should_stop(old_centroids, centroids):
        self.centroids = centroids
        break
    pass

  def predict(self, X):
    centroids = self.centroids
    clusters, clusters_dict = self.step(X, centroids)
    return np.array(clusters)
    pass