In [3]:
from pyspark import SparkContext
import numpy as np
import math
import matplotlib.pyplot as plt


def read_centroids(centroids_path):
    centroids = []
    with open(centroids_path, "r") as f:
        for line in f.readlines():
            line = line.split("\t")
            line = [float(x) for x in line]
            centroids.append(np.array(line))

    return centroids

def l2_distance(p1, p2):
    return math.sqrt(np.sum((p1 - p2) ** 2))

def data_clean(point):
    point = point.split("\t")

    return np.array(point, dtype=float)

def kmeans_map(point):
    min_index = None
    distance = float("+inf")
    for i in range(len(centroids)):
        temp = l2_distance(point, centroids[i])
        if temp < distance:
            min_index = i
            distance = temp

    return (min_index, (point, 1), distance)

def cumulate(pair):
    cumulate_point = np.zeros(20)
    cumulate_count = 0
    for point, count in pair[1]:
        cumulate_point += point
        cumulate_count += count

    return (pair[0], cumulate_point, cumulate_count)

def plotting(output_path, costs):
    fig = plt.figure()

    plt.plot(np.arange(1, len(costs) + 1), costs, "r", linewidth=2)

    plt.xticks(np.arange(1, len(costs) + 1))
    plt.title(r"Error $\Phi$ with iterations.")
    plt.xlabel("Iteration")
    plt.ylabel(r"Error $\Phi$")

    plt.savefig(output_path, dpi=200)

if __name__ == '__main__':   
    data_path = "data.txt"
    centroids_path = "centroid.txt"
#     output_path = "/Users/yimingzhao/Desktop/Postgraduate/EachSemester/2020_SPRING/CS6665/CS6665_Assignments/Assignment_3/result2a.png"
    
    centroids = read_centroids(centroids_path)

    max_iteration = 5
    
#     sc = SparkContext("local", "kmeans app")

    data = sc.textFile(data_path).cache()
    data = data.map(data_clean)

    costs = []
    for i in range(max_iteration):
        kmeans_map_result = data.map(kmeans_map)

        # cost = kmeans_map_result.map(lambda x: x[2] ** 2).reduce(lambda a, b: a + b)
        cost = kmeans_map_result.map(lambda x: x[2]).reduce(lambda a, b: a + b)

        costs.append(cost)

        kmeans_reduce = kmeans_map_result.map(lambda x: (x[0], [x[1]])) \
            .reduceByKey(lambda x, y: x + y) \
                .map(cumulate) \
                    .map(lambda x: (x[0], x[1] / x[2]))
        
        for index, centroid in kmeans_reduce.collect():
            centroids[index] = centroid
        print("k=",i)
        print(centroids)
    

    for centroid in centroids:
        print(centroid)
#     plotting(output_path, costs)
    



k= 0
[array([0.03371429, 0.36295238, 0.49419048, 0.01142857, 0.54533333,
       0.08314286, 0.30828571, 0.16104762, 0.15657143, 0.26209524,
       0.1167619 , 0.46161905, 0.12190476, 0.004     , 0.08866667,
       0.72980952, 0.17714286, 1.44790476, 1.86552381, 0.0632381 ]), array([0.30341176, 0.29064706, 0.46164706, 0.        , 0.22023529,
       0.16688235, 0.30811765, 0.14964706, 0.20158824, 0.72758824,
       0.19611765, 0.979     , 0.24235294, 0.25035294, 0.09035294,
       0.31041176, 0.38517647, 0.20711765, 3.90270588, 0.12717647]), array([0.12076923, 0.20512821, 0.73051282, 0.02435897, 0.64692308,
       0.46435897, 0.19846154, 0.24051282, 0.44487179, 0.38410256,
       0.12435897, 0.80897436, 0.11948718, 0.        , 1.81487179,
       0.10615385, 0.25794872, 0.95512821, 1.82897436, 0.10641026]), array([0.07036232, 0.10724638, 0.23507246, 0.01673913, 0.57492754,
       0.17934783, 0.39108696, 0.37282609, 0.13173913, 0.29434783,
       0.1442029 , 0.28753623, 0.13572464, 0.02760

k= 4
[array([0.074   , 0.382375, 0.546875, 0.00525 , 0.635375, 0.087125,
       0.445625, 0.133125, 0.124875, 0.25975 , 0.125   , 0.5415  ,
       0.168125, 0.      , 0.126   , 0.477125, 0.230875, 1.687125,
       2.290375, 0.064375]), array([2.33098592e-01, 9.56338028e-02, 3.80211268e-01, 2.39436620e-03,
       5.17253521e-01, 1.18521127e-01, 2.97535211e-01, 1.79788732e-01,
       6.17605634e-02, 2.88450704e-01, 1.35774648e-01, 8.99366197e-01,
       1.47112676e-01, 1.23239437e-02, 1.56338028e-02, 2.69929577e-01,
       4.25563380e-01, 1.81338028e-01, 4.78211268e+00, 1.24507042e-01]), array([0.12923077, 0.20615385, 0.69897436, 0.02435897, 0.59641026,
       0.48153846, 0.18871795, 0.24897436, 0.51307692, 0.44      ,
       0.12846154, 0.81897436, 0.13666667, 0.00846154, 1.89641026,
       0.09      , 0.27512821, 0.87025641, 1.99435897, 0.09666667]), array([0.11688372, 0.08576744, 0.37623256, 0.00539535, 0.69772093,
       0.29065116, 0.32223256, 0.11493023, 0.05604651, 0.13483721,
   