In [None]:
def kmeans_fit(data, k, max_iter, q, init):
  from pyspark.sql import functions
  from pyspark.sql.window import Window
  import numpy as np
  from pyspark import StorageLevel
  from pyspark.sql import SparkSession

  spark = SparkSession.builder.getOrCreate()
  
  # Points Pre-Processing : combining the inputs of the points into an ArrayType, naming the column 'point'.
  # Creating a new column 'centroid' to assign a centroid to each point.
  # Seting distance to be maximum in first iteration, because we assign a new centroid
  # for every centroid whose disance is lower than current distance.
  # Using persist MEMORY ONLY, so we can use this initialized dataframe at the
  # start of every iteration. 
  original_data = data.withColumn('point', functions.array(data.columns))\
                  .withColumn("distance", functions.lit(np.inf))\
                  .withColumn("centroid", functions.lit(0))\
                  .persist(StorageLevel.MEMORY_ONLY)

  # Initialized centroids
  centroids = spark.createDataFrame(init)
  # Centroids Pre-Processing: assigning a number to each centroid to use as 
  # a label, combining the original columns to ArrayType, and finally selecting
  # only the columns we need.
  centroids = centroids.withColumn('centroid', functions.row_number()\
                        .over(Window.orderBy(functions.monotonically_increasing_id())) - 1)\
                        .withColumn('centroid_val', functions.array(centroids.columns))\
                        .select('centroid', 'centroid_val')

  # Initialize current centroids of every iteration as original centroids
  current_centroids = centroids

  # We can assume that all points are of dimension 'd'.
  d = len(centroids.select('centroid_val').collect()[0]['centroid_val'])
  
  # Initialize new_data of every iteration as original data
  new_data = original_data
  for iter in range(max_iter):
    
    # Original data is saved on RAM, so restarting the data is much faster.
    new_data = original_data
    
    # For every centroid, calculate if the distance between centroid and every point
    # is smaller then distance between the point and it's current centroid
    # Because the for loop is from lower label to higher label, if there is a tie
    # between two centroids for any given point, the points will be attributed 
    # to the centroid with the higher label.
    for i in range(k):
      cent = current_centroids.select('centroid_val').where(functions.col('centroid') == i).collect()[0]['centroid_val']

      new_data = new_data.withColumn("centroid_val", functions.array([functions.lit(x) for x in cent]))\
                    .withColumn("difference",functions.expr(f"zip_with(point, centroid_val, (x, y) -> (x - y)*(x-y))"))\
                    .select('*', functions.sqrt(sum([functions.col('difference').getItem(j) for j in range(d)])).alias('distance_new'))\
                    .withColumn('cent_new', functions.when(functions.col('distance_new') > functions.col('distance'), new_data.centroid).otherwise(i))\
                    .withColumn('distance_new', functions.when(functions.col('distance_new') > functions.col('distance'), new_data.distance).otherwise(functions.col('distance_new')))\
                    .drop('centroid', 'centroid_val', 'difference', 'distance')\
                    .withColumnRenamed('cent_new','centroid').withColumnRenamed('distance_new','distance')

    # Minimizing MSE - the minimizer for MSE is always the mean value of all points in any group assigned
    # to the same label. So we calculate the mean of all points, to get the new centroid for a given label.
    new_centroids = new_data.withColumn("row_number", functions.row_number().over(Window.partitionBy('centroid')\
                    .orderBy("distance"))).filter(functions.col('row_number')%q != 0)\
                    .groupBy('centroid').mean()
    new_centroids = new_centroids.withColumn('centroid_val_new', functions\
                    .array([i for i in new_centroids.columns if i !='centroid' and i!="avg(centroid)"\
                    and i !="avg(row_number)" and i!="avg(distance)"]))\
                    .select('centroid', 'centroid_val_new').persist()

    # Calculating if the centorids have changed by MSE between old and new centroids.
    diff_centroid = new_centroids.join(current_centroids, on='centroid')\
                    .withColumn("difference",functions.expr(f"zip_with(centroid_val, centroid_val_new, (x, y) -> abs(x - y))"))\
                    .select('*', sum([functions.col('difference').getItem(j) for j in range(d)]).alias('total_difference'))\
                    .agg({'total_difference': 'sum'})
    
    # If MSE is 0, return the current centroids
    if diff_centroid.collect()[0]['sum(total_difference)'] == 0:
      return current_centroids.select('centroid_val')
    # If MSE != 0 , continue to next iteration.
    current_centroids = new_centroids.withColumnRenamed('centroid_val_new', 'centroid_val')
  
  return current_centroids.select('centroid_val')