Skip to content

Commit 916012d

Browse files
committed
Added more training sets in MinesweeperScan to try to get better results
1 parent 74f1fdd commit 916012d

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/test/groovy/net/zomis/machlearn/images/MinesweeperScan.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,27 @@ private static char[][] scanGrid(BufferedImage runImage, ZRect[][] gridLocations
4848
String fileName = "challenge-flags-16x16.png";
4949
BufferedImage image = MyImageUtil.resource(fileName);
5050
ImageAnalysis analyze = new ImageAnalysis(36, 36, true);
51-
ImageNetwork network = analyze.neuralNetwork(40)
52-
.classify('_', analyze.imagePart(image, 622, 200))
53-
.classify('1', analyze.imagePart(image, 793, 287))
54-
.classify('2', analyze.imagePart(image, 665, 200))
55-
.classify('3', analyze.imagePart(image, 793, 244))
56-
.classify('4', analyze.imagePart(image, 750, 416))
57-
.classify('5', analyze.imagePart(image, 664, 502))
58-
.classify('6', analyze.imagePart(image, 707, 502))
59-
.classify('a', analyze.imagePart(image, 793, 200))
60-
.classifyNone(analyze.imagePart(image, 0, 0))
51+
Map<Character, ZPoint> trainingSet = new HashMap<>();
52+
trainingSet.put('_', new ZPoint(622, 200));
53+
trainingSet.put('1', new ZPoint(793, 287));
54+
trainingSet.put('2', new ZPoint(665, 200));
55+
trainingSet.put('3', new ZPoint(793, 244));
56+
trainingSet.put('4', new ZPoint(750, 416));
57+
trainingSet.put('5', new ZPoint(664, 502));
58+
trainingSet.put('6', new ZPoint(707, 502));
59+
trainingSet.put('a', new ZPoint(793, 200));
60+
61+
ImageNetworkBuilder networkBuilder = analyze.neuralNetwork(40);
62+
for (Map.Entry<Character, ZPoint> ee : trainingSet.entrySet()) {
63+
int yy = ee.getValue().getY();
64+
int xx = ee.getValue().getX();
65+
for (int y = 4; y <= 4; y += 2) {
66+
for (int x = 4; x <= 4; x += 2) {
67+
networkBuilder = networkBuilder.classify(ee.getKey(), analyze.imagePart(image, xx + x, yy + y));
68+
}
69+
}
70+
}
71+
ImageNetwork network = networkBuilder.classifyNone(analyze.imagePart(image, 0, 0))
6172
.classifyNone(analyze.imagePart(image, 878, 456))
6273
.classifyNone(analyze.imagePart(image, 903, 456))
6374
.classifyNone(analyze.imagePart(image, 948, 456))

0 commit comments

Comments
 (0)