Skip to content

Commit 0b10a2e

Browse files
committed
Make more use of the partitioned data in ProgrammersCommentTest
1 parent f127009 commit 0b10a2e

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/test/groovy/net/zomis/machlearn/text/duga/ProgrammersCommentTest.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,28 @@ public void commentLearning() {
7474

7575
PartitionedDataSet partitionedData = data.partition(0.6, 0.2, 0.2, new Random(42));
7676
LearningDataSet trainingSet = partitionedData.getTrainingSet();
77+
LearningDataSet crossValidSet = partitionedData.getCrossValidationSet();
78+
LearningDataSet testSet = partitionedData.getTestSet();
7779

7880
double[] learnedTheta = GradientDescent.gradientDescent(
7981
LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()),
8082
new ConvergenceIterations(20000),
8183
new double[data.numFeaturesWithZero()], 0.01);
8284

83-
double cost = LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY()).apply(learnedTheta);
84-
System.out.println("Cost: " + cost);
85+
double cost = LogisticRegression.costFunction(trainingSet.getXs(), trainingSet.getY())
86+
.apply(learnedTheta);
87+
System.out.println("Training Set Cost: " + cost);
88+
89+
double crossCost = LogisticRegression.costFunction(crossValidSet.getXs(), crossValidSet.getY())
90+
.apply(learnedTheta);
91+
System.out.println("CrossValidation Cost: " + crossCost);
8592

8693
ClassifierFunction function = (theta, x) ->
8794
LogisticRegression.hypothesis(theta, x) >= 0.3;
88-
89-
PrecisionRecallF1 score = data.precisionRecallF1(learnedTheta, function);
90-
System.out.println(score);
95+
System.out.println("ALL Score: " + data.precisionRecallF1(learnedTheta, function));
96+
System.out.println("Training Score: " + data.precisionRecallF1(learnedTheta, function));
97+
System.out.println("CrossVal Score: " + crossValidSet.precisionRecallF1(learnedTheta, function));
98+
System.out.println("TestSet Score: " + testSet.precisionRecallF1(learnedTheta, function));
9199

92100
System.out.println("False negatives:");
93101
data.stream()

0 commit comments

Comments
 (0)