In [0]:
# !pip3 install torch torchvision

In [0]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import  Variable

In [0]:
# Load dataset
# Make dataset iterable
# create model class
# Instantiate model class
# Instantiate loss class
# Instantiate optimizer class
# Train model

In [5]:
# Load dataset
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
# Make dataset iterable


batch_size = 100
n_iters = 3000
num_epochs = int(n_iters / (len(train_dataset)/batch_size))

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [0]:
# Conv -> Max pool -> Conv -> Max pool -> FC

In [0]:
# Create Model Class
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    
    # Convolution 1
    self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2)
    self.relu1 = nn.ReLU()
    
    # Max pool 1
    self.maxpool1 = nn.AdaptiveMaxPool2d(output_size=14)
    
    # Convolution 2
    self.cnn2 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2)
    self.relu2 = nn.ReLU()
    
    # Max pool 2
    self.maxpool2 = nn.AdaptiveMaxPool2d(output_size=7)
    
    # FC (readout)
    self.fc1 = nn.Linear(32*7*7, 10)
   
    
  def forward(self, x):
    # Convolution 1
    out = self.cnn1(x)
    out = self.relu1(out)
    
    # Max pool 1
    out = self.maxpool1(out)
    
    # Convolution 2
    out = self.cnn2(x)
    out = self.relu2(out)
    
    # Max pool 2
    out = self.maxpool2(out)
    
    # Resize
    # 100 is the batch size...
    # (100, 32, 7, 7) -> (100, 32*7*7)
    out = out.view(out.size(0), -1)
    
    # Linear (readout)
    out = self.fc1(out)
    
    return out

In [0]:
model = CNN()

In [0]:
criterion = nn.CrossEntropyLoss()
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [27]:
# Train the model

# Convert input/labels to Variables
# Clear gradient buffers
# Get output given the inputs
# Get loss
# Get gradients w.r.t parameters
# Update parameters using the gradients
# Repeat

import time
tick = time.time()

iter = 0
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    
    
    images = Variable(images) # No need to resize
    labels = Variable(labels)
    
    optimizer.zero_grad()
    
    outputs = model(images)
    
    loss = criterion(outputs, labels)
    
    loss.backward()
    
    optimizer.step()
    
    iter += 1
    
    if iter % 500 == 0:
      correct = 0
      total = 0
      
      for images, labels in test_loader:
        
        images = Variable(images)
        
        outputs = model(images)
        
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()
        
      accuracy = 100 * correct / total
      
      print('Iteration: {}, Loss: {}, Accuracy:{}'.format(iter, loss.data[0], accuracy))
    
tock = time.time()    



Iteration: 500, Loss: 0.6978496313095093, Accuracy:89.11
Iteration: 1000, Loss: 0.26803529262542725, Accuracy:91.94
Iteration: 1500, Loss: 0.21909640729427338, Accuracy:93.26
Iteration: 2000, Loss: 0.3459615409374237, Accuracy:93.95
Iteration: 2500, Loss: 0.19628138840198517, Accuracy:94.63
Iteration: 3000, Loss: 0.15083910524845123, Accuracy:94.74


In [28]:
print('Time required is', tock - tick)

Time required is 259.9495882987976
