Skip to content

Commit 6dd975f

Browse files
committed
Add initial support for separating data into trainingset, cross-validationset and testset
1 parent 11cfe8a commit 6dd975f

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

src/main/java/net/zomis/machlearn/common/LearningDataSet.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import net.zomis.machlearn.neural.LearningData;
44

55
import java.util.ArrayList;
6+
import java.util.Collections;
67
import java.util.List;
8+
import java.util.Random;
79
import java.util.function.Predicate;
810
import java.util.stream.Collectors;
911
import java.util.stream.Stream;
@@ -54,4 +56,22 @@ public Stream<LearningData> stream() {
5456
return data.stream();
5557
}
5658

59+
public PartitionedDataSet partition(double trainingSetRatio,
60+
double crossValidationSetRatio, double testSetRatio, Random random) {
61+
List<LearningData> shuffledData = new ArrayList<>(this.data);
62+
Collections.shuffle(shuffledData, random);
63+
// Calculate the sum to support ratios like 0.1, 0.2, 0.3
64+
double sum = trainingSetRatio + crossValidationSetRatio + testSetRatio;
65+
int size = shuffledData.size();
66+
int indexSplit1 = (int) (trainingSetRatio / sum * size);
67+
int indexSplit2 = indexSplit1 + (int) (crossValidationSetRatio / sum * size);
68+
List<LearningData> trainingSet =
69+
new ArrayList<>(shuffledData.subList(0, indexSplit1));
70+
List<LearningData> crossValidationSet =
71+
new ArrayList<>(shuffledData.subList(indexSplit1, indexSplit2));
72+
List<LearningData> testSet =
73+
new ArrayList<>(shuffledData.subList(indexSplit2, size));
74+
return new PartitionedDataSet(trainingSet, crossValidationSet, testSet);
75+
}
76+
5777
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package net.zomis.machlearn.common;
2+
3+
import net.zomis.machlearn.neural.LearningData;
4+
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
8+
/**
9+
* Represents a LearningDataSet that has been partitioned into training set, cross-validation set, and/or test-set.
10+
*/
11+
public class PartitionedDataSet {
12+
13+
private final List<LearningData> trainingSet;
14+
private final List<LearningData> crossValidationSet;
15+
private final List<LearningData> testSet;
16+
17+
public PartitionedDataSet(List<LearningData> trainingSet, List<LearningData> crossValidationSet,
18+
List<LearningData> testSet) {
19+
this.trainingSet = new ArrayList<>(trainingSet);
20+
this.crossValidationSet = new ArrayList<>(crossValidationSet);
21+
this.testSet = new ArrayList<>(testSet);
22+
}
23+
24+
public List<LearningData> getCrossValidationSet() {
25+
return new ArrayList<>(crossValidationSet);
26+
}
27+
28+
public List<LearningData> getTestSet() {
29+
return new ArrayList<>(testSet);
30+
}
31+
32+
public List<LearningData> getTrainingSet() {
33+
return new ArrayList<>(trainingSet);
34+
}
35+
36+
}

0 commit comments

Comments
 (0)