66
77"""
88
9- import numpy as np
109import math
1110import matplotlib .pyplot as plt
1211import random
1312
13+ # k means parameters
14+ MAX_LOOP = 10
15+ DCOST_TH = 0.1
1416show_animation = True
1517
1618
17- class Clusters :
18-
19- def __init__ (self , x , y , nlabel ):
20- self .x = x
21- self .y = y
22- self .ndata = len (self .x )
23- self .nlabel = nlabel
24- self .labels = [random .randint (0 , nlabel - 1 )
25- for _ in range (self .ndata )]
26- self .cx = [0.0 for _ in range (nlabel )]
27- self .cy = [0.0 for _ in range (nlabel )]
28-
29-
3019def kmeans_clustering (rx , ry , nc ):
31-
3220 clusters = Clusters (rx , ry , nc )
33- clusters = calc_centroid (clusters )
21+ clusters . calc_centroid ()
3422
35- MAX_LOOP = 10
36- DCOST_TH = 0.1
37- pcost = 100.0
23+ pre_cost = float ("inf" )
3824 for loop in range (MAX_LOOP ):
39- # print("Loop :", loop)
40- clusters , cost = update_clusters (clusters )
41- clusters = calc_centroid (clusters )
25+ print ("loop :" , loop )
26+ cost = clusters . update_clusters ()
27+ clusters . calc_centroid ()
4228
43- dcost = abs (cost - pcost )
44- if dcost < DCOST_TH :
29+ d_cost = abs (cost - pre_cost )
30+ if d_cost < DCOST_TH :
4531 break
46- pcost = cost
32+ pre_cost = cost
4733
4834 return clusters
4935
5036
51- def calc_centroid (clusters ):
52-
53- for ic in range (clusters .nlabel ):
54- x , y = calc_labeled_points (ic , clusters )
55- ndata = len (x )
56- clusters .cx [ic ] = sum (x ) / ndata
57- clusters .cy [ic ] = sum (y ) / ndata
58-
59- return clusters
60-
61-
62- def update_clusters (clusters ):
63- cost = 0.0
64-
65- for ip in range (clusters .ndata ):
66- px = clusters .x [ip ]
67- py = clusters .y [ip ]
68-
69- dx = [icx - px for icx in clusters .cx ]
70- dy = [icy - py for icy in clusters .cy ]
71-
72- dlist = [math .sqrt (idx ** 2 + idy ** 2 ) for (idx , idy ) in zip (dx , dy )]
73- mind = min (dlist )
74- min_id = dlist .index (mind )
75- clusters .labels [ip ] = min_id
76- cost += mind
77-
78- return clusters , cost
79-
80-
81- def calc_labeled_points (ic , clusters ):
82-
83- inds = np .array ([i for i in range (clusters .ndata )
84- if clusters .labels [i ] == ic ])
85- tx = np .array (clusters .x )
86- ty = np .array (clusters .y )
87-
88- x = tx [inds ]
89- y = ty [inds ]
90-
91- return x , y
92-
93-
94- def calc_raw_data (cx , cy , npoints , rand_d ):
37+ class Clusters :
9538
39+ def __init__ (self , x , y , n_label ):
40+ self .x = x
41+ self .y = y
42+ self .n_data = len (self .x )
43+ self .n_label = n_label
44+ self .labels = [random .randint (0 , n_label - 1 )
45+ for _ in range (self .n_data )]
46+ self .center_x = [0.0 for _ in range (n_label )]
47+ self .center_y = [0.0 for _ in range (n_label )]
48+
49+ def plot_cluster (self ):
50+ for label in set (self .labels ):
51+ x , y = self ._get_labeled_x_y (label )
52+ plt .plot (x , y , "." )
53+
54+ def calc_centroid (self ):
55+ for label in set (self .labels ):
56+ x , y = self ._get_labeled_x_y (label )
57+ n_data = len (x )
58+ self .center_x [label ] = sum (x ) / n_data
59+ self .center_y [label ] = sum (y ) / n_data
60+
61+ def update_clusters (self ):
62+ cost = 0.0
63+
64+ for ip in range (self .n_data ):
65+ px = self .x [ip ]
66+ py = self .y [ip ]
67+
68+ dx = [icx - px for icx in self .center_x ]
69+ dy = [icy - py for icy in self .center_y ]
70+
71+ dist_list = [math .sqrt (idx ** 2 + idy ** 2 ) for (idx , idy ) in zip (dx , dy )]
72+ min_dist = min (dist_list )
73+ min_id = dist_list .index (min_dist )
74+ self .labels [ip ] = min_id
75+ cost += min_dist
76+
77+ return cost
78+
79+ def _get_labeled_x_y (self , target_label ):
80+ x = [self .x [i ] for i , label in enumerate (self .labels ) if label == target_label ]
81+ y = [self .y [i ] for i , label in enumerate (self .labels ) if label == target_label ]
82+ return x , y
83+
84+
85+ def calc_raw_data (cx , cy , n_points , rand_d ):
9686 rx , ry = [], []
9787
9888 for (icx , icy ) in zip (cx , cy ):
99- for _ in range (npoints ):
89+ for _ in range (n_points ):
10090 rx .append (icx + rand_d * (random .random () - 0.5 ))
10191 ry .append (icy + rand_d * (random .random () - 0.5 ))
10292
10393 return rx , ry
10494
10595
10696def update_positions (cx , cy ):
107-
97+ # object moving parameters
10898 DX1 = 0.4
10999 DY1 = 0.5
110-
111- cx [0 ] += DX1
112- cy [0 ] += DY1
113-
114100 DX2 = - 0.3
115101 DY2 = - 0.5
116102
103+ cx [0 ] += DX1
104+ cy [0 ] += DY1
117105 cx [1 ] += DX2
118106 cy [1 ] += DY2
119107
120108 return cx , cy
121109
122110
123- def calc_association (cx , cy , clusters ):
124-
125- inds = []
126-
127- for ic , _ in enumerate (cx ):
128- tcx = cx [ic ]
129- tcy = cy [ic ]
130-
131- dx = [icx - tcx for icx in clusters .cx ]
132- dy = [icy - tcy for icy in clusters .cy ]
133-
134- dlist = [math .sqrt (idx ** 2 + idy ** 2 ) for (idx , idy ) in zip (dx , dy )]
135- min_id = dlist .index (min (dlist ))
136- inds .append (min_id )
137-
138- return inds
139-
140-
141111def main ():
142112 print (__file__ + " start!!" )
143113
144114 cx = [0.0 , 8.0 ]
145115 cy = [0.0 , 8.0 ]
146- npoints = 10
116+ n_points = 10
147117 rand_d = 3.0
148- ncluster = 2
118+ n_cluster = 2
149119 sim_time = 15.0
150120 dt = 1.0
151121 time = 0.0
@@ -154,20 +124,18 @@ def main():
154124 print ("Time:" , time )
155125 time += dt
156126
157- # simulate objects
127+ # objects moving simulation
158128 cx , cy = update_positions (cx , cy )
159- rx , ry = calc_raw_data (cx , cy , npoints , rand_d )
129+ raw_x , raw_y = calc_raw_data (cx , cy , n_points , rand_d )
160130
161- clusters = kmeans_clustering (rx , ry , ncluster )
131+ clusters = kmeans_clustering (raw_x , raw_y , n_cluster )
162132
163133 # for animation
164134 if show_animation : # pragma: no cover
165135 plt .cla ()
166- inds = calc_association (cx , cy , clusters )
167- for ic in inds :
168- x , y = calc_labeled_points (ic , clusters )
169- plt .plot (x , y , "x" )
170- plt .plot (cx , cy , "o" )
136+ clusters .plot_cluster ()
137+
138+ plt .plot (cx , cy , "or" )
171139 plt .xlim (- 2.0 , 10.0 )
172140 plt .ylim (- 2.0 , 10.0 )
173141 plt .pause (dt )
0 commit comments