In [5]:
import os
import time
import random
import glob
import cv2

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

In [47]:
data_dir = 'mnist_train_imgs'

# TODO: Use all images
data_dirs = [ 
            #  'mnist_test_imgs',
            #  'mnist_train_imgs',
             '/Users/nathanielyoungren/Desktop/code_projects/cbhs_numbers/saved_digits'
             ]

category_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


weights = [8, 2] # [7, 2, 1] # TODO: Have a test set!

# train_dir = os.path.join('mnist_train_imgs')

# valid_dir = os.path.join('mnist_test_imgs')
# test_dir = os.path.join('saved_digits')
seed = 0

img_size = (28, 28)

In [48]:
# Create train/valid sets
def shuffle_files(directories, weights, ext='**/*.png', seed=0):
    files = []
    for directory in directories:
        globstr = os.path.join(directory, ext)
        files += glob.glob(globstr)

    random.Random(seed).shuffle(files)
    
    weighted_split = []
    i = 0
    for w in weights[:-1]:
        _i = int(w * len(files) / sum(weights))
        weighted_split.append(files[i:i+_i])
        i += _i
    weighted_split.append(files[i:])

    return weighted_split

split_data = shuffle_files(directories=data_dirs, weights=weights, seed=seed)


In [49]:

# TODO: Add horizontal/vertical smearing to the images.
# TODO: Add A TINY BIT of random pixel noise to the images.
# TODO: Identify the min / max image fill of the dataset.

def random_alignment(img):
    ax1sums = np.sum(img, axis=0)
    up_shift = np.argmax(ax1sums>0)
    down_shift = np.argmax(ax1sums[::-1]>0)
    
    ax2sums = np.sum(img, axis=1)
    left_shift = np.argmax(ax2sums>0)
    right_shift = np.argmax(ax2sums[::-1]>0)

    random_y = random.randint(-up_shift, down_shift)
    random_x = random.randint(-left_shift, right_shift)
    
    return np.roll(img, (random_x, random_y), axis=(0, 1))

def transform(img, randomize=False):
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    
    if random.random() < randomize:
        img = random_alignment(img)
        
    img = torch.tensor(img)
    img = img.unsqueeze(0)
    img = img.float()
    img = img / 255
    return img

# TODO: Add up/down/left/right shift to the images.
# TODO: Flip images? Flip certain numbers?
# TODO: Rework to allow multiple data directories. (mnist + saved digits)
        
# Define a custom pytorch Dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, files, img_size, transform, labels, randomize=0.0):
        # self.root_dir = root_dir
        self.image_files = files
        self.randomize = randomize
        self.transform = transform
        self.img_size = img_size
        self.labels = labels # NOTE: Not enforced upon loaded images, may crash if extra directories are around
    
    # Define the length of the dataset
    def __len__(self):
        return len(self.image_files)
    
    # Define the getitem function to return images and labels
    def __getitem__(self, idx):
        # Filename without the path
        image_name = os.path.split(os.path.split(self.image_files[idx])[0])[1]

        # Open image, apply transforms.
        image = cv2.imread(self.image_files[idx])
        if self.transform:
            image = self.transform(image, randomize=self.randomize)
            
        label = image_name[0]
        # print(label)
        target = torch.tensor(self.labels.index(label))
        
        # Create a dictionary with the image and label
        sample = {'image': image, 'target': target, 'label': label}
        
        return sample

# Create dataloaders for training and validation data
train_data = CustomDataset(files=split_data[0], img_size=img_size, transform=transform, labels=category_labels, randomize=1.0)
valid_data = CustomDataset(files=split_data[1], img_size=img_size, transform=transform, labels=category_labels, randomize=0.5)
print(len(train_data), len(valid_data))

train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=1, shuffle=True)
print(len(train_loader), len(valid_loader))

584 147
584 147


In [50]:

class Digit_OCR_CNN(nn.Module):

    def __init__(self):
        super(Digit_OCR_CNN, self).__init__()

        # Convolutional layers
        
        # self.drop0 = nn.Dropout(0.33)
        # self.conv1 = nn.Conv2d(1, 3, 3, padding=1)
        # self.conv2 = nn.Conv2d(3, 6, 3, padding=1)
        # self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(28*28, 128)
        self.drop1 = nn.Dropout(0.2)
        # self.drop2 = nn.Dropout(0.10)
        self.fc2 = nn.Linear(128, 128)
        # self.drop3 = nn.Dropout(0.10)
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, state):
        # x = self.drop0(state)
        # x = F.relu(self.conv1(state))
        # x = self.pool(F.relu(self.conv1(state)))
        # x = self.pool(F.relu(self.conv2(x)))
        x = state.view(-1, 28*28)

        # print(x.size())
        x = F.relu(self.fc1(x))
        # x = self.drop2(x)
        x = self.drop1(x)

        x = F.relu(self.fc2(x))
        # x = self.drop3(x)
        x = F.log_softmax(self.fc3(x), dim=1)

        return x

