# Data Preparation :

In [379]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

print(torch.__version__)

mnist_trainset = torchvision.datasets.MNIST(
                              root='./data', 
                              train=True, 
                              download=True, 
                              transform=transforms.Compose([
                                  transforms.ToTensor()
                              ])
                  )

1.9.0+cu111


## Logic to get MNIST dataset into custom tensors :

In [380]:
def create_one_hotencoding_inp(x: int):
  assert isinstance(x, int) or isinstance(x, np.int64) or isinstance(x, np.int32), "input should be of datatype 'int'"
  assert x>=0 and x<=9, "input should be between 0-9"
  base_arr = np.zeros(10)
  base_arr[x] = 1
  return base_arr

# create_one_hotencoding_inp("2")
# create_one_hotencoding_inp(10)

In [381]:
i=0
i_max=10000
lst_imgs = []
lst_labels = []
for img, label in mnist_trainset:
  lst_imgs.append(img)
  lst_labels.append(label)
  i += 1
  if i==i_max:
    break;

In [382]:
# Checking class balance
pd.Series(lst_labels).value_counts().sort_index()

0    1001
1    1127
2     991
3    1032
4     980
5     863
6    1014
7    1070
8     944
9     978
dtype: int64

In [383]:
inp_images = torch.stack(lst_imgs)
inp_images.shape

torch.Size([10000, 1, 28, 28])

In [384]:
out_img_labels = torch.tensor(lst_labels)
out_img_labels.shape

torch.Size([10000])

In [385]:
inp_rand_np = np.random.randint(0,10,len(out_img_labels))
print(f"Number of unique values in my random set '{pd.Series(inp_rand_np).nunique()}'")

Number of unique values in my random set '10'


In [386]:
print(inp_images.shape)

inp_rand = torch.tensor(inp_rand_np)
print(inp_rand.shape)

print(out_img_labels.shape)

out_sum = inp_rand + out_img_labels
print(out_sum.shape)

torch.Size([10000, 1, 28, 28])
torch.Size([10000])
torch.Size([10000])
torch.Size([10000])


In [387]:
inp_rand_ohe = torch.tensor([create_one_hotencoding_inp(x) for x in inp_rand_np])
inp_rand_ohe.shape

torch.Size([10000, 10])

In [388]:
from torch.utils.data import Dataset

class CustomMnistDataset(Dataset):
  def __init__(self):
    self.data = []
    for i in range(0,len(out_sum)):
      self.data.append([inp_images[i], inp_rand_ohe[i], out_img_labels[i], out_sum[i]])

  def __getitem__(self, index):
    return self.data[index]
    
  def __len__(self):
    return len(self.data)

In [389]:
myDataset = CustomMnistDataset()

In [390]:
len(myDataset)

10000

In [391]:
train_loader = torch.utils.data.DataLoader(myDataset
    ,batch_size=100
    ,shuffle=True
)

In [392]:
len(train_loader)

100

In [400]:
i=0
for tr in train_loader:
  print(tr[0].shape, tr[1].shape, tr[2].shape, tr[3].shape)
  i+=1
  if i==3:
    break;

torch.Size([100, 1, 28, 28]) torch.Size([100, 10]) torch.Size([100]) torch.Size([100])
torch.Size([100, 1, 28, 28]) torch.Size([100, 10]) torch.Size([100]) torch.Size([100])
torch.Size([100, 1, 28, 28]) torch.Size([100, 10]) torch.Size([100]) torch.Size([100])


# Model Building :

In [394]:
import torch.nn.functional as F
import torch.nn as nn

In [395]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) 
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
    self.fc1 = nn.Linear(in_features=(12 * 4 * 4)+10, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out1 = nn.Linear(in_features=60, out_features=10)
    self.out2 = nn.Linear(in_features=60, out_features=19)
  
  def forward(self, t1, t2):
    # input layer
    x, y = t1, t2

    # conv1 layer
    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) # 28 | 24 | 12

    # conv2 layer
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) # 12 | 8 | 4 >> 12x4x4

    # reshapre
    x = x.reshape(-1, 12 * 4 * 4)
    z = torch.cat((x,y.reshape(-1,10)), dim=1)
    
    # fc1 layer
    z = self.fc1(z)
    z = F.relu(z)

    # fc2 layer
    z = self.fc2(z)
    z = F.relu(z)

    # output layer
    o1 = self.out1(z)
    o1 = F.softmax(o1, dim=1)

    # output layer
    o2 = self.out2(z)
    o2 = F.softmax(o2, dim=1)
    return o1, o2

In [396]:
network = Network()

print(network)

Network(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=202, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (out1): Linear(in_features=60, out_features=10, bias=True)
  (out2): Linear(in_features=60, out_features=19, bias=True)
)


In [397]:
import torch.optim as optim

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f12951c6ed0>

In [398]:
def get_num_correct(preds, labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [399]:
optimizer = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(10):

    total_loss = 0
    total_correct_labels = 0
    total_correct_sum = 0

    for batch in train_loader: # Get Batch
        i_images, i_rand, o_labels, o_sum = batch 

        preds_labels, preds_sum = network(i_images, i_rand.float()) # Pass Batch
        loss1 = F.cross_entropy(preds_labels, o_labels) # Calculate Loss
        loss2 = F.cross_entropy(preds_sum, o_sum) # Calculate Loss
        loss = loss1 + loss2

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct_labels += get_num_correct(preds_labels, o_labels)
        total_correct_sum += get_num_correct(preds_sum, o_sum)

    print(
        "epoch", epoch, 
        "total_correct_labels:", total_correct_labels, 
        "total_correct_sum:", total_correct_sum, 
        "loss:", total_loss
    )

epoch 0 total_correct_labels: 5551 total_correct_sum: 926 loss: 483.2493739128113
epoch 1 total_correct_labels: 8133 total_correct_sum: 916 loss: 458.5680203437805
epoch 2 total_correct_labels: 9146 total_correct_sum: 899 loss: 448.76365900039673
epoch 3 total_correct_labels: 9052 total_correct_sum: 881 loss: 449.88876485824585
epoch 4 total_correct_labels: 9279 total_correct_sum: 909 loss: 447.25404691696167
epoch 5 total_correct_labels: 9196 total_correct_sum: 915 loss: 448.053747177124
epoch 6 total_correct_labels: 9285 total_correct_sum: 903 loss: 447.3347878456116
epoch 7 total_correct_labels: 9325 total_correct_sum: 883 loss: 447.1517286300659
epoch 8 total_correct_labels: 9044 total_correct_sum: 899 loss: 449.80251693725586
epoch 9 total_correct_labels: 8951 total_correct_sum: 902 loss: 450.6666383743286
