In [100]:
import torch
import torchvision
import torchvision.transforms as transforms

In [101]:
mnist_train_set = torchvision.datasets.FashionMNIST(
    root = './data',
    train = True,
    download = True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [102]:
from torch.utils.data import Dataset
import torch.nn.functional as F


class MnistWithRandomNumberDataset(Dataset):
  def __init__(self, mnist_data, random_nums):
    self.mnist_data = mnist_data
    self.random_nums = random_nums

  def __len__(self):
    return len(self.mnist_data)

  def __getitem__(self, loc):
    img, label = self.mnist_data[loc]
    random_num = self.random_nums[loc]
    sum = label + random_num
    # print(label, random_num, sum)
    # random_num_one_hot = F.one_hot(torch.tensor(random_num), num_classes=10)
    # Convert the tensor to dtype torch.float32
    # random_num_one_hot = random_num_one_hot.to(dtype=torch.float64)
    # print(random_num_one_hot)
    # sum_one_hot = F.one_hot(torch.tensor(sum), num_classes=19)
    return img, label, random_num, sum

In [103]:
from torch.utils.data import DataLoader
import random
random.seed(23)

# Load MNIST data and random numbers
random_nums = [random.randint(0, 9) for i in range(len(mnist_train_set))]
dataset = MnistWithRandomNumberDataset(mnist_train_set, random_nums)
train_loader = DataLoader(dataset, batch_size = 64, shuffle = True)

In [104]:
import torch.optim as optim
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f626c6cca90>

In [105]:
def get_num_correct(preds, labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [106]:
# Get the first batch
batch = next(iter(train_loader))

# Extract the data and label
images, true_labels, random_nums, true_sums = batch

In [139]:
import torch
import torch.nn as nn

class Network(nn.Module):
  def __init__(self):
    super(Network, self).__init__()
    
    # Convolutional layers to process the image
    self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=3)
    
    
    # Fully connected layers to process the image
    self.fc1 = nn.Linear(in_features=128, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=60, out_features=10)
    
  def forward(self, x, r):
    # print("x shape1: ", x.shape)
    x = self.conv1(x)
    # print("x shape2: ", x.shape)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)

    # print("x shape3: ", x.shape)
    x = self.conv2(x)
    # print("x shape4: ", x.shape)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)


    x = self.conv3(x)
    # print("x shape5: ", x.shape)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    # print("x shape6: ", x.shape)

    x = x.view(-1, 128)
    # print("x shape7: ", x.shape)
    x = self.fc1(x)
    # print("x shape8: ", x.shape)
    x = F.relu(x)
    # print("x shape9: ", x.shape)
    x = self.fc2(x)
    # print("x shape10: ", x.shape)
    x = F.relu(x)
    # print("x shape11: ", x.shape)
    x = self.out(x)
    # print("x shape12: ", x.shape)
    # print("r shape: ", r.unsqueeze(1).shape)

    max_indices_x = torch.argmax(x, dim=1)  # shape: (64,)
    summed_indices = max_indices_x + r  # shape: (64,)
    sum = F.one_hot(summed_indices, num_classes=19) 
    sum = sum.to(dtype=torch.float32)
    sum.requires_grad_()
    # print("one_hot_tensor shape: ", sum.shape)
    # print(sum)
    return x, sum

**building network for single batch**

In [140]:
network = Network()

# train_loader = MnistWithRandomNumberDataset(mnist_train_set, random_nums)
train_loader = DataLoader(dataset, batch_size = 64, shuffle = True)
optimizer = optim.Adam(network.parameters(), lr=0.01)

batch = next(iter(train_loader)) # Get Batch
images, labels, random_nums, sums = batch

label_preds, sum_preds = network(images, random_nums) # Pass Batch
label_loss = F.cross_entropy(label_preds, labels)
sum_loss = F.cross_entropy(sum_preds, sums)
loss = 0.5 * (label_loss + sum_loss) # Calculate Loss
print('loss1:', loss.item())
print('correct1:', get_num_correct(label_preds, labels))
optimizer.zero_grad()
label_loss.backward() # Calculate Gradients
optimizer.step() # Update Weights


label_preds, sum_preds = network(images, random_nums) # Pass Batch
label_loss = F.cross_entropy(label_preds, labels)
sum_loss = F.cross_entropy(sum_preds, sums)
loss = 0.5 * (label_loss + sum_loss) # Calculate Loss
print('loss2:', loss.item())
print('correct2:', get_num_correct(label_preds, labels))

loss1: 2.598670244216919
correct1: 9
loss2: 2.5742971897125244
correct2: 9


**Doing for multiple epochs and batches**

In [None]:
train_loader = DataLoader(dataset, batch_size = 64, shuffle = True)
optimizer = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(3):

    total_loss = 0
    total_correct_label = 0
    total_loss_label = 0
    total_correct_sum = 0
    total_loss_sum = 0

    for batch in train_loader: # Get Batch
        images, labels, random_nums, sums = batch 

        label_preds, sum_preds = network(images, random_nums) # Pass Batch
        label_loss = F.cross_entropy(label_preds, labels)
        sum_loss = F.cross_entropy(sum_preds, sums)
        loss = 0.5 * (label_loss + sum_loss) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_loss_label += label_loss.item()
        total_loss_sum += sum_loss.item()
        total_correct_label += get_num_correct(label_preds, labels)
        total_correct_sum += get_num_correct(sum_preds, sums)

    print(
        "epoch", epoch, 
        "total_correct_label:", total_correct_label, 
        "total_loss_label:", total_loss_label,
        "total_correct_sum:", total_correct_sum,
        "total_loss_sum:", total_loss_sum,
        "loss:", total_loss
    )

epoch 0 total_correct_label: 45865 total_loss_label: 584.5705933868885 total_correct_sum: 45865 total_loss_sum: 2126.015678882599 loss: 1355.293135046959
