diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java new file mode 100644 index 000000000..4169d488d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.HingeLoss; +import org.apache.flink.ml.common.optimizer.Optimizer; +import org.apache.flink.ml.common.optimizer.SGD; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the linear support vector classification. + * + *

See https://en.wikipedia.org/wiki/Support-vector_machine#Linear_SVM. + */ +public class LinearSVC implements Estimator, LinearSVCParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public LinearSVC() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings({"rawTypes", "ConstantConditions"}) + public LinearSVCModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream trainData = + tEnv.toDataStream(inputs[0]) + .map( + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : (Double) dataPoint.getField(getWeightCol()); + double label = (Double) dataPoint.getField(getLabelCol()); + Preconditions.checkState( + Double.compare(0.0, label) == 0 + || Double.compare(1.0, label) == 0, + "LinearSVC only supports binary classification. But detected label: %s.", + label); + DenseVector features = + (DenseVector) dataPoint.getField(getFeaturesCol()); + return new LabeledPointWithWeight(features, label, weight); + }); + + DataStream initModelData = + DataStreamUtils.reduce( + trainData.map(x -> x.getFeatures().size()), + (ReduceFunction) + (t0, t1) -> { + Preconditions.checkState( + t0.equals(t1), + "The training data should all have same dimensions."); + return t0; + }) + .map(DenseVector::new); + + Optimizer optimizer = + new SGD( + getMaxIter(), + getLearningRate(), + getGlobalBatchSize(), + getTol(), + getReg(), + getElasticNet()); + DataStream rawModelData = + optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE); + + DataStream modelData = rawModelData.map(LinearSVCModelData::new); + LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LinearSVC load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java new file mode 100644 index 000000000..253bbccd0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** A Model which classifies data using the model data computed by {@link LinearSVC}. */ +public class LinearSVCModel implements Model, LinearSVCModelParams { + + private final Map, Object> paramMap = new HashMap<>(); + + private Table modelDataTable; + + public LinearSVCModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = + LinearSVCModelData.getModelDataStream(modelDataTable); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + BasicTypeInfo.DOUBLE_TYPE_INFO, + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol())); + + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + DataStream inputData = inputList.get(0); + return inputData.map( + new PredictLabelFunction( + broadcastModelKey, getFeaturesCol(), getThreshold()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + @Override + public LinearSVCModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LinearSVCModelData.getModelDataStream(modelDataTable), + path, + new LinearSVCModelData.ModelDataEncoder()); + } + + public static LinearSVCModel load(StreamTableEnvironment tEnv, String path) throws IOException { + LinearSVCModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new LinearSVCModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends RichMapFunction { + + private final String broadcastModelKey; + + private final String featuresCol; + + private final double threshold; + + private DenseVector coefficient; + + public PredictLabelFunction( + String broadcastModelKey, String featuresCol, double threshold) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCol = featuresCol; + this.threshold = threshold; + } + + @Override + public Row map(Row dataPoint) { + if (coefficient == null) { + LinearSVCModelData modelData = + (LinearSVCModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + coefficient = modelData.coefficient; + } + DenseVector features = (DenseVector) dataPoint.getField(featuresCol); + Row predictionResult = predictOneDataPoint(features, coefficient, threshold); + return Row.join(dataPoint, predictionResult); + } + } + + /** + * The main logic that predicts one input data point. + * + * @param feature The input feature. + * @param coefficient The model parameters. + * @param threshold The threshold for prediction. + * @return The prediction label and the raw predictions. + */ + private static Row predictOneDataPoint( + DenseVector feature, DenseVector coefficient, double threshold) { + double dotValue = BLAS.dot(feature, coefficient); + return Row.of(dotValue >= threshold ? 1.0 : 0.0, Vectors.dense(dotValue, -dotValue)); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java new file mode 100644 index 000000000..96e8a27ae --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LinearSVCModel}. + * + *

