Skip to content

Commit a83abe7

Browse files
committed
Make PartionedDataSet store LearningDataSets to make use of the LearningDataSet APIs
1 parent 6dd975f commit a83abe7

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/main/java/net/zomis/machlearn/common/LearningDataSet.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212

1313
public class LearningDataSet {
1414

15-
private final List<LearningData> data = new ArrayList<>();
15+
private final List<LearningData> data;
16+
17+
public LearningDataSet() {
18+
this(new ArrayList<>());
19+
}
20+
21+
public LearningDataSet(List<LearningData> data) {
22+
this.data = new ArrayList<>(data);
23+
}
1624

1725
public void add(Object representation, double[] x, double y) {
1826
this.add(representation, x, new double[]{y});

src/main/java/net/zomis/machlearn/common/PartitionedDataSet.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,34 @@
22

33
import net.zomis.machlearn.neural.LearningData;
44

5-
import java.util.ArrayList;
65
import java.util.List;
76

87
/**
98
* Represents a LearningDataSet that has been partitioned into training set, cross-validation set, and/or test-set.
109
*/
1110
public class PartitionedDataSet {
1211

13-
private final List<LearningData> trainingSet;
14-
private final List<LearningData> crossValidationSet;
15-
private final List<LearningData> testSet;
12+
private final LearningDataSet trainingSet;
13+
private final LearningDataSet crossValidationSet;
14+
private final LearningDataSet testSet;
1615

1716
public PartitionedDataSet(List<LearningData> trainingSet, List<LearningData> crossValidationSet,
1817
List<LearningData> testSet) {
19-
this.trainingSet = new ArrayList<>(trainingSet);
20-
this.crossValidationSet = new ArrayList<>(crossValidationSet);
21-
this.testSet = new ArrayList<>(testSet);
18+
this.trainingSet = new LearningDataSet(trainingSet);
19+
this.crossValidationSet = new LearningDataSet(crossValidationSet);
20+
this.testSet = new LearningDataSet(testSet);
2221
}
2322

24-
public List<LearningData> getCrossValidationSet() {
25-
return new ArrayList<>(crossValidationSet);
23+
public LearningDataSet getCrossValidationSet() {
24+
return crossValidationSet;
2625
}
2726

28-
public List<LearningData> getTestSet() {
29-
return new ArrayList<>(testSet);
27+
public LearningDataSet getTestSet() {
28+
return testSet;
3029
}
3130

32-
public List<LearningData> getTrainingSet() {
33-
return new ArrayList<>(trainingSet);
31+
public LearningDataSet getTrainingSet() {
32+
return trainingSet;
3433
}
3534

3635
}

0 commit comments

Comments
 (0)