<a href="https://colab.research.google.com/github/Vigneshthanga/258-Deep-Learning/blob/master/Assignment-4/LeNet_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import Libraries

In [0]:
from torch.nn import Module
from torch import nn
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

## Model Definition LeNET5



In [0]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [0]:
model = Model()

## Logic to choose GPU if available else false back to CPU for model training

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [0]:
use_gpu = torch.cuda.is_available()
if use_gpu:
    model = model.to(device)

## Download the MNIST Dataset, Train the model, Predict with Test Dataset

In [6]:
batch_size = 256
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

sgd = SGD(model.parameters(), lr=1e-1)
cross_error = CrossEntropyLoss()
epoch = 100

for _epoch in range(epoch):
    for idx, data in enumerate(train_loader):
        train_x, train_label = data[0].to(device), data[1].to(device)
        label_np = np.zeros((train_label.shape[0], 10))
        sgd.zero_grad()
        predict_y = model(train_x.float())
        _error = cross_error(predict_y, train_label.long())
        if idx % 100 == 0:
            print('idx: {}, _error: {}'.format(idx, _error))
        _error.backward()
        sgd.step()

    correct = 0
    _sum = 0

    for idx, data in enumerate(test_loader):
        test_x, test_label = data[0].to(device), data[1].to(device)
        predict_y = model(test_x.float()).detach()
        predict_ys = predict_y.cpu().data.numpy().argmax(axis = -1)
        label_np = test_label.cpu().data.numpy()
        _ = predict_ys[0] == label_np
        correct += np.sum(_, axis=-1)
        _sum += _.shape[0]
    print('accuracy: {:.2f}'.format(correct / _sum))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./train/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./train/MNIST/raw/train-images-idx3-ubyte.gz to ./train/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./train/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./train/MNIST/raw/train-labels-idx1-ubyte.gz to ./train/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./train/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./train/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./train/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./train/MNIST/raw
Processing...
Done!
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./test/MNIST/raw/train-images-idx3-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./test/MNIST/raw/train-images-idx3-ubyte.gz to ./test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./test/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./test/MNIST/raw/train-labels-idx1-ubyte.gz to ./test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./test/MNIST/raw/t10k-images-idx3-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./test/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./test/MNIST/raw
Processing...
Done!

idx: 0, _error: 2.3035762310028076
idx: 100, _error: 2.2939846515655518
idx: 200, _error: 2.0435593128204346
accuracy: 0.10
idx: 0, _error: 1.4628751277923584
idx: 100, _error: 0.2543770968914032
idx: 200, _error: 0.14520317316055298
accuracy: 0.11
idx: 0, _error: 0.22078575193881989
idx: 100, _error: 0.1387253999710083
idx: 200, _error: 0.11186842620372772
accuracy: 0.11
idx: 0, _error: 0.17354632914066315
idx: 100, _error: 0.10512129962444305
idx: 200, _error: 0.10242320597171783
accuracy: 0.11
idx: 0, _error: 0.12609338760375977
idx: 100, _error: 0.08930975198745728
idx: 200, _error: 0.09915008395910263
accuracy: 0.11
idx: 0, _error: 0.09991477429866791
idx: 100, _error: 0.08035293966531754
idx: 200, _error: 0.09596671909093857
accuracy: 0.11
idx: 0, _error: 0.07819026708602905
idx: 100, _error: 0.07314570248126984
idx: 200, _error: 0.09197083860635757
accuracy: 0.11
idx: 0, _error: 0.0687