|
2 | 2 |
|
3 | 3 | import net.zomis.machlearn.common.ClassifierFunction; |
4 | 4 | import net.zomis.machlearn.common.LearningDataSet; |
| 5 | +import net.zomis.machlearn.common.PartitionedDataSet; |
5 | 6 | import net.zomis.machlearn.common.PrecisionRecallF1; |
6 | 7 | import net.zomis.machlearn.images.MyGroovyUtils; |
7 | 8 | import net.zomis.machlearn.neural.LearningData; |
|
13 | 14 | import net.zomis.machlearn.text.TextFeatureMapper; |
14 | 15 | import org.junit.Test; |
15 | 16 |
|
16 | | -import java.util.ArrayList; |
17 | | -import java.util.Arrays; |
18 | | -import java.util.Comparator; |
19 | | -import java.util.List; |
| 17 | +import java.util.*; |
20 | 18 | import java.util.regex.Pattern; |
21 | 19 | import java.util.stream.Collectors; |
22 | 20 |
|
@@ -74,11 +72,15 @@ public void commentLearning() { |
74 | 72 | System.out.println("Data is:"); |
75 | 73 | data.getData().stream().forEach(System.out::println); |
76 | 74 |
|
| 75 | + PartitionedDataSet partitionedData = data.partition(0.6, 0.2, 0.2, new Random(42)); |
| 76 | + LearningDataSet trainingSet = partitionedData.getTrainingSet(); |
| 77 | + |
77 | 78 | double[] learnedTheta = GradientDescent.gradientDescent( |
78 | | - LogisticRegression.costFunction(data.getXs(), data.getY()), |
| 79 | + LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()), |
79 | 80 | new ConvergenceIterations(20000), |
80 | 81 | new double[data.numFeaturesWithZero()], 0.01); |
81 | | - double cost = LogisticRegression.costFunction(data.getXs(), data.getY()).apply(learnedTheta); |
| 82 | + |
| 83 | + double cost = LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()).apply(learnedTheta); |
82 | 84 | System.out.println("Cost: " + cost); |
83 | 85 |
|
84 | 86 | ClassifierFunction function = (theta, x) -> |
|
0 commit comments