@@ -74,20 +74,28 @@ public void commentLearning() {
7474
7575 PartitionedDataSet partitionedData = data .partition (0.6 , 0.2 , 0.2 , new Random (42 ));
7676 LearningDataSet trainingSet = partitionedData .getTrainingSet ();
77+ LearningDataSet crossValidSet = partitionedData .getCrossValidationSet ();
78+ LearningDataSet testSet = partitionedData .getTestSet ();
7779
7880 double [] learnedTheta = GradientDescent .gradientDescent (
7981 LogisticRegression .costFunction (trainingSet .getXs (), trainingSet .getY ()),
8082 new ConvergenceIterations (20000 ),
8183 new double [data .numFeaturesWithZero ()], 0.01 );
8284
83- double cost = LogisticRegression .costFunction (trainingSet .getXs (), trainingSet .getY ()).apply (learnedTheta );
84- System .out .println ("Cost: " + cost );
85+ double cost = LogisticRegression .costFunction (trainingSet .getXs (), trainingSet .getY ())
86+ .apply (learnedTheta );
87+ System .out .println ("Training Set Cost: " + cost );
88+
89+ double crossCost = LogisticRegression .costFunction (crossValidSet .getXs (), crossValidSet .getY ())
90+ .apply (learnedTheta );
91+ System .out .println ("CrossValidation Cost: " + crossCost );
8592
8693 ClassifierFunction function = (theta , x ) ->
8794 LogisticRegression .hypothesis (theta , x ) >= 0.3 ;
88-
89- PrecisionRecallF1 score = data .precisionRecallF1 (learnedTheta , function );
90- System .out .println (score );
95+ System .out .println ("ALL Score: " + data .precisionRecallF1 (learnedTheta , function ));
96+ System .out .println ("Training Score: " + data .precisionRecallF1 (learnedTheta , function ));
97+ System .out .println ("CrossVal Score: " + crossValidSet .precisionRecallF1 (learnedTheta , function ));
98+ System .out .println ("TestSet Score: " + testSet .precisionRecallF1 (learnedTheta , function ));
9199
92100 System .out .println ("False negatives:" );
93101 data .stream ()
0 commit comments