Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -75,21 +76,20 @@ public void naiveBayesDefaultParams() {

@Test
public void testNaiveBayes() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
));
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)));

StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});

DataFrame dataset = jsql.createDataFrame(jrdd, schema);
DataFrame dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ public void tearDown() {
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};

JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
RowFactory.create(0.0),
RowFactory.create(0.2)
));
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
DataFrame dataset = jsql.createDataFrame(data, schema);
DataFrame dataset = jsql.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
RowFactory.create(0.0),
RowFactory.create(0.2)),
schema);

Bucketizer bucketizer = new Bucketizer()
.setInputCol("feature")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ public void tearDown() {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
RowFactory.create(Vectors.dense(input))
));
DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
}));
DataFrame dataset = jsql.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
}));

double[] expectedResult = input.clone();
(new DoubleDCT_1D(input.length)).forward(expectedResult, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -55,17 +56,17 @@ public void tearDown() {

@Test
public void hashingTF() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(0.0, "Hi I heard about Spark"),
RowFactory.create(0.0, "I wish Java could use case classes"),
RowFactory.create(1.0, "Logistic regression models are neat")
));
);
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});

DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
DataFrame sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -60,7 +61,7 @@ public void polynomialExpansionTest() {
.setOutputCol("polyFeatures")
.setDegree(3);

JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(
Vectors.dense(-2.0, 2.3),
Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
Expand All @@ -70,7 +71,7 @@ public void polynomialExpansionTest() {
Vectors.dense(0.6, -1.1),
Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
)
));
);

StructType schema = new StructType(new StructField[] {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -58,14 +59,14 @@ public void javaCompatibilityTest() {
.setInputCol("raw")
.setOutputCol("filtered");

JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
));
);
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
DataFrame dataset = jsql.createDataFrame(rdd, schema);
DataFrame dataset = jsql.createDataFrame(data, schema);

remover.transform(dataset).collect();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -56,9 +57,9 @@ public void testStringIndexer() {
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
JavaRDD<Row> rdd = jsc.parallelize(
Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
List<Row> data = Arrays.asList(
c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"));
DataFrame dataset = sqlContext.createDataFrame(data, schema);

StringIndexer indexer = new StringIndexer()
.setInputCol("label")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ public void testVectorAssembler() {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -63,12 +64,12 @@ public void vectorSlice() {
};
AttributeGroup group = new AttributeGroup("userFeatures", attrs);

JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
));
);

DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));

VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ public void tearDown() {

@Test
public void testJavaWord2Vec() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
));
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
DataFrame documentDF = sqlContext.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))),
schema);

Word2Vec word2Vec = new Word2Vec()
.setInputCol("text")
Expand Down