Skip to content

Commit 1fa87fd

Browse files
committed
Added detection of all the squares in the grid in MinesweeperScan
1 parent 048f875 commit 1fa87fd

File tree

1 file changed

+247
-3
lines changed

1 file changed

+247
-3
lines changed

src/test/groovy/net/zomis/machlearn/images/MinesweeperScan.java

Lines changed: 247 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
package net.zomis.machlearn.images;
22

33
import net.zomis.machlearn.neural.Backpropagation;
4+
import org.imgscalr.Scalr;
45

56
import java.awt.*;
67
import java.awt.image.BufferedImage;
7-
import java.util.Random;
8+
import java.io.File;
9+
import java.util.*;
10+
import java.util.List;
811

912
public class MinesweeperScan {
1013

14+
private static String LEARN_IMAGE = "challenge-flags-16x16.png";
15+
private static BufferedImage img = ImageUtil.resource(LEARN_IMAGE);
16+
1117
public static void scan() {
1218
ImageAnalysis analysis = new ImageAnalysis(1, 100, true);
13-
String fileName = "challenge-flags-16x16.png";
14-
BufferedImage img = ImageUtil.resource(fileName);
1519
ImageNetwork network = analysis.neuralNetwork(40)
1620
.classifyNone(analysis.imagePart(img, 0, 540))
1721
.classifyNone(analysis.imagePart(img, 100, 540))
@@ -29,6 +33,246 @@ public static void scan() {
2933
BufferedImage runImage = ImageUtil.resource("challenge-press-26x14.png");
3034
ZRect rect = findEdges(network, analysis, runImage);
3135
System.out.println("Edges: " + rect);
36+
// also try find separations by scanning lines and finding the line with the lowest delta diff
37+
38+
ZRect[][] gridLocations = findGrid(runImage, rect);
39+
char[][] gridValues = scanGrid(runImage, gridLocations);
40+
for (int y = 0; y < gridValues.length; y++) {
41+
for (int x = 0; x < gridValues[y].length; x++) {
42+
System.out.print(gridValues[y][x]);
43+
}
44+
System.out.println();
45+
}
46+
}
47+
48+
private static char[][] scanGrid(BufferedImage runImage, ZRect[][] gridLocations) {
49+
String fileName = "challenge-flags-16x16.png";
50+
// String fileName = "different-colors.png";
51+
BufferedImage image = ImageUtil.resource(fileName);
52+
ImageAnalysis analyze = new ImageAnalysis(36, 36, true);
53+
ImageNetwork network = analyze.neuralNetwork(40)
54+
.classify('_', analyze.imagePart(image, 622, 200))
55+
.classify('1', analyze.imagePart(image, 793, 287))
56+
.classify('2', analyze.imagePart(image, 665, 200))
57+
.classify('3', analyze.imagePart(image, 793, 244))
58+
.classify('4', analyze.imagePart(image, 750, 416))
59+
.classify('5', analyze.imagePart(image, 664, 502))
60+
.classify('6', analyze.imagePart(image, 707, 502))
61+
.classify('a', analyze.imagePart(image, 793, 200))
62+
.classifyNone(analyze.imagePart(image, 0, 0))
63+
.classifyNone(analyze.imagePart(image, 878, 456))
64+
.classifyNone(analyze.imagePart(image, 903, 456))
65+
.classifyNone(analyze.imagePart(image, 948, 456))
66+
.classifyNone(analyze.imagePart(image, 1004, 558))
67+
.classifyNone(analyze.imagePart(image, 921, 496))
68+
.classifyNone(analyze.imagePart(image, 921, 536))
69+
.classifyNone(analyze.imagePart(image, 963, 536))
70+
.learn(new Backpropagation(0.1, 4000), new Random(42));
71+
72+
char[][] result = new char[gridLocations.length][gridLocations[0].length];
73+
for (int y = 0; y < gridLocations.length; y++) {
74+
for (int x = 0; x < gridLocations[y].length; x++) {
75+
ZRect rect = gridLocations[y][x];
76+
Map<Object, Double> output = scanSquare(analyze, network, runImage, rect);
77+
char ch = charForOutput(output);
78+
result[y][x] = ch;
79+
}
80+
}
81+
return result;
82+
}
83+
84+
private static char charForOutput(Map<Object, Double> output) {
85+
if (output == null) {
86+
return '%';
87+
}
88+
Map.Entry<Object, Double> max = output.entrySet().stream()
89+
.max(Comparator.comparingDouble(e -> e.getValue())).get();
90+
if (max.getValue() < 0.5) {
91+
return '#';
92+
}
93+
return (Character) max.getKey();
94+
}
95+
96+
private static Map<Object, Double> scanSquare(ImageAnalysis analyze, ImageNetwork network, BufferedImage runImage, ZRect rect) {
97+
if (rect == null) {
98+
return null;
99+
}
100+
int min = Math.min(analyze.getWidth(), analyze.getHeight());
101+
int minRect = Math.min(rect.width(), rect.height());
102+
BufferedImage image = Scalr.crop(runImage, rect.left, rect.top, minRect, minRect);
103+
BufferedImage run = Scalr.resize(image, min, min);
104+
System.out.printf("Running on %s with target size %d, %d run image is %d, %d%n", rect,
105+
analyze.getWidth(), analyze.getHeight(), run.getWidth(), run.getHeight());
106+
return network.run(analyze.imagePart(run, 0, 0));
107+
}
108+
109+
private static ZRect[][] findGrid(BufferedImage runImage, ZRect rect) {
110+
// Classify the line separator as true
111+
ImageAnalysis horizontalAnalysis = new ImageAnalysis(50, 2, true);
112+
ImageNetwork horizontal = horizontalAnalysis.neuralNetwork(20)
113+
.classify(true, horizontalAnalysis.imagePart(img, 600, 235))
114+
.classify(true, horizontalAnalysis.imagePart(img, 700, 235))
115+
.classifyNone(horizontalAnalysis.imagePart(img, 600, 249))
116+
.classifyNone(horizontalAnalysis.imagePart(img, 664, 249))
117+
.learn(new Backpropagation(0.1, 10000), new Random(42));
118+
119+
ImageAnalysis verticalAnalysis = new ImageAnalysis(2, 50, true);
120+
ImageNetwork vertical = verticalAnalysis.neuralNetwork(20)
121+
.classify(true, verticalAnalysis.imagePart(img, 700, 300))
122+
.classify(true, verticalAnalysis.imagePart(img, 700, 400))
123+
.classifyNone(verticalAnalysis.imagePart(img, 682, 279))
124+
.classifyNone(verticalAnalysis.imagePart(img, 765, 279))
125+
.classifyNone(verticalAnalysis.imagePart(img, 630, 249))
126+
.classifyNone(verticalAnalysis.imagePart(img, 795, 290))
127+
.classifyNone(verticalAnalysis.imagePart(img, 795, 365))
128+
.classifyNone(verticalAnalysis.imagePart(img, 795, 465))
129+
.classifyNone(verticalAnalysis.imagePart(img, 722, 497))
130+
.classifyNone(verticalAnalysis.imagePart(img, 770, 249))
131+
.classifyNone(verticalAnalysis.imagePart(img, 719, 497))
132+
.learn(new Backpropagation(0.1, 10000), new Random(42));
133+
134+
List<Integer> horizontalLines = new ArrayList<>();
135+
for (int y = rect.top; y + horizontalAnalysis.getHeight() < rect.bottom; y++) {
136+
double[] input = horizontalAnalysis.imagePart(runImage, rect.left + 10, y);
137+
double[] output = horizontal.getNetwork().run(input);
138+
double result = output[0];
139+
if (result > 0.7) {
140+
horizontalLines.add(y);
141+
}
142+
}
143+
144+
List<Integer> verticalLines = new ArrayList<>();
145+
for (int x = rect.left; x + verticalAnalysis.getWidth() < rect.right; x++) {
146+
double[] input = verticalAnalysis.imagePart(runImage, x, rect.top + 10);
147+
double[] output = vertical.getNetwork().run(input);
148+
double result = output[0];
149+
if (result > 0.7) {
150+
verticalLines.add(x);
151+
}
152+
}
153+
154+
// runAndSave(verticalAnalysis, vertical, runImage);
155+
156+
System.out.println("Edges: " + rect);
157+
System.out.println("Horizontal: " + horizontalLines);
158+
System.out.println("Vertical : " + verticalLines);
159+
160+
horizontalLines = removeCloseValues(horizontalLines, 15);
161+
verticalLines = removeCloseValues(verticalLines, 15);
162+
System.out.println("------------");
163+
System.out.println("Horizontal " + horizontalLines.size() + ": " + horizontalLines);
164+
System.out.println("Vertical " + verticalLines.size() + ": " + verticalLines);
165+
166+
// Remove outliers
167+
int squareWidth = verticalLines.get(1) - verticalLines.get(0);
168+
int squareHeight = horizontalLines.get(1) - horizontalLines.get(0);
169+
verticalLines = removeCloseValues(verticalLines, (int) (squareWidth * 0.75));
170+
horizontalLines = removeCloseValues(horizontalLines, (int) (squareHeight * 0.75));
171+
172+
System.out.println("------------");
173+
System.out.println("Horizontal " + horizontalLines.size() + ": " + horizontalLines);
174+
System.out.println("Vertical " + verticalLines.size() + ": " + verticalLines);
175+
176+
ZRect[][] gridLocations = grabRects(runImage, rect, horizontalLines, verticalLines, squareWidth, squareHeight);
177+
System.out.println("Square size = " + squareWidth + " x " + squareHeight);
178+
System.out.println("Squares found: " + gridLocations[0].length + " x " + gridLocations.length);
179+
return gridLocations;
180+
}
181+
182+
private static ZRect[][] grabRects(BufferedImage image, ZRect rect, List<Integer> horizontalLines, List<Integer> verticalLines,
183+
int squareWidth, int squareHeight) {
184+
horizontalLines = new ArrayList<>(horizontalLines);
185+
verticalLines = new ArrayList<>(verticalLines);
186+
horizontalLines.add(rect.top);
187+
horizontalLines.add(rect.bottom);
188+
verticalLines.add(rect.left);
189+
verticalLines.add(rect.right);
190+
Collections.sort(horizontalLines);
191+
Collections.sort(verticalLines);
192+
193+
horizontalLines = removeCloseValues(horizontalLines, (int) (squareHeight * 0.75));
194+
verticalLines = removeCloseValues(verticalLines, (int) (squareWidth * 0.75));
195+
196+
// int beforeFirstX = verticalLines.get(0) - rect.left;
197+
// int afterLastX = rect.right - verticalLines.get(verticalLines.size() - 1);
198+
// int beforeFirstY = horizontalLines.get(0) - rect.top;
199+
// int afterLastY = rect.bottom - horizontalLines.get(horizontalLines.size() - 1);
200+
201+
System.out.println("Horizontal " + horizontalLines.size() + ": " + horizontalLines);
202+
System.out.println("Vertical " + verticalLines.size() + ": " + verticalLines);
203+
204+
ZRect[][] results = new ZRect[horizontalLines.size() + 1][verticalLines.size() + 1];
205+
int x = 0;
206+
for (Integer left : verticalLines) {
207+
int y = 0;
208+
for (Integer top : horizontalLines) {
209+
ZRect r = new ZRect();
210+
r.left = left;
211+
r.top = top;
212+
r.right = left + squareWidth;
213+
r.bottom = top + squareHeight;
214+
if (r.right >= image.getWidth()) {
215+
continue;
216+
}
217+
if (r.bottom >= image.getHeight()) {
218+
continue;
219+
}
220+
results[y][x] = r;
221+
y++;
222+
}
223+
x++;
224+
}
225+
226+
return results;
227+
}
228+
229+
private static List<Integer> removeCloseValues(List<Integer> values, int closeRange) {
230+
List<Integer> result = new ArrayList<>();
231+
Integer last = null;
232+
for (Integer i : values) {
233+
if (last == null || last + closeRange < i) {
234+
last = i;
235+
result.add(i);
236+
}
237+
}
238+
return result;
239+
}
240+
241+
private static void runAndSave(ImageAnalysis analysis, ImageNetwork network, BufferedImage image) {
242+
BufferedImage[] networkResult = runOnImage(analysis, network, image);
243+
for (int i = 0; i < networkResult.length; i++) {
244+
ImageUtil.save(networkResult[i], new File("network-result-" + i + ".png"));
245+
}
246+
}
247+
248+
private static BufferedImage[] runOnImage(ImageAnalysis analysis, ImageNetwork network, BufferedImage runImage) {
249+
int maxY = runImage.getHeight() - analysis.getHeight();
250+
int maxX = runImage.getWidth() - analysis.getWidth();
251+
BufferedImage[] images = new BufferedImage[network.getNetwork().getOutputLayer().size()];
252+
for (int i = 0; i < images.length; i++) {
253+
images[i] = new BufferedImage(runImage.getWidth(), runImage.getHeight(), BufferedImage.TYPE_INT_ARGB);
254+
Graphics2D graphics = images[i].createGraphics();
255+
graphics.setColor(Color.MAGENTA);
256+
graphics.fillRect(0, 0, runImage.getWidth(), runImage.getHeight());
257+
}
258+
259+
for (int y = 0; y < maxY; y++) {
260+
if (y % 20 == 0) {
261+
System.out.println("process y " + y);
262+
}
263+
for (int x = 0; x < maxX; x++) {
264+
double[] input = analysis.imagePart(runImage, x, y);
265+
double[] output = network.getNetwork().run(input);
266+
for (int i = 0; i < output.length; i++) {
267+
double value = output[i];
268+
int grayscaleValue = (int) (value * 255);
269+
// System.out.println(x + ", " + y + ": " + grayscaleValue + " -- " + value);
270+
int rgb = 0xff << 24 | grayscaleValue << 16 | grayscaleValue << 8 | grayscaleValue;
271+
images[i].setRGB(x, y, rgb);
272+
}
273+
}
274+
}
275+
return images;
32276
}
33277

34278
private static ZRect findEdges(ImageNetwork network, ImageAnalysis analysis, BufferedImage runImage) {

0 commit comments

Comments
 (0)