diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
new file mode 100644
index 000000000..4169d488d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.HingeLoss;
+import org.apache.flink.ml.common.optimizer.Optimizer;
+import org.apache.flink.ml.common.optimizer.SGD;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the linear support vector classification.
+ *
+ *
See https://en.wikipedia.org/wiki/Support-vector_machine#Linear_SVM.
+ */
+public class LinearSVC implements Estimator, LinearSVCParams {
+
+ private final Map, Object> paramMap = new HashMap<>();
+
+ public LinearSVC() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings({"rawTypes", "ConstantConditions"})
+ public LinearSVCModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream trainData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ dataPoint -> {
+ double weight =
+ getWeightCol() == null
+ ? 1.0
+ : (Double) dataPoint.getField(getWeightCol());
+ double label = (Double) dataPoint.getField(getLabelCol());
+ Preconditions.checkState(
+ Double.compare(0.0, label) == 0
+ || Double.compare(1.0, label) == 0,
+ "LinearSVC only supports binary classification. But detected label: %s.",
+ label);
+ DenseVector features =
+ (DenseVector) dataPoint.getField(getFeaturesCol());
+ return new LabeledPointWithWeight(features, label, weight);
+ });
+
+ DataStream initModelData =
+ DataStreamUtils.reduce(
+ trainData.map(x -> x.getFeatures().size()),
+ (ReduceFunction)
+ (t0, t1) -> {
+ Preconditions.checkState(
+ t0.equals(t1),
+ "The training data should all have same dimensions.");
+ return t0;
+ })
+ .map(DenseVector::new);
+
+ Optimizer optimizer =
+ new SGD(
+ getMaxIter(),
+ getLearningRate(),
+ getGlobalBatchSize(),
+ getTol(),
+ getReg(),
+ getElasticNet());
+ DataStream rawModelData =
+ optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE);
+
+ DataStream modelData = rawModelData.map(LinearSVCModelData::new);
+ LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static LinearSVC load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java
new file mode 100644
index 000000000..253bbccd0
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link LinearSVC}. */
+public class LinearSVCModel implements Model, LinearSVCModelParams {
+
+ private final Map, Object> paramMap = new HashMap<>();
+
+ private Table modelDataTable;
+
+ public LinearSVCModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream inputStream = tEnv.toDataStream(inputs[0]);
+
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream modelDataStream =
+ LinearSVCModelData.getModelDataStream(modelDataTable);
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
+ BasicTypeInfo.DOUBLE_TYPE_INFO,
+ DenseVectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldNames(),
+ getPredictionCol(),
+ getRawPredictionCol()));
+
+ DataStream predictionResult =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(inputStream),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.map(
+ new PredictLabelFunction(
+ broadcastModelKey, getFeaturesCol(), getThreshold()),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(predictionResult)};
+ }
+
+ @Override
+ public LinearSVCModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ LinearSVCModelData.getModelDataStream(modelDataTable),
+ path,
+ new LinearSVCModelData.ModelDataEncoder());
+ }
+
+ public static LinearSVCModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+ LinearSVCModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new LinearSVCModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public Map, Object> getParamMap() {
+ return paramMap;
+ }
+
+ /** A utility function used for prediction. */
+ private static class PredictLabelFunction extends RichMapFunction {
+
+ private final String broadcastModelKey;
+
+ private final String featuresCol;
+
+ private final double threshold;
+
+ private DenseVector coefficient;
+
+ public PredictLabelFunction(
+ String broadcastModelKey, String featuresCol, double threshold) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.featuresCol = featuresCol;
+ this.threshold = threshold;
+ }
+
+ @Override
+ public Row map(Row dataPoint) {
+ if (coefficient == null) {
+ LinearSVCModelData modelData =
+ (LinearSVCModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ coefficient = modelData.coefficient;
+ }
+ DenseVector features = (DenseVector) dataPoint.getField(featuresCol);
+ Row predictionResult = predictOneDataPoint(features, coefficient, threshold);
+ return Row.join(dataPoint, predictionResult);
+ }
+ }
+
+ /**
+ * The main logic that predicts one input data point.
+ *
+ * @param feature The input feature.
+ * @param coefficient The model parameters.
+ * @param threshold The threshold for prediction.
+ * @return The prediction label and the raw predictions.
+ */
+ private static Row predictOneDataPoint(
+ DenseVector feature, DenseVector coefficient, double threshold) {
+ double dotValue = BLAS.dot(feature, coefficient);
+ return Row.of(dotValue >= threshold ? 1.0 : 0.0, Vectors.dense(dotValue, -dotValue));
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java
new file mode 100644
index 000000000..96e8a27ae
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LinearSVCModel}.
+ *
+ * This class also provides methods to convert model data from Table to Datastream, and classes
+ * to save/load model data.
+ */
+public class LinearSVCModelData {
+
+ public DenseVector coefficient;
+
+ public LinearSVCModelData(DenseVector coefficient) {
+ this.coefficient = coefficient;
+ }
+
+ public LinearSVCModelData() {}
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream getModelDataStream(Table modelData) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
+ return tEnv.toDataStream(modelData)
+ .map(x -> new LinearSVCModelData((DenseVector) x.getField(0)));
+ }
+
+ /** Data encoder for {@link LinearSVCModel}. */
+ public static class ModelDataEncoder implements Encoder {
+
+ @Override
+ public void encode(LinearSVCModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DenseVectorSerializer.INSTANCE.serialize(
+ modelData.coefficient, new DataOutputViewStreamWrapper(outputStream));
+ }
+ }
+
+ /** Data decoder for {@link LinearSVCModel}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat {
+
+ @Override
+ public Reader createReader(
+ Configuration configuration, FSDataInputStream inputStream) {
+ return new Reader() {
+
+ @Override
+ public LinearSVCModelData read() throws IOException {
+ try {
+ DenseVector coefficient =
+ DenseVectorSerializer.INSTANCE.deserialize(
+ new DataInputViewStreamWrapper(inputStream));
+ return new LinearSVCModelData(coefficient);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ inputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation getProducedType() {
+ return TypeInformation.of(LinearSVCModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java
new file mode 100644
index 000000000..9e02233cd
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link LinearSVCModel}.
+ *
+ * @param The class type of this instance.
+ */
+public interface LinearSVCModelParams
+ extends HasFeaturesCol, HasPredictionCol, HasRawPredictionCol {
+ /**
+ * Param for threshold in linear support vector classifier. It applies to the rawPrediction and
+ * can be any real number, where Inf makes all predictions 0.0 and -Inf makes all predictions
+ * 1.0.
+ */
+ Param THRESHOLD =
+ new DoubleParam(
+ "threshold",
+ "Threshold in binary classification prediction applied to rawPrediction.",
+ 0.0,
+ ParamValidators.notNull());
+
+ default Double getThreshold() {
+ return get(THRESHOLD);
+ }
+
+ default T setThreshold(Double value) {
+ set(THRESHOLD, value);
+ return (T) this;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java
new file mode 100644
index 000000000..7754e89f5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasLearningRate;
+import org.apache.flink.ml.common.param.HasMaxIter;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasTol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+
+/**
+ * Params for {@link LinearSVC}.
+ *
+ * @param The class type of this instance.
+ */
+public interface LinearSVCParams
+ extends HasLabelCol,
+ HasWeightCol,
+ HasMaxIter,
+ HasReg,
+ HasElasticNet,
+ HasLearningRate,
+ HasGlobalBatchSize,
+ HasTol,
+ LinearSVCModelParams {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java
new file mode 100644
index 000000000..eb0f3bf58
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.classification.linearsvc.LinearSVC;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * The loss function for hinge loss. See {@link LinearSVC} for example.
+ *
+ * See https://en.wikipedia.org/wiki/Hinge_loss.
+ */
+@Internal
+public class HingeLoss implements LossFunc {
+ public static final HingeLoss INSTANCE = new HingeLoss();
+
+ private HingeLoss() {}
+
+ @Override
+ public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) {
+ double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+ double labelScaled = 2 * dataPoint.getLabel() - 1;
+ return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot);
+ }
+
+ @Override
+ public void computeGradient(
+ LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) {
+ double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+ double labelScaled = 2 * dataPoint.getLabel() - 1;
+ if (1 - labelScaled * dot > 0) {
+ BLAS.axpy(
+ -labelScaled * dataPoint.getWeight(),
+ dataPoint.getFeatures(),
+ cumGradient,
+ dataPoint.getFeatures().size());
+ }
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
new file mode 100644
index 000000000..156244e30
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.linearsvc.LinearSVC;
+import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
+import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearSVC} and {@link LinearSVCModel}. */
+public class LinearSVCTest {
+
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private StreamExecutionEnvironment env;
+
+ private StreamTableEnvironment tEnv;
+
+ private static final List trainData =
+ Arrays.asList(
+ Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
+ Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
+ Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
+ Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
+ Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
+ Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
+ Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
+ Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
+ Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
+ Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
+
+ private static final double[] expectedCoefficient =
+ new double[] {0.470, -0.273, -0.410, -0.546};
+
+ private static final double TOLERANCE = 1e-7;
+
+ private Table trainDataTable;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ Collections.shuffle(trainData);
+ trainDataTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ trainData,
+ new RowTypeInfo(
+ new TypeInformation[] {
+ DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE
+ },
+ new String[] {"features", "label", "weight"})));
+ }
+
+ @SuppressWarnings("ConstantConditions, unchecked")
+ private void verifyPredictionResult(
+ Table output, String featuresCol, String predictionCol, String rawPredictionCol)
+ throws Exception {
+ List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ for (Row predictionRow : predResult) {
+ DenseVector feature = (DenseVector) predictionRow.getField(featuresCol);
+ double prediction = (double) predictionRow.getField(predictionCol);
+ DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol);
+ if (feature.get(0) <= 5) {
+ assertEquals(0, prediction, TOLERANCE);
+ assertTrue(rawPrediction.get(0) < 0);
+ } else {
+ assertEquals(1, prediction, TOLERANCE);
+ assertTrue(rawPrediction.get(0) > 0);
+ }
+ }
+ }
+
+ @Test
+ public void testParam() {
+ LinearSVC linearSVC = new LinearSVC();
+ assertEquals("features", linearSVC.getFeaturesCol());
+ assertEquals("label", linearSVC.getLabelCol());
+ assertNull(linearSVC.getWeightCol());
+ assertEquals(20, linearSVC.getMaxIter());
+ assertEquals(1e-6, linearSVC.getTol(), TOLERANCE);
+ assertEquals(0.1, linearSVC.getLearningRate(), TOLERANCE);
+ assertEquals(32, linearSVC.getGlobalBatchSize());
+ assertEquals(0, linearSVC.getReg(), TOLERANCE);
+ assertEquals(0, linearSVC.getElasticNet(), TOLERANCE);
+ assertEquals(0.0, linearSVC.getThreshold(), TOLERANCE);
+ assertEquals("prediction", linearSVC.getPredictionCol());
+ assertEquals("rawPrediction", linearSVC.getRawPredictionCol());
+
+ linearSVC
+ .setFeaturesCol("test_features")
+ .setLabelCol("test_label")
+ .setWeightCol("test_weight")
+ .setMaxIter(1000)
+ .setTol(0.001)
+ .setLearningRate(0.5)
+ .setGlobalBatchSize(1000)
+ .setReg(0.1)
+ .setElasticNet(0.5)
+ .setThreshold(0.5)
+ .setPredictionCol("test_predictionCol")
+ .setRawPredictionCol("test_rawPredictionCol");
+ assertEquals("test_features", linearSVC.getFeaturesCol());
+ assertEquals("test_label", linearSVC.getLabelCol());
+ assertEquals("test_weight", linearSVC.getWeightCol());
+ assertEquals(1000, linearSVC.getMaxIter());
+ assertEquals(0.001, linearSVC.getTol(), TOLERANCE);
+ assertEquals(0.5, linearSVC.getLearningRate(), TOLERANCE);
+ assertEquals(1000, linearSVC.getGlobalBatchSize());
+ assertEquals(0.1, linearSVC.getReg(), TOLERANCE);
+ assertEquals(0.5, linearSVC.getElasticNet(), TOLERANCE);
+ assertEquals(0.5, linearSVC.getThreshold(), TOLERANCE);
+ assertEquals("test_predictionCol", linearSVC.getPredictionCol());
+ assertEquals("test_rawPredictionCol", linearSVC.getRawPredictionCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable = trainDataTable.as("test_features", "test_label", "test_weight");
+ LinearSVC linearSVC =
+ new LinearSVC()
+ .setFeaturesCol("test_features")
+ .setLabelCol("test_label")
+ .setWeightCol("test_weight")
+ .setPredictionCol("test_predictionCol")
+ .setRawPredictionCol("test_rawPredictionCol");
+ Table output = linearSVC.fit(trainDataTable).transform(tempTable)[0];
+ assertEquals(
+ Arrays.asList(
+ "test_features",
+ "test_label",
+ "test_weight",
+ "test_predictionCol",
+ "test_rawPredictionCol"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+ Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0];
+ verifyPredictionResult(
+ output,
+ linearSVC.getFeaturesCol(),
+ linearSVC.getPredictionCol(),
+ linearSVC.getRawPredictionCol());
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+ linearSVC =
+ StageTestUtils.saveAndReload(
+ tEnv, linearSVC, tempFolder.newFolder().getAbsolutePath());
+ LinearSVCModel model = linearSVC.fit(trainDataTable);
+ model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Collections.singletonList("coefficient"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+ Table output = model.transform(trainDataTable)[0];
+ verifyPredictionResult(
+ output,
+ linearSVC.getFeaturesCol(),
+ linearSVC.getPredictionCol(),
+ linearSVC.getRawPredictionCol());
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+ LinearSVCModel model = linearSVC.fit(trainDataTable);
+ List modelData =
+ IteratorUtils.toList(
+ LinearSVCModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect());
+ assertEquals(1, modelData.size());
+ assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+ LinearSVCModel model = linearSVC.fit(trainDataTable);
+
+ LinearSVCModel newModel = new LinearSVCModel();
+ ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ newModel.setModelData(model.getModelData());
+ Table output = newModel.transform(trainDataTable)[0];
+ verifyPredictionResult(
+ output,
+ linearSVC.getFeaturesCol(),
+ linearSVC.getPredictionCol(),
+ linearSVC.getRawPredictionCol());
+ }
+
+ @Test
+ public void testMoreSubtaskThanData() throws Exception {
+ env.setParallelism(12);
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight").setGlobalBatchSize(128);
+ Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0];
+ verifyPredictionResult(
+ output,
+ linearSVC.getFeaturesCol(),
+ linearSVC.getPredictionCol(),
+ linearSVC.getRawPredictionCol());
+ }
+
+ @Test
+ public void testRegularization() throws Exception {
+ checkRegularization(0, RandomUtils.nextDouble(0, 1), expectedCoefficient);
+ checkRegularization(0.1, 0, new double[] {0.437, -0.262, -0.393, -0.524});
+ checkRegularization(0.1, 1, new double[] {0.426, -0.197, -0.329, -0.463});
+ checkRegularization(0.1, 0.5, new double[] {0.419, -0.238, -0.372, -0.505});
+ }
+
+ @Test
+ public void testThreshold() throws Exception {
+ checkThreshold(-Double.MAX_VALUE, 1);
+ checkThreshold(Double.MAX_VALUE, 0);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient)
+ throws Exception {
+ LinearSVCModel model =
+ new LinearSVC()
+ .setWeightCol("weight")
+ .setReg(reg)
+ .setElasticNet(elasticNet)
+ .fit(trainDataTable);
+ List modelData =
+ IteratorUtils.toList(
+ LinearSVCModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect());
+ final double errorTol = 1e-3;
+ assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void checkThreshold(double threshold, double target) throws Exception {
+ LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+
+ Table predictions =
+ linearSVC.setThreshold(threshold).fit(trainDataTable).transform(trainDataTable)[0];
+
+ List predResult =
+ IteratorUtils.toList(tEnv.toDataStream(predictions).executeAndCollect());
+ for (Row r : predResult) {
+ assertEquals(target, r.getField(linearSVC.getPredictionCol()));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 8d0661328..4bdab6b99 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -151,16 +151,16 @@ private void verifyPredictionResult(
@Test
public void testParam() {
LogisticRegression logisticRegression = new LogisticRegression();
+ assertEquals("features", logisticRegression.getFeaturesCol());
assertEquals("label", logisticRegression.getLabelCol());
assertNull(logisticRegression.getWeightCol());
assertEquals(20, logisticRegression.getMaxIter());
- assertEquals(0, logisticRegression.getReg(), TOLERANCE);
- assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
+ assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE);
assertEquals(32, logisticRegression.getGlobalBatchSize());
- assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
+ assertEquals(0, logisticRegression.getReg(), TOLERANCE);
+ assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
assertEquals("auto", logisticRegression.getMultiClass());
- assertEquals("features", logisticRegression.getFeaturesCol());
assertEquals("prediction", logisticRegression.getPredictionCol());
assertEquals("rawPrediction", logisticRegression.getRawPredictionCol());
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java
new file mode 100644
index 000000000..1cd165ecf
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link HingeLoss}. */
+public class HingeLossTest {
+ private static final LabeledPointWithWeight dataPoint1 =
+ new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 2.0);
+ private static final LabeledPointWithWeight dataPoint2 =
+ new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 2.0);
+ private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0);
+ private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0);
+ private static final double TOLERANCE = 1e-7;
+
+ @Test
+ public void computeLoss() {
+ double loss = HingeLoss.INSTANCE.computeLoss(dataPoint1, coefficient);
+ assertEquals(4.0, loss, TOLERANCE);
+
+ loss = HingeLoss.INSTANCE.computeLoss(dataPoint2, coefficient);
+ assertEquals(0.0, loss, TOLERANCE);
+ }
+
+ @Test
+ public void computeGradient() {
+ HingeLoss.INSTANCE.computeGradient(dataPoint1, coefficient, cumGradient);
+ assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE);
+
+ HingeLoss.INSTANCE.computeGradient(dataPoint2, coefficient, cumGradient);
+ assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
index 58e89c454..3ea99d387 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
@@ -118,15 +118,15 @@ private void verifyPredictionResult(Table output, String labelCol, String predic
@Test
public void testParam() {
LinearRegression linearRegression = new LinearRegression();
+ assertEquals("features", linearRegression.getFeaturesCol());
assertEquals("label", linearRegression.getLabelCol());
assertNull(linearRegression.getWeightCol());
assertEquals(20, linearRegression.getMaxIter());
- assertEquals(0, linearRegression.getReg(), TOLERANCE);
- assertEquals(0, linearRegression.getElasticNet(), TOLERANCE);
+ assertEquals(1e-6, linearRegression.getTol(), TOLERANCE);
assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE);
assertEquals(32, linearRegression.getGlobalBatchSize());
- assertEquals(1e-6, linearRegression.getTol(), TOLERANCE);
- assertEquals("features", linearRegression.getFeaturesCol());
+ assertEquals(0, linearRegression.getReg(), TOLERANCE);
+ assertEquals(0, linearRegression.getElasticNet(), TOLERANCE);
assertEquals("prediction", linearRegression.getPredictionCol());
linearRegression