In [None]:
#from statistics import mean, stdev
from typing import Iterable,Tuple,Sequence, Dict,List
from collections import defaultdict
from random import sample
from pprint import pprint
from math import fsum,hypot,sqrt
from functools import partial

Point = Tuple[int,...]
Centroid = Point

def transpose(data):
    'Swapping rows an dcolumns in 2-d array'
    return list(zip(*data))

def mean(data: Iterable[float]) -> float:
    'Accurate arithmentic mean'
    data = list(data)
    return fsum(data) / len(data)

def dist(p: Point,q: Point,fsum=fsum,sqrt=sqrt,zip=zip) -> float:
    'Euclidean distance fn for multi-dim data'
    return sqrt(fsum([(x - y)**2 for x,y in zip(p,q)]))

def assign_data(centroids: Sequence[Centroid],data: Iterable[Point]) -> Dict[Centroid,List[Point]]:
    d = defaultdict(list)
    for point in data:
        closest_centroid = min(centroids,key=partial(dist,point))  #right
        #closest_centroid = min(centroids,key=partial(dist,point))
        d[closest_centroid].append(point)
        
    return dict(d)

def compute_centroids(groups: Iterable[Sequence[Point]]) -> List[Centroid]:
    'Compute the centroid of each group'
    #return [tuple(map(mean,zip(*group))) for group in groups]
    return [tuple(map(mean,transpose(group))) for group in groups]

def k_means(data: Iterable[Point],k: int=2,iteration:int=50) ->List[Centroid]:
    data = list(data)
    centroids = sample(data,k)
    for i in range(iteration):
        labeled = assign_data(centroids,data)
        centroids = compute_centroids(labeled.values())
    return centroids


if __name__ == '__main__':

    points=[
            (10,41,23),
            (22,30,29),
            (11,42,5),
            (20,32,4),
            (12,40,12),
            (21,36,23)
            ]


    centroids = k_means(points,k=3)
    d=assign_data(centroids,points)
    pprint(d)
      

In [None]:
points=[(10,41,23), (22,30,29),(11,42,5),(20,32,4),(12,40,12),(21,36,23)]
for point in points:
    print(point,dist(point,(9,39,20)))

In [None]:
from dis import dis
dis(dist)

In [None]:
centroids = [(9,39,20),(12,36,25)]
point=(11,42,5)
print([dist(point,centroid) for centroid in centroids])
print(min([dist(point,centroid) for centroid in centroids]))
#min(centroids,key=lambda centroid: dist(point,centroid))
min(centroids,key=partial(dist,point))

In [None]:
pprint(assign_data(centroids,points),width=45)

In [None]:
groups=[
    [(10, 41, 23), (11, 42, 5), (20, 32, 4),(12, 40, 12)], 
        [(22, 30, 29), (21, 36, 23)]
      ]
group = [(10, 41, 23), (11, 42, 5), (20, 32, 4),(12, 40, 12)]

In [None]:
list(zip(*group))
list(map(mean,zip(*group)))
tuple(map(mean,zip(*group)))
[tuple(map(mean,zip(*group))) for group in groups]

In [None]:
list(map(mean,zip(*group)))

In [None]:
tuple(map(mean,zip(*group)))

In [None]:
[tuple(map(mean,zip(*group))) for group in groups]

In [None]:
k_means(points,k=2)