Skip to content

Commit

Permalink
Add LearningDataSet class
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomis committed Feb 21, 2016
1 parent 2510026 commit 5d6a279
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package net.zomis.machlearn.common;

public interface ClassifierFunction {

boolean classify(double[] theta, double[] x);

}
51 changes: 51 additions & 0 deletions src/main/java/net/zomis/machlearn/common/LearningDataSet.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package net.zomis.machlearn.common;

import net.zomis.machlearn.neural.LearningData;
import net.zomis.machlearn.text.duga.PrecisionRecallF1;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class LearningDataSet {

private final List<LearningData> data = new ArrayList<>();

public void add(Object representation, double[] x, double y) {
this.add(representation, x, new double[]{y});
}

public void add(Object representation, double[] x, double[] y) {
this.data.add(new LearningData(representation, x, y));
}

public double[][] getXs() {
return data.stream()
.map(LearningData::getInputs)
.collect(Collectors.toList()).toArray(new double[data.size()][]);
}

public double[] getY() {
return data.stream()
.map(LearningData::getOutputs)
.mapToDouble(d -> d[0]).toArray();
}

public int numFeaturesWithZero() {
return data.get(0).getInputs().length + 1;
}

public PrecisionRecallF1 precisionRecallF1(double[] theta, ClassifierFunction hypothesis) {
PrecisionRecallF1 score = new PrecisionRecallF1();
for (LearningData ld : data) {
boolean prediction = hypothesis.classify(theta, ld.getInputs());
boolean actual = ld.getOutputs()[0] >= 0.5;
score.add(actual, prediction);
}
return score;
}

public List<LearningData> getData() {
return data;
}
}

0 comments on commit 5d6a279

Please sign in to comment.