# COMP 551 Assignment 3 - Digit Localization

In [4]:
# basic
import cv2
import numpy as np
import imutils
import matplotlib.pyplot as plt
import pandas as pd

# PyTorch 
import torch

train_labels = pd.read_csv('train_max_y.csv')
train_images = pd.read_pickle('train_max_x')
test_images = pd.read_pickle('test_max_x')

print(train_labels.head())
print(train_images.shape)
print(test_images.shape)

   Id  Label
0   0      6
1   1      7
2   2      2
3   3      9
4   4      7
(50000, 128, 128)
(10000, 128, 128)


In [5]:
##### DIGIT EXTRACTION PARAMETERS #####
# can be tweaked to optimize digit extraction from image
threshold = 210    # binary thresholding image (keep white (255) digits, set everything else to black)
digit_width = 15   # number of pixels wide a contour box must be to be detected as a digit
digit_height = 15  # number of pixels tall a contour box must be to be detected as a digit

##### STORAGE VARIABLES #####
allDigitWindows = []   # holds the windows containing digits (in black & white)

In [6]:
for img in train_images:
    # threshold the image to only black (non-digit) or white (digit)
    ret, t_img = cv2.threshold(img.copy(), threshold, 255, cv2.THRESH_BINARY)
    digitBoxes = t_img.copy()
    
    # find contours in the thresholded image
    contours = cv2.findContours(t_img.copy().astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # loop over the candidates for digits, determine which are actually digits, then save them
    digitWindows = []
    for c in imutils.grab_contours(contours):
        # compute the bounding box of the contour
        (x, y, w, h) = cv2.boundingRect(c)

        # if the contour is sufficiently large, decide that it is a digit
        if w >= digit_width and w <= 28 and h >= digit_height and h <= 28:
            digitWindows.append(np.asarray(t_img[y:y+h, x:x+w]))
#             cv2.rectangle(digitBoxes, (x, y), (x + w, y + h), (255, 0, 0), 1)
        elif w >= digit_width or h >= digit_height:
            digitWindows.append(np.asarray(t_img[y:y+h, x:x+w]))

    # add digitContours to array
    allDigitWindows.append(digitWindows)
    
    # display progression
#     plt.figure(figsize = (20,20))
#     plt.subplot(131), plt.imshow(img, cmap='gray', vmin=0, vmax=255)
#     plt.subplot(132), plt.imshow(t_img, cmap='gray', vmin=0, vmax=255)
#     plt.subplot(133), plt.imshow(digitBoxes, cmap='gray', vmin=0, vmax=255)
#     plt.show()

In [7]:
# determine performance of digit identifier
count = 0
average = 0
none = 0
more = 0
less = 0
for img in allDigitWindows:
    if len(img) != 3:
        count += 1
    if len(img) > 3:
        more += 1
    if len(img) < 3:
        less += 1
    if len(img) ==0:
        none += 1
    average += len(img)
average /= len(allDigitWindows)

print("number of images with > 3 identified digits: " + str(more))
print("number of images with < 3 identified digits: " + str(less))
print("number of images with != 3 identified digits: " + str(count))
print("average number of digits identified per image: " + str(average))
print("number of images with no digits identified: " + str(none))

number of images with > 3 identified digits: 1590
number of images with < 3 identified digits: 871
number of images with != 3 identified digits: 2461
average number of digits identified per image: 3.0221
number of images with no digits identified: 0


### Format data to be fed into the model

In [None]:
digitData = []

for img in allDigitWindows:
    imgData = []
    for digit in img:
        
        resized = cv2.copyMakeBorder(digit, 5, 5, 5, 5, cv2.BORDER_CONSTANT, 255)
        resized = cv2.resize(resized, (28, 28))

        imgData.append(np.asarray(resized))
    digitData.append(imgData)
    
torch.save(digitData, 'win15Thresh210DigitData.pkl')