<a href="https://colab.research.google.com/github/AvinaashAnandK/DeepLearningJourney/blob/main/%5BDeep_Learning_Assignments%5D_%5BCS198%5D_Assignment_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using the MNIST dataset, achieve 90% accuracy with a fully connected network.
1. Epochs: 1
2. Use Nesterov momentum in optim.SGD

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [2]:
# @title Util to display a sample image
def display_input_image(input,input_type="np"):
  if type not in ["np","tensor"]:
    img_as_np = plt.imread(input)
    processed_type = "np"
  else:
    processed_type = input_type

  print("Displaying image:")

  if processed_type == "np":
    plt.imshow(img_as_np)
  else:
    plt.imshow(input.permute(1,2,0))

  plt.show()

In [3]:
# @title Downloading the dataset - MNIST

mnist_train = datasets.MNIST(root = "./data", train = True, download = True)
mnist_test = datasets.MNIST(root = "./data", train = False, download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4206853.15it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 65690.10it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:06<00:00, 245979.30it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10269826.83it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# Accessing dataset information
print(f"Shape of the dataset: {mnist_train.data.shape}")
print(f"Shape of the labels: {mnist_train.targets.shape}")
print(f"Classes: {mnist_train.classes}")

Shape of the dataset: torch.Size([60000, 28, 28])
Shape of the labels: torch.Size([60000])
Classes: ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


In [5]:
# Creating loaders for train, validation (for hyperparam training) and testing
transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, ), (0.5, ))
])

mnist_train = datasets.MNIST(root = "./data", train = True, download = True, transform = transform)
mnist_test = datasets.MNIST(root = "./data", train = False, download = True, transform = transform)

In [6]:
model_params = {
    "hidden_layer_dim": 4000,
    "num_layers": 5,
    "classification_labels": 10,
    "epochs": 1,
    "learning_rate": 1e-2,
    "batch_size": 32,
    "test_val_split": 0.1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "train_every": 100,
    "nesterov": True,
    "momentum": 0.9
}

model_params['input_dims'] = np.prod(mnist_train.data.shape[1:])

# Change model_params
batch_size = 64
model_params['batch_size'] = batch_size

In [7]:
DATASET_SIZE = mnist_train.data.shape[0]
NUM_TRAIN = int(DATASET_SIZE *(1-model_params['test_val_split']))
NUM_VAL = DATASET_SIZE - NUM_TRAIN
print(f"MNIST Dataset has {DATASET_SIZE} images for training, Training Set Length: {NUM_TRAIN} after validation set of {NUM_VAL} images.")

MNIST Dataset has 60000 images for training, Training Set Length: 54000 after validation set of 6000 images.


In [8]:
train_set, val_set = torch.utils.data.random_split(mnist_train, [NUM_TRAIN, NUM_VAL])
train = DataLoader(train_set, batch_size = model_params['batch_size'], shuffle = True)
val = DataLoader(val_set, batch_size = model_params['batch_size'], shuffle = True)
test = DataLoader(mnist_test, batch_size = model_params['batch_size'], shuffle = True)

In [9]:
def flatten(x):
  batch = x.shape[0]
  return x.view(batch, -1)

class FiveLayerFC(nn.Module):
  def __init__(self, params):
    super().__init__()
    self.num_layers = params['num_layers']
    self.output_size = params['classification_labels']
    self.hidden_size = params['hidden_layer_dim']
    self.input_size = params['input_dims']

    self.layers = nn.ModuleList()

    self.layers.append(nn.Linear(self.input_size, self.hidden_size))
    self.layers.append(nn.ReLU())

    for _ in range(self.num_layers-1):
      self.layers.append(nn.Linear(self.hidden_size,self.hidden_size))
      self.layers.append(nn.ReLU())

    self.layers.append(nn.Linear(self.hidden_size, self.output_size))

  def forward(self, x):
    x = flatten(x)
    for layer in self.layers:
      x = layer(x)
    return x

def check_accuracy(loader, model, model_params):
  device = model_params['device']
  pred_correct, total_count = 0, 0

  model.eval()
  with torch.no_grad():
    for batch, (inputs, labels) in tqdm(enumerate(loader), total=len(loader), desc="Checking Accuracy"):
      inputs, labels = inputs.to(device), labels.to(device)
      preds = model(inputs)
      _, pred_labels = preds.max(1)
      pred_correct += (pred_labels == labels).sum()
      total_count += preds.shape[0]

  accuracy = float(pred_correct)/float(total_count)
  print(f"Accuracy: {accuracy*100}%")

  return accuracy

