In [43]:
import torch
from PIL import Image
import numpy as np
import glob
import cv2
from torchvision import transforms as T
from matplotlib import pyplot as plt

In [44]:
Transform=T.Compose([T.Grayscale(),T.Resize((28,28)),T.ToTensor(),T.Normalize((0.5),(0.5))])

In [45]:
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()        
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )        # fully connected layer, output 10 classes
        self.conv3 = nn.Sequential(
            nn.Conv2d(32,48,5,1,2), 
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(48,64,5,1,2), 
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(64, 10)    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)        
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output, x    # return x for visualization

In [46]:
model = CNN()
model.load_state_dict(torch.load('model.pth'))
cells_path = 'cells'

In [47]:
board=[]
paths = []
for data_path in glob.glob(cells_path+'/*'):
    paths.append(data_path)

paths = np.sort(paths)
row=[]
for data_path in paths:
    img = cv2.imread(data_path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    _,img = cv2.threshold(img,80,255,cv2.THRESH_BINARY_INV)
    kernel = np.ones((3,3),np.uint8)
    # cv2.imshow('image',img)
    # cv2.waitKey(0)
    f=1
    if(cv2.countNonZero(img)<20):
        row.append(0)
        f=0
    img = Image.fromarray(img)
    img = Transform(img)
    img = torch.transpose(img,0,2)
    img = torch.transpose(img,0,1)
    
    if(f):
        img = torch.reshape(img,(1,1,28,28))
        with torch.no_grad():
            out = model(img)
            out = torch.max(out[0],dim=1)[1] 
        out = np.reshape(np.array(out),(1,))[0]
        row.append(out)
    if(np.size(row)==9):
        board.append(row)
        row=[]
cv2.destroyAllWindows()

In [48]:
board

[[8, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 5, 0, 8, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 4, 1, 1, 0, 1],
 [0, 6, 0, 1, 4, 1, 0, 8, 0],
 [5, 1, 8, 1, 6, 0, 1, 0, 0],
 [0, 1, 0, 5, 1, 2, 1, 4, 1],
 [1, 1, 7, 1, 4, 1, 6, 1, 0],
 [1, 8, 1, 3, 1, 9, 0, 4, 1],
 [3, 1, 0, 1, 5, 1, 1, 1, 1]]