# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [2]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

Then, we define the model, object function and optimizer that we use to classify.

In [30]:
import torch.optim as optim
import torch.nn.functional as F
class SimpleNet(nn.Module):
    def __init__(self):    
        super(SimpleNet, self).__init__()  
        self.conv1 = nn.Conv2d(1, 6, 5, 1, 2)       
        self.pool = nn.MaxPool2d(2, 2)       
        self.conv2 = nn.Conv2d(6, 16, 5, 1, 2)     
        self.fc1 = nn.Linear(16 * 7 * 7, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):                  
        x = self.pool(F.relu(self.conv1(x)))  
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x



# TODO:define model

    
model = SimpleNet()

# TODO:define loss function and optimiter
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Next, we can start to train and evaluate!

In [34]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()
            model.cuda()
        else:
            images = Variable(images)
            labels = Variable(labels)
        out = model(images)
        loss = criterion(out, labels)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(loss.data.item())

  
        
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    
    
    
    





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:06, 75.67it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 76.31it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 77.00it/s]


  7%|█████▌                                                                           | 32/468 [00:00<00:05, 76.17it/s]


  9%|██████▉                                                                          | 40/468 [00:00<00:05, 76.01it/s]


 10%|████████▎                                                                        | 48/468 [00:00<00:05, 75.91it/s]


 12%|█████████▋                                                                       | 56/468 [00:00<00:05, 76.71it/s]


 14%|███████████     

0.04312323406338692





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:05, 79.40it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 78.01it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 78.20it/s]


  7%|█████▌                                                                           | 32/468 [00:00<00:05, 77.65it/s]


  9%|███████                                                                          | 41/468 [00:00<00:05, 79.75it/s]


 11%|████████▋                                                                        | 50/468 [00:00<00:05, 80.21it/s]


 12%|██████████                                                                       | 58/468 [00:00<00:05, 78.56it/s]


 14%|███████████▌    

0.10062449425458908





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:06, 76.37it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 76.60it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 76.98it/s]


  7%|█████▌                                                                           | 32/468 [00:00<00:05, 76.58it/s]


  9%|██████▉                                                                          | 40/468 [00:00<00:05, 76.31it/s]


 10%|████████▎                                                                        | 48/468 [00:00<00:05, 76.77it/s]


 12%|█████████▋                                                                       | 56/468 [00:00<00:05, 76.23it/s]


 14%|███████████     

0.12225930392742157





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▌                                                                                | 9/468 [00:00<00:05, 83.53it/s]


  4%|██▉                                                                              | 17/468 [00:00<00:05, 81.51it/s]


  6%|████▌                                                                            | 26/468 [00:00<00:05, 82.34it/s]


  7%|██████                                                                           | 35/468 [00:00<00:05, 83.15it/s]


  9%|███████▌                                                                         | 44/468 [00:00<00:05, 83.51it/s]


 11%|█████████                                                                        | 52/468 [00:00<00:05, 80.74it/s]


 13%|██████████▍                                                                      | 60/468 [00:00<00:05, 80.35it/s]


 15%|███████████▊    

0.027864933013916016





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:05, 79.40it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 77.79it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 78.27it/s]


  7%|█████▌                                                                           | 32/468 [00:00<00:05, 78.61it/s]


  9%|██████▉                                                                          | 40/468 [00:00<00:05, 77.71it/s]


 10%|████████▎                                                                        | 48/468 [00:00<00:05, 77.98it/s]


 12%|█████████▋                                                                       | 56/468 [00:00<00:05, 76.83it/s]


 14%|███████████     

0.08942063897848129





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:05, 77.86it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 78.10it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 78.03it/s]


  7%|█████▌                                                                           | 32/468 [00:00<00:05, 76.86it/s]


  9%|██████▉                                                                          | 40/468 [00:00<00:05, 76.95it/s]


 10%|████████▎                                                                        | 48/468 [00:00<00:05, 76.33it/s]


 12%|█████████▋                                                                       | 56/468 [00:00<00:05, 77.24it/s]


 14%|███████████     

