# Training LeNet using MNIST and Devito

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import joey as ml
import numpy as np
import matplotlib.pyplot as plt
from devito import logger

In [2]:
logger.set_log_noperf()

In [3]:
def create_lenet():
    # Six 3x3 filters, activation RELU
    layer1 = ml.Conv(kernel_size=(6, 3, 3),
                     input_size=(batch_size, 1, 32, 32),
                     activation=ml.activation.ReLU(),
                     generate_code=False)
    # Max 2x2 subsampling
    layer2 = ml.MaxPooling(kernel_size=(2, 2),
                           input_size=(batch_size, 6, 30, 30),
                           stride=(2, 2),
                           generate_code=False)
    # Sixteen 3x3 filters, activation RELU
    layer3 = ml.Conv(kernel_size=(16, 3, 3),
                     input_size=(batch_size, 6, 15, 15),
                     activation=ml.activation.ReLU(),
                     generate_code=False)
    # Max 2x2 subsampling
    layer4 = ml.MaxPooling(kernel_size=(2, 2),
                           input_size=(batch_size, 16, 13, 13),
                           stride=(2, 2),
                           strict_stride_check=False,
                           generate_code=False)
    # Full connection (16 * 6 * 6 -> 120), activation RELU
    layer5 = ml.FullyConnected(weight_size=(120, 576),
                               input_size=(576, batch_size),
                               activation=ml.activation.ReLU(),
                               generate_code=False)
    # Full connection (120 -> 84), activation RELU
    layer6 = ml.FullyConnected(weight_size=(84, 120),
                               input_size=(120, batch_size),
                               activation=ml.activation.ReLU(),
                               generate_code=False)
    # Full connection (84 -> 10), output layer
    layer7 = ml.FullyConnectedSoftmax(weight_size=(10, 84),
                                      input_size=(84, batch_size),
                                      generate_code=False)
    # Flattening layer necessary between layer 4 and 5
    layer_flat = ml.Flat(input_size=(batch_size, 16, 6, 6),
                         generate_code=False)
    
    layers = [layer1, layer2, layer3, layer4,
              layer_flat, layer5, layer6, layer7]
    
    return (ml.Net(layers), layers)

def relu(x):
    return Max(0, x)

def maximum(lst):
    return Max(*lst)

In [4]:
def train(net, input_data, expected_results, pytorch_optimizer):
    outputs = net.forward(input_data)
    
    def loss_grad(layer, b):
        gradients = []
    
        for i in range(10):
            result = layer.result.data[i, b]
            if i == expected_results[b]:
                result -= 1
            gradients.append(result)
    
        return gradients
    
    net.backward(loss_grad, pytorch_optimizer)

In [5]:
batch_size = 4
iterations = 100

In [6]:
transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize(0.5, 0.5)])
trainset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

In [7]:
devito_net, devito_layers = create_lenet()
optimizer = optim.SGD(devito_net.pytorch_parameters, lr=0.001, momentum=0.9)

  spacing = (np.array(self.extent) / (np.array(self.shape) - 1)).astype(self.dtype)


In [8]:
layer1_kernel = torch.tensor(devito_layers[0].kernel.data)
layer1_bias = torch.tensor(devito_layers[0].bias.data)
layer3_kernel = torch.tensor(devito_layers[2].kernel.data)
layer3_bias = torch.tensor(devito_layers[2].bias.data)
layer5_kernel = torch.tensor(devito_layers[5].kernel.data)
layer5_bias = torch.tensor(devito_layers[5].bias.data)
layer6_kernel = torch.tensor(devito_layers[6].kernel.data)
layer6_bias = torch.tensor(devito_layers[6].bias.data)
layer7_kernel = torch.tensor(devito_layers[7].kernel.data)
layer7_bias = torch.tensor(devito_layers[7].bias.data)

In [9]:
for i, data in enumerate(trainloader, 0):
    images, labels = data
    images.double()
    
    train(devito_net, images, labels, optimizer)
    
    if i == iterations - 1:
        break

PyTorch:

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [11]:
net = Net()
net.double()

with torch.no_grad():
    net.conv1.weight[:] = layer1_kernel
    net.conv1.bias[:] = layer1_bias
    net.conv2.weight[:] = layer3_kernel
    net.conv2.bias[:] = layer3_bias
    net.fc1.weight[:] = layer5_kernel
    net.fc1.bias[:] = layer5_bias
    net.fc2.weight[:] = layer6_kernel
    net.fc2.bias[:] = layer6_bias
    net.fc3.weight[:] = layer7_kernel
    net.fc3.bias[:] = layer7_bias

In [12]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
for i, data in enumerate(trainloader, 0):
    images, labels = data
    optimizer.zero_grad()
    outputs = net(images.double())
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    if i == iterations - 1:
        break

In [13]:
layers = [devito_layers[0], devito_layers[2], devito_layers[5], devito_layers[6], devito_layers[7]]
pytorch_layers = [net.conv1, net.conv2, net.fc1, net.fc2, net.fc3]

max_error = 0
index = -1

for i in range(5):
    kernel = layers[i].kernel.data
    pytorch_kernel = pytorch_layers[i].weight.detach().numpy()
    
    kernel_error = abs(kernel - pytorch_kernel) / abs(pytorch_kernel)
    
    bias = layers[i].bias.data
    pytorch_bias = pytorch_layers[i].bias.detach().numpy()
    
    bias_error = abs(bias - pytorch_bias) / abs(pytorch_bias)
    
    error = max(np.nanmax(kernel_error), np.nanmax(bias_error))
    print('layers[' + str(i) + '] maximum relative error: ' + str(error))
    
    if error > max_error:
        max_error = error
        index = i

print()
print('Maximum relative error is in layers[' + str(index) + ']: ' + str(max_error))

layers[0] maximum relative error: 1.3407192418942228e-14
layers[1] maximum relative error: 2.5473953487847755e-12
layers[2] maximum relative error: 1.0280702754555613e-12
layers[3] maximum relative error: 8.103937510501394e-13
layers[4] maximum relative error: 1.3993901259386298e-13

Maximum relative error is in layers[1]: 2.5473953487847755e-12
