<a href="https://colab.research.google.com/github/akshaya-nagarajan/DeepLearningProjects/blob/master/Assignment_4/DLAssignment4_LeNet_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import Required Libraries

In [0]:
from torch.nn import Module
from torch import nn
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

## Create the LeNet 5 Model

In [0]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

## Download the MNIST Dataset, Train the model, Predict with Test Dataset

In [0]:
model = Model()

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [0]:
use_gpu = torch.cuda.is_available()
if use_gpu:
    model = model.to(device)

In [0]:
batch_size = 256
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

sgd = SGD(model.parameters(), lr=1e-1)
cross_error = CrossEntropyLoss()
epoch = 100

for _epoch in range(epoch):
    for idx, data in enumerate(train_loader):
        train_x, train_label = data[0].to(device), data[1].to(device)
        label_np = np.zeros((train_label.shape[0], 10))
        sgd.zero_grad()
        predict_y = model(train_x.float())
        _error = cross_error(predict_y, train_label.long())
        if idx % 100 == 0:
            print('idx: {}, _error: {}'.format(idx, _error))
        _error.backward()
        sgd.step()

    correct = 0
    _sum = 0

    for idx, data in enumerate(test_loader):
        test_x, test_label = data[0].to(device), data[1].to(device)
        predict_y = model(test_x.float()).detach()
        predict_ys = predict_y.cpu().data.numpy().argmax(axis = -1) #np.argmax(predict_y.cuda(), axis=-1) #predict.cpu().numpy() predict_y.cpu().data.numpy().argmax()
        #print('Predicted',predict_y)
        label_np = test_label.cpu().data.numpy() #test_label.numpy()
        #print('Actual Label', label_np.argmax())
        _ = predict_ys[0] == label_np
        correct += np.sum(_, axis=-1)
        _sum += _.shape[0]

    print('accuracy: {:.2f}'.format(correct / _sum))

idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, _error: 2.3025848865509033
idx: 100, _error: 2.3025848865509033
idx: 200, _error: 2.3025848865509033
accuracy: 0.10
idx: 0, 