#### Imports

In [3]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from NeuralNet import NeuralNet
import torch.optim as optim
import cv2

##### Getting the train dataset

In [4]:
train = datasets.MNIST('', train = True, download = True, transform = transforms.Compose(([transforms.ToTensor()])))
train_set = torch.utils.data.DataLoader(train, batch_size= 10, shuffle = True)

##### Setting the nn + loss function + optimizer

In [5]:
net = NeuralNet()
#loss function
criterion = nn.CrossEntropyLoss()
#all the parameters are controlled by this optimizer
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)

In [6]:
EPOCH = 5
for epoch in range(EPOCH):
    running_loss = 0
    for i, data in enumerate(train_set, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        #backpropagation
        loss.backward()
        #optimize the gradient
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0
print("Training is done!")

[1,  2000] loss: 1.449
[1,  4000] loss: 0.266
[1,  6000] loss: 0.163
[2,  2000] loss: 0.121
[2,  4000] loss: 0.102
[2,  6000] loss: 0.089
[3,  2000] loss: 0.076
[3,  4000] loss: 0.070
[3,  6000] loss: 0.066
[4,  2000] loss: 0.054
[4,  4000] loss: 0.051
[4,  6000] loss: 0.059
[5,  2000] loss: 0.049
[5,  4000] loss: 0.043
[5,  6000] loss: 0.042
Training is done!


##### making a canvas to get user input

In [8]:
canvas = np.ones((600, 600), dtype = "uint8") * 255
canvas[100:500, 100:500] = 0
start_point = None
end_point = None
is_drawing = False

def draw(img, start_at, end_at):
    cv2.line(img, start_at, end_at, 255, 15)

def mouse_move(event, x, y, flags, params):
    global start_point
    global end_point
    global canvas
    global is_drawing
    if event == cv2.EVENT_LBUTTONDOWN:
        if is_drawing:
            start_point = (x, y)
    elif event == cv2.EVENT_MOUSEMOVE:
        if is_drawing:
            end_point = (x, y)
            draw(canvas, start_point, end_point)
            start_point = end_point
    elif event == cv2.EVENT_LBUTTONUP:
        is_drawing = False
cv2.namedWindow("Draw a Number!")
cv2.setMouseCallback("Draw a Number!", mouse_move)

transform = transforms.Compose([transforms.ToTensor()])
while(True):
    cv2.imshow("Draw a Number!", canvas)
    key = cv2.waitKey(1) & 0xFF
    if key == ord('q'):
        break
    elif key == ord('s'):
        is_drawing = True
    elif key == ord('c'):
        canvas[100:500, 100:500] = 0
    elif key == ord('p'):
        image = canvas[100:500, 100:500]
        image = cv2.resize(image, (28, 28))
        image = transform(image)
        image = image.unsqueeze(1)
        outputs = net(image)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted)
cv2.destroyAllWindows()

tensor([7])
tensor([4])
tensor([7])
tensor([5])
tensor([8])
tensor([1])
tensor([3])
tensor([3])
tensor([3])
tensor([2])
