Skip to content

Commit dacff1c

Browse files
committed
Fixed kmeans clustering bug AtsushiSakai#251
1 parent 0b07425 commit dacff1c

File tree

1 file changed

+71
-103
lines changed

1 file changed

+71
-103
lines changed

Mapping/kmeans_clustering/kmeans_clustering.py

Lines changed: 71 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,146 +6,116 @@
66
77
"""
88

9-
import numpy as np
109
import math
1110
import matplotlib.pyplot as plt
1211
import random
1312

13+
# k means parameters
14+
MAX_LOOP = 10
15+
DCOST_TH = 0.1
1416
show_animation = True
1517

1618

17-
class Clusters:
18-
19-
def __init__(self, x, y, nlabel):
20-
self.x = x
21-
self.y = y
22-
self.ndata = len(self.x)
23-
self.nlabel = nlabel
24-
self.labels = [random.randint(0, nlabel - 1)
25-
for _ in range(self.ndata)]
26-
self.cx = [0.0 for _ in range(nlabel)]
27-
self.cy = [0.0 for _ in range(nlabel)]
28-
29-
3019
def kmeans_clustering(rx, ry, nc):
31-
3220
clusters = Clusters(rx, ry, nc)
33-
clusters = calc_centroid(clusters)
21+
clusters.calc_centroid()
3422

35-
MAX_LOOP = 10
36-
DCOST_TH = 0.1
37-
pcost = 100.0
23+
pre_cost = float("inf")
3824
for loop in range(MAX_LOOP):
39-
# print("Loop:", loop)
40-
clusters, cost = update_clusters(clusters)
41-
clusters = calc_centroid(clusters)
25+
print("loop:", loop)
26+
cost = clusters.update_clusters()
27+
clusters.calc_centroid()
4228

43-
dcost = abs(cost - pcost)
44-
if dcost < DCOST_TH:
29+
d_cost = abs(cost - pre_cost)
30+
if d_cost < DCOST_TH:
4531
break
46-
pcost = cost
32+
pre_cost = cost
4733

4834
return clusters
4935

5036

51-
def calc_centroid(clusters):
52-
53-
for ic in range(clusters.nlabel):
54-
x, y = calc_labeled_points(ic, clusters)
55-
ndata = len(x)
56-
clusters.cx[ic] = sum(x) / ndata
57-
clusters.cy[ic] = sum(y) / ndata
58-
59-
return clusters
60-
61-
62-
def update_clusters(clusters):
63-
cost = 0.0
64-
65-
for ip in range(clusters.ndata):
66-
px = clusters.x[ip]
67-
py = clusters.y[ip]
68-
69-
dx = [icx - px for icx in clusters.cx]
70-
dy = [icy - py for icy in clusters.cy]
71-
72-
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
73-
mind = min(dlist)
74-
min_id = dlist.index(mind)
75-
clusters.labels[ip] = min_id
76-
cost += mind
77-
78-
return clusters, cost
79-
80-
81-
def calc_labeled_points(ic, clusters):
82-
83-
inds = np.array([i for i in range(clusters.ndata)
84-
if clusters.labels[i] == ic])
85-
tx = np.array(clusters.x)
86-
ty = np.array(clusters.y)
87-
88-
x = tx[inds]
89-
y = ty[inds]
90-
91-
return x, y
92-
93-
94-
def calc_raw_data(cx, cy, npoints, rand_d):
37+
class Clusters:
9538

39+
def __init__(self, x, y, n_label):
40+
self.x = x
41+
self.y = y
42+
self.n_data = len(self.x)
43+
self.n_label = n_label
44+
self.labels = [random.randint(0, n_label - 1)
45+
for _ in range(self.n_data)]
46+
self.center_x = [0.0 for _ in range(n_label)]
47+
self.center_y = [0.0 for _ in range(n_label)]
48+
49+
def plot_cluster(self):
50+
for label in set(self.labels):
51+
x, y = self._get_labeled_x_y(label)
52+
plt.plot(x, y, ".")
53+
54+
def calc_centroid(self):
55+
for label in set(self.labels):
56+
x, y = self._get_labeled_x_y(label)
57+
n_data = len(x)
58+
self.center_x[label] = sum(x) / n_data
59+
self.center_y[label] = sum(y) / n_data
60+
61+
def update_clusters(self):
62+
cost = 0.0
63+
64+
for ip in range(self.n_data):
65+
px = self.x[ip]
66+
py = self.y[ip]
67+
68+
dx = [icx - px for icx in self.center_x]
69+
dy = [icy - py for icy in self.center_y]
70+
71+
dist_list = [math.sqrt(idx ** 2 + idy ** 2) for (idx, idy) in zip(dx, dy)]
72+
min_dist = min(dist_list)
73+
min_id = dist_list.index(min_dist)
74+
self.labels[ip] = min_id
75+
cost += min_dist
76+
77+
return cost
78+
79+
def _get_labeled_x_y(self, target_label):
80+
x = [self.x[i] for i, label in enumerate(self.labels) if label == target_label]
81+
y = [self.y[i] for i, label in enumerate(self.labels) if label == target_label]
82+
return x, y
83+
84+
85+
def calc_raw_data(cx, cy, n_points, rand_d):
9686
rx, ry = [], []
9787

9888
for (icx, icy) in zip(cx, cy):
99-
for _ in range(npoints):
89+
for _ in range(n_points):
10090
rx.append(icx + rand_d * (random.random() - 0.5))
10191
ry.append(icy + rand_d * (random.random() - 0.5))
10292

10393
return rx, ry
10494

10595

10696
def update_positions(cx, cy):
107-
97+
# object moving parameters
10898
DX1 = 0.4
10999
DY1 = 0.5
110-
111-
cx[0] += DX1
112-
cy[0] += DY1
113-
114100
DX2 = -0.3
115101
DY2 = -0.5
116102

103+
cx[0] += DX1
104+
cy[0] += DY1
117105
cx[1] += DX2
118106
cy[1] += DY2
119107

120108
return cx, cy
121109

122110

123-
def calc_association(cx, cy, clusters):
124-
125-
inds = []
126-
127-
for ic, _ in enumerate(cx):
128-
tcx = cx[ic]
129-
tcy = cy[ic]
130-
131-
dx = [icx - tcx for icx in clusters.cx]
132-
dy = [icy - tcy for icy in clusters.cy]
133-
134-
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
135-
min_id = dlist.index(min(dlist))
136-
inds.append(min_id)
137-
138-
return inds
139-
140-
141111
def main():
142112
print(__file__ + " start!!")
143113

144114
cx = [0.0, 8.0]
145115
cy = [0.0, 8.0]
146-
npoints = 10
116+
n_points = 10
147117
rand_d = 3.0
148-
ncluster = 2
118+
n_cluster = 2
149119
sim_time = 15.0
150120
dt = 1.0
151121
time = 0.0
@@ -154,20 +124,18 @@ def main():
154124
print("Time:", time)
155125
time += dt
156126

157-
# simulate objects
127+
# objects moving simulation
158128
cx, cy = update_positions(cx, cy)
159-
rx, ry = calc_raw_data(cx, cy, npoints, rand_d)
129+
raw_x, raw_y = calc_raw_data(cx, cy, n_points, rand_d)
160130

161-
clusters = kmeans_clustering(rx, ry, ncluster)
131+
clusters = kmeans_clustering(raw_x, raw_y, n_cluster)
162132

163133
# for animation
164134
if show_animation: # pragma: no cover
165135
plt.cla()
166-
inds = calc_association(cx, cy, clusters)
167-
for ic in inds:
168-
x, y = calc_labeled_points(ic, clusters)
169-
plt.plot(x, y, "x")
170-
plt.plot(cx, cy, "o")
136+
clusters.plot_cluster()
137+
138+
plt.plot(cx, cy, "or")
171139
plt.xlim(-2.0, 10.0)
172140
plt.ylim(-2.0, 10.0)
173141
plt.pause(dt)

0 commit comments

Comments
 (0)