In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.optim.lr_scheduler import StepLR

# Where the magic happens...
# torch.manual_seed(3456)

# Needed to download MNIST dataset without HTTP Error
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,  6, 5),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N),
            nn.Softmax(1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

kwargs = {'batch_size': 1}

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:00, 10749537.98it/s]                             


Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 6944360.59it/s]          

Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 11369516.79it/s]                           


Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 2736345.12it/s]          

Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw






In [4]:
# ---------- train_data ----------
with open('train_data.txt') as f:
    train_data = f.readlines()
    
# Strip new lines
train_data = [d.strip() for d in train_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
train_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in train_data]

# ---------- test data ----------
with open('test_data.txt') as f:
    test_data = f.readlines()
    
# Strip new lines
test_data = [d.strip() for d in test_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
test_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in test_data]

# ---------- network and optimizer ----------
model = MNIST_Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Tensorize
train_data = torch.tensor(train_data)

# Batch
train_data = [train_data[i: i+ kwargs['batch_size']] for i in range(0, len(train_data), kwargs['batch_size'])]
train_data = torch.stack(train_data[:-1])

# Tensorize
test_data = torch.tensor(test_data)

# Batch
test_data = [test_data[i: i+kwargs['batch_size']] for i in range(0, len(test_data), kwargs['batch_size'])]
test_data = torch.stack(test_data[:-1])

In [6]:
def test():
    model.eval()

    total = 0
    correct = 0
    for j, test_batch in enumerate(test_data):
        idx1, idx2, summation = test_batch[0]
        X1 = mnist_test_data[idx1][0].unsqueeze(0)
        X2 = mnist_test_data[idx2][0].unsqueeze(0)

        output1 = model(X1)
        output2 = model(X2)

        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        correct += (summation == (pred1 + pred2)).sum()
        total += len(test_batch)

    print('Test Accuracy: {}/{} ({:.0f}%)\n'.format(correct, total, 100. * correct / total)) 

In [7]:
def brute_force(output1, output2, summation):
    combinations =  torch.cartesian_prod(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
    probs = []
    for combination in combinations:
        probs += [output1[0][combination[0]] * output2[0][combination[1]]]
    probs = torch.stack(probs)
    indices = (combinations[:, 0] + combinations[:, 1] == summation).nonzero(as_tuple=True)[0]
    return -torch.log(probs[indices].sum())

In [8]:
from tqdm import tqdm

NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
      
    # train
    for i, batch in enumerate(tqdm(train_data)):

        model.train()
        optimizer.zero_grad()
        
        idx1, idx2, summation = batch[0]
        X1 = mnist_train_data[idx1][0].unsqueeze(0)
        X2 = mnist_train_data[idx2][0].unsqueeze(0)
        output1 = model(X1)
        output2 = model(X2)
        
        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        
        closs = brute_force(output1, output2, summation)
            
        closs.backward()
        optimizer.step()
        
        if i % 1000 == 0 and i != 0:
            test()

  return _VF.cartesian_prod(tensors)  # type: ignore[attr-defined]
  3%|▎         | 1011/29999 [00:20<1:16:40,  6.30it/s]

Test Accuracy: 2986/4999 (60%)



  7%|▋         | 2018/29999 [00:38<1:20:41,  5.78it/s]

Test Accuracy: 4467/4999 (89%)



 10%|█         | 3009/29999 [00:59<1:31:37,  4.91it/s]

Test Accuracy: 4458/4999 (89%)



 13%|█▎        | 4009/29999 [01:43<4:42:57,  1.53it/s]

Test Accuracy: 4547/4999 (91%)



 17%|█▋        | 5019/29999 [02:04<1:22:38,  5.04it/s]

Test Accuracy: 4435/4999 (89%)



 20%|██        | 6016/29999 [02:27<1:15:03,  5.32it/s]

Test Accuracy: 4694/4999 (94%)



 23%|██▎       | 7016/29999 [02:50<1:04:04,  5.98it/s]

Test Accuracy: 4671/4999 (93%)



 27%|██▋       | 8000/29999 [03:08<08:38, 42.45it/s]  


KeyboardInterrupt: 