net = Digit_OCR_CNN()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
net.to(device)
best_acc = 0.0

cpu


In [58]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.NLLLoss()
# criterion = nn.KLDivLoss() # Boolean values
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.AdamW(net.parameters(), lr=0.00005)#, momentum=0.9)

In [59]:

def check_best(best_acc):
    net.eval()
    best_name = 'best_model.pt'
    with torch.no_grad():
        count = 0
        correct = 0
        for d in valid_loader:
            i = d['image']
            l = d['label']
            r = net(i.to(device))
            _, predicted = torch.max(r, 1)
            for i, _r in enumerate(predicted):
                count += 1
                if category_labels[_r] == l[i]:
                    correct += 1
                else:
                    pass
        
        curr_acc = correct/count
        if curr_acc > best_acc:
            best_acc = curr_acc
            torch.save(net.state_dict(), best_name)
            print(f'\tNew best accuracy: {best_acc:.5f}\n\t\t> Saving model as', best_name)
    
    net.train()
    return best_acc

def train(best_acc):
    print('Starting accuracy:', best_acc)
    
    # How low the loss must be to trigger early stopping.
    loss_end_thresh = 0.0
    # How many consecutive loss values must be below the threshold to trigger early stopping.
    consecutive_thresh = 5
    # Track how many consecutive loss values have been below the threshold.
    thresh_track = 0
    
    num_epochs = 50
    
    best_interval = int(len(train_loader) / 4) 

    loss_interval = 10
    loss_end_thresh = 0.0
    consecutive_thresh = 5
    thresh_track = 0
    for epoch in range(num_epochs):  # loop over the dataset multiple times
        net.train()

        running_loss = 0.0
        for i, data in enumerate(train_loader):

            inputs, targets = data['image'], data['target']

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            # print(outputs.size())

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            if i % loss_interval == loss_interval-1:    # print every 2000 mini-batches
                mean_loss = abs(running_loss / loss_interval)
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, mean_loss))
                running_loss = 0.0
                if mean_loss < loss_end_thresh:
                    print(f'Thresh at {thresh_track} of {consecutive_thresh}')
                    thresh_track += 1
                    if thresh_track > consecutive_thresh:
                        best_acc = check_best(best_acc=best_acc)
                        return best_acc
                else:
                    thresh_track = 0
            if i % best_interval == best_interval-1:
                # Check for best
                best_acc = check_best(best_acc=best_acc)
                
    return best_acc

# # # 
prev_acc = best_acc

st = time.time()
best_acc = train(best_acc)
et = time.time()


print('Finished Training')
print('Prev accuracy:', prev_acc)
print('Best accuracy:', best_acc)
print('Change:', best_acc - prev_acc)
print('Time:', et - st)  # milliseconds
prev_acc = best_acc


Starting accuracy: 0.9251700680272109
[1,    10] loss: 0.372
[1,    20] loss: 0.304
[1,    30] loss: 0.194
[1,    40] loss: 0.497
[1,    50] loss: 0.397
[1,    60] loss: 0.047
[1,    70] loss: 0.291
[1,    80] loss: 0.509
[1,    90] loss: 0.153
[1,   100] loss: 0.122
[1,   110] loss: 1.300
[1,   120] loss: 0.387
[1,   130] loss: 0.162
[1,   140] loss: 0.265
[1,   150] loss: 0.575
[1,   160] loss: 0.315
[1,   170] loss: 0.581
[1,   180] loss: 0.324
[1,   190] loss: 0.121
[1,   200] loss: 0.357
[1,   210] loss: 0.501
[1,   220] loss: 0.779
[1,   230] loss: 0.327
[1,   240] loss: 0.178
[1,   250] loss: 0.323
[1,   260] loss: 1.071
[1,   270] loss: 0.066
[1,   280] loss: 0.162
[1,   290] loss: 0.482
[1,   300] loss: 0.071
[1,   310] loss: 0.989
[1,   320] loss: 0.445
[1,   330] loss: 0.550
[1,   340] loss: 0.593
[1,   350] loss: 0.228
[1,   360] loss: 0.105
[1,   370] loss: 0.651
[1,   380] loss: 0.344
[1,   390] loss: 0.481
[1,   400] loss: 0.317
[1,   410] loss: 0.334
[1,   420] loss: 0.

In [51]:
net = Digit_OCR_CNN()
# load_name = 'models/ability_icon_1_2.pt'
load_name = 'best_model.pt'
net.load_state_dict(torch.load(load_name))
net.to(device)

Digit_OCR_CNN(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (drop1): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

In [45]:
torch.save(net.state_dict(), 'test_model2.pt')

In [69]:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in category_labels}
total_pred = {classname: 0 for classname in category_labels}
incorrect_pred = {classname: [] for classname in category_labels}


