In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [None]:
import random
import math

distance = lambda a, b: math.sqrt((a[0]-b[0])**2 + (a[1]-b[1])**2 + (a[2]-b[2])**2)

def meanPos(set, arr):
    ps = [arr[i] for i in set]
    r = sum([p[0] for p in ps]) // len(set)
    g = sum([p[1] for p in ps]) // len(set)
    b = sum([p[2] for p in ps]) // len(set)
    return (r, g, b)

makePair = lambda a: (a[0], a[1], a[2])
randSeed = lambda set, num: random.sample(set, num)

class KMeans:
    def __init__(
            self,
            k=4,
            seedFunc=randSeed,
            distanceFunc=distance,
            centralPosition=meanPos,
            maxPass=100,
            changeThreshold=1,
            makeHashable=makePair
            ):
        self.k = k
        self.distanceFunc = distanceFunc
        self.centralPosition = centralPosition
        self.maxPass = maxPass
        self.changeThreshold = changeThreshold
        self.seedFunc = seedFunc
        self.makeHashable = makeHashable
    
    def cluster(self, arr):
        passNum = 1
        print("pass: ", 1)

        # step 1: randomly choose k seed centroids
        rand = self.seedFunc(range(0, len(arr)), self.k)
        centroids = [self.makeHashable(arr[i]) for i in rand]
        clusters = {}
        for c in centroids:
            clusters[c] = set()

        # step 2: assign every points into the set of the closest centroids.
        for i in range(len(arr)):
            c = min(centroids, key=lambda c: self.distanceFunc(arr[i], c))
            clusters[c].add(i)
        
        changed = True
        while changed and passNum < self.maxPass:
            passNum += 1
            print("pass: ", passNum)
            changeNum = 0
            changed = False
            # step 3: calculate new centroids
            newCentroids = [self.centralPosition(clusters[c], arr) for c in centroids]

            # get old centroid-point pairs
            oldPairs = []
            for i in range(len(centroids)):
                oldSet = clusters[centroids[i]]
                oldPairs += [(index, newCentroids[i]) for index in oldSet]

            # step 2: assign every points into the set of the closest centroids.
            # new clusters
            newClusters = {}
            for newC in newCentroids:
                newClusters[newC] = set()
            
            centroids = newCentroids
            clusters = newClusters

            # assign each index to new labels
            for i, oldCentroid in oldPairs:
                c = min(centroids, key=lambda c: self.distanceFunc(arr[i], c))
                clusters[c].add(i)

                if c != oldCentroid:
                    changeNum += 1

            print("{} point(s) got label changed.".format(changeNum))
            
            if changeNum >= self.changeThreshold:
                changed = True
        
        if not changed:
            print("Iteration stops because change num did not exceed the threshold({}) in the last pass.".format(self.changeThreshold))
        if passNum >= self.maxPass:
            print("Iteration stops because pass num exceeds max pass limit({}).".format(self.maxPass))

        result = []

        # part of step 3: calculate the final centroids
        for centroid in centroids:
            finalCentroid = self.centralPosition(clusters[centroid], arr)
            result += [(index, finalCentroid) for index in clusters[centroid]]

        # get flatten results in the original pixel order
        result.sort(key=lambda p: p[0])

        return centroids, result


kmeans = KMeans(k=10, maxPass=10000)

In [None]:
img = np.asarray(Image.open('./输入.jpg'))
imgplot = plt.imshow(img)

In [None]:
print(img.shape)
arr = img.reshape(-1, img.shape[-1])

In [None]:
res = kmeans.cluster(arr)

In [None]:
pp = [e[1] for e in res[1]]
ppp = []
# retrieve the original dimension(2D rgb picture)
for i in range(img.shape[0]):
    ppp += [ pp[i*img.shape[1] : (i+1)*img.shape[1]] ]

In [None]:
a = np.array(ppp)
print(a.shape)
imgplot = plt.imshow(a)