Skip to content

Commit f127009

Browse files
committed
Make use of PartitionedDataSet in ProgrammersCommentTest
1 parent a83abe7 commit f127009

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/test/groovy/net/zomis/machlearn/text/duga/ProgrammersCommentTest.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import net.zomis.machlearn.common.ClassifierFunction;
44
import net.zomis.machlearn.common.LearningDataSet;
5+
import net.zomis.machlearn.common.PartitionedDataSet;
56
import net.zomis.machlearn.common.PrecisionRecallF1;
67
import net.zomis.machlearn.images.MyGroovyUtils;
78
import net.zomis.machlearn.neural.LearningData;
@@ -13,10 +14,7 @@
1314
import net.zomis.machlearn.text.TextFeatureMapper;
1415
import org.junit.Test;
1516

16-
import java.util.ArrayList;
17-
import java.util.Arrays;
18-
import java.util.Comparator;
19-
import java.util.List;
17+
import java.util.*;
2018
import java.util.regex.Pattern;
2119
import java.util.stream.Collectors;
2220

@@ -74,11 +72,15 @@ public void commentLearning() {
7472
System.out.println("Data is:");
7573
data.getData().stream().forEach(System.out::println);
7674

75+
PartitionedDataSet partitionedData = data.partition(0.6, 0.2, 0.2, new Random(42));
76+
LearningDataSet trainingSet = partitionedData.getTrainingSet();
77+
7778
double[] learnedTheta = GradientDescent.gradientDescent(
78-
LogisticRegression.costFunction(data.getXs(), data.getY()),
79+
LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()),
7980
new ConvergenceIterations(20000),
8081
new double[data.numFeaturesWithZero()], 0.01);
81-
double cost = LogisticRegression.costFunction(data.getXs(), data.getY()).apply(learnedTheta);
82+
83+
double cost = LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()).apply(learnedTheta);
8284
System.out.println("Cost: " + cost);
8385

8486
ClassifierFunction function = (theta, x) ->

0 commit comments

Comments
 (0)