Skip to content

Commit

Permalink
IGNITE-7451: Make Linear SVM for multi-classification
Browse files Browse the repository at this point in the history
this closes #3484
  • Loading branch information
zaleslaw authored and YuriBabak committed Feb 12, 2018
1 parent 6f6f8dd commit c661963
Show file tree
Hide file tree
Showing 15 changed files with 512 additions and 44 deletions.
Expand Up @@ -19,10 +19,7 @@


import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths;
import org.apache.ignite.Ignite; import org.apache.ignite.Ignite;
import org.apache.ignite.Ignition; import org.apache.ignite.Ignition;
import org.apache.ignite.examples.ExampleNodeStartup; import org.apache.ignite.examples.ExampleNodeStartup;
Expand All @@ -33,13 +30,13 @@
import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
import org.apache.ignite.ml.structures.preprocessing.LabellingMachine; import org.apache.ignite.ml.structures.preprocessing.LabellingMachine;
import org.apache.ignite.ml.structures.preprocessing.Normalizer; import org.apache.ignite.ml.structures.preprocessing.Normalizer;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.thread.IgniteThread; import org.apache.ignite.thread.IgniteThread;


/** /**
* <p> * <p>
* Example of using {@link org.apache.ignite.ml.svm.SVMLinearClassificationModel} with Titanic dataset.</p> * Example of using {@link org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel} with Titanic dataset.</p>
* <p> * <p>
* Note that in this example we cannot guarantee order in which nodes return results of intermediate * Note that in this example we cannot guarantee order in which nodes return results of intermediate
* computations and therefore algorithm can return different results.</p> * computations and therefore algorithm can return different results.</p>
Expand Down Expand Up @@ -95,10 +92,10 @@ public static void main(String[] args) throws InterruptedException {
LabeledDataset train = split.train(); LabeledDataset train = split.train();


System.out.println("\n>>> Create new linear binary SVM trainer object."); System.out.println("\n>>> Create new linear binary SVM trainer object.");
Trainer<SVMLinearClassificationModel, LabeledDataset> trainer = new SVMLinearBinaryClassificationTrainer(); Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> trainer = new SVMLinearBinaryClassificationTrainer();


System.out.println("\n>>> Perform the training to get the model."); System.out.println("\n>>> Perform the training to get the model.");
SVMLinearClassificationModel mdl = trainer.train(train); SVMLinearBinaryClassificationModel mdl = trainer.train(train);


System.out.println("\n>>> SVM classification model: " + mdl); System.out.println("\n>>> SVM classification model: " + mdl);


Expand Down
Expand Up @@ -191,10 +191,20 @@ public void setLabel(int idx, double lb) {


/** */ /** */
public static Vector emptyVector(int size, boolean isDistributed) { public static Vector emptyVector(int size, boolean isDistributed) {

if(isDistributed) if(isDistributed)
return new SparseDistributedVector(size); return new SparseDistributedVector(size);
else else
return new DenseLocalOnHeapVector(size); return new DenseLocalOnHeapVector(size);
} }

