In [15]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets

### LOADING DATASET

In [16]:
train_dataset = dsets.MNIST(root='./data',train=True, transform=transforms.ToTensor(),download=True)

test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

### MAKING DATASET ITERABLE

In [17]:
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)


### Model

In [24]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()

        # Average pool 1
        #self.avgpool1 = nn.AvgPool2d(kernel_size=2)
        self.maxpool1=nn.MaxPool2d(kernel_size=2)

        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()

        # Average pool 2
        #self.avgpool2 = nn.AvgPool2d(kernel_size=2)
        self.maxpool2=nn.MaxPool2d(kernel_size=2)

        # Fully connected 1 (readout)
        self.fc1 = nn.Linear(32 * 7 * 7, 10) 

    def forward(self, x):
        # Convolution 1
        out = self.cnn1(x)
        out = self.relu1(out)

        # Average pool 1
       # out = self.avgpool1(out)
        out=self.maxpool1(out)

        # Convolution 2 
        out = self.cnn2(out)
        out = self.relu2(out)

        # Max pool 2 
        #out = self.avgpool2(out)
        out=self.maxpool2(out)

        # Resize
        # Original size: (100, 32, 7, 7)
        # out.size(0): 100
        # New out size: (100, 32*7*7)
        out = out.view(out.size(0), -1)

        # Linear function (readout)
        out = self.fc1(out)

        return out

In [25]:
model = CNNModel()
criterion = nn.CrossEntropyLoss()
learning_rate = 0.01
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

### train the model

In [26]:
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
       
        images = images.requires_grad_()# Load images as tensors with gradient accumulation abilities

       
        optimizer.zero_grad() # Clear gradients w.r.t. parameters

        
        outputs = model(images)

        
        loss = criterion(outputs, labels)

        
        loss.backward()# Getting gradients w.r.t. parameters

        # Updating parameters
        optimizer.step()

        iter += 1
# Calculate Accuracy       
        if iter % 500 == 0:
              
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                # Load images to tensors with gradient accumulation abilities
                images = images.requires_grad_()

                # Forward pass only to get logits/output
                outputs = model(images)

                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)

                # Total number of labels
                total += labels.size(0)

                # Total correct predictions
                correct += (predicted == labels).sum()

            accuracy = 100 * correct / total

            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))

Iteration: 500. Loss: 0.06614162772893906. Accuracy: 98
Iteration: 1000. Loss: 0.04435869678854942. Accuracy: 98
Iteration: 1500. Loss: 0.04804736003279686. Accuracy: 98
Iteration: 2000. Loss: 0.015441480092704296. Accuracy: 98
Iteration: 2500. Loss: 0.06950191408395767. Accuracy: 98
Iteration: 3000. Loss: 0.06579634547233582. Accuracy: 98