0.05453300476074219





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▌                                                                                | 9/468 [00:00<00:05, 82.04it/s]


  4%|██▉                                                                              | 17/468 [00:00<00:05, 81.23it/s]


  5%|████▎                                                                            | 25/468 [00:00<00:05, 79.48it/s]


  7%|█████▉                                                                           | 34/468 [00:00<00:05, 80.44it/s]


  9%|███████▎                                                                         | 42/468 [00:00<00:05, 79.43it/s]


 11%|████████▋                                                                        | 50/468 [00:00<00:05, 78.26it/s]


 12%|██████████                                                                       | 58/468 [00:00<00:05, 77.24it/s]


 14%|███████████▌    

0.03168857842683792





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:06, 76.39it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 76.61it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 76.77it/s]


  7%|█████▋                                                                           | 33/468 [00:00<00:05, 78.07it/s]


  9%|███████                                                                          | 41/468 [00:00<00:05, 76.89it/s]


 11%|████████▋                                                                        | 50/468 [00:00<00:05, 78.37it/s]


 12%|██████████                                                                       | 58/468 [00:00<00:05, 78.22it/s]


 14%|███████████▍    

0.013390175998210907





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▌                                                                                | 9/468 [00:00<00:05, 80.57it/s]


  4%|██▉                                                                              | 17/468 [00:00<00:05, 79.74it/s]


  5%|████▎                                                                            | 25/468 [00:00<00:05, 78.02it/s]


  7%|█████▋                                                                           | 33/468 [00:00<00:05, 77.97it/s]


  9%|███████                                                                          | 41/468 [00:00<00:05, 76.82it/s]


 10%|████████▍                                                                        | 49/468 [00:00<00:05, 77.14it/s]


 12%|█████████▊                                                                       | 57/468 [00:00<00:05, 76.04it/s]


 14%|███████████▍    

0.11409630626440048





  0%|                                                                                          | 0/468 [00:00<?, ?it/s]


  2%|█▍                                                                                | 8/468 [00:00<00:06, 74.97it/s]


  3%|██▊                                                                              | 16/468 [00:00<00:05, 76.03it/s]


  5%|████▏                                                                            | 24/468 [00:00<00:05, 77.01it/s]


  7%|█████▋                                                                           | 33/468 [00:00<00:05, 78.25it/s]


  9%|███████                                                                          | 41/468 [00:00<00:05, 77.69it/s]


 10%|████████▍                                                                        | 49/468 [00:00<00:05, 78.20it/s]


 12%|█████████▊                                                                       | 57/468 [00:00<00:05, 78.10it/s]


 14%|███████████▍    

0.051366135478019714


RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 1, 5, 5], but got input of size [128, 784] instead

In [37]:
model.eval()
eval_acc = 0
train_acc = 0
for images, labels in tqdm(test_loader):
    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
    out = model(images)
    _, pred = torch.max(out, 1)
    num_correct = (pred == labels).sum()
    eval_acc += num_correct.item()
    
for images, labels in tqdm(train_loader):
    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
    out = model(images)
    _, pred = torch.max(out, 1)
    num_correct = (pred == labels).sum()
    train_acc += num_correct.item()
    
print('Train Acc: {:.6f}, Test Acc: {:.6f}'.format(
    train_acc / (len(train_dataset)),
    eval_acc / (len(test_dataset))
))




  0%|                                                                                           | 0/78 [00:00<?, ?it/s]


 14%|███████████▍                                                                     | 11/78 [00:00<00:00, 109.20it/s]


 29%|███████████████████████▉                                                         | 23/78 [00:00<00:00, 110.46it/s]


 44%|███████████████████████████████████▎                                             | 34/78 [00:00<00:00, 109.10it/s]


 59%|███████████████████████████████████████████████▊                                 | 46/78 [00:00<00:00, 111.00it/s]


 74%|████████████████████████████████████████████████████████████▏                    | 58/78 [00:00<00:00, 111.74it/s]


 90%|████████████████████████████████████████████████████████████████████████▋        | 70/78 [00:00<00:00, 111.95it/s]


100%|█████████████████████████████████████████████████████████████████████████████████| 78/78 [00:00<00:00, 111.73it/s]


  0%|                

Train Acc: 0.984900, Test Acc: 0.982100


#### Q5:
Please print the training and testing accuracy.