This repository has been archived by the owner on Nov 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2k
/
SupportVectorMachine.cs
403 lines (373 loc) · 13.9 KB
/
SupportVectorMachine.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
// Accord Machine Learning Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2015
// cesarsouza at gmail.com
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
//
namespace Accord.MachineLearning.VectorMachines
{
using System;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using Accord.Statistics.Links;
using Accord.Statistics.Models.Regression;
/// <summary>
/// Linear Support Vector Machine (SVM)
/// </summary>
///
/// <remarks>
/// <para>
/// Support vector machines (SVMs) are a set of related supervised learning methods
/// used for classification and regression. In simple words, given a set of training
/// examples, each marked as belonging to one of two categories, a SVM training algorithm
/// builds a model that predicts whether a new example falls into one category or the
/// other.</para>
/// <para>
/// Intuitively, an SVM model is a representation of the examples as points in space,
/// mapped so that the examples of the separate categories are divided by a clear gap
/// that is as wide as possible. New examples are then mapped into that same space and
/// predicted to belong to a category based on which side of the gap they fall on.</para>
///
/// <para>
/// For the non-linear generalization of the Support Vector Machine using arbitrary
/// kernel functions, please see the <see cref="KernelSupportVectorMachine"/>.
/// </para>
///
/// <para>
/// References:
/// <list type="bullet">
/// <item><description><a href="http://en.wikipedia.org/wiki/Support_vector_machine">
/// http://en.wikipedia.org/wiki/Support_vector_machine </a></description></item>
/// </list></para>
/// </remarks>
///
/// <example>
/// <code>
/// // Example AND problem
/// double[][] inputs =
/// {
/// new double[] { 0, 0 }, // 0 and 0: 0 (label -1)
/// new double[] { 0, 1 }, // 0 and 1: 0 (label -1)
/// new double[] { 1, 0 }, // 1 and 0: 0 (label -1)
/// new double[] { 1, 1 } // 1 and 1: 1 (label +1)
/// };
///
/// // Dichotomy SVM outputs should be given as [-1;+1]
/// int[] labels =
/// {
/// // 0, 0, 0, 1
/// -1, -1, -1, 1
/// };
///
/// // Create a Support Vector Machine for the given inputs
/// SupportVectorMachine machine = new SupportVectorMachine(inputs[0].Length);
///
/// // Instantiate a new learning algorithm for SVMs
/// SequentialMinimalOptimization smo = new SequentialMinimalOptimization(machine, inputs, labels);
///
/// // Set up the learning algorithm
/// smo.Complexity = 1.0;
///
/// // Run the learning algorithm
/// double error = smo.Run();
///
/// // Compute the decision output for one of the input vectors
/// int decision = System.Math.Sign(machine.Compute(inputs[0]));
/// </code>
/// </example>
///
/// <seealso cref="KernelSupportVectorMachine"/>
/// <seealso cref="MulticlassSupportVectorMachine"/>
/// <seealso cref="MultilabelSupportVectorMachine"/>
///
/// <seealso cref="Accord.MachineLearning.VectorMachines.Learning.SequentialMinimalOptimization"/>
///
[Serializable]
public class SupportVectorMachine : ISupportVectorMachine
{
private int inputCount;
private double[][] supportVectors;
private double[] weights;
private double threshold;
private ILinkFunction linkFunction;
/// <summary>
/// Gets or sets the <see cref="ILinkFunction">link
/// function</see> used by this machine, if any.
/// </summary>
///
/// <value>The link function used to transform machine outputs.</value>
///
public ILinkFunction Link
{
get { return linkFunction; }
set { linkFunction = value; }
}
/// <summary>
/// Gets a value indicating whether this machine produces probabilistic outputs.
/// </summary>
///
/// <value>
/// <c>true</c> if this machine produces probabilistic outputs; otherwise, <c>false</c>.
/// </value>
///
public bool IsProbabilistic
{
get { return linkFunction != null; }
}
/// <summary>
/// Creates a new Support Vector Machine
/// </summary>
///
/// <param name="inputs">The number of inputs for the machine.</param>
///
public SupportVectorMachine(int inputs)
{
this.inputCount = inputs;
}
/// <summary>
/// Gets the number of inputs accepted by this machine.
/// </summary>
///
/// <remarks>
/// If the number of inputs is zero, this means the machine
/// accepts a indefinite number of inputs. This is often the
/// case for kernel vector machines using a sequence kernel.
/// </remarks>
///
public int Inputs
{
get { return inputCount; }
}
/// <summary>
/// Gets or sets the collection of support vectors used by this machine.
/// </summary>
///
public double[][] SupportVectors
{
get { return supportVectors; }
set { supportVectors = value; }
}
/// <summary>
/// Gets whether this machine is in compact mode. Compact
/// machines do not need to keep storing their support vectors.
/// </summary>
///
public bool IsCompact
{
get { return supportVectors == null; }
}
/// <summary>
/// Gets or sets the collection of weights used by this machine.
/// </summary>
///
public double[] Weights
{
get { return weights; }
set { weights = value; }
}
/// <summary>
/// Gets or sets the threshold (bias) term for this machine.
/// </summary>
///
public double Threshold
{
get { return threshold; }
set { threshold = value; }
}
/// <summary>
/// Computes the given input to produce the corresponding output.
/// </summary>
///
/// <remarks>
/// For a binary decision problem, the decision for the negative
/// or positive class is typically computed by taking the sign of
/// the machine's output.
/// </remarks>
///
/// <param name="inputs">An input vector.</param>
/// <param name="output">The output of the machine. If this is a
/// <see cref="IsProbabilistic">probabilistic</see> machine, the
/// output is the probability of the positive class. If this is
/// a standard machine, the output is the distance to the decision
/// hyperplane in feature space.</param>
///
/// <returns>The decision label for the given input.</returns>
///
public virtual int Compute(double[] inputs, out double output)
{
output = threshold;
if (supportVectors == null)
{
for (int i = 0; i < weights.Length; i++)
output += weights[i] * inputs[i];
}
else
{
for (int i = 0; i < supportVectors.Length; i++)
{
double sum = 0;
for (int j = 0; j < inputs.Length; j++)
sum += supportVectors[i][j] * inputs[j];
output += weights[i] * sum;
}
}
if (IsProbabilistic)
{
output = linkFunction.Inverse(output);
return output >= 0.5 ? +1 : -1;
}
return output >= 0 ? +1 : -1;
}
/// <summary>
/// Computes the given input to produce the corresponding output.
/// </summary>
///
/// <remarks>
/// For a binary decision problem, the decision for the negative
/// or positive class is typically computed by taking the sign of
/// the machine's output.
/// </remarks>
///
/// <param name="inputs">An input vector.</param>
///
/// <returns>The output for the given input. In a typical classification
/// problem, the sign of this value should be considered as the class label.</returns>
///
public double Compute(double[] inputs)
{
double output;
Compute(inputs, out output);
return output;
}
/// <summary>
/// Creates a new <see cref="SupportVectorMachine"/> that is
/// completely equivalent to a <see cref="LogisticRegression"/>.
/// </summary>
///
/// <param name="regression">The <see cref="LogisticRegression"/> to be converted.</param>
///
/// <returns>
/// A <see cref="SupportVectorMachine"/> whose linear weights are
/// equivalent to the given <see cref="LogisticRegression"/>'s
/// <see cref="GeneralizedLinearRegression.Coefficients"> linear
/// coefficients</see>, properly configured with a <see cref="LogLinkFunction"/>.
/// </returns>
///
public static SupportVectorMachine FromLogisticRegression(LogisticRegression regression)
{
double[] weights = regression.Coefficients;
var svm = new SupportVectorMachine(regression.Inputs);
for (int i = 0; i < svm.weights.Length; i++)
svm.Weights[i] = weights[i + 1];
svm.Threshold = regression.Intercept;
svm.Link = new LogitLinkFunction(1, 0);
return svm;
}
/// <summary>
/// Creates a new linear <see cref="SupportVectorMachine"/>
/// with the given set of linear <paramref name="weights"/>.
/// </summary>
///
/// <param name="weights">The machine's linear coefficients.</param>
///
/// <returns>
/// A <see cref="SupportVectorMachine"/> whose linear coefficients
/// are defined by the given <paramref name="weights"/> vector.
/// </returns>
///
public static SupportVectorMachine FromWeights(double[] weights)
{
var svm = new SupportVectorMachine(weights.Length - 1);
svm.Weights = new double[svm.Inputs];
for (int i = 0; i < svm.weights.Length; i++)
svm.Weights[i] = weights[i + 1];
svm.Threshold = weights[0];
return svm;
}
/// <summary>
/// Converts a <see cref="Accord.Statistics.Kernels.Linear"/>-kernel
/// machine into an array of linear coefficients. The first position
/// in the array is the <see cref="Threshold"/> value.
/// </summary>
///
/// <returns>
/// An array of linear coefficients representing this machine.
/// </returns>
///
public virtual double[] ToWeights()
{
double[] w = new double[weights.Length + 1];
for (int i = 0; i < weights.Length; i++)
w[i + 1] = weights[i];
w[0] = threshold;
return w;
}
/// <summary>
/// Saves the machine to a stream.
/// </summary>
///
/// <param name="stream">The stream to which the machine is to be serialized.</param>
///
public virtual void Save(Stream stream)
{
BinaryFormatter b = new BinaryFormatter();
b.Serialize(stream, this);
}
/// <summary>
/// Saves the machine to a stream.
/// </summary>
///
/// <param name="path">The path to the file to which the machine is to be serialized.</param>
///
public void Save(string path)
{
using (FileStream fs = new FileStream(path, FileMode.Create))
{
Save(fs);
}
}
/// <summary>
/// Loads a machine from a stream.
/// </summary>
///
/// <param name="stream">The stream from which the machine is to be deserialized.</param>
///
/// <returns>The deserialized machine.</returns>
///
public static SupportVectorMachine Load(Stream stream)
{
BinaryFormatter b = new BinaryFormatter();
return (SupportVectorMachine)b.Deserialize(stream);
}
/// <summary>
/// Loads a machine from a file.
/// </summary>
///
/// <param name="path">The path to the file from which the machine is to be deserialized.</param>
///
/// <returns>The deserialized machine.</returns>
///
public static SupportVectorMachine Load(string path)
{
using (FileStream fs = new FileStream(path, FileMode.Open))
{
return Load(fs);
}
}
}
}