# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [3]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)


Then, we define the model, object function and optimizer that we use to classify.

In [4]:
import torch.nn.functional as func
import torch.optim as optim


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(4 * 4 * 50, 600)
        self.fc2 = nn.Linear(600, 10)
    
    def forward(self, x):
        x = func.relu(self.conv1(x))
        x = func.max_pool2d(x, kernel_size=2, stride=2)
        x = func.relu(self.conv2(x))
        x = func.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 4 * 4 * 50)
        x = func.relu(self.fc1(x))
        x = self.fc2(x)
        x = func.log_softmax(x, dim=1)
        return x
    

model = SimpleNet()

# TODO:define loss function and optimiter
momentum = 0.5
learning_rate = 0.01 

criterion = func.nll_loss
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)


Next, we can start to train and evaluate!

In [5]:
import numpy as np
import logging

from source import TRAIN_LOGGER

logging.basicConfig(format='%(asctime)s %(message)s',
                    filename=TRAIN_LOGGER, level=logging.INFO)

device = torch.device('cpu')

def test(data_loader, model_, device_):
    accuracy = 0
    with torch.no_grad():
        for images_, labels_ in tqdm(data_loader):
            images_, labels_ = images_.to(device_), labels_.to(device_)
            output_ = model_(images_)
            pred_ = output_.argmax(dim=1, keepdim=True)
            accuracy += pred_.eq(labels_.view_as(pred_)).sum().item()
    
    accuracy = 100. * accuracy / len(data_loader.dataset)
    return accuracy


train_accuracy_list = []
test_accuracy_list = []

for epoch in range(NUM_EPOCHS):
    # train process
    model.train()
    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
    model.eval()
    
    train_accuracy = test(train_loader, model, device)
    test_accuracy = test(test_loader, model, device)
    
    train_accuracy_list.append(train_accuracy)
    test_accuracy_list.append(test_accuracy)
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    
    logging.info('''\n
    ============================
    Epoch {}
    train accuracy {:.2f}%
    test accuracy {:.2f}%
    ============================'''.format(
        epoch + 1, train_accuracy, test_accuracy
    ))

logging.info('''\n
    ============================
    Average train accuracy {:.2f}% 
    Average test accuracy {:.2f}%
    ============================'''.format(
    np.mean(train_accuracy_list), np.mean(test_accuracy_list)
    )) 


  0%|          | 0/468 [00:00<?, ?it/s]

  0%|          | 1/468 [00:00<01:51,  4.20it/s]

  0%|          | 2/468 [00:00<01:41,  4.57it/s]

  1%|          | 3/468 [00:00<01:32,  5.02it/s]

  1%|          | 4/468 [00:00<01:21,  5.70it/s]

  1%|          | 5/468 [00:00<01:14,  6.25it/s]

  1%|▏         | 6/468 [00:00<01:14,  6.21it/s]

  1%|▏         | 7/468 [00:01<01:21,  5.67it/s]

  2%|▏         | 8/468 [00:01<01:29,  5.12it/s]

  2%|▏         | 9/468 [00:01<01:33,  4.89it/s]

  2%|▏         | 10/468 [00:01<01:22,  5.55it/s]

  2%|▏         | 11/468 [00:01<01:18,  5.85it/s]

  3%|▎         | 12/468 [00:02<01:12,  6.33it/s]

  3%|▎         | 13/468 [00:02<01:05,  6.90it/s]

  3%|▎         | 15/468 [00:02<01:01,  7.37it/s]

  4%|▎         | 17/468 [00:02<00:54,  8.22it/s]

  4%|▍         | 19/468 [00:02<00:49,  9.04it/s]

  4%|▍         | 21/468 [00:02<00:45,  9.85it/s]

  5%|▍         | 23/468 [00:03<00:42, 10.44it/s]

  5%|▌         | 25/468 [00:03<00:41, 10.60it/s]

  6%|▌         | 27/468 [00:03<00:47,  9.35it/s]

  6%|▌         | 28/468 [00:03<00:58,  7.49it/s]

  6%|▌         | 29/468 [00:03<01:06,  6.63it/s]

  6%|▋         | 30/468 [00:04<01:20,  5.42it/s]

  7%|▋         | 31/468 [00:04<01:09,  6.24it/s]

  7%|▋         | 32/468 [00:04<01:03,  6.86it/s]

  7%|▋         | 33/468 [00:04<00:57,  7.56it/s]

  7%|▋         | 35/468 [00:04<00:51,  8.47it/s]

  8%|▊         | 37/468 [00:04<00:46,  9.31it/s]

  8%|▊         | 39/468 [00:05<00:44,  9.63it/s]

  9%|▉         | 41/468 [00:05<00:41, 10.28it/s]

  9%|▉         | 43/468 [00:05<00:41, 10.35it/s]

 10%|▉         | 45/468 [00:05<00:39, 10.75it/s]

 10%|█         | 47/468 [00:05<00:39, 10.53it/s]

 10%|█         | 49/468 [00:05<00:39, 10.63it/s]

 11%|█         | 51/468 [00:06<00:38, 10.87it/s]

 11%|█▏        | 53/468 [00:06<00:36, 11.24it/s]

 12%|█▏        | 55/468 [00:06<00:36, 11.21it/s]

 12%|█▏        | 57/468 [00:06<00:36, 11.34it/s]

 13%|█▎        | 59/468 [00:06<00:35, 11.49it/s]

 13%|█▎        | 61/468 [00:06<00:35, 11.57it/s]

 13%|█▎        | 63/468 [00:07<00:34, 11.62it/s]

 14%|█▍        | 65/468 [00:07<00:35, 11.47it/s]

 14%|█▍        | 67/468 [00:07<00:39, 10.18it/s]

 15%|█▍        | 69/468 [00:07<00:39, 10.08it/s]

 15%|█▌        | 71/468 [00:07<00:38, 10.43it/s]

 16%|█▌        | 73/468 [00:08<00:37, 10.60it/s]

 16%|█▌        | 75/468 [00:08<00:37, 10.59it/s]

 16%|█▋        | 77/468 [00:08<00:37, 10.51it/s]

 17%|█▋        | 79/468 [00:08<00:35, 10.88it/s]

 17%|█▋        | 81/468 [00:08<00:37, 10.19it/s]

 18%|█▊        | 83/468 [00:09<00:36, 10.57it/s]

 18%|█▊        | 85/468 [00:09<00:35, 10.81it/s]

 19%|█▊        | 87/468 [00:09<00:35, 10.63it/s]

 19%|█▉        | 89/468 [00:09<00:35, 10.74it/s]

 19%|█▉        | 91/468 [00:09<00:34, 11.07it/s]

 20%|█▉        | 93/468 [00:09<00:33, 11.15it/s]

 20%|██        | 95/468 [00:10<00:36, 10.29it/s]

 21%|██        | 97/468 [00:10<00:42,  8.82it/s]

 21%|██        | 98/468 [00:10<00:45,  8.09it/s]

 21%|██        | 99/468 [00:10<00:44,  8.33it/s]

 21%|██▏       | 100/468 [00:10<00:44,  8.36it/s]

 22%|██▏       | 102/468 [00:11<00:41,  8.85it/s]

 22%|██▏       | 103/468 [00:11<00:39,  9.16it/s]

#### Q5:
Please print the training and testing accuracy.