def train_model(loader_train, loader_val, model, optimizer, model_params):
  device = model_params['device']
  epochs = model_params['epochs']
  train_every = model_params['train_every']

  lossfn = nn.CrossEntropyLoss()

  train_acc, val_acc = [], []

  model.to(device)


  for epoch in tqdm(range(epochs),desc="Epochs"):
    model.train()
    for batch, (inputs, labels) in tqdm(enumerate(loader_train),total=len(loader_train),desc=f"Batches in epoch {epoch}"):
      inputs, labels = inputs.to(device), labels.to(device)
      preds = model(inputs)

      loss = lossfn(preds, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if batch % train_every == 0:
        train_acc.append(check_accuracy(loader_train, model, model_params))

    val_acc.append(check_accuracy(loader_val, model, model_params))
  return train_acc, val_acc

In [20]:
model_params['learning_rate']

0.01

In [28]:
if torch.cuda.is_available():
  model_params['device'] = torch.device("cuda")
else:
  model_params['device'] = torch.device("cpu")

print(f"Using device: {model_params['device']}")
model = FiveLayerFC(model_params)

optimizer = optim.SGD(model.parameters(), lr=model_params['learning_rate'], momentum=model_params['momentum'])
# optimizer = optim.Adam(model.parameters(), lr=0.001)

Using device: cuda


In [29]:
train_acc, val_acc = train_model(train, val, model, optimizer, model_params)

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]
Batches in epoch 0:   0%|          | 0/844 [00:00<?, ?it/s][A

Checking Accuracy:   0%|          | 0/844 [00:00<?, ?it/s][A[A

Checking Accuracy:   1%|          | 8/844 [00:00<00:11, 73.52it/s][A[A

Checking Accuracy:   2%|▏         | 16/844 [00:00<00:11, 69.50it/s][A[A

Checking Accuracy:   3%|▎         | 23/844 [00:00<00:12, 66.34it/s][A[A

Checking Accuracy:   4%|▎         | 30/844 [00:00<00:12, 64.18it/s][A[A

Checking Accuracy:   4%|▍         | 37/844 [00:00<00:13, 60.46it/s][A[A

Checking Accuracy:   5%|▌         | 44/844 [00:00<00:12, 62.40it/s][A[A

Checking Accuracy:   6%|▌         | 51/844 [00:00<00:12, 64.07it/s][A[A

Checking Accuracy:   7%|▋         | 58/844 [00:00<00:11, 65.62it/s][A[A

Checking Accuracy:   8%|▊         | 65/844 [00:00<00:11, 65.99it/s][A[A

Checking Accuracy:   9%|▊         | 72/844 [00:01<00:11, 66.94it/s][A[A

Checking Accuracy:   9%|▉         | 79/844 [00:01<00:11, 66.85it/s][A[A



Accuracy: 8.88888888888889%



Batches in epoch 0:   2%|▏         | 13/844 [00:14<09:09,  1.51it/s][A
Batches in epoch 0:   2%|▏         | 19/844 [00:14<05:09,  2.67it/s][A
Batches in epoch 0:   3%|▎         | 25/844 [00:14<03:12,  4.24it/s][A
Batches in epoch 0:   4%|▎         | 31/844 [00:14<02:09,  6.29it/s][A
Batches in epoch 0:   4%|▍         | 37/844 [00:14<01:29,  8.99it/s][A
Batches in epoch 0:   5%|▌         | 43/844 [00:14<01:04, 12.42it/s][A
Batches in epoch 0:   6%|▌         | 50/844 [00:14<00:46, 17.24it/s][A
Batches in epoch 0:   7%|▋         | 56/844 [00:15<00:36, 21.75it/s][A
Batches in epoch 0:   7%|▋         | 63/844 [00:15<00:28, 27.73it/s][A
Batches in epoch 0:   8%|▊         | 69/844 [00:15<00:23, 32.68it/s][A
Batches in epoch 0:   9%|▉         | 75/844 [00:15<00:20, 37.51it/s][A
Batches in epoch 0:  10%|▉         | 81/844 [00:15<00:18, 41.63it/s][A
Batches in epoch 0:  10%|█         | 87/844 [00:15<00:17, 43.89it/s][A
Batches in epoch 0:  11%|█         | 93/844 [00:15<00:16, 46.67

Accuracy: 59.57222222222222%



Batches in epoch 0:  13%|█▎        | 111/844 [00:29<06:12,  1.97it/s][A
Batches in epoch 0:  14%|█▍        | 117/844 [00:29<04:21,  2.78it/s][A
Batches in epoch 0:  15%|█▍        | 123/844 [00:29<03:04,  3.92it/s][A
Batches in epoch 0:  15%|█▌        | 130/844 [00:29<02:05,  5.71it/s][A
Batches in epoch 0:  16%|█▌        | 137/844 [00:30<01:27,  8.08it/s][A
Batches in epoch 0:  17%|█▋        | 143/844 [00:30<01:05, 10.72it/s][A
Batches in epoch 0:  18%|█▊        | 149/844 [00:30<00:50, 13.71it/s][A
Batches in epoch 0:  18%|█▊        | 155/844 [00:30<00:41, 16.75it/s][A
Batches in epoch 0:  19%|█▉        | 160/844 [00:30<00:35, 19.24it/s][A
Batches in epoch 0:  20%|█▉        | 165/844 [00:30<00:30, 22.62it/s][A
Batches in epoch 0:  20%|██        | 170/844 [00:30<00:25, 26.33it/s][A
Batches in epoch 0:  21%|██        | 175/844 [00:30<00:22, 29.70it/s][A
Batches in epoch 0:  21%|██▏       | 180/844 [00:31<00:19, 33.27it/s][A
Batches in epoch 0:  22%|██▏       | 185/844 [00:3

Accuracy: 83.7462962962963%



Batches in epoch 0:  26%|██▌       | 216/844 [00:48<05:09,  2.03it/s][A
Batches in epoch 0:  26%|██▋       | 222/844 [00:49<03:30,  2.96it/s][A
Batches in epoch 0:  27%|██▋       | 229/844 [00:49<02:18,  4.45it/s][A
Batches in epoch 0:  28%|██▊       | 235/844 [00:49<01:39,  6.13it/s][A
Batches in epoch 0:  29%|██▊       | 241/844 [00:49<01:11,  8.40it/s][A
Batches in epoch 0:  29%|██▉       | 247/844 [00:49<00:52, 11.29it/s][A
Batches in epoch 0:  30%|██▉       | 253/844 [00:49<00:40, 14.63it/s][A
Batches in epoch 0:  31%|███       | 259/844 [00:49<00:31, 18.84it/s][A
Batches in epoch 0:  31%|███▏      | 265/844 [00:49<00:24, 23.73it/s][A
Batches in epoch 0:  32%|███▏      | 272/844 [00:49<00:19, 29.79it/s][A
Batches in epoch 0:  33%|███▎      | 278/844 [00:50<00:16, 34.82it/s][A
Batches in epoch 0:  34%|███▎      | 284/844 [00:50<00:14, 39.67it/s][A
Batches in epoch 0:  34%|███▍      | 290/844 [00:50<00:12, 43.10it/s][A
Batches in epoch 0:  35%|███▌      | 297/844 [00:5

Accuracy: 89.16481481481482%



Batches in epoch 0:  37%|███▋      | 312/844 [01:04<03:25,  2.59it/s][A
Batches in epoch 0:  38%|███▊      | 318/844 [01:04<02:22,  3.69it/s][A
Batches in epoch 0:  38%|███▊      | 324/844 [01:04<01:40,  5.19it/s][A
Batches in epoch 0:  39%|███▉      | 330/844 [01:04<01:11,  7.20it/s][A
Batches in epoch 0:  40%|███▉      | 336/844 [01:05<00:51,  9.83it/s][A
Batches in epoch 0:  41%|████      | 343/844 [01:05<00:36, 13.70it/s][A
Batches in epoch 0:  41%|████▏     | 349/844 [01:05<00:28, 17.68it/s][A
Batches in epoch 0:  42%|████▏     | 355/844 [01:05<00:21, 22.25it/s][A
Batches in epoch 0:  43%|████▎     | 361/844 [01:05<00:17, 26.92it/s][A
Batches in epoch 0:  43%|████▎     | 367/844 [01:05<00:15, 31.16it/s][A
Batches in epoch 0:  44%|████▍     | 373/844 [01:05<00:13, 35.83it/s][A
Batches in epoch 0:  45%|████▍     | 379/844 [01:05<00:11, 40.61it/s][A
Batches in epoch 0:  46%|████▌     | 385/844 [01:05<00:10, 44.43it/s][A
Batches in epoch 0:  46%|████▋     | 391/844 [01:0

Accuracy: 89.14814814814814%



Batches in epoch 0:  49%|████▉     | 414/844 [01:20<02:40,  2.67it/s][A
Batches in epoch 0:  50%|████▉     | 420/844 [01:20<01:52,  3.77it/s][A
Batches in epoch 0:  50%|█████     | 426/844 [01:20<01:19,  5.26it/s][A
Batches in epoch 0:  51%|█████     | 432/844 [01:20<00:56,  7.27it/s][A
Batches in epoch 0:  52%|█████▏    | 438/844 [01:20<00:42,  9.65it/s][A
Batches in epoch 0:  53%|█████▎    | 444/844 [01:21<00:31, 12.61it/s][A
Batches in epoch 0:  53%|█████▎    | 449/844 [01:21<00:25, 15.48it/s][A
Batches in epoch 0:  54%|█████▍    | 454/844 [01:21<00:20, 18.73it/s][A
Batches in epoch 0:  54%|█████▍    | 459/844 [01:21<00:17, 22.37it/s][A
Batches in epoch 0:  55%|█████▍    | 464/844 [01:21<00:14, 26.23it/s][A
Batches in epoch 0:  56%|█████▌    | 469/844 [01:21<00:12, 29.63it/s][A
Batches in epoch 0:  56%|█████▌    | 474/844 [01:21<00:11, 32.46it/s][A
Batches in epoch 0:  57%|█████▋    | 479/844 [01:21<00:10, 36.06it/s][A
Batches in epoch 0:  57%|█████▋    | 484/844 [01:2

Accuracy: 92.02222222222223%



Batches in epoch 0:  61%|██████    | 515/844 [01:36<02:19,  2.36it/s][A
Batches in epoch 0:  62%|██████▏   | 522/844 [01:37<01:28,  3.63it/s][A
Batches in epoch 0:  63%|██████▎   | 528/844 [01:37<01:01,  5.11it/s][A
Batches in epoch 0:  63%|██████▎   | 534/844 [01:37<00:43,  7.10it/s][A
Batches in epoch 0:  64%|██████▍   | 540/844 [01:37<00:31,  9.65it/s][A
Batches in epoch 0:  65%|██████▍   | 547/844 [01:37<00:22, 13.47it/s][A
Batches in epoch 0:  66%|██████▌   | 553/844 [01:37<00:17, 17.02it/s][A
Batches in epoch 0:  66%|██████▌   | 559/844 [01:37<00:13, 21.31it/s][A
Batches in epoch 0:  67%|██████▋   | 565/844 [01:37<00:10, 26.24it/s][A
Batches in epoch 0:  68%|██████▊   | 571/844 [01:37<00:08, 31.11it/s][A
Batches in epoch 0:  68%|██████▊   | 577/844 [01:38<00:07, 35.82it/s][A
Batches in epoch 0:  69%|██████▉   | 583/844 [01:38<00:06, 40.58it/s][A
Batches in epoch 0:  70%|██████▉   | 589/844 [01:38<00:05, 44.19it/s][A
Batches in epoch 0:  70%|███████   | 595/844 [01:3

Accuracy: 93.49074074074075%



Batches in epoch 0:  73%|███████▎  | 612/844 [01:52<01:26,  2.68it/s][A
Batches in epoch 0:  73%|███████▎  | 618/844 [01:52<00:59,  3.78it/s][A
Batches in epoch 0:  74%|███████▍  | 624/844 [01:52<00:41,  5.29it/s][A
Batches in epoch 0:  75%|███████▍  | 630/844 [01:52<00:29,  7.31it/s][A
Batches in epoch 0:  75%|███████▌  | 636/844 [01:52<00:20,  9.95it/s][A
Batches in epoch 0:  76%|███████▌  | 642/844 [01:53<00:15, 13.29it/s][A
Batches in epoch 0:  77%|███████▋  | 648/844 [01:53<00:11, 17.34it/s][A
Batches in epoch 0:  77%|███████▋  | 654/844 [01:53<00:08, 21.98it/s][A
Batches in epoch 0:  78%|███████▊  | 660/844 [01:53<00:06, 27.03it/s][A
Batches in epoch 0:  79%|███████▉  | 666/844 [01:53<00:05, 31.43it/s][A
Batches in epoch 0:  80%|███████▉  | 672/844 [01:53<00:04, 34.70it/s][A
Batches in epoch 0:  80%|████████  | 678/844 [01:53<00:04, 39.21it/s][A
Batches in epoch 0:  81%|████████  | 684/844 [01:53<00:03, 43.31it/s][A
Batches in epoch 0:  82%|████████▏ | 690/844 [01:5

Accuracy: 93.73518518518519%



Batches in epoch 0:  85%|████████▍ | 714/844 [02:08<00:47,  2.76it/s][A
Batches in epoch 0:  85%|████████▌ | 720/844 [02:08<00:32,  3.85it/s][A
Batches in epoch 0:  86%|████████▌ | 726/844 [02:08<00:22,  5.32it/s][A
Batches in epoch 0:  87%|████████▋ | 732/844 [02:08<00:15,  7.32it/s][A
Batches in epoch 0:  87%|████████▋ | 738/844 [02:08<00:10,  9.93it/s][A
Batches in epoch 0:  88%|████████▊ | 744/844 [02:08<00:07, 13.24it/s][A
Batches in epoch 0:  89%|████████▉ | 750/844 [02:09<00:05, 17.12it/s][A
Batches in epoch 0:  90%|████████▉ | 756/844 [02:09<00:04, 21.65it/s][A
Batches in epoch 0:  90%|█████████ | 762/844 [02:09<00:03, 26.49it/s][A
Batches in epoch 0:  91%|█████████ | 768/844 [02:09<00:02, 31.74it/s][A
Batches in epoch 0:  92%|█████████▏| 774/844 [02:09<00:01, 36.66it/s][A
Batches in epoch 0:  92%|█████████▏| 780/844 [02:09<00:01, 38.95it/s][A
Batches in epoch 0:  93%|█████████▎| 786/844 [02:09<00:01, 43.07it/s][A
Batches in epoch 0:  94%|█████████▍| 792/844 [02:0

Accuracy: 93.85555555555555%



Batches in epoch 0:  96%|█████████▌| 810/844 [02:24<00:15,  2.15it/s][A
Batches in epoch 0:  97%|█████████▋| 815/844 [02:24<00:09,  3.03it/s][A
Batches in epoch 0:  97%|█████████▋| 820/844 [02:25<00:05,  4.24it/s][A
Batches in epoch 0:  98%|█████████▊| 826/844 [02:25<00:02,  6.19it/s][A
Batches in epoch 0:  98%|█████████▊| 831/844 [02:25<00:01,  8.29it/s][A
Batches in epoch 0:  99%|█████████▉| 836/844 [02:25<00:00, 10.97it/s][A
Batches in epoch 0: 100%|██████████| 844/844 [02:25<00:00,  5.80it/s]

Checking Accuracy:   0%|          | 0/94 [00:00<?, ?it/s][A
Checking Accuracy:   4%|▍         | 4/94 [00:00<00:02, 31.98it/s][A
Checking Accuracy:   9%|▊         | 8/94 [00:00<00:02, 35.64it/s][A
Checking Accuracy:  14%|█▍        | 13/94 [00:00<00:02, 39.67it/s][A
Checking Accuracy:  19%|█▉        | 18/94 [00:00<00:01, 41.42it/s][A
Checking Accuracy:  24%|██▍       | 23/94 [00:00<00:01, 40.40it/s][A
Checking Accuracy:  30%|██▉       | 28/94 [00:00<00:01, 40.35it/s][A
Checking Ac

Accuracy: 94.38333333333333%





In [30]:
model_params

{'hidden_layer_dim': 4000,
 'num_layers': 5,
 'classification_labels': 10,
 'epochs': 1,
 'learning_rate': 0.01,
 'batch_size': 64,
 'test_val_split': 0.1,
 'device': device(type='cuda'),
 'train_every': 100,
 'nesterov': True,
 'momentum': 0.9,
 'input_dims': 784}

In [31]:
train_acc

[0.08888888888888889,
 0.5957222222222223,
 0.837462962962963,
 0.8916481481481482,
 0.8914814814814814,
 0.9202222222222223,
 0.9349074074074074,
 0.9373518518518519,
 0.9385555555555556]

In [32]:
val_acc

[0.9438333333333333]