# MNIST Addition

The task considered in this notebook is very reminiscent of the classical learning task on the MNIST data. However, instead of providing labels for single digits, we train on pairs of images labeled with the sum of the individual digits. It was first introduced in Manhaeve 2018.

In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
torch.manual_seed(1234)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fa78a92f0d0>

We begin by defining our model, taken from the Pytorch MNIST tutorial

In [2]:
class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,  6, 5),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

We load the usual MNIST image data

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

test_kwargs = {'batch_size': 256}
test_loader = torch.utils.data.DataLoader(mnist_test_data, **test_kwargs)

We load the MNIST addition dataset, generated by pairing random MNIST digits and labeling them with their summation i.e. each datum is of the form (idx1, idx2, summation) where idx1 corresponds to the index of the first image, idx2 corresponds to the index of the second image, and summation corresponds to the sum of their groundtruth labels

In [4]:
# ---------- train_data ----------
with open('train_data.txt') as f:
    train_data = f.readlines()
    
# Strip new lines
train_data = [d.strip() for d in train_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
train_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in train_data]

# ---------- test data ----------
with open('test_data.txt') as f:
    test_data = f.readlines()
    
# Strip new lines
test_data = [d.strip() for d in test_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
test_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in test_data]

# Tensorize
train_data = torch.tensor(train_data)[:9000]
test_data = torch.tensor(test_data)

Create our model as well as our optimizer

In [5]:
model = MNIST_Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Even though we train on pairs of images and their summation, we test on the classic setting i.e. predicting the label of a single digit

In [6]:
def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Test set: Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In lieu of the traditional cross entropy loss, we require that the sum of predicted labels match the groundtruth by enforcing it as a constraint at training time. This requires that we import the *constraint* module. Line 12 declares *enforce_sum_constraint* as a constraint to be enforced at training time. We note that our constraint function, *enforce_sum* is a vanilla python function, and does not make use of any foreign syntax.

In [14]:
# ---------- Set up the constraints ----------
import sys
sys.path.append("../")

from pylon.constraint import constraint
# import pylon
# import importlib
# importlib.reload(pylon.constraint)
# importlib.reload(pylon)

def enforce_sum(img1, img2, **kwargs):
    # print("?", kwargs)
    return img1 + img2 == kwargs['summation']


enfore_sum_constraint = constraint(enforce_sum)

Finally, we proceed to our normal training loop, where we minimize our constraint loss during training, as can be seen on line 22

In [15]:
from tqdm import tqdm

NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
      
    # train
    for i, batch in enumerate(tqdm(train_data)):
        model.train()
        optimizer.zero_grad()
        idx1, idx2, summation = batch
        X1 = mnist_train_data[idx1][0].unsqueeze(0)
        X2 = mnist_train_data[idx2][0].unsqueeze(0)
        
        output1 = model(X1)
        output2 = model(X2)
        
        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        # print(summation)
        closs = enfore_sum_constraint(output1, output2, summation=summation)


        closs.backward()
        optimizer.step()
        
        if i % 1000 == 0 and i != 0:
            test()
        
    test()

 11%|█         | 1009/9000 [00:40<08:34, 15.54it/s] 

Test set: Accuracy: 9599/10000 (96%)



 22%|██▏       | 2011/9000 [01:14<07:53, 14.75it/s]

Test set: Accuracy: 9725/10000 (97%)



 33%|███▎      | 3007/9000 [01:32<06:40, 14.96it/s]

Test set: Accuracy: 9695/10000 (97%)



 45%|████▍     | 4007/9000 [02:01<08:24,  9.90it/s]

Test set: Accuracy: 9765/10000 (98%)



 56%|█████▌    | 5008/9000 [02:22<04:48, 13.86it/s]

Test set: Accuracy: 9766/10000 (98%)



 67%|██████▋   | 6011/9000 [02:46<04:13, 11.78it/s]

Test set: Accuracy: 9750/10000 (98%)



 78%|███████▊  | 7008/9000 [03:08<02:36, 12.69it/s]

Test set: Accuracy: 9760/10000 (98%)



 89%|████████▉ | 8006/9000 [04:07<01:41,  9.79it/s]

Test set: Accuracy: 9746/10000 (97%)



100%|██████████| 9000/9000 [04:26<00:00, 33.77it/s]


Test set: Accuracy: 9755/10000 (98%)

