In [17]:
import torch 
import torch.nn as nn
import torchvision.transforms as transform
import torchvision.datasets as datasets
import torch 
from torch.utils.data import DataLoader
from torch.autograd import Variable 

In [18]:
train_dataset = datasets.MNIST(root ='./data', download = True, transform = transform.ToTensor(), train = True)

In [19]:
test_dataset = datasets.MNIST(root='./data', train = False, transform = transform.ToTensor())

In [20]:
print(train_dataset.train_data.size())

torch.Size([60000, 28, 28])


In [21]:
print(train_dataset.train_labels.size())

torch.Size([60000])


In [22]:
print(test_dataset.test_data.size())

torch.Size([10000, 28, 28])


In [23]:
print(test_dataset.test_labels.size())

torch.Size([10000])


In [24]:
batch_size = 100
num_iter = 3000
epochs = int(num_iter/(len(train_dataset)/batch_size))


train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset= test_dataset, batch_size= batch_size, shuffle= False)

In [25]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        #Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()

        #Avg pool 1
        self.avgpool1 = nn.AvgPool2d(kernel_size=2)

        #Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()

        #Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        #Fully Connected 1
        self.fcl = nn.Linear(32*7*7,10)
    
    
    def forward(self, x):
        
        #Convolution 1
        out = self.cnn1(x)
        out = self.relu1(out)
        
        #Avg Pooling 1
        out = self.avgpool1(out)
        
        #Convolution 2
        out = self.cnn2(out)
        out = self.relu2(out)
        
        #Max pooling 1 
        out = self.maxpool1(out)
        
        #Resize
        #Original size (100, 32,7,7)
        #where out.size(0) = 100
        #therefore we are multiplying 100 times so
        #new out size = (100, 32*7*7)
        out = out.view(out.size(0), -1)
        
        #linear function
        out = self.fcl(out)
        
        return out

In [26]:
#Instantiating our model

model = CNNModel()

In [27]:
#specifying loss function

criterion = nn.CrossEntropyLoss()

In [28]:
#Optimisation

learning_rate = 0.01

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [30]:
#Training the network

iter = 0

for epoch in range(epochs):
    for i , (images, labels) in enumerate (train_loader):
        images = Variable(images)
        labels = Variable(labels)
        
        #clearing gradients w.r.t parameters
        optimizer.zero_grad()
        
        outputs = model(images)
        
        loss = criterion (outputs, labels)
        
        #retrieving bgradients
        loss.backward()
        
        #updating parameters
        optimizer.step()
        
        iter += 1
        
        if (iter%500 == 0):
            #calculating accuracy
            correct = 0
            total = 0
            
            for images, labels in test_loader:
                images = Variable(images)
                
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                
                #total num of labels
                total += labels.size(0)
                
                #total corrected predictions
                correct += (predicted == labels).sum()
                
            accuracy = 100* correct/total
            
            print('Iteration: {}, Loss: {}, Accuracy: {}'.format(iter, loss.data[0], accuracy))



Iteration: 500, Loss: 0.5567384958267212, Accuracy: 84
Iteration: 1000, Loss: 0.2217874974012375, Accuracy: 91
Iteration: 1500, Loss: 0.2946849465370178, Accuracy: 93
Iteration: 2000, Loss: 0.3738614618778229, Accuracy: 94
Iteration: 2500, Loss: 0.2307964712381363, Accuracy: 95
Iteration: 3000, Loss: 0.0759103074669838, Accuracy: 96