This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LinearSVCModelData { + + public DenseVector coefficient; + + public LinearSVCModelData(DenseVector coefficient) { + this.coefficient = coefficient; + } + + public LinearSVCModelData() {} + + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData) + .map(x -> new LinearSVCModelData((DenseVector) x.getField(0))); + } + + /** Data encoder for {@link LinearSVCModel}. */ + public static class ModelDataEncoder implements Encoder { + + @Override + public void encode(LinearSVCModelData modelData, OutputStream outputStream) + throws IOException { + DenseVectorSerializer.INSTANCE.serialize( + modelData.coefficient, new DataOutputViewStreamWrapper(outputStream)); + } + } + + /** Data decoder for {@link LinearSVCModel}. */ + public static class ModelDataDecoder extends SimpleStreamFormat { + + @Override + public Reader createReader( + Configuration configuration, FSDataInputStream inputStream) { + return new Reader() { + + @Override + public LinearSVCModelData read() throws IOException { + try { + DenseVector coefficient = + DenseVectorSerializer.INSTANCE.deserialize( + new DataInputViewStreamWrapper(inputStream)); + return new LinearSVCModelData(coefficient); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + inputStream.close(); + } + }; + } + + @Override + public TypeInformation getProducedType() { + return TypeInformation.of(LinearSVCModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java new file mode 100644 index 000000000..9e02233cd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link LinearSVCModel}. + * + * @param The class type of this instance. + */ +public interface LinearSVCModelParams + extends HasFeaturesCol, HasPredictionCol, HasRawPredictionCol { + /** + * Param for threshold in linear support vector classifier. It applies to the rawPrediction and + * can be any real number, where Inf makes all predictions 0.0 and -Inf makes all predictions + * 1.0. + */ + Param THRESHOLD = + new DoubleParam( + "threshold", + "Threshold in binary classification prediction applied to rawPrediction.", + 0.0, + ParamValidators.notNull()); + + default Double getThreshold() { + return get(THRESHOLD); + } + + default T setThreshold(Double value) { + set(THRESHOLD, value); + return (T) this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java new file mode 100644 index 000000000..7754e89f5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.ml.common.param.HasElasticNet; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasLearningRate; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasReg; +import org.apache.flink.ml.common.param.HasTol; +import org.apache.flink.ml.common.param.HasWeightCol; + +/** + * Params for {@link LinearSVC}. + * + * @param The class type of this instance. + */ +public interface LinearSVCParams + extends HasLabelCol, + HasWeightCol, + HasMaxIter, + HasReg, + HasElasticNet, + HasLearningRate, + HasGlobalBatchSize, + HasTol, + LinearSVCModelParams {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java new file mode 100644 index 000000000..eb0f3bf58 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.classification.linearsvc.LinearSVC; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; + +/** + * The loss function for hinge loss. See {@link LinearSVC} for example. + * + *

See https://en.wikipedia.org/wiki/Hinge_loss. + */ +@Internal +public class HingeLoss implements LossFunc { + public static final HingeLoss INSTANCE = new HingeLoss(); + + private HingeLoss() {} + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + if (1 - labelScaled * dot > 0) { + BLAS.axpy( + -labelScaled * dataPoint.getWeight(), + dataPoint.getFeatures(), + cumGradient, + dataPoint.getFeatures().size()); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java new file mode 100644 index 000000000..156244e30 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.linearsvc.LinearSVC; +import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; +import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.RandomUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LinearSVC} and {@link LinearSVCModel}. */ +public class LinearSVCTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List trainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final double[] expectedCoefficient = + new double[] {0.470, -0.273, -0.410, -0.546}; + + private static final double TOLERANCE = 1e-7; + + private Table trainDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(trainData); + trainDataTable = + tEnv.fromDataStream( + env.fromCollection( + trainData, + new RowTypeInfo( + new TypeInformation[] { + DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("ConstantConditions, unchecked") + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseVector feature = (DenseVector) predictionRow.getField(featuresCol); + double prediction = (double) predictionRow.getField(predictionCol); + DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0); + } + } + } + + @Test + public void testParam() { + LinearSVC linearSVC = new LinearSVC(); + assertEquals("features", linearSVC.getFeaturesCol()); + assertEquals("label", linearSVC.getLabelCol()); + assertNull(linearSVC.getWeightCol()); + assertEquals(20, linearSVC.getMaxIter()); + assertEquals(1e-6, linearSVC.getTol(), TOLERANCE); + assertEquals(0.1, linearSVC.getLearningRate(), TOLERANCE); + assertEquals(32, linearSVC.getGlobalBatchSize()); + assertEquals(0, linearSVC.getReg(), TOLERANCE); + assertEquals(0, linearSVC.getElasticNet(), TOLERANCE); + assertEquals(0.0, linearSVC.getThreshold(), TOLERANCE); + assertEquals("prediction", linearSVC.getPredictionCol()); + assertEquals("rawPrediction", linearSVC.getRawPredictionCol()); + + linearSVC + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setLearningRate(0.5) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setElasticNet(0.5) + .setThreshold(0.5) + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + assertEquals("test_features", linearSVC.getFeaturesCol()); + assertEquals("test_label", linearSVC.getLabelCol()); + assertEquals("test_weight", linearSVC.getWeightCol()); + assertEquals(1000, linearSVC.getMaxIter()); + assertEquals(0.001, linearSVC.getTol(), TOLERANCE); + assertEquals(0.5, linearSVC.getLearningRate(), TOLERANCE); + assertEquals(1000, linearSVC.getGlobalBatchSize()); + assertEquals(0.1, linearSVC.getReg(), TOLERANCE); + assertEquals(0.5, linearSVC.getElasticNet(), TOLERANCE); + assertEquals(0.5, linearSVC.getThreshold(), TOLERANCE); + assertEquals("test_predictionCol", linearSVC.getPredictionCol()); + assertEquals("test_rawPredictionCol", linearSVC.getRawPredictionCol()); + } + + @Test + public void testOutputSchema() { + Table tempTable = trainDataTable.as("test_features", "test_label", "test_weight"); + LinearSVC linearSVC = + new LinearSVC() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = linearSVC.fit(trainDataTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0]; + verifyPredictionResult( + output, + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + linearSVC = + StageTestUtils.saveAndReload( + tEnv, linearSVC, tempFolder.newFolder().getAbsolutePath()); + LinearSVCModel model = linearSVC.fit(trainDataTable); + model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Collections.singletonList("coefficient"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult( + output, + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); + } + + @Test + @SuppressWarnings("unchecked") + public void testGetModelData() throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + LinearSVCModel model = linearSVC.fit(trainDataTable); + List modelData = + IteratorUtils.toList( + LinearSVCModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + assertEquals(1, modelData.size()); + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); + } + + @Test + public void testSetModelData() throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + LinearSVCModel model = linearSVC.fit(trainDataTable); + + LinearSVCModel newModel = new LinearSVCModel(); + ReadWriteUtils.updateExistingParams(newModel, model.getParamMap()); + newModel.setModelData(model.getModelData()); + Table output = newModel.transform(trainDataTable)[0]; + verifyPredictionResult( + output, + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); + } + + @Test + public void testMoreSubtaskThanData() throws Exception { + env.setParallelism(12); + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight").setGlobalBatchSize(128); + Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0]; + verifyPredictionResult( + output, + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); + } + + @Test + public void testRegularization() throws Exception { + checkRegularization(0, RandomUtils.nextDouble(0, 1), expectedCoefficient); + checkRegularization(0.1, 0, new double[] {0.437, -0.262, -0.393, -0.524}); + checkRegularization(0.1, 1, new double[] {0.426, -0.197, -0.329, -0.463}); + checkRegularization(0.1, 0.5, new double[] {0.419, -0.238, -0.372, -0.505}); + } + + @Test + public void testThreshold() throws Exception { + checkThreshold(-Double.MAX_VALUE, 1); + checkThreshold(Double.MAX_VALUE, 0); + } + + @SuppressWarnings("unchecked") + private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) + throws Exception { + LinearSVCModel model = + new LinearSVC() + .setWeightCol("weight") + .setReg(reg) + .setElasticNet(elasticNet) + .fit(trainDataTable); + List modelData = + IteratorUtils.toList( + LinearSVCModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + final double errorTol = 1e-3; + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol); + } + + @SuppressWarnings("unchecked") + private void checkThreshold(double threshold, double target) throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + + Table predictions = + linearSVC.setThreshold(threshold).fit(trainDataTable).transform(trainDataTable)[0]; + + List predResult = + IteratorUtils.toList(tEnv.toDataStream(predictions).executeAndCollect()); + for (Row r : predResult) { + assertEquals(target, r.getField(linearSVC.getPredictionCol())); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index 8d0661328..4bdab6b99 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -151,16 +151,16 @@ private void verifyPredictionResult( @Test public void testParam() { LogisticRegression logisticRegression = new LogisticRegression(); + assertEquals("features", logisticRegression.getFeaturesCol()); assertEquals("label", logisticRegression.getLabelCol()); assertNull(logisticRegression.getWeightCol()); assertEquals(20, logisticRegression.getMaxIter()); - assertEquals(0, logisticRegression.getReg(), TOLERANCE); - assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); + assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE); assertEquals(32, logisticRegression.getGlobalBatchSize()); - assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); + assertEquals(0, logisticRegression.getReg(), TOLERANCE); + assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); assertEquals("auto", logisticRegression.getMultiClass()); - assertEquals("features", logisticRegression.getFeaturesCol()); assertEquals("prediction", logisticRegression.getPredictionCol()); assertEquals("rawPrediction", logisticRegression.getRawPredictionCol()); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java new file mode 100644 index 000000000..1cd165ecf --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link HingeLoss}. */ +public class HingeLossTest { + private static final LabeledPointWithWeight dataPoint1 = + new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 2.0); + private static final LabeledPointWithWeight dataPoint2 = + new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 2.0); + private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final double TOLERANCE = 1e-7; + + @Test + public void computeLoss() { + double loss = HingeLoss.INSTANCE.computeLoss(dataPoint1, coefficient); + assertEquals(4.0, loss, TOLERANCE); + + loss = HingeLoss.INSTANCE.computeLoss(dataPoint2, coefficient); + assertEquals(0.0, loss, TOLERANCE); + } + + @Test + public void computeGradient() { + HingeLoss.INSTANCE.computeGradient(dataPoint1, coefficient, cumGradient); + assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE); + + HingeLoss.INSTANCE.computeGradient(dataPoint2, coefficient, cumGradient); + assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java index 58e89c454..3ea99d387 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java @@ -118,15 +118,15 @@ private void verifyPredictionResult(Table output, String labelCol, String predic @Test public void testParam() { LinearRegression linearRegression = new LinearRegression(); + assertEquals("features", linearRegression.getFeaturesCol()); assertEquals("label", linearRegression.getLabelCol()); assertNull(linearRegression.getWeightCol()); assertEquals(20, linearRegression.getMaxIter()); - assertEquals(0, linearRegression.getReg(), TOLERANCE); - assertEquals(0, linearRegression.getElasticNet(), TOLERANCE); + assertEquals(1e-6, linearRegression.getTol(), TOLERANCE); assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE); assertEquals(32, linearRegression.getGlobalBatchSize()); - assertEquals(1e-6, linearRegression.getTol(), TOLERANCE); - assertEquals("features", linearRegression.getFeaturesCol()); + assertEquals(0, linearRegression.getReg(), TOLERANCE); + assertEquals(0, linearRegression.getElasticNet(), TOLERANCE); assertEquals("prediction", linearRegression.getPredictionCol()); linearRegression