<a href="https://colab.research.google.com/github/Pavithra777/Assignment2.5/blob/main/Assignment2_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset

In [None]:
# Loading taining dataset from MNIST datasets
train_set = torchvision.datasets.MNIST(
    root = './data',
    train = True,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [None]:
train_set

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [None]:
import torch
from torchvision import datasets, transforms

# This class add random input to the MNIST dataset
class MNISTWithExtraInput(torch.utils.data.Dataset):
    def __init__(self, root, extra_input_data, transform=None):
        self.root = root
        self.transform = transform
        self.extra_input_data = extra_input_data

        # Load the MNIST dataset
        self.mnist_dataset = datasets.MNIST(root=self.root, train=True, download=True, transform=self.transform)

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        # Get the MNIST data and labels for the given index
        x, y = self.mnist_dataset[idx]

        # Get the extra input data for the given index
        extra_input = self.extra_input_data[idx]

        return x, y, extra_input


In [None]:
import torch
from torchvision import datasets, transforms

# Combining List of MNISTWithExtraInput dataset 
class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, datasets):
        self.datasets = datasets

    def __len__(self):
        # Return the total number of elements in all datasets
        return sum(len(dataset) for dataset in self.datasets)

    def __getitem__(self, idx):
        # Find the dataset and index for the given element
        for dataset in self.datasets:
            if idx < len(dataset):
                return dataset[idx]
            idx -= len(dataset)

In [None]:
dataset_list = []
# Adding random inputs ranging from 0 to 9 
for i in range(0,10):
  zero_tensor = x = torch.zeros(60000,1)
  transform = transforms.Compose([
        transforms.ToTensor()
    ])
  mnist_dataset = MNISTWithExtraInput('./data', zero_tensor+i,transform=transform)
  dataset_list.append(mnist_dataset)
# Merging all 10 MNISTWithExtraInput 
combined_dataset=CombinedDataset(dataset_list)

In [None]:
# Creating dataloader of batch size 64
train_loader = torch.utils.data.DataLoader(combined_dataset, batch_size=64, shuffle=True)


In [None]:
train_loader.dataset[0]
train_loader.dataset[8]

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels= 32,kernel_size=3)  # 1x28x28|32x3x3|32x26x26
    self.conv2 = nn.Conv2d(in_channels=32, out_channels= 64,kernel_size=3) # 32x26x26|64x3x3|64x24x24
    self.conv3 = nn.Conv2d(in_channels=64, out_channels= 128,kernel_size=3) # 64x12x12|128x3x3|128x10x10
    self.out = nn.Linear(128*10*10,10)  # First output layer has 10 class  
    self.out1 = nn.Linear(20,40) 
    self.out2 = nn.Linear(40,20) # Second output layer has 20 class (0 to 19)
   
  def forward(self,t,y):
    # conv1 layer 
    x = self.conv1(t)
    x= F.relu(x)
    # conv2 layer 
    x = self.conv2(x)
    x= F.relu(x)
    x = F.max_pool2d(x,kernel_size=2,stride=2)
    # conv3 layer 
    x = self.conv3(x)
    x= F.relu(x)
    x= x.reshape(64,-1)
    x = self.out(x)
    y= y.reshape(64,-1)
    #Merging first output layer with second input layer
    comb = torch.concat([x,y],axis = 1)
    x1 = self.out1(comb) 
    x2 = self.out2(x1) 
    return x,x2
   

In [None]:
def get_num_correct(preds,labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [None]:
import torch.optim as optim
network = Network()
train_loader = torch.utils.data.DataLoader(combined_dataset, batch_size=64, shuffle=True)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Network().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)



for t in range(0,5):
  total_loss = 0
  total_correct1 =0
  total_correct2 =0
  for batch in train_loader:
    images,labels,num = batch
    # moving all the input tensor to GPU
    images,labels , num = images.to(device),labels.to(device),num.to(device)
    random_num = num
    # Converting random number (integer) to binary column
    num = F.one_hot(num.to(torch.int64), num_classes = 10)
    # Calling neural network which takes first input as image and second input as random number
    #returns 2 output, First one is number detected on image
    # and second one is sum of number on image and random number 
    pred1,pred2 = model(images,num)
    #Calculating loss for first prediction
    loss1 = F.cross_entropy(pred1.squeeze(),labels)
    #Creating label for second prediction (sum of first output and second input)
    label2 = labels + random_num.squeeze() 
    #Calculating loss for second prediction
    #Using Cross-entropy loss Function  
    #It measures the difference between the predicted probability distribution and the true probability distribution of the classes.
    loss2 = F.cross_entropy(pred2.squeeze(),label2.long())
    #Calculating total loss
    loss = loss1+loss2
    #zeroing the gradients from the previous step
    optimizer.zero_grad()
    #computing the gradients of the loss with respect to the model's parameters
    loss.backward()
    # updating the model's parameters based on these gradients
    optimizer.step()
    total_loss +=loss
    # Gets total number of correct output
    correct1 = get_num_correct(pred1,labels)
    correct2 = get_num_correct(pred2,label2)
    total_correct1 += correct1 
    total_correct2 += correct2
  print("epoch : " , t , "total loss : ", total_loss.item(), " correct1 : ", total_correct1, "correct2 : ",total_correct2)

epoch :  0 total loss :  10401.4091796875  correct1 :  595649 correct2 :  372379
epoch :  1 total loss :  5014.6826171875  correct1 :  599665 correct2 :  501401
epoch :  2 total loss :  3386.888427734375  correct1 :  599936 correct2 :  533625
epoch :  3 total loss :  2402.40478515625  correct1 :  599984 correct2 :  552877
epoch :  4 total loss :  1811.1605224609375  correct1 :  599994 correct2 :  564475
