In [62]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline

MODEL_PATH = 'model.pt'
IMG_PATH = 'tmp/'
IMG_SHAPE = (64, 64)
MATRIX_SIZE = sorted(list(map(lambda x: list(map(int, x.split('.')[0].split('_'))), os.listdir(IMG_PATH))))[-1]

CAT_MAPPING = {0: 'both', 1: 'double_text', 2: 'down', 3: 'inverse_arrow', 4: 'other', 5: 'right', 6: 'single_text'}
N_CLASSES = len(CAT_MAPPING)

In [63]:
class Net(nn.Module):
    # Pytorch CNN model class
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        
        self.conv3 = nn.Conv2d(16, 32, 5)
        self.conv4 = nn.Conv2d(32, 64, 5)
        
        
        self.dropout = nn.Dropout(0.3)
        
        self.fc1 = nn.Linear(64*11*11, 512)
        self.bnorm1 = nn.BatchNorm1d(512)
        
        self.fc2 = nn.Linear(512, 128)
        self.bnorm2 = nn.BatchNorm1d(128)
        
        self.fc3 = nn.Linear(128, 64)
        self.bnorm3 = nn.BatchNorm1d(64)
        
        self.fc4 = nn.Linear(64, N_CLASSES)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 64*11*11)
        x = self.dropout(x)
        x = F.relu(self.bnorm1(self.fc1(x)))
        x = F.relu(self.bnorm2(self.fc2(x)))
        x = F.relu(self.bnorm3(self.fc3(x)))
        x = self.fc4(x)
        return x

In [64]:
# create Net instance and load the pretrained weights
net = Net()
net.load_state_dict(torch.load(MODEL_PATH))
net.eval()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (dropout): Dropout(p=0.3)
  (fc1): Linear(in_features=7744, out_features=512, bias=True)
  (bnorm1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (bnorm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=128, out_features=64, bias=True)
  (bnorm3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=64, out_features=7, bias=True)
)

In [65]:
propertyMatrix = np.zeros((MATRIX_SIZE[0]+1, MATRIX_SIZE[1] + 1)) 
textualPropertyMatrix = pd.DataFrame(np.zeros((MATRIX_SIZE[0]+1, MATRIX_SIZE[1] + 1)))

In [66]:
# define basic image transforms for preprocessing
transform = transforms.Compose(
[
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5,), std = (0.5, ))
])

for im in os.listdir(IMG_PATH):
    img_path = IMG_PATH + im
    image = cv2.imread(img_path)
    
    idxs = list(map(int, im.split('.')[0].split('_')))
#     plt.imshow(image)
#     plt.show()

    image = cv2.resize(image, IMG_SHAPE)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = transform(image).reshape(1, 3, IMG_SHAPE[0], IMG_SHAPE[1])
    pred = net(Variable(image)).detach().numpy()
    propertyMatrix[idxs[0], idxs[1]] = np.argmax(pred)
    textualPropertyMatrix.iloc[idxs[0], idxs[1]] = CAT_MAPPING[np.argmax(pred)]
#     print(CAT_MAPPING[np.argmax(pred)])
#     print('------------------------------------------------------')

In [67]:
textualPropertyMatrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
0,other,other,other,other,other,double_text,inverse_arrow,single_text,inverse_arrow,double_text,inverse_arrow,double_text,inverse_arrow,single_text
1,other,other,other,other,other,inverse_arrow,other,other,other,down,other,down,other,double_text
2,other,other,other,other,other,other,other,double_text,right,other,other,other,other,down
3,other,other,other,other,other,inverse_arrow,other,down,other,other,other,double_text,right,other
4,other,other,other,other,other,single_text,double_text,right,other,double_text,right,down,other,double_text
5,other,other,other,other,other,inverse_arrow,down,other,other,down,double_text,right,other,down
6,single_text,inverse_arrow,single_text,inverse_arrow,single_text,double_text,right,other,other,other,down,other,double_text,right
7,inverse_arrow,other,other,other,down,down,other,other,double_text,right,other,double_text,both,other
8,single_text,other,double_text,right,other,other,other,other,down,double_text,right,down,other,other
9,inverse_arrow,other,down,other,other,other,other,double_text,right,down,double_text,right,other,other


In [68]:
try:
    from PIL import Image
except ImportError:
    import Image
import pytesseract

def ocr_core(filename):
    """
    This function will handle the core OCR processing of images.
    """
    image = cv2.imread(filename)
    #gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    text = pytesseract.image_to_string(image, lang='fra', config=r'--psm 1')  # We'll use Pillow's Image class to open the image and pytesseract to detect the string in the image
    return text

In [69]:
def findArrows(data, i, j):
    '''
    Finds arrows adjacent to the current text displacement
    '''
    arrows = []
    try:
        arrow1 = data.iloc[i+1, j]
    except:
        arrow1 = 'other'
    
    # if arrow is under the text fieled
    if arrow1 != 'double_text' and arrow1 != 'single_text' and arrow1 != 'other':
        if arrow1 == 'inverse_arrow':
            arrow1 = 'right'
            arrows.append([arrow1, i+1, j])
        elif arrow1 == 'down':
            arrow1 = 'down'
            arrows.append([arrow1, i+1, j])
    
    try:
        arrow2 = data.iloc[i, j+1]
    except:
        arrow2 = 'other'
       
    # if arrow is to the right of the text field
    if arrow2 != 'double_text' and arrow2 != 'single_text' and arrow2 != 'other':
        if arrow2 == 'inverse_arrow':
            arrow2 = 'down'
            arrows.append([arrow2, i, j+1])
        elif arrow2 == 'right':
            arrow2 = 'right' 
            arrows.append([arrow2, i, j+1])
        
    return arrows


def extractJson(data):
    json = {'definitions':[]}
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            if data.iloc[i, j] == 'single_text' or data.iloc[i, j] == 'double_text':
                arrows = findArrows(data, i, j)
                text = ocr_core(IMG_PATH + f'{i}_{j}.png')
                for index, arrow in enumerate(arrows):
                    json['definitions'].append({
                        'label': str(text), 
                        'position': [i,j], 
                        'solution': {
                            'startPosition':[arrow[1], arrow[2]],
                            'direction': arrow[0]
                        }
                    })
    return json

In [70]:
%%time

# extract json
jsonData = extractJson(textualPropertyMatrix)

Wall time: 10 s


In [76]:
import codecs
import json

with codecs.open('p3_1.json', 'w', encoding = 'utf_8_sig') as f:
    json.dump(jsonData, f, ensure_ascii=False)