|
3 | 3 | import net.zomis.machlearn.neural.LearningData; |
4 | 4 |
|
5 | 5 | import java.util.ArrayList; |
| 6 | +import java.util.Collections; |
6 | 7 | import java.util.List; |
| 8 | +import java.util.Random; |
7 | 9 | import java.util.function.Predicate; |
8 | 10 | import java.util.stream.Collectors; |
9 | 11 | import java.util.stream.Stream; |
@@ -54,4 +56,22 @@ public Stream<LearningData> stream() { |
54 | 56 | return data.stream(); |
55 | 57 | } |
56 | 58 |
|
| 59 | + public PartitionedDataSet partition(double trainingSetRatio, |
| 60 | + double crossValidationSetRatio, double testSetRatio, Random random) { |
| 61 | + List<LearningData> shuffledData = new ArrayList<>(this.data); |
| 62 | + Collections.shuffle(shuffledData, random); |
| 63 | + // Calculate the sum to support ratios like 0.1, 0.2, 0.3 |
| 64 | + double sum = trainingSetRatio + crossValidationSetRatio + testSetRatio; |
| 65 | + int size = shuffledData.size(); |
| 66 | + int indexSplit1 = (int) (trainingSetRatio / sum * size); |
| 67 | + int indexSplit2 = indexSplit1 + (int) (crossValidationSetRatio / sum * size); |
| 68 | + List<LearningData> trainingSet = |
| 69 | + new ArrayList<>(shuffledData.subList(0, indexSplit1)); |
| 70 | + List<LearningData> crossValidationSet = |
| 71 | + new ArrayList<>(shuffledData.subList(indexSplit1, indexSplit2)); |
| 72 | + List<LearningData> testSet = |
| 73 | + new ArrayList<>(shuffledData.subList(indexSplit2, size)); |
| 74 | + return new PartitionedDataSet(trainingSet, crossValidationSet, testSet); |
| 75 | + } |
| 76 | + |
57 | 77 | } |
0 commit comments