/** Makes copy with new Label objects and old features and Metadata objects. */
public LabeledDataset copy(){
LabeledDataset res = new LabeledDataset(this.data, this.colSize);
res.isDistributed = this.isDistributed;
res.meta = this.meta;
for (int i = 0; i < rowSize; i++)
res.setLabel(i, this.label(i));

return res;
}
} }
Expand Up @@ -27,7 +27,7 @@
/** /**
* Base class for SVM linear classification model. * Base class for SVM linear classification model.
*/ */
public class SVMLinearClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearClassificationModel>, Serializable { public class SVMLinearBinaryClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearBinaryClassificationModel>, Serializable {
/** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */ /** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */
private boolean isKeepingRawLabels = false; private boolean isKeepingRawLabels = false;


Expand All @@ -41,47 +41,51 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
private double intercept; private double intercept;


/** */ /** */
public SVMLinearClassificationModel(Vector weights, double intercept) { public SVMLinearBinaryClassificationModel(Vector weights, double intercept) {
this.weights = weights; this.weights = weights;
this.intercept = intercept; this.intercept = intercept;
} }


/** /**
* Set up the output label format. * Set up the output label format.
*
* @param isKeepingRawLabels The parameter value. * @param isKeepingRawLabels The parameter value.
* @return Model with new isKeepingRawLabels parameter value. * @return Model with new isKeepingRawLabels parameter value.
*/ */
public SVMLinearClassificationModel withRawLabels(boolean isKeepingRawLabels) { public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabels) {
this.isKeepingRawLabels = isKeepingRawLabels; this.isKeepingRawLabels = isKeepingRawLabels;
return this; return this;
} }


/** /**
* Set up the threshold. * Set up the threshold.
*
* @param threshold The parameter value. * @param threshold The parameter value.
* @return Model with new threshold parameter value. * @return Model with new threshold parameter value.
*/ */
public SVMLinearClassificationModel withThreshold(double threshold) { public SVMLinearBinaryClassificationModel withThreshold(double threshold) {
this.threshold = threshold; this.threshold = threshold;
return this; return this;
} }


/** /**
* Set up the weights. * Set up the weights.
*
* @param weights The parameter value. * @param weights The parameter value.
* @return Model with new weights parameter value. * @return Model with new weights parameter value.
*/ */
public SVMLinearClassificationModel withWeights(Vector weights) { public SVMLinearBinaryClassificationModel withWeights(Vector weights) {
this.weights = weights; this.weights = weights;
return this; return this;
} }


/** /**
* Set up the intercept. * Set up the intercept.
*
* @param intercept The parameter value. * @param intercept The parameter value.
* @return Model with new intercept parameter value. * @return Model with new intercept parameter value.
*/ */
public SVMLinearClassificationModel withIntercept(double intercept) { public SVMLinearBinaryClassificationModel withIntercept(double intercept) {
this.intercept = intercept; this.intercept = intercept;
return this; return this;
} }
Expand All @@ -97,6 +101,7 @@ public SVMLinearClassificationModel withIntercept(double intercept) {


/** /**
* Gets the output label format mode. * Gets the output label format mode.
*
* @return The parameter value. * @return The parameter value.
*/ */
public boolean isKeepingRawLabels() { public boolean isKeepingRawLabels() {
Expand All @@ -105,6 +110,7 @@ public boolean isKeepingRawLabels() {


/** /**
* Gets the threshold. * Gets the threshold.
*
* @return The parameter value. * @return The parameter value.
*/ */
public double threshold() { public double threshold() {
Expand All @@ -113,6 +119,7 @@ public double threshold() {


/** /**
* Gets the weights. * Gets the weights.
*
* @return The parameter value. * @return The parameter value.
*/ */
public Vector weights() { public Vector weights() {
Expand All @@ -121,14 +128,15 @@ public Vector weights() {


/** /**
* Gets the intercept. * Gets the intercept.
*
* @return The parameter value. * @return The parameter value.
*/ */
public double intercept() { public double intercept() {
return intercept; return intercept;
} }


/** {@inheritDoc} */ /** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<SVMLinearClassificationModel, P> exporter, P path) { @Override public <P> void saveModel(Exporter<SVMLinearBinaryClassificationModel, P> exporter, P path) {
exporter.save(this, path); exporter.save(this, path);
} }


Expand All @@ -138,7 +146,7 @@ public double intercept() {
return true; return true;
if (o == null || getClass() != o.getClass()) if (o == null || getClass() != o.getClass())
return false; return false;
SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)o; SVMLinearBinaryClassificationModel mdl = (SVMLinearBinaryClassificationModel)o;
return Double.compare(mdl.intercept, intercept) == 0 return Double.compare(mdl.intercept, intercept) == 0
&& Double.compare(mdl.threshold, threshold) == 0 && Double.compare(mdl.threshold, threshold) == 0
&& Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0
Expand Down
Expand Up @@ -32,7 +32,7 @@
* and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found * and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found
* here https://arxiv.org/abs/1409.1458. * here https://arxiv.org/abs/1409.1458.
*/ */
public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearClassificationModel, LabeledDataset> { public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> {
/** Amount of outer SDCA algorithm iterations. */ /** Amount of outer SDCA algorithm iterations. */
private int amountOfIterations = 20; private int amountOfIterations = 20;


Expand All @@ -51,7 +51,7 @@ public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearCl
* @param data data to build model * @param data data to build model
* @return model * @return model
*/ */
@Override public SVMLinearClassificationModel train(LabeledDataset data) { @Override public SVMLinearBinaryClassificationModel train(LabeledDataset data) {
isDistributed = data.isDistributed(); isDistributed = data.isDistributed();


final int weightVectorSizeWithIntercept = data.colSize() + 1; final int weightVectorSizeWithIntercept = data.colSize() + 1;
Expand All @@ -62,7 +62,7 @@ public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearCl
weights = weights.plus(deltaWeights); // creates new vector weights = weights.plus(deltaWeights); // creates new vector
} }


return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
} }


/** */ /** */
Expand Down
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.ignite.ml.svm;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.math.Vector;

/** Base class for multi-classification model for set of SVM classifiers. */
public class SVMLinearMultiClassClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearMultiClassClassificationModel>, Serializable {
/** List of models associated with each class. */
private Map<Double, SVMLinearBinaryClassificationModel> models;

/** */
public SVMLinearMultiClassClassificationModel() {
this.models = new HashMap<>();
}

/** {@inheritDoc} */
@Override public Double apply(Vector input) {
TreeMap<Double, Double> maxMargins = new TreeMap<>();

models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k));

return maxMargins.lastEntry().getValue();
}

/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<SVMLinearMultiClassClassificationModel, P> exporter, P path) {
exporter.save(this, path);
}

/** {@inheritDoc} */
@Override public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o;
return Objects.equals(models, mdl.models);
}

/** {@inheritDoc} */
@Override public int hashCode() {
return Objects.hash(models);
}

/** {@inheritDoc} */
@Override public String toString() {
StringBuilder wholeStr = new StringBuilder();

models.forEach((clsLb, mdl) -> {
wholeStr.append("The class with label " + clsLb + " has classifier: " + mdl.toString() + System.lineSeparator());
});

return wholeStr.toString();
}

/**
* Adds a specific SVM binary classifier to the bunch of same classifiers.
*
* @param clsLb The class label for the added model.
* @param mdl The model.
*/
public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
models.put(clsLb, mdl);
}
}

0 comments on commit c661963

Please sign in to comment.