# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [6]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.optim as optim
import time
import torch.nn.functional as F
BATCH_SIZE = 128
NUM_EPOCHS = 15

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [7]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [8]:
for images, labels in tqdm(train_loader):
    print(images.shape)


  0%|          | 0/468 [00:00<?, ?it/s][A
  1%|▏         | 6/468 [00:00<00:07, 57.83it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


  3%|▎         | 13/468 [00:00<00:07, 58.89it/s][A
  4%|▍         | 20/468 [00:00<00:07, 59.37it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



  6%|▌         | 26/468 [00:00<00:07, 59.24it/s][A
  7%|▋         | 33/468 [00:00<00:07, 59.92it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


  9%|▊         | 40/468 [00:00<00:07, 60.25it/s][A
 10%|▉         | 46/468 [00:00<00:07, 60.12it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 11%|█         | 52/468 [00:00<00:06, 59.77it/s][A
 13%|█▎        | 59/468 [00:00<00:06, 60.30it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 14%|█▍        | 65/468 [00:01<00:06, 59.90it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 15%|█▌        | 72/468 [00:01<00:06, 59.93it/s][A
 17%|█▋        | 78/468 [00:01<00:06, 59.64it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 18%|█▊        | 84/468 [00:01<00:06, 59.64it/s][A
 19%|█▉        | 90/468 [00:01<00:06, 59.44it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 21%|██        | 96/468 [00:01<00:06, 59.48it/s][A
 22%|██▏       | 103/468 [00:01<00:06, 60.09it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 23%|██▎       | 109/468 [00:01<00:05, 60.01it/s][A
 25%|██▍       | 116/468 [00:01<00:05, 60.16it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 26%|██▋       | 123/468 [00:02<00:05, 59.95it/s][A
 28%|██▊       | 129/468 [00:02<00:05, 59.83it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 29%|██▉       | 136/468 [00:02<00:05, 60.25it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 31%|███       | 143/468 [00:02<00:05, 60.09it/s][A
 32%|███▏      | 150/468 [00:02<00:05, 59.82it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 33%|███▎      | 156/468 [00:02<00:05, 59.74it/s][A
 35%|███▍      | 163/468 [00:02<00:05, 60.28it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 36%|███▋      | 170/468 [00:02<00:04, 60.66it/s][A
 38%|███▊      | 177/468 [00:02<00:04, 61.09it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 39%|███▉      | 184/468 [00:03<00:04, 60.60it/s][A
 41%|████      | 191/468 [00:03<00:04, 60.73it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 42%|████▏     | 198/468 [00:03<00:04, 60.20it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 44%|████▍     | 205/468 [00:03<00:04, 59.59it/s][A
 45%|████▌     | 211/468 [00:03<00:04, 59.57it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 46%|████▋     | 217/468 [00:03<00:04, 59.22it/s][A
 48%|████▊     | 223/468 [00:03<00:04, 58.63it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 49%|████▉     | 229/468 [00:03<00:04, 58.39it/s][A
 50%|█████     | 235/468 [00:03<00:03, 58.30it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 52%|█████▏    | 242/468 [00:04<00:03, 59.04it/s][A



torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 53%|█████▎    | 249/468 [00:04<00:03, 59.61it/s][A
 54%|█████▍    | 255/468 [00:04<00:03, 59.58it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 56%|█████▌    | 262/468 [00:04<00:03, 59.85it/s][A
 57%|█████▋    | 269/468 [00:04<00:03, 60.36it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 59%|█████▉    | 276/468 [00:04<00:03, 60.56it/s][A
 60%|██████    | 283/468 [00:04<00:03, 60.23it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 62%|██████▏   | 290/468 [00:04<00:02, 60.46it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 63%|██████▎   | 297/468 [00:04<00:02, 60.39it/s][A
 65%|██████▍   | 304/468 [00:05<00:02, 60.42it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 66%|██████▋   | 311/468 [00:05<00:02, 60.29it/s][A
 68%|██████▊   | 318/468 [00:05<00:02, 60.36it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 69%|██████▉   | 325/468 [00:05<00:02, 59.78it/s][A



torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 71%|███████   | 332/468 [00:05<00:02, 59.99it/s][A
 72%|███████▏  | 339/468 [00:05<00:02, 59.76it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 74%|███████▎  | 345/468 [00:05<00:02, 59.25it/s][A
 75%|███████▌  | 351/468 [00:05<00:01, 58.99it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 76%|███████▋  | 357/468 [00:05<00:01, 58.47it/s][A
 78%|███████▊  | 364/468 [00:06<00:01, 59.37it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 79%|███████▉  | 371/468 [00:06<00:01, 59.86it/s][A
 81%|████████  | 378/468 [00:06<00:01, 59.90it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 82%|████████▏ | 384/468 [00:06<00:01, 59.79it/s][A
 83%|████████▎ | 390/468 [00:06<00:01, 59.73it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 85%|████████▍ | 397/468 [00:06<00:01, 59.80it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 86%|████████▋ | 404/468 [00:06<00:01, 60.01it/s][A
 88%|████████▊ | 411/468 [00:06<00:00, 59.54it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 89%|████████▉ | 418/468 [00:06<00:00, 59.74it/s][A
 91%|█████████ | 424/468 [00:07<00:00, 59.77it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 92%|█████████▏| 431/468 [00:07<00:00, 60.06it/s][A
 94%|█████████▎| 438/468 [00:07<00:00, 60.42it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



 95%|█████████▌| 445/468 [00:07<00:00, 60.60it/s][A


torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])


 97%|█████████▋| 452/468 [00:07<00:00, 60.80it/s][A
 98%|█████████▊| 459/468 [00:07<00:00, 60.40it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])



100%|██████████| 468/468 [00:07<00:00, 59.96it/s][A

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 28])





Then, we define the model, object function and optimizer that we use to classify.

In [9]:
class SimpleNet(nn.Module):
# TODO:define model
    def __init__(self, in_channel, channel_1, channel_2, num_classes):
        super().__init__()
        self.conv1=torch.nn.Conv2d(in_channel,channel_1 , kernel_size=5, stride=2, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        nn.init.kaiming_normal_(self.conv1.weight)
        
        self.conv2=torch.nn.Conv2d(channel_1,channel_2 ,kernel_size=(3,3) , stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        nn.init.kaiming_normal_(self.conv2.weight)
       
        self.fc=nn.Linear(channel_2*100, num_classes)
        nn.init.kaiming_normal_(self.fc.weight)

    def forward(self,x):
        scores=None
        out_1 = F.relu(self.conv1(x))
        out_2 = F.relu(self.conv2(out_1))
        
        N = out_2.shape[0]
        out_flat=out_2.view(N,-1)
        
        scores = self.fc(out_flat)
        return scores
        
model = SimpleNet(in_channel=1, channel_1=32, channel_2=16, num_classes=10)

# TODO:define loss function and optimiter
#criterion = F.cross_entropy()
optimizer = optim.SGD(model.parameters(), lr=3e-3)

Next, we can start to train and evaluate!

In [None]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
  
        model.train()


        scores=model(images)
        #print(scores)
        loss= F.cross_entropy(scores, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
        



  0%|          | 0/468 [00:00<?, ?it/s][A
  0%|          | 2/468 [00:00<00:27, 17.14it/s][A
  1%|          | 4/468 [00:00<00:28, 16.47it/s][A
  1%|▏         | 6/468 [00:00<00:28, 16.14it/s][A
  2%|▏         | 8/468 [00:00<00:29, 15.60it/s][A
  2%|▏         | 10/468 [00:00<00:29, 15.51it/s][A
  3%|▎         | 12/468 [00:00<00:29, 15.26it/s][A
  3%|▎         | 14/468 [00:00<00:29, 15.23it/s][A
  3%|▎         | 16/468 [00:01<00:30, 14.93it/s][A
  4%|▍         | 18/468 [00:01<00:29, 15.07it/s][A
  4%|▍         | 20/468 [00:01<00:29, 15.16it/s][A
  5%|▍         | 22/468 [00:01<00:29, 14.95it/s][A
  5%|▌         | 24/468 [00:01<00:29, 14.87it/s][A
  6%|▌         | 26/468 [00:01<00:29, 15.03it/s][A
  6%|▌         | 28/468 [00:01<00:29, 14.89it/s][A
  6%|▋         | 30/468 [00:01<00:29, 14.80it/s][A
  7%|▋         | 32/468 [00:02<00:28, 15.10it/s][A
  7%|▋         | 34/468 [00:02<00:28, 14.99it/s][A
  8%|▊         | 36/468 [00:02<00:28, 15.07it/s][A
  8%|▊         | 38/468 

 66%|██████▌   | 310/468 [00:22<00:12, 12.90it/s][A
 67%|██████▋   | 312/468 [00:22<00:12, 13.00it/s][A
 67%|██████▋   | 314/468 [00:23<00:11, 12.89it/s][A
 68%|██████▊   | 316/468 [00:23<00:12, 12.46it/s][A
 68%|██████▊   | 318/468 [00:23<00:12, 12.22it/s][A
 68%|██████▊   | 320/468 [00:23<00:12, 11.82it/s][A
 69%|██████▉   | 322/468 [00:23<00:12, 12.05it/s][A
 69%|██████▉   | 324/468 [00:23<00:11, 12.43it/s][A
 70%|██████▉   | 326/468 [00:24<00:11, 12.64it/s][A
 70%|███████   | 328/468 [00:24<00:10, 12.79it/s][A
 71%|███████   | 330/468 [00:24<00:10, 12.90it/s][A
 71%|███████   | 332/468 [00:24<00:10, 12.88it/s][A
 71%|███████▏  | 334/468 [00:24<00:10, 12.93it/s][A
 72%|███████▏  | 336/468 [00:24<00:10, 12.80it/s][A
 72%|███████▏  | 338/468 [00:24<00:10, 12.78it/s][A
 73%|███████▎  | 340/468 [00:25<00:09, 12.87it/s][A
 73%|███████▎  | 342/468 [00:25<00:09, 12.96it/s][A
 74%|███████▎  | 344/468 [00:25<00:09, 13.09it/s][A
 74%|███████▍  | 346/468 [00:25<00:09, 13.29it

 32%|███▏      | 150/468 [00:11<00:23, 13.60it/s][A
 32%|███▏      | 152/468 [00:11<00:22, 13.74it/s][A
 33%|███▎      | 154/468 [00:11<00:22, 13.72it/s][A
 33%|███▎      | 156/468 [00:11<00:22, 13.74it/s][A
 34%|███▍      | 158/468 [00:11<00:22, 13.52it/s][A
 34%|███▍      | 160/468 [00:12<00:23, 13.38it/s][A
 35%|███▍      | 162/468 [00:12<00:22, 13.55it/s][A
 35%|███▌      | 164/468 [00:12<00:22, 13.51it/s][A
 35%|███▌      | 166/468 [00:12<00:22, 13.48it/s][A
 36%|███▌      | 168/468 [00:12<00:22, 13.51it/s][A
 36%|███▋      | 170/468 [00:12<00:22, 13.49it/s][A
 37%|███▋      | 172/468 [00:12<00:22, 13.26it/s][A
 37%|███▋      | 174/468 [00:13<00:22, 13.22it/s][A
 38%|███▊      | 176/468 [00:13<00:22, 13.14it/s][A
 38%|███▊      | 178/468 [00:13<00:22, 12.99it/s][A
 38%|███▊      | 180/468 [00:13<00:23, 12.32it/s][A
 39%|███▉      | 182/468 [00:13<00:23, 12.18it/s][A
 39%|███▉      | 184/468 [00:13<00:23, 12.06it/s][A
 40%|███▉      | 186/468 [00:14<00:23, 12.02it

 98%|█████████▊| 458/468 [00:35<00:00, 12.35it/s][A
 98%|█████████▊| 460/468 [00:35<00:00, 11.70it/s][A
 99%|█████████▊| 462/468 [00:35<00:00, 11.77it/s][A
 99%|█████████▉| 464/468 [00:35<00:00, 11.91it/s][A
100%|█████████▉| 466/468 [00:36<00:00, 11.89it/s][A
100%|██████████| 468/468 [00:36<00:00, 12.90it/s][A

  0%|          | 0/468 [00:00<?, ?it/s][A
  0%|          | 2/468 [00:00<00:34, 13.59it/s][A
  1%|          | 4/468 [00:00<00:34, 13.35it/s][A
  1%|▏         | 6/468 [00:00<00:35, 13.12it/s][A
  2%|▏         | 8/468 [00:00<00:35, 13.10it/s][A
  2%|▏         | 10/468 [00:00<00:34, 13.19it/s][A
  3%|▎         | 12/468 [00:00<00:34, 13.22it/s][A
  3%|▎         | 14/468 [00:01<00:35, 12.73it/s][A
  3%|▎         | 16/468 [00:01<00:36, 12.36it/s][A
  4%|▍         | 18/468 [00:01<00:36, 12.40it/s][A
  4%|▍         | 20/468 [00:01<00:35, 12.63it/s][A
  5%|▍         | 22/468 [00:01<00:35, 12.71it/s][A
  5%|▌         | 24/468 [00:01<00:34, 12.84it/s][A
  6%|▌         | 2

 63%|██████▎   | 297/468 [00:23<00:13, 12.40it/s][A
 64%|██████▍   | 299/468 [00:24<00:13, 12.67it/s][A
 64%|██████▍   | 301/468 [00:24<00:12, 13.04it/s][A
 65%|██████▍   | 303/468 [00:24<00:13, 12.65it/s][A
 65%|██████▌   | 305/468 [00:24<00:12, 12.78it/s][A
 66%|██████▌   | 307/468 [00:24<00:12, 12.81it/s][A
 66%|██████▌   | 309/468 [00:24<00:12, 12.78it/s][A
 66%|██████▋   | 311/468 [00:25<00:12, 12.73it/s][A
 67%|██████▋   | 313/468 [00:25<00:12, 12.56it/s][A
 67%|██████▋   | 315/468 [00:25<00:12, 12.73it/s][A
 68%|██████▊   | 317/468 [00:25<00:11, 12.96it/s][A
 68%|██████▊   | 319/468 [00:25<00:11, 13.31it/s][A
 69%|██████▊   | 321/468 [00:25<00:10, 13.48it/s][A
 69%|██████▉   | 323/468 [00:25<00:10, 13.19it/s][A
 69%|██████▉   | 325/468 [00:26<00:10, 13.24it/s][A
 70%|██████▉   | 327/468 [00:26<00:10, 13.42it/s][A
 70%|███████   | 329/468 [00:26<00:10, 13.67it/s][A
 71%|███████   | 331/468 [00:26<00:10, 13.56it/s][A
 71%|███████   | 333/468 [00:26<00:10, 12.82it

 29%|██▉       | 138/468 [00:10<00:25, 13.05it/s][A
 30%|██▉       | 140/468 [00:10<00:25, 12.78it/s][A
 30%|███       | 142/468 [00:10<00:25, 12.64it/s][A
 31%|███       | 144/468 [00:10<00:25, 12.60it/s][A
 31%|███       | 146/468 [00:10<00:25, 12.69it/s][A
 32%|███▏      | 148/468 [00:11<00:25, 12.74it/s][A
 32%|███▏      | 150/468 [00:11<00:24, 12.85it/s][A
 32%|███▏      | 152/468 [00:11<00:24, 13.02it/s][A
 33%|███▎      | 154/468 [00:11<00:24, 12.91it/s][A
 33%|███▎      | 156/468 [00:11<00:24, 12.96it/s][A
 34%|███▍      | 158/468 [00:11<00:23, 13.10it/s][A
 34%|███▍      | 160/468 [00:12<00:23, 13.22it/s][A
 35%|███▍      | 162/468 [00:12<00:23, 13.17it/s][A
 35%|███▌      | 164/468 [00:12<00:22, 13.44it/s][A
 35%|███▌      | 166/468 [00:12<00:22, 13.45it/s][A
 36%|███▌      | 168/468 [00:12<00:22, 13.48it/s][A
 36%|███▋      | 170/468 [00:12<00:21, 13.57it/s][A
 37%|███▋      | 172/468 [00:12<00:21, 13.58it/s][A
 37%|███▋      | 174/468 [00:13<00:21, 13.65it

 95%|█████████▌| 446/468 [00:34<00:01, 12.91it/s][A
 96%|█████████▌| 448/468 [00:34<00:01, 12.81it/s][A
 96%|█████████▌| 450/468 [00:34<00:01, 12.59it/s][A
 97%|█████████▋| 452/468 [00:34<00:01, 12.88it/s][A
 97%|█████████▋| 454/468 [00:34<00:01, 12.81it/s][A
 97%|█████████▋| 456/468 [00:35<00:00, 12.49it/s][A
 98%|█████████▊| 458/468 [00:35<00:00, 12.70it/s][A
 98%|█████████▊| 460/468 [00:35<00:00, 12.91it/s][A
 99%|█████████▊| 462/468 [00:35<00:00, 13.08it/s][A
 99%|█████████▉| 464/468 [00:35<00:00, 13.39it/s][A
100%|█████████▉| 466/468 [00:35<00:00, 13.30it/s][A
100%|██████████| 468/468 [00:36<00:00, 12.98it/s][A

  0%|          | 0/468 [00:00<?, ?it/s][A
  0%|          | 2/468 [00:00<00:32, 14.29it/s][A
  1%|          | 4/468 [00:00<00:32, 14.19it/s][A
  1%|▏         | 6/468 [00:00<00:33, 13.87it/s][A
  2%|▏         | 8/468 [00:00<00:33, 13.70it/s][A
  2%|▏         | 10/468 [00:00<00:33, 13.87it/s][A
  3%|▎         | 12/468 [00:00<00:33, 13.73it/s][A
  3%|▎      

 61%|██████    | 286/468 [00:23<00:15, 12.05it/s][A
 62%|██████▏   | 288/468 [00:23<00:14, 12.48it/s][A
 62%|██████▏   | 290/468 [00:23<00:14, 12.68it/s][A
 62%|██████▏   | 292/468 [00:23<00:14, 12.46it/s][A
 63%|██████▎   | 294/468 [00:23<00:13, 12.58it/s][A
 63%|██████▎   | 296/468 [00:23<00:13, 12.59it/s][A
 64%|██████▎   | 298/468 [00:24<00:13, 12.77it/s][A
 64%|██████▍   | 300/468 [00:24<00:12, 13.04it/s][A
 65%|██████▍   | 302/468 [00:24<00:12, 13.19it/s][A
 65%|██████▍   | 304/468 [00:24<00:12, 13.20it/s][A
 65%|██████▌   | 306/468 [00:24<00:12, 13.40it/s][A
 66%|██████▌   | 308/468 [00:24<00:11, 13.52it/s][A
 66%|██████▌   | 310/468 [00:24<00:11, 13.26it/s][A
 67%|██████▋   | 312/468 [00:25<00:11, 13.21it/s][A
 67%|██████▋   | 314/468 [00:25<00:11, 13.33it/s][A
 68%|██████▊   | 316/468 [00:25<00:11, 13.17it/s][A
 68%|██████▊   | 318/468 [00:25<00:11, 13.09it/s][A
 68%|██████▊   | 320/468 [00:25<00:11, 13.35it/s][A
 69%|██████▉   | 322/468 [00:25<00:10, 13.40it

 27%|██▋       | 126/468 [00:09<00:25, 13.44it/s][A
 27%|██▋       | 128/468 [00:09<00:25, 13.51it/s][A
 28%|██▊       | 130/468 [00:09<00:25, 13.51it/s][A
 28%|██▊       | 132/468 [00:09<00:25, 13.43it/s][A
 29%|██▊       | 134/468 [00:10<00:24, 13.57it/s][A
 29%|██▉       | 136/468 [00:10<00:24, 13.61it/s][A
 29%|██▉       | 138/468 [00:10<00:24, 13.61it/s][A
 30%|██▉       | 140/468 [00:10<00:24, 13.55it/s][A
 30%|███       | 142/468 [00:10<00:24, 13.54it/s][A
 31%|███       | 144/468 [00:10<00:23, 13.60it/s][A
 31%|███       | 146/468 [00:10<00:23, 13.62it/s][A
 32%|███▏      | 148/468 [00:11<00:23, 13.53it/s][A
 32%|███▏      | 150/468 [00:11<00:23, 13.55it/s][A
 32%|███▏      | 152/468 [00:11<00:23, 13.46it/s][A
 33%|███▎      | 154/468 [00:11<00:23, 13.42it/s][A
 33%|███▎      | 156/468 [00:11<00:23, 13.29it/s][A
 34%|███▍      | 158/468 [00:11<00:23, 13.24it/s][A
 34%|███▍      | 160/468 [00:12<00:23, 13.29it/s][A
 35%|███▍      | 162/468 [00:12<00:22, 13.50it

 93%|█████████▎| 434/468 [00:33<00:02, 12.15it/s][A
 93%|█████████▎| 436/468 [00:33<00:02, 12.39it/s][A
 94%|█████████▎| 438/468 [00:33<00:02, 12.52it/s][A
 94%|█████████▍| 440/468 [00:33<00:02, 12.33it/s][A
 94%|█████████▍| 442/468 [00:33<00:02, 12.15it/s][A
 95%|█████████▍| 444/468 [00:33<00:01, 12.12it/s][A
 95%|█████████▌| 446/468 [00:34<00:01, 12.17it/s][A
 96%|█████████▌| 448/468 [00:34<00:01, 11.60it/s][A
 96%|█████████▌| 450/468 [00:34<00:01, 11.57it/s][A
 97%|█████████▋| 452/468 [00:34<00:01, 11.73it/s][A
 97%|█████████▋| 454/468 [00:34<00:01, 11.84it/s][A
 97%|█████████▋| 456/468 [00:34<00:01, 11.97it/s][A
 98%|█████████▊| 458/468 [00:35<00:00, 12.14it/s][A
 98%|█████████▊| 460/468 [00:35<00:00, 12.05it/s][A
 99%|█████████▊| 462/468 [00:35<00:00, 12.10it/s][A
 99%|█████████▉| 464/468 [00:35<00:00, 12.08it/s][A
100%|█████████▉| 466/468 [00:35<00:00, 12.14it/s][A
100%|██████████| 468/468 [00:35<00:00, 13.01it/s][A

  0%|          | 0/468 [00:00<?, ?it/s][A
  

 59%|█████▊    | 274/468 [00:22<00:15, 12.91it/s][A
 59%|█████▉    | 276/468 [00:22<00:14, 12.95it/s][A
 59%|█████▉    | 278/468 [00:22<00:15, 12.42it/s][A
 60%|█████▉    | 280/468 [00:22<00:15, 11.98it/s][A
 60%|██████    | 282/468 [00:23<00:15, 11.94it/s][A
 61%|██████    | 284/468 [00:23<00:15, 12.06it/s][A
 61%|██████    | 286/468 [00:23<00:14, 12.32it/s][A
 62%|██████▏   | 288/468 [00:23<00:14, 12.63it/s][A
 62%|██████▏   | 290/468 [00:23<00:13, 12.91it/s][A
 62%|██████▏   | 292/468 [00:23<00:13, 12.61it/s][A
 63%|██████▎   | 294/468 [00:24<00:13, 12.54it/s][A
 63%|██████▎   | 296/468 [00:24<00:14, 12.06it/s][A
 64%|██████▎   | 298/468 [00:24<00:14, 11.54it/s][A
 64%|██████▍   | 300/468 [00:24<00:14, 11.33it/s][A
 65%|██████▍   | 302/468 [00:24<00:15, 10.82it/s][A
 65%|██████▍   | 304/468 [00:24<00:14, 10.98it/s][A
 65%|██████▌   | 306/468 [00:25<00:14, 11.32it/s][A
 66%|██████▌   | 308/468 [00:25<00:14, 11.41it/s][A
 66%|██████▌   | 310/468 [00:25<00:13, 11.51it

 24%|██▍       | 114/468 [00:08<00:28, 12.55it/s][A
 25%|██▍       | 116/468 [00:08<00:27, 12.61it/s][A
 25%|██▌       | 118/468 [00:09<00:27, 12.60it/s][A
 26%|██▌       | 120/468 [00:09<00:27, 12.54it/s][A
 26%|██▌       | 122/468 [00:09<00:28, 12.32it/s][A
 26%|██▋       | 124/468 [00:09<00:28, 12.02it/s][A
 27%|██▋       | 126/468 [00:09<00:28, 11.99it/s][A
 27%|██▋       | 128/468 [00:09<00:28, 12.09it/s][A
 28%|██▊       | 130/468 [00:10<00:27, 12.12it/s][A
 28%|██▊       | 132/468 [00:10<00:27, 12.14it/s][A
 29%|██▊       | 134/468 [00:10<00:27, 12.34it/s][A
 29%|██▉       | 136/468 [00:10<00:26, 12.53it/s][A
 29%|██▉       | 138/468 [00:10<00:26, 12.68it/s][A
 30%|██▉       | 140/468 [00:10<00:25, 12.85it/s][A
 30%|███       | 142/468 [00:11<00:25, 12.67it/s][A
 31%|███       | 144/468 [00:11<00:25, 12.71it/s][A
 31%|███       | 146/468 [00:11<00:25, 12.84it/s][A
 32%|███▏      | 148/468 [00:11<00:25, 12.71it/s][A
 32%|███▏      | 150/468 [00:11<00:24, 12.77it

 90%|█████████ | 422/468 [00:32<00:03, 13.59it/s][A
 91%|█████████ | 424/468 [00:32<00:03, 13.59it/s][A
 91%|█████████ | 426/468 [00:32<00:03, 13.60it/s][A
 91%|█████████▏| 428/468 [00:32<00:02, 13.54it/s][A
 92%|█████████▏| 430/468 [00:33<00:02, 13.41it/s][A
 92%|█████████▏| 432/468 [00:33<00:02, 13.48it/s][A
 93%|█████████▎| 434/468 [00:33<00:02, 13.17it/s][A
 93%|█████████▎| 436/468 [00:33<00:02, 12.74it/s][A
 94%|█████████▎| 438/468 [00:33<00:02, 12.76it/s][A
 94%|█████████▍| 440/468 [00:33<00:02, 12.93it/s][A
 94%|█████████▍| 442/468 [00:33<00:01, 13.15it/s][A
 95%|█████████▍| 444/468 [00:34<00:01, 13.13it/s][A
 95%|█████████▌| 446/468 [00:34<00:01, 12.86it/s][A
 96%|█████████▌| 448/468 [00:34<00:01, 12.65it/s][A
 96%|█████████▌| 450/468 [00:34<00:01, 12.51it/s][A
 97%|█████████▋| 452/468 [00:34<00:01, 12.51it/s][A
 97%|█████████▋| 454/468 [00:34<00:01, 12.25it/s][A
 97%|█████████▋| 456/468 [00:35<00:00, 12.10it/s][A
 98%|█████████▊| 458/468 [00:35<00:00, 12.26it

 56%|█████▌    | 262/468 [00:20<00:16, 12.77it/s][A
 56%|█████▋    | 264/468 [00:20<00:16, 12.37it/s][A
 57%|█████▋    | 266/468 [00:20<00:15, 12.77it/s][A
 57%|█████▋    | 268/468 [00:20<00:15, 12.98it/s][A
 58%|█████▊    | 270/468 [00:20<00:15, 13.14it/s][A
 58%|█████▊    | 272/468 [00:20<00:14, 13.30it/s][A
 59%|█████▊    | 274/468 [00:20<00:14, 13.26it/s][A
 59%|█████▉    | 276/468 [00:21<00:14, 13.04it/s][A
 59%|█████▉    | 278/468 [00:21<00:14, 13.26it/s][A
 60%|█████▉    | 280/468 [00:21<00:13, 13.58it/s][A
 60%|██████    | 282/468 [00:21<00:13, 13.64it/s][A
 61%|██████    | 284/468 [00:21<00:13, 13.66it/s][A
 61%|██████    | 286/468 [00:21<00:13, 13.37it/s][A
 62%|██████▏   | 288/468 [00:21<00:13, 13.31it/s][A
 62%|██████▏   | 290/468 [00:22<00:13, 13.13it/s][A
 62%|██████▏   | 292/468 [00:22<00:13, 12.89it/s][A
 63%|██████▎   | 294/468 [00:22<00:13, 12.87it/s][A
 63%|██████▎   | 296/468 [00:22<00:13, 12.95it/s][A
 64%|██████▎   | 298/468 [00:22<00:13, 13.04it

 22%|██▏       | 102/468 [00:08<00:27, 13.31it/s][A
 22%|██▏       | 104/468 [00:08<00:27, 13.42it/s][A
 23%|██▎       | 106/468 [00:08<00:26, 13.56it/s][A
 23%|██▎       | 108/468 [00:08<00:26, 13.60it/s][A
 24%|██▎       | 110/468 [00:08<00:26, 13.29it/s][A
 24%|██▍       | 112/468 [00:08<00:27, 13.03it/s][A
 24%|██▍       | 114/468 [00:08<00:27, 12.94it/s][A
 25%|██▍       | 116/468 [00:09<00:27, 12.98it/s][A
 25%|██▌       | 118/468 [00:09<00:26, 13.19it/s][A
 26%|██▌       | 120/468 [00:09<00:26, 13.13it/s][A
 26%|██▌       | 122/468 [00:09<00:26, 13.08it/s][A
 26%|██▋       | 124/468 [00:09<00:25, 13.31it/s][A
 27%|██▋       | 126/468 [00:09<00:25, 13.27it/s][A
 27%|██▋       | 128/468 [00:10<00:25, 13.16it/s][A
 28%|██▊       | 130/468 [00:10<00:25, 13.10it/s][A
 28%|██▊       | 132/468 [00:10<00:25, 12.95it/s][A
 29%|██▊       | 134/468 [00:10<00:25, 12.91it/s][A
 29%|██▉       | 136/468 [00:10<00:26, 12.71it/s][A
 29%|██▉       | 138/468 [00:10<00:26, 12.57it

 88%|████████▊ | 410/468 [00:31<00:04, 13.22it/s][A
 88%|████████▊ | 412/468 [00:32<00:04, 13.23it/s][A
 88%|████████▊ | 414/468 [00:32<00:04, 13.34it/s][A
 89%|████████▉ | 416/468 [00:32<00:03, 13.15it/s][A
 89%|████████▉ | 418/468 [00:32<00:03, 13.16it/s][A
 90%|████████▉ | 420/468 [00:32<00:03, 13.08it/s][A
 90%|█████████ | 422/468 [00:32<00:03, 13.21it/s][A
 91%|█████████ | 424/468 [00:32<00:03, 13.51it/s][A
 91%|█████████ | 426/468 [00:33<00:03, 13.32it/s][A
 91%|█████████▏| 428/468 [00:33<00:03, 13.25it/s][A
 92%|█████████▏| 430/468 [00:33<00:02, 13.25it/s][A
 92%|█████████▏| 432/468 [00:33<00:02, 13.30it/s][A
 93%|█████████▎| 434/468 [00:33<00:02, 13.50it/s][A
 93%|█████████▎| 436/468 [00:33<00:02, 13.23it/s][A
 94%|█████████▎| 438/468 [00:34<00:02, 13.05it/s][A
 94%|█████████▍| 440/468 [00:34<00:02, 12.98it/s][A
 94%|█████████▍| 442/468 [00:34<00:01, 13.28it/s][A
 95%|█████████▍| 444/468 [00:34<00:01, 13.25it/s][A
 95%|█████████▌| 446/468 [00:34<00:01, 13.28it

 53%|█████▎    | 250/468 [00:19<00:16, 13.32it/s][A
 54%|█████▍    | 252/468 [00:19<00:16, 13.27it/s][A
 54%|█████▍    | 254/468 [00:19<00:15, 13.40it/s][A
 55%|█████▍    | 256/468 [00:19<00:15, 13.52it/s][A
 55%|█████▌    | 258/468 [00:19<00:15, 13.60it/s][A
 56%|█████▌    | 260/468 [00:19<00:15, 13.66it/s][A
 56%|█████▌    | 262/468 [00:19<00:15, 13.73it/s][A
 56%|█████▋    | 264/468 [00:20<00:14, 13.63it/s][A
 57%|█████▋    | 266/468 [00:20<00:14, 13.54it/s][A
 57%|█████▋    | 268/468 [00:20<00:15, 12.98it/s][A
 58%|█████▊    | 270/468 [00:20<00:15, 12.57it/s][A
 58%|█████▊    | 272/468 [00:20<00:15, 12.80it/s][A
 59%|█████▊    | 274/468 [00:20<00:14, 13.12it/s][A
 59%|█████▉    | 276/468 [00:21<00:14, 13.19it/s][A

In [None]:

num_correct = 0
num_samples = 0
for images, labels in tqdm(train_loader):

        model.eval()  # set model to evaluation mode
        with torch.no_grad():
        
           
            scores = model(images)
            _, preds = scores.max(1)
            num_correct += (preds == labels).sum()
            num_samples += preds.size(0)
            acc = float(num_correct) / num_samples
print('TestsetGot %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
num_correct_test = 0
num_samples_test = 0
for images, labels in tqdm(test_loader):

        model.eval()  # set model to evaluation mode
        with torch.no_grad():
        
           
            scores = model(images)
            _, preds = scores.max(1)
            num_correct_test += (preds == labels).sum()
            num_samples_test += preds.size(0)
            acc = float(num_correct) / num_samples
print('TestsetGot %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
print('TrainsetGot %d / %d correct (%.2f)' % (num_correct_test, num_samples_test, 100 * acc))

#### Q5:
Please print the training and testing accuracy.