<a href="https://colab.research.google.com/github/SyedMa3/eva8/blob/main/session-2/session_2point5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim

In [None]:
class AddDataset(Dataset):
  def __init__(self, mnist_set):
    self.data = mnist_set

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    n = np.random.randint(10)

    return (image, label) , n

  def __len__(self):
    return len(self.data)


In [None]:
mnist_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_set = AddDataset(mnist_set)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3)
    self.fc1 = nn.Linear(in_features=12*5*5, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.fc3 = nn.Linear(in_features=60, out_features=10)

    self.fc4 = nn.Linear(in_features=2, out_features=30)
    self.fc5 = nn.Linear(in_features=30, out_features=120)
    self.out = nn.Linear(in_features=120, out_features=19)

  def forward(self, t, t2):

    x = t

    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)

    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    
    x = x.reshape(-1, 12*5*5)
    # print("here", x.shape)

    x = self.fc1(x)
    x = F.relu(x)

    x = self.fc2(x)
    x = F.relu(x)

    x = self.fc3(x)

    y = x
    x = F.softmax(x, dim=1)

    x = x.argmax(dim=1)
    # print("here", x.shape)
    # x = x.reshape(-1, 1)
    # t2 = t2.reshape(-1, 1)


    t3 = torch.stack((x, t2), dim = 1)
    t3 = t3.float()
    # print(t3.dtype)
    # t3 = t3.reshape(-1, 2)
    # print("here", t3.shape)
    
    t3 = self.fc4(t3)
    t3 = F.relu(t3)

    t3 = self.fc5(t3)
    t3 = F.relu(t3)

    t3 = self.out(t3)
    # t3 = F.softmax(t3, dim=1)
    # t3 = t3.argmax(dim=1)

    # print("here", t3)

    return y, t3

In [None]:
def get_num_correct(preds, labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [None]:
network = Network()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

network.to(device)

Network(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=300, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (fc3): Linear(in_features=60, out_features=10, bias=True)
  (fc4): Linear(in_features=2, out_features=30, bias=True)
  (fc5): Linear(in_features=30, out_features=120, bias=True)
  (out): Linear(in_features=120, out_features=19, bias=True)
)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=100,
    shuffle=True
)

optimiser = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(15):

  total_loss = 0
  total_image_correct = 0
  total_sum_correct = 0


  for batch in train_loader:
    (images, labels), ns = batch
    
    images = images.to(device)
    labels = labels.to(device)
    ns = ns.to(device)


    preds = network(images, ns)
    # print(labels.shape, ns.shape)
    # print(preds[0].shape, preds[1].shape)

    image_loss = F.cross_entropy(preds[0], labels)
    add_loss = F.cross_entropy(preds[1], labels+ns)

    # loss = image_loss + add_loss5: Coding Drill Dow

    optimiser.zero_grad()
    image_loss.backward()
    add_loss.backward()

    optimiser.step()

    total_loss += (image_loss + add_loss).item()
    total_image_correct += get_num_correct(preds[0], labels)
    total_sum_correct += get_num_correct(preds[1], labels+ns)

  print(
      "epoch", epoch,
      "loss", total_loss,
      "image_correct", total_image_correct,
      "sum_correct", total_sum_correct
  )


epoch 0 loss 1409.4939019382 image_correct 59341 sum_correct 40764
epoch 1 loss 290.56004270911217 image_correct 59408 sum_correct 55887
epoch 2 loss 204.38898973912 image_correct 59439 sum_correct 58158
epoch 3 loss 196.818211004138 image_correct 59388 sum_correct 58663
epoch 4 loss 164.41627035290003 image_correct 59421 sum_correct 59073
epoch 5 loss 153.78037232533097 image_correct 59491 sum_correct 59160
epoch 6 loss 135.7962089497596 image_correct 59471 sum_correct 59099
epoch 7 loss 133.7856511529535 image_correct 59465 sum_correct 59272
epoch 8 loss 99.13924493454397 image_correct 59596 sum_correct 59476
epoch 9 loss 117.81851453334093 image_correct 59473 sum_correct 59087


In [None]:
torch.set_grad_enabled(False)
sample = next(iter(train_set)) 
(image, label), n = sample
# image.shape, image.unsqueeze(0).shape

prred = network(image.to(device), torch.tensor([n]).to(device))

print(prred, "tt:", label, n)

(tensor([[-11.6144, -27.3194, -15.9017,   0.7179, -29.3899,  15.2625, -22.8032,
         -14.0756,  -6.7766,  -3.3354]], device='cuda:0'), tensor([[  -0.6735,    4.4018,    4.7424,    4.9949,    6.1105,    5.9598,
            3.4933,    7.3932,   12.9197,    9.6122,    7.2826,    4.9864,
            2.0107,    4.3197,    3.9219,  -19.0030,  -57.6831, -109.1557,
         -187.5523]], device='cuda:0')) tt: 5 3


In [None]:
torch.set_grad_enabled(True)


<torch.autograd.grad_mode.set_grad_enabled at 0x7ff4538ba0d0>