In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# Define Model

In [33]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Get MNIST and transform it to tensor

In [3]:
mnist_data =datasets.MNIST("./mnist_data", train=True, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                          ]))

# Get mean and std for normalizing data

In [4]:
data = [d[0].data.cpu().numpy() for d in mnist_data]

In [5]:
np.mean(data)

0.13066062

In [6]:
np.std(data)

0.30810776

In [7]:
mnist_data[0][0].shape

torch.Size([1, 28, 28])

# train and trainloader

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", train=True, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ])),
    batch_size=batch_size, shuffle=True,
    num_workers=1, pin_memory=True
)

In [14]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        pred = model(data)
        loss = F.nll_loss(pred, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if idx % 100 == 0:
            print("Train Epoch:{},iteration:{},loss:{}".format(
                epoch, idx, loss.item()))

# test and testloader

In [10]:
test_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", train=False, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ])),
    batch_size = batch_size, shuffle=True,
    num_workers=1, pin_memory=True
)

In [16]:
def test(model, device, test_loader):
    model.eval()
    total_loss = 0.
    correct = 0.
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target= data.to(device), target.to(device)
            
            output = model(data)
            total_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    total_loss /= len(test_loader.dataset)
    correct /= len(test_loader.dataset)
    correct *=100
    print("Test Loss:{},Accuracy:{}".format(total_loss, correct))

# start training

In [34]:
lr = 0.01
momentum = 0.5
model = SimpleCNN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
num_epochs = 2
for epoch in range(num_epochs):
    train(model, device, train_dataloader, optimizer, epoch)
    test(model, device, test_dataloader)

torch.save(model.state_dict(), "mnist_cnn.pt")

Train Epoch:0,iteration:0,loss:2.3114089965820312
Train Epoch:0,iteration:100,loss:0.3460025191307068
Train Epoch:0,iteration:200,loss:0.17810732126235962
Train Epoch:0,iteration:300,loss:0.25878429412841797
Train Epoch:0,iteration:400,loss:0.21613113582134247
Train Epoch:0,iteration:500,loss:0.04941364377737045
Train Epoch:0,iteration:600,loss:0.22189849615097046
Train Epoch:0,iteration:700,loss:0.10240968316793442
Train Epoch:0,iteration:800,loss:0.40712541341781616
Train Epoch:0,iteration:900,loss:0.27638140320777893
Train Epoch:0,iteration:1000,loss:0.08502998948097229
Train Epoch:0,iteration:1100,loss:0.09348072111606598
Train Epoch:0,iteration:1200,loss:0.09116894751787186
Train Epoch:0,iteration:1300,loss:0.04040426015853882
Train Epoch:0,iteration:1400,loss:0.22217920422554016
Train Epoch:0,iteration:1500,loss:0.09791646152734756
Train Epoch:0,iteration:1600,loss:0.007152751088142395
Train Epoch:0,iteration:1700,loss:0.007190421223640442
Train Epoch:0,iteration:1800,loss:0.0631