From f85581902f3377a9935daaa95260f5f6fc115188 Mon Sep 17 00:00:00 2001 From: wilsoncao <275239608@qq.com> Date: Thu, 12 Jun 2014 16:54:59 +0800 Subject: [PATCH 1/2] rebase ml API example --- .../example/java/ml/LinearRegression.java | 314 ++++++++++++++++++ .../java/ml/util/LinearRegressionData.java | 62 ++++ .../util/LinearRegressionDataGenerator.java | 109 ++++++ 3 files changed, 485 insertions(+) create mode 100644 stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java create mode 100644 stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java create mode 100644 stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java new file mode 100644 index 0000000000000..8f7c098871327 --- /dev/null +++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java @@ -0,0 +1,314 @@ +/*********************************************************************************************************************** + * + * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + **********************************************************************************************************************/ + +package eu.stratosphere.example.java.ml; + +import java.io.Serializable; +import java.util.Collection; +import eu.stratosphere.api.java.DataSet; +import eu.stratosphere.api.java.ExecutionEnvironment; +import eu.stratosphere.api.java.IterativeDataSet; +import eu.stratosphere.api.java.functions.MapFunction; +import eu.stratosphere.api.java.functions.ReduceFunction; +import eu.stratosphere.api.java.tuple.Tuple2; +import eu.stratosphere.configuration.Configuration; +import eu.stratosphere.example.java.ml.util.LinearRegressionData; + +/** + * This example implements a basic Linear Regression using batch gradient descent algorithm. + * + *

+ * Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering algorithm and works as follows:
+ * Giving a data set and target set, the BGD try to find out the best parameters for the data set to fit the target set. + * In each iteration, the algorithm computes the gradient of the cost function and use it to update all the parameters. + * The algorithm terminates after a fixed number of iterations (as in this implementation) + * With enough iteration, the algorithm can minimize the cost function and find the best parameters + * This is the Wikipedia entry for the Linear regression and Gradient descent algorithm. + * + *

+ * This implementation works on one-dimensional data. And find the two-dimensional theta.
+ * It find the best Theta parameter to fit the target. + * + *

+ * Input files are plain text files and must be formatted as follows: + *

+ * + *

+ * This example shows how to use: + *

+ */ + +/** + * A linearRegression example to solve the y = theta0 + theta1*x problem. + */ +@SuppressWarnings("serial") +public class LinearRegression { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String[] args) throws Exception{ + + if(!parseParameters(args)) { + return; + } + + // set up execution environment + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // get input x data from elements + DataSet data = getDataSet(env); + + // get the parameters from elements + DataSet parameters = getParamsDataSet(env); + + // set number of bulk iterations for SGD linear Regression + IterativeDataSet loop = parameters.iterate(numIterations); + + DataSet new_parameters = data + // compute a single step using every sample + .map(new SubUpdate()).withBroadcastSet(loop, "parameters") + // sum up all the steps + .reduce(new UpdateAccumulator()) + // average the steps and update all parameters + .map(new Update()); + + // feed new parameters back into next iteration + DataSet result = loop.closeWith(new_parameters); + + // emit result + if(fileOutput) { + result.writeAsCsv(outputPath, "\n", " "); + } else { + result.print(); + } + + // execute program + env.execute("Linear Regression example"); + + } + + // ************************************************************************* + // DATA TYPES + // ************************************************************************* + + /** + * A simple data sample, x means the input, and y means the target. + */ + public static class Data implements Serializable{ + public double x,y; + + public Data() {}; + + public Data(double x ,double y){ + this.x = x; + this.y = y; + } + + @Override + public String toString() { + return "(" + x + "|" + y + ")"; + } + + } + + /** + * A set of parameters -- theta0, theta1. + */ + public static class Params implements Serializable{ + + private double theta0,theta1; + + public Params(){}; + + public Params(double x0, double x1){ + this.theta0 = x0; + this.theta1 = x1; + } + + @Override + public String toString() { + return "(" + theta0 + "|" + theta1 + ")"; + } + + public double getTheta0() { + return theta0; + } + + public double getTheta1() { + return theta1; + } + + public void setTheta0(double theta0) { + this.theta0 = theta0; + } + + public void setTheta1(double theta1) { + this.theta1 = theta1; + } + + public Params div(Integer a){ + this.theta0 = theta0 / a ; + this.theta1 = theta1 / a ; + return this; + } + + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** Converts a Tuple2 into a Data. */ + public static final class TupleDataConverter extends MapFunction, Data> { + + @Override + public Data map(Tuple2 t) throws Exception { + return new Data(t.f0, t.f1); + } + } + + /** Converts a Tuple2 into a Params. */ + public static final class TupleParamsConverter extends MapFunction,Params> { + + @Override + public Params map(Tuple2 t)throws Exception { + return new Params(t.f0,t.f1); + } + } + + /** + * Compute a single BGD type update for every parameters. + */ + public static class SubUpdate extends MapFunction>{ + + private Collection parameters; + + private Params parameter; + + private int count = 1; + + /** Reads the parameters from a broadcast variable into a collection. */ + @Override + public void open(Configuration parameters) throws Exception { + this.parameters = getRuntimeContext().getBroadcastVariable("parameters"); + } + + @Override + public Tuple2 map(Data in) throws Exception { + + for(Params p : parameters){ + this.parameter = p; + } + + double theta_0 = parameter.theta0 - 0.01*((parameter.theta0 + (parameter.theta1*in.x)) - in.y); + double theta_1 = parameter.theta1 - 0.01*(((parameter.theta0 + (parameter.theta1*in.x)) - in.y) * in.x); + + return new Tuple2(new Params(theta_0,theta_1),count); + } + } + + /** + * Accumulator all the update. + * */ + public static class UpdateAccumulator extends ReduceFunction> { + + @Override + public Tuple2 reduce(Tuple2 val1, Tuple2 val2) { + + double new_theta0 = val1.f0.theta0 + val2.f0.theta0; + double new_theta1 = val1.f0.theta1 + val2.f0.theta1; + Params result = new Params(new_theta0,new_theta1); + return new Tuple2( result, val1.f1 + val2.f1); + + } + } + + /** + * Compute the final update by average them. + */ + public static class Update extends MapFunction,Params>{ + + @Override + public Params map(Tuple2 arg0) throws Exception { + + return arg0.f0.div(arg0.f1); + + } + + } + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String dataPath = null; + private static String outputPath = null; + private static int numIterations = 10; + + private static boolean parseParameters(String[] programArguments) { + + if(programArguments.length > 0) { + // parse input arguments + fileOutput = true; + if(programArguments.length == 3) { + dataPath = programArguments[0]; + outputPath = programArguments[1]; + numIterations = Integer.parseInt(programArguments[2]); + } else { + System.err.println("Usage: LinearRegression "); + return false; + } + } else { + System.out.println("Executing Linear Regression example with default parameters and built-in default data."); + System.out.println(" Provide parameters to read input data from files."); + System.out.println(" See the documentation for the correct format of input files."); + System.out.println(" We provide a data generator to create synthetic input files for this program."); + System.out.println(" Usage: LinearRegression "); + } + return true; + } + + private static DataSet getDataSet(ExecutionEnvironment env) { + if(fileOutput) { + // read data from CSV file + return env.readCsvFile(dataPath) + .fieldDelimiter(' ') + .includeFields(true, true) + .types(Double.class, Double.class) + .map(new TupleDataConverter()); + } else { + return LinearRegressionData.getDefaultDataDataSet(env); + } + } + + private static DataSet getParamsDataSet(ExecutionEnvironment env) { + + return LinearRegressionData.getDefaultParamsDataSet(env); + + } + +} + diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java new file mode 100644 index 0000000000000..39d86ec811554 --- /dev/null +++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java @@ -0,0 +1,62 @@ +/*********************************************************************************************************************** + * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + **********************************************************************************************************************/ + +package eu.stratosphere.example.java.ml.util; + +import eu.stratosphere.api.java.DataSet; +import eu.stratosphere.api.java.ExecutionEnvironment; +import eu.stratosphere.example.java.ml.LinearRegression.Data; +import eu.stratosphere.example.java.ml.LinearRegression.Params; + +/** + * Provides the default data sets used for the Linear Regression example program. + * The default data sets are used, if no parameters are given to the program. + * + */ +public class LinearRegressionData{ + + public static DataSet getDefaultParamsDataSet(ExecutionEnvironment env){ + + return env.fromElements( + new Params(0.0,0.0) + ); + } + + public static DataSet getDefaultDataDataSet(ExecutionEnvironment env){ + + return env.fromElements( + new Data(0.5,1.0), + new Data(1.0,2.0), + new Data(2.0,4.0), + new Data(3.0,6.0), + new Data(4.0,8.0), + new Data(5.0,10.0), + new Data(6.0,12.0), + new Data(7.0,14.0), + new Data(8.0,16.0), + new Data(9.0,18.0), + new Data(10.0,20.0), + new Data(-0.08,-0.16), + new Data(0.13,0.26), + new Data(-1.17,-2.35), + new Data(1.72,3.45), + new Data(1.70,3.41), + new Data(1.20,2.41), + new Data(-0.59,-1.18), + new Data(0.28,0.57), + new Data(1.65,3.30), + new Data(-0.55,-1.08) + ); + } + +} \ No newline at end of file diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java new file mode 100644 index 0000000000000..fe346814bcf2b --- /dev/null +++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java @@ -0,0 +1,109 @@ +/*********************************************************************************************************************** + * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + **********************************************************************************************************************/ + +package eu.stratosphere.example.java.ml.util; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.text.DecimalFormat; +import java.util.Locale; +import java.util.Random; + +/** + * Generates data for the {@link LinearRegression} example program. + */ +public class LinearRegressionDataGenerator { + + static { + Locale.setDefault(Locale.US); + } + + private static final String POINTS_FILE = "data"; + private static final long DEFAULT_SEED = 4650285087650871364L; + private static final int DIMENSIONALITY = 1; + private static final DecimalFormat FORMAT = new DecimalFormat("#0.00"); + private static final char DELIMITER = ' '; + + /** + * Main method to generate data for the {@link LinearRegression} example program. + *

+ * The generator creates to files: + *

    + *
  • {tmp.dir}/data for the data points + *
+ * + * @param args + *
    + *
  1. Int: Number of data points + *
  2. Optional Long: Random seed + *
+ */ + public static void main(String[] args) throws IOException { + + System.out.println(args.length); + + // check parameter count + if (args.length < 1) { + System.out.println("LinearRegressionDataGenerator []"); + System.exit(1); + } + + // parse parameters + final int numDataPoints = Integer.parseInt(args[0]); + final long firstSeed = args.length > 1 ? Long.parseLong(args[4]) : DEFAULT_SEED; + final Random random = new Random(firstSeed); + final String tmpDir = System.getProperty("java.io.tmpdir"); + + // write the points out + BufferedWriter pointsOut = null; + try { + pointsOut = new BufferedWriter(new FileWriter(new File(POINTS_FILE))); + StringBuilder buffer = new StringBuilder(); + + // DIMENSIONALITY + 1 means that the number of x(dimensionality) and target y + double[] point = new double[DIMENSIONALITY+1]; + + for (int i = 1; i <= numDataPoints; i++) { + point[0] = random.nextGaussian(); + point[1] = 2 * point[0] + 0.01*random.nextGaussian(); + writePoint(point, buffer, pointsOut); + } + + } + finally { + if (pointsOut != null) { + pointsOut.close(); + } + } + + System.out.println("Wrote "+numDataPoints+" data points to "+tmpDir+"/"+POINTS_FILE); + } + + + private static void writePoint(double[] data, StringBuilder buffer, BufferedWriter out) throws IOException { + buffer.setLength(0); + + // write coordinates + for (int j = 0; j < data.length; j++) { + buffer.append(FORMAT.format(data[j])); + if(j < data.length - 1) { + buffer.append(DELIMITER); + } + } + + out.write(buffer.toString()); + out.newLine(); + } +} From b0a485b60eaf9859d1492ec1482a24dcca4a2468 Mon Sep 17 00:00:00 2001 From: wilsoncao <275239608@qq.com> Date: Mon, 16 Jun 2014 20:59:41 +0800 Subject: [PATCH 2/2] add to the contributor list. --- CONTRIBUTORS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index a810e919ff9bf..b842c4d71a746 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -19,4 +19,4 @@ Christian Richter Sebastian Schelter Chesnay Schepler Kostas Tzoumas - +Zihong Cao