This repository has been archived by the owner on Nov 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2k
/
ClusterCollection`3.cs
456 lines (414 loc) · 15.9 KB
/
ClusterCollection`3.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
// Accord Machine Learning Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2016
// cesarsouza at gmail.com
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
//
namespace Accord.MachineLearning
{
using Accord.Math;
using Accord.Math.Distances;
using Accord.Statistics.Distributions.Univariate;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading.Tasks;
/// <summary>
/// Cluster collection.
/// </summary>
///
/// <seealso cref="KMeans"/>
/// <seealso cref="KModes"/>
/// <seealso cref="MeanShift"/>
/// <seealso cref="GaussianMixtureModel"/>
///
[Serializable]
public class ClusterCollection<TData, TCentroids, TCluster>
: MulticlassScoreClassifierBase<TData>, IClusterCollection<TData, TCluster>
{
/// <summary>
/// Data cluster.
/// </summary>
///
[Serializable]
public class Cluster<TCollection>
where TCollection : ClusterCollection<TData, TCentroids, TCluster>
{
private TCollection owner;
private int index;
/// <summary>
/// Gets the collection to which this cluster belongs to.
/// </summary>
///
public TCollection Owner { get { return owner; } }
/// <summary>
/// Gets the label for this cluster.
/// </summary>
///
public int Index
{
get { return this.index; }
}
/// <summary>
/// Gets the cluster's centroid.
/// </summary>
///
public TCentroids Centroid
{
get { return owner.Centroids[index]; }
set { owner.Centroids[index] = value; }
}
/// <summary>
/// Gets the proportion of samples in the cluster.
/// </summary>
///
public double Proportion
{
get { return owner.Proportions[index]; }
set { owner.Proportions[index] = value; }
}
/// <summary>
/// Computes the distortion of the cluster, measured
/// as the average distance between the cluster points
/// and its centroid.
/// </summary>
///
/// <param name="data">The input points.</param>
///
/// <returns>The average distance between all points
/// in the cluster and the cluster centroid.</returns>
///
public double Distortion(TData[] data)
{
return owner.Distortion(data, Vector.Create(data.Length, index));
}
/// <summary>
/// Initializes a new instance of the <see cref="Cluster{TCollection}"/> class.
/// </summary>
///
/// <param name="owner">The owner collection.</param>
/// <param name="index">The cluster index.</param>
///
public Cluster(TCollection owner, int index)
{
this.owner = owner;
this.index = index;
}
}
private IDistance<TData, TCentroids> distance;
private double[] proportions;
private TCentroids[] centroids;
private TCluster[] clusters;
/// <summary>
/// Initializes a new instance of the <see cref="KMeansClusterCollection"/> class.
/// </summary>
///
/// <param name="k">The number of clusters K.</param>
/// <param name="distance">The distance metric to consider.</param>
///
public ClusterCollection(int k, IDistance<TData, TCentroids> distance)
{
// To store centroids of the clusters
this.proportions = new double[k];
this.centroids = new TCentroids[k];
this.clusters = new TCluster[k];
this.Distance = distance;
this.NumberOfOutputs = k;
}
/// <summary>
/// Gets or sets the distance function used to measure the distance
/// between a point and the cluster centroid in this clustering definition.
/// </summary>
///
public IDistance<TData, TCentroids> Distance
{
get { return distance; }
set { distance = value; }
}
/// <summary>
/// Gets or sets the clusters' centroids.
/// </summary>
///
/// <value>The clusters' centroids.</value>
///
public TCentroids[] Centroids
{
get { return centroids; }
set
{
if (value == null)
throw new ArgumentNullException("value");
if (value.Length != this.Count)
throw new ArgumentException("The number of centroids should be equal to K.", "value");
for (int i = 0; i < centroids.Length; i++)
centroids[i] = value[i];
}
}
/// <summary>
/// Gets the proportion of samples in each cluster.
/// </summary>
///
public virtual double[] Proportions
{
get { return proportions; }
}
/// <summary>
/// Gets the cluster definitions.
/// </summary>
///
public TCluster[] Clusters { get { return clusters; } }
/// <summary>
/// Returns the closest cluster to an input vector.
/// </summary>
///
/// <param name="point">The input vector.</param>
///
/// <returns>
/// The index of the nearest cluster to the given data point.
/// </returns>
///
[Obsolete("Please use Decide() instead.")]
public virtual int Nearest(TData point)
{
return Decide(point);
}
/// <summary>
/// Returns the closest cluster to an input vector.
/// </summary>
///
/// <param name="point">The input vector.</param>
/// <param name="responses">The responses probabilities for each label.</param>
///
/// <returns>
/// The index of the nearest cluster to the given data point.
/// </returns>
///
[Obsolete("Please use Scores() instead.")]
public virtual int Nearest(TData point, out double[] responses)
{
int decision;
responses = Scores(point, out decision);
return decision;
}
/// <summary>
/// Returns the closest cluster to an input point.
/// </summary>
///
/// <param name="point">The input vector.</param>
/// <param name="response">A value between 0 and 1 representing
/// the confidence in the generated classification.</param>
///
/// <returns>
/// The index of the nearest cluster
/// to the given data point. </returns>
///
[Obsolete("Please use Score() instead.")]
public virtual int Nearest(TData point, out double response)
{
int decision;
response = Score(point, out decision);
return decision;
}
/// <summary>
/// Returns the closest clusters to an input vector array.
/// </summary>
///
/// <param name="points">The input vector array.</param>
///
/// <returns>
/// An array containing the index of the nearest cluster
/// to the corresponding point in the input array.</returns>
///
[Obsolete("Please use Decide() instead.")]
public virtual int[] Nearest(TData[] points)
{
return Decide(points);
}
/// <summary>
/// Returns the closest clusters to an input vector array.
/// </summary>
///
/// <param name="points">The input vector array.</param>
/// <param name="responses">The responses probabilities for each label.</param>
///
/// <returns>
/// An array containing the index of the nearest cluster
/// to the corresponding point in the input array.</returns>
///
[Obsolete("Please use Scores() instead.")]
public virtual int[] Nearest(TData[] points, out double[][] responses)
{
int[] decisions = new int[points.Length];
responses = Scores(points, ref decisions);
return decisions;
}
/// <summary>
/// Calculates the average square distance from the data points
/// to the nearest clusters' centroids.
/// </summary>
///
/// <remarks>
/// The average distance from centroids can be used as a measure
/// of the "goodness" of the clustering. The more the data are
/// aggregated around the centroids, the less the average distance.
/// </remarks>
///
/// <returns>
/// The average square distance from the data points to the nearest
/// clusters' centroids.
/// </returns>
///
public virtual double Distortion(TData[] data, int[] labels = null, double[] weights = null)
{
if (labels == null)
labels = Decide(data);
if (weights == null)
{
double error = 0.0;
for (int i = 0; i < data.Length; i++)
error += Distance.Distance(data[i], centroids[labels[i]]);
return error / (double)data.Length;
}
else
{
double error = 0.0;
for (int i = 0; i < data.Length; i++)
error += weights[i] * Distance.Distance(data[i], centroids[labels[i]]);
return error / weights.Sum();
}
}
/// <summary>
/// Transform data points into feature vectors containing the
/// distance between each point and each of the clusters.
/// </summary>
///
/// <param name="points">The input points.</param>
/// <param name="labels">The label of each input point.</param>
/// <param name="weights">The weight associated with each point.</param>
/// <param name="result">An optional matrix to store the computed transformation.</param>
///
/// <returns>A vector containing the distance between the input points and the clusters.</returns>
///
public virtual double[] Transform(TData[] points, int[] labels, double[] weights = null, double[] result = null)
{
if (result == null)
result = new double[points.Length];
if (weights == null)
{
for (int i = 0; i < result.Length; i++)
result[i] = Distance.Distance(points[i], centroids[labels[i]]);
}
else
{
for (int i = 0; i < result.Length; i++)
result[i] = weights[i] * Distance.Distance(points[i], centroids[labels[i]]);
}
return result;
}
/// <summary>
/// Transform data points into feature vectors containing the
/// distance between each point and each of the clusters.
/// </summary>
///
/// <param name="points">The input points.</param>
/// <param name="weights">The weight associated with each point.</param>
/// <param name="result">An optional matrix to store the computed transformation.</param>
///
/// <returns>A vector containing the distance between the input points and the clusters.</returns>
///
public virtual double[][] Transform(TData[] points, double[] weights = null, double[][] result = null)
{
if (result == null)
{
result = new double[points.Length][];
for (int i = 0; i < result.Length; i++)
result[i] = new double[centroids.Length];
}
if (weights == null)
{
for (int i = 0; i < result.Length; i++)
for (int j = 0; j < centroids.Length; j++)
result[i][j] = Distance.Distance(points[i], centroids[j]);
}
else
{
for (int i = 0; i < result.Length; i++)
for (int j = 0; j < centroids.Length; j++)
result[i][j] = weights[i] * Distance.Distance(points[i], centroids[j]);
}
return result;
}
/// <summary>
/// Gets the number of clusters in the collection.
/// </summary>
///
public int Count
{
get { return Clusters.Length; }
}
/// <summary>
/// Gets the cluster at the given index.
/// </summary>
///
/// <param name="index">The index of the cluster. This should also be the class label of the cluster.</param>
///
/// <returns>An object holding information about the selected cluster.</returns>
///
public TCluster this[int index]
{
get { return clusters[index]; }
}
/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
///
/// <returns>
/// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection.
/// </returns>
///
public IEnumerator<TCluster> GetEnumerator()
{
foreach (var cluster in clusters)
yield return cluster;
}
/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
///
/// <returns>
/// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection.
/// </returns>
///
IEnumerator IEnumerable.GetEnumerator()
{
return clusters.GetEnumerator();
}
/// <summary>
/// Computes a numerical score measuring the association between
/// the given <paramref name="input" /> vector and a given
/// <paramref name="classIndex" />.
/// </summary>
/// <param name="input">The input vector.</param>
/// <param name="classIndex">The index of the class whose score will be computed.</param>
/// <returns>System.Double.</returns>
public override double Score(TData input, int classIndex)
{
return -Distance.Distance(input, centroids[classIndex]);
}
}
}