In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time

from cnn_scalar_form import CNNScalarForm
from cnn_vector_form import CNNVectorForm

In [2]:
def train(log_interval, model, device, train_loader, optimizer, epoch, num_iter=None):
    model.train()
    i = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            
        i += 1
        if num_iter is not None and i >= num_iter:
            break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


#### Parameters

In [3]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(1)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
batch_size = 64

#### Load dataset

In [4]:
def get_train_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=True, **kwargs)
    
    return train_loader



In [5]:
def get_test_loader(batch_size):
    test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    
    return test_loader


### CNN scalar form model

Let's check the results after each step in forwards pass for cnn implemented by scalar form. The results are compared with pytorch. To avoid weasting time, there is only 1 epoch and batch_size = 1

In [6]:
model = CNNScalarForm(True).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [7]:
train_loader = get_train_loader(1)
test_loader = get_test_loader(1)

In [8]:
start = time.time()
train(1, model, device, train_loader, optimizer, 1, 1)
end = time.time()
print(f"Time executing: {end - start} s")

Check conv. MSE: 6.838811031958502e-15
Check max pool. MSE: 6.992468737821063e-15
Check reshape. MSE: 6.992468737821063e-15
Check fc1. MSE: 2.224150497023153e-13
Check relu. MSE: 1.0452203610001459e-13
Check fc2. MSE: 6.384476280985041e-14
Check softmax. MSE: 8.881784197001252e-16
Time executing: 36.4209041595459 s


### CNN vector form model

In [9]:
model = CNNVectorForm(True).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

Compare the result after each step in forward with pytorch, train with batch_size=1

In [11]:
start = time.time()
train(1, model, device, train_loader, optimizer, 1, 1)
end = time.time()
print(f"Time executing: {end - start} s")

Check conv. MSE: 5.287909375645402e-15
Check max pool. MSE: 4.804774911542232e-15
Check reshape. MSE: 4.804774911542232e-15
Check fc1. MSE: 5.5813840149025046e-14
Check relu. MSE: 3.0691263061804336e-14
Check fc2. MSE: 1.289246487345963e-14
Check softmax. MSE: 1.3877787807814457e-16
Time executing: 0.6581439971923828 s


As we can see vector form implementation works quickly significant than scalar form (0.6 s < 36.4 s) (which was expetced) for batch_size=1. 