In [57]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

########################################## Lenet-5 Network
class Net(Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.pool1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.pool2(out)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = self.relu3(out)
        out = self.fc2(out)
        out = self.relu4(out)
        out = self.fc3(out)
        out = self.relu5(out)
        return out
###########################################
if __name__ == '__main__':
    batch_size = 64
    mnist_train = mnist.MNIST('./train', train=True, download=True, transform=ToTensor())
    mnist_test = mnist.MNIST('./test', train=False, download=True, transform=ToTensor())
    train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)    
    net = Net()
    #The cost function we used for logistic regression
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01)
    n_epochs = 150
    for epoch in range(1, n_epochs +1):
        train_loss = 0.0
        for imgs, labels in train_loader:
            optimizer.zero_grad()
            outputs = net(imgs)
            train_loss = loss_fn(outputs, labels)
            train_loss.backward()
            optimizer.step()
        if epoch == 1 or epoch % 10 == 0:
            print("Epoch{}, Training loss{}".format(epoch, train_loss))

Epoch1, Training loss2.273669719696045
Epoch10, Training loss0.8017953634262085
Epoch20, Training loss0.39260053634643555
Epoch30, Training loss0.276013046503067
Epoch40, Training loss0.14580915868282318
Epoch50, Training loss0.46126383543014526
Epoch60, Training loss0.36156022548675537
Epoch70, Training loss0.07252158224582672
Epoch80, Training loss0.2878243327140808
Epoch90, Training loss0.14413675665855408
Epoch100, Training loss0.14416943490505219
Epoch110, Training loss0.5065944790840149
Epoch120, Training loss8.552281360607594e-05
Epoch130, Training loss0.14483048021793365
Epoch140, Training loss0.21641796827316284
Epoch150, Training loss0.0719677284359932


In [56]:
# Test Section 
from sys
from PIL import Image
mnist_sample = mnist.MNIST('./sample',train=False,download=True, transform=None)
for j in range(100):
    for i in range(10):
        if mnist_sample[j][1] == 0:
            mnist_sample[j][0].save('zero.png')
        elif mnist_sample[j][1] == 1:
            mnist_sample[j][0].save('one.png')
        elif mnist_sample[j][1] == 2:
            mnist_sample[j][0].save('two.png')
        elif mnist_sample[j][1] == 3:
            mnist_sample[j][0].save('three.png')
        elif mnist_sample[j][1] == 4:
            mnist_sample[j][0].save('four.png')
        elif mnist_sample[j][1] == 5:
            mnist_sample[j][0].save('five.png')
        elif mnist_sample[j][1] == 6:
            mnist_sample[j][0].save('six.png')
        elif mnist_sample[j][1] == 7:
            mnist_sample[j][0].save('seven.png')
        elif mnist_sample[j][1] == 8:
            mnist_sample[j][0].save('eight.png')
        elif mnist_sample[j][1] == 9:
            mnist_sample[j][0].save('nine.png')

In [67]:
if len(sys.argv) != 1:
        print(
            "<Usage>: Prediction value is not coming ex)'one.png' "
        )  # program이나py파일이 하나라도 안들어왔을 경우 exception 처리
        exit(1)
image_ = Image.open('four.png')
image = torch.zeros(16,1,28,28)
transform = transforms.Compose([
                                 transforms.ToTensor(), # image to Tensor
                             ])
image = image + transform(image_)
optimizer = optim.SGD(net.parameters(), lr=0.01)
optimizer.zero_grad()
# forward propagation
model_output = net(image)
print(f"Assume : {model_output.argmax(dim=1)[0]}")

<Usage>: Prediction value is not coming ex)'one.png' 
Assume : 4
