Skip to content

Commit 48e0678

Browse files
committed
added script
1 parent 4e6b08d commit 48e0678

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""
2+
3+
Object clustering with k-means algorithm
4+
5+
author: Atsushi Sakai (@Atsushi_twi)
6+
7+
"""
8+
9+
import numpy as np
10+
import math
11+
import matplotlib.pyplot as plt
12+
import random
13+
14+
show_animation = True
15+
16+
17+
class Clusters:
18+
19+
def __init__(self, x, y, nlabel):
20+
self.x = x
21+
self.y = y
22+
self.ndata = len(self.x)
23+
self.nlabel = nlabel
24+
self.labels = [random.randint(0, nlabel - 1)
25+
for _ in range(self.ndata)]
26+
self.cx = [0.0 for _ in range(nlabel)]
27+
self.cy = [0.0 for _ in range(nlabel)]
28+
29+
30+
def kmeans_clustering(rx, ry, nc):
31+
32+
clusters = Clusters(rx, ry, nc)
33+
clusters = calc_centroid(clusters)
34+
35+
MAX_LOOP = 10
36+
DCOST_TH = 0.1
37+
pcost = 100.0
38+
for loop in range(MAX_LOOP):
39+
# print("Loop:", loop)
40+
clusters, cost = update_clusters(clusters)
41+
clusters = calc_centroid(clusters)
42+
43+
dcost = abs(cost - pcost)
44+
if dcost < DCOST_TH:
45+
break
46+
pcost = cost
47+
48+
return clusters
49+
50+
51+
def calc_centroid(clusters):
52+
53+
for ic in range(clusters.nlabel):
54+
x, y = calc_labeled_points(ic, clusters)
55+
ndata = len(x)
56+
clusters.cx[ic] = sum(x) / ndata
57+
clusters.cy[ic] = sum(y) / ndata
58+
59+
return clusters
60+
61+
62+
def update_clusters(clusters):
63+
cost = 0.0
64+
65+
for ip in range(clusters.ndata):
66+
px = clusters.x[ip]
67+
py = clusters.y[ip]
68+
69+
dx = [icx - px for icx in clusters.cx]
70+
dy = [icy - py for icy in clusters.cy]
71+
72+
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
73+
mind = min(dlist)
74+
min_id = dlist.index(mind)
75+
clusters.labels[ip] = min_id
76+
cost += min_id
77+
78+
return clusters, cost
79+
80+
81+
def calc_labeled_points(ic, clusters):
82+
83+
inds = np.array([i for i in range(clusters.ndata)
84+
if clusters.labels[i] == ic])
85+
tx = np.array(clusters.x)
86+
ty = np.array(clusters.y)
87+
88+
x = tx[inds]
89+
y = ty[inds]
90+
91+
return x, y
92+
93+
94+
def calc_raw_data(cx, cy, npoints, rand_d):
95+
96+
rx, ry = [], []
97+
98+
for (icx, icy) in zip(cx, cy):
99+
for _ in range(npoints):
100+
rx.append(icx + rand_d * (random.random() - 0.5))
101+
ry.append(icy + rand_d * (random.random() - 0.5))
102+
103+
return rx, ry
104+
105+
106+
def update_positions(cx, cy):
107+
108+
DX1 = 0.4
109+
DY1 = 0.5
110+
111+
cx[0] += DX1
112+
cy[0] += DY1
113+
114+
DX2 = -0.3
115+
DY2 = -0.5
116+
117+
cx[1] += DX2
118+
cy[1] += DY2
119+
120+
return cx, cy
121+
122+
123+
def calc_association(cx, cy, clusters):
124+
125+
inds = []
126+
127+
for ic in range(len(cx)):
128+
tcx = cx[ic]
129+
tcy = cy[ic]
130+
131+
dx = [icx - tcx for icx in clusters.cx]
132+
dy = [icy - tcy for icy in clusters.cy]
133+
134+
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
135+
min_id = dlist.index(min(dlist))
136+
inds.append(min_id)
137+
138+
return inds
139+
140+
141+
def main():
142+
print(__file__ + " start!!")
143+
144+
cx = [0.0, 8.0]
145+
cy = [0.0, 8.0]
146+
npoints = 10
147+
rand_d = 3.0
148+
ncluster = 2
149+
sim_time = 15.0
150+
dt = 1.0
151+
time = 0.0
152+
153+
while time <= sim_time:
154+
print("Time:", time)
155+
time += dt
156+
157+
# simulate objects
158+
cx, cy = update_positions(cx, cy)
159+
rx, ry = calc_raw_data(cx, cy, npoints, rand_d)
160+
161+
clusters = kmeans_clustering(rx, ry, ncluster)
162+
163+
# for animation
164+
if show_animation:
165+
plt.cla()
166+
inds = calc_association(cx, cy, clusters)
167+
for ic in inds:
168+
x, y = calc_labeled_points(ic, clusters)
169+
plt.plot(x, y, "x")
170+
plt.plot(cx, cy, "o")
171+
plt.xlim(-2.0, 10.0)
172+
plt.ylim(-2.0, 10.0)
173+
plt.pause(dt)
174+
175+
print("Done")
176+
177+
178+
if __name__ == '__main__':
179+
main()

0 commit comments

Comments
 (0)