In [1]:
#importing necessary packages
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim

In [2]:
class AddDataset(Dataset):
  """
    The dataset will return a tuple ((image, label), random number)
  """
  def __init__(self, mnist_set):
    self.data = mnist_set

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    n = np.random.randint(10) #function used to create a random integer in [0,9]

    return (image, label) , n

  def __len__(self):
    return len(self.data)


In [3]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
    self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.fc3 = nn.Linear(in_features=60, out_features=10)

    self.fc4 = nn.Linear(in_features=2, out_features=30)
    self.fc5 = nn.Linear(in_features=30, out_features=120)
    self.out = nn.Linear(in_features=120, out_features=19)

  def forward(self, t, t2):

    x = t

    # MNIST block starts
    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)

    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    
    x = x.reshape(-1, 12*4*4)

    x = self.fc1(x)
    x = F.relu(x)

    x = self.fc2(x)
    x = F.relu(x)

    x = self.fc3(x)

    y = x
    #MNIST block ends

    x = F.softmax(x, dim=1)

    x = x.argmax(dim=1) #converting to prediction integer to use in sum block

    t3 = torch.stack((x, t2), dim = 1)
    t3 = t3.float()
    # the input vector for the sum block will be tuple of (predicted label, random number)
    # i.e we are combining the first input and the second input here

    # sum predictor block starts
    t3 = self.fc4(t3)
    t3 = F.relu(t3)

    t3 = self.fc5(t3)
    t3 = F.relu(t3)

    t3 = self.out(t3)
    # sum predictor block ends

    return y, t3 #returning the prediction tensors for MNIST and sum respectively

In [4]:
def get_num_correct(preds, labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [5]:
network = Network()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #setting device as GPU if available

network.to(device) #moving the neural network to device(GPU if available)

Network(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (fc3): Linear(in_features=60, out_features=10, bias=True)
  (fc4): Linear(in_features=2, out_features=30, bias=True)
  (fc5): Linear(in_features=30, out_features=120, bias=True)
  (out): Linear(in_features=120, out_features=19, bias=True)
)

In [6]:
mnist_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_set = AddDataset(mnist_set) #converting the MNIST dataset into our custom dataset

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=100,
    shuffle=True
) #initialising a dataLoader with batch_size = 100

optimiser = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(15):

  total_loss = 0
  total_image_correct = 0
  total_sum_correct = 0


  for batch in train_loader:
    (images, labels), ns = batch
    
    images = images.to(device)
    labels = labels.to(device)
    ns = ns.to(device)


    preds = network(images, ns)

    #different losses for the two predictions
    image_loss = F.cross_entropy(preds[0], labels)
    add_loss = F.cross_entropy(preds[1], labels+ns)

    optimiser.zero_grad()
    #different backward props since they have different tensors and different losses
    image_loss.backward()
    add_loss.backward()

    optimiser.step()

    total_loss += (image_loss + add_loss).item()
    total_image_correct += get_num_correct(preds[0], labels)
    total_sum_correct += get_num_correct(preds[1], labels+ns)

  print("epoch", epoch)
  print(
      "image loss:", total_loss,
      "add loss:", total_loss,
      "image_accuracy:", (total_image_correct/60000),
      "add_accuracy:", (total_sum_correct/60000)
  )


epoch 0
image loss: 827.0129678547382 add loss: 827.0129678547382 image_accuracy: 0.9443666666666667 add_accuracy: 0.6745
epoch 1
image loss: 383.8723166882992 add loss: 383.8723166882992 image_accuracy: 0.9779833333333333 add_accuracy: 0.9458
epoch 2
image loss: 244.88867956399918 add loss: 244.88867956399918 image_accuracy: 0.9804333333333334 add_accuracy: 0.9771833333333333
epoch 3
image loss: 180.79763338714838 add loss: 180.79763338714838 image_accuracy: 0.9830833333333333 add_accuracy: 0.9814333333333334
epoch 4
image loss: 159.52213974110782 add loss: 159.52213974110782 image_accuracy: 0.9838666666666667 add_accuracy: 0.9795333333333334
epoch 5
image loss: 136.27398498170078 add loss: 136.27398498170078 image_accuracy: 0.9851 add_accuracy: 0.9821333333333333
epoch 6
image loss: 135.65634261630476 add loss: 135.65634261630476 image_accuracy: 0.9851166666666666 add_accuracy: 0.9828333333333333
epoch 7
image loss: 124.36413899995387 add loss: 124.36413899995387 image_accuracy: 0.98

In [8]:
torch.set_grad_enabled(False)
sample = train_set[10] 
(image, label), n = sample
# image.shape, image.unsqueeze(0).shape

prred = network(image.to(device), torch.tensor([n]).to(device))
torch.set_grad_enabled(True)
print(prred, "tt:", label, n)

(tensor([[-18.7109,  -3.2683,  -8.1701,  25.6209, -15.6780,  13.8173, -12.4034,
           1.3868,   0.9079,  11.7303]], device='cuda:0'), tensor([[ -20.6820,  -12.2311,   -9.4469,   -9.2068,   -6.7497,   -6.5078,
           -4.6843,    1.2212,   -1.9230,   -6.0582,   -2.5299,   -4.9681,
           -2.9078,   -3.8296,   -7.3303,  -10.1213,  -10.1486,  -30.7801,
         -104.2772]], device='cuda:0')) tt: 3 4