# again no gradients needed
with torch.no_grad():
    net.eval()
    for data in valid_loader:
        images, targets, labels = data['image'], data['target'], data['label']
        outputs = net(images.to(device))
        _, predictions = torch.max(outputs, 1)
        # print(predictions)
        # collect the correct predictions for each class
        for i, (label, target, prediction, output) in enumerate(zip(labels, targets, predictions, outputs)):
            # print(label, target, prediction, output)
            if target == prediction:
                correct_pred[label] += 1
            else:
                incorrect_pred[label].append(category_labels[int(prediction)])

                new_img = np.repeat(images[i].numpy(), 3)
                new_img = np.reshape(new_img, (28, 28, 3))

                new_img *= 255
                new_img = new_img.astype(np.uint8)
                cv2.imshow(f'{label} =/= {prediction}', new_img)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

            total_pred[label] += 1

# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:10s} is {accuracy:.1f} %')
    print(f'\tIncorrect predictions: {incorrect_pred[classname]}')
print(f'Overall Accuracy: {(100 * float(sum(correct_pred.values()) / sum(total_pred.values()))):.3f}')
del correct_pred
del total_pred
del correct_count

Accuracy for class: 0          is 88.9 %
	Incorrect predictions: ['5', '3']
Accuracy for class: 1          is 94.1 %
	Incorrect predictions: ['3']
Accuracy for class: 2          is 86.4 %
	Incorrect predictions: ['6', '3', '6']
Accuracy for class: 3          is 87.5 %
	Incorrect predictions: ['9', '5']
Accuracy for class: 4          is 100.0 %
	Incorrect predictions: []
Accuracy for class: 5          is 83.3 %
	Incorrect predictions: ['6', '8']
Accuracy for class: 6          is 100.0 %
	Incorrect predictions: []
Accuracy for class: 7          is 100.0 %
	Incorrect predictions: []
Accuracy for class: 8          is 80.0 %
	Incorrect predictions: ['0', '9']
Accuracy for class: 9          is 90.9 %
	Incorrect predictions: ['4']
Overall Accuracy: 91.156


In [58]:
#Create the dataloader for training and test dat
gui_data = CustomDataset(root_dir=os.path.join('saved_digits'),
                         img_size=img_size,
                         transform=transform,
                         labels=category_labels,
                         randomize=True)
gui_loader = DataLoader(gui_data, batch_size=32, shuffle=True)

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in category_labels}
total_pred = {classname: 0 for classname in category_labels}
incorrect_pred = {classname: [] for classname in category_labels}

# again no gradients needed
with torch.no_grad():
    net.eval()
    for data in gui_loader:
        images, targets, labels = data['image'], data['target'], data['label']
        outputs = net(images.to(device))
        _, predictions = torch.max(outputs, 1)

        # collect the correct predictions for each class
        for label, target, prediction, output in zip(labels, targets, predictions, outputs):
            # print(label, target, prediction, output)
            if target == prediction:
                correct_pred[label] += 1
            else:
                incorrect_pred[label].append(category_labels[int(prediction)])

            total_pred[label] += 1

# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:10s} is {accuracy:.1f} %')
    print(f'\tIncorrect predictions: {incorrect_pred[classname]}')
print(f'Overall Accuracy: {(100 * float(sum(correct_pred.values()) / sum(total_pred.values()))):.3f}')
del correct_pred
del total_pred
del correct_count

Accuracy for class: 0          is 19.0 %
	Incorrect predictions: ['3', '6', '3', '2', '2', '8', '2', '2', '2', '7', '3', '9', '3', '3', '2', '3', '7']
Accuracy for class: 1          is 4.8 %
	Incorrect predictions: ['6', '2', '5', '4', '0', '0', '2', '0', '7', '7', '6', '7', '4', '7', '5', '7', '7', '6', '5', '2']
Accuracy for class: 2          is 42.9 %
	Incorrect predictions: ['3', '6', '4', '3', '3', '6', '6', '6', '0', '7', '0', '8']
Accuracy for class: 3          is 57.1 %
	Incorrect predictions: ['2', '9', '6', '2', '5', '9', '6', '2', '2']
Accuracy for class: 4          is 23.8 %
	Incorrect predictions: ['0', '2', '0', '0', '6', '5', '6', '2', '9', '2', '7', '1', '2', '5', '8', '0']
Accuracy for class: 5          is 19.0 %
	Incorrect predictions: ['6', '9', '6', '3', '3', '4', '9', '7', '2', '3', '7', '0', '6', '6', '3', '3', '8']
Accuracy for class: 6          is 50.0 %
	Incorrect predictions: ['9', '8', '2', '5', '5', '3', '8', '5', '3', '5']
Accuracy for class: 7          is 