Skip to content

Commit

Permalink
Add ProgrammersCommentTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomis committed Feb 21, 2016
1 parent 5d6a279 commit 1ce9b18
Showing 1 changed file with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
package net.zomis.machlearn.text.duga;

import net.zomis.machlearn.common.LearningDataSet;
import net.zomis.machlearn.common.PrecisionRecallF1;
import net.zomis.machlearn.images.MyGroovyUtils;
import net.zomis.machlearn.regression.ConvergenceIterations;
import net.zomis.machlearn.regression.GradientDescent;
import net.zomis.machlearn.regression.LogisticRegression;
import net.zomis.machlearn.text.BagOfWords;
import net.zomis.machlearn.text.TextFeatureBuilder;
import net.zomis.machlearn.text.TextFeatureMapper;
import org.junit.Test;

import java.util.stream.Collectors;
import java.util.ArrayList;
import java.util.List;

public class ProgrammersCommentTest {

Expand All @@ -16,17 +24,43 @@ public void commentLearning() {
BagOfWords bowYes = new BagOfWords();
BagOfWords bowNo = new BagOfWords();
BagOfWords bowAll = new BagOfWords();
TextFeatureBuilder textFeatures = new TextFeatureBuilder();

LearningDataSet data = new LearningDataSet();
List<String> processedStrings = new ArrayList<>();
for (String str : lines) {
if (!str.startsWith("0 ") && !str.startsWith("1 ")) {
continue;
}
boolean expected = str.startsWith("1");
String text = str.substring(2);
String processed = preprocess(text);
char expectedChar = expected ? '1' : '0';
processedStrings.add(expectedChar + processed);
textFeatures.add(processed);
BagOfWords bow = expected ? bowYes : bowNo;
bow.addText(text);
bowAll.addText(text);
// println text
}

TextFeatureMapper mapper = textFeatures.mapper();

for (String str : processedStrings) {
boolean expectTrue = str.charAt(0) == '1';
data.add(str, mapper.toFeatures(str), expectTrue ? 1 : 0);
}

data.getData().stream().forEach(System.out::println);

double[] learnedTheta = GradientDescent.gradientDescent(
LogisticRegression.costFunction(data.getXs(), data.getY()),
new ConvergenceIterations(10000),
new double[data.numFeaturesWithZero()], 0.01);
PrecisionRecallF1 score = data.precisionRecallF1(learnedTheta, (theta, x) ->
LogisticRegression.hypothesis(theta, x) >= 0.5);
System.out.println(score);

System.out.println(bowAll.getData());
System.out.println("-------------");
System.out.println(bowYes.getData());
Expand All @@ -41,4 +75,8 @@ public void commentLearning() {
// println 'Count ' + stream.size()
}

private String preprocess(String text) {
return text.toLowerCase();
}

}

0 comments on commit 1ce9b18

Please sign in to comment.