In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.optim as optim

In [9]:
cuda = torch.cuda.is_available()
print("CUDA Available:", cuda)

device = torch.device("cuda" if cuda else "cpu")

CUDA Available: True
CUDA Available: True


In [10]:
def encoding(x):
    binary = bin(int(x))[2:]
    result = list(map(float, list(str(0) * (5 - len(binary)) + str(binary))))
    result = torch.tensor(result,device=device).unsqueeze(dim=0)
    return result

In [11]:
class MyDataset(Dataset):
    def __init__(self, root="./data", to_train=True, transform=None):
        super().__init__()

        if not transform:
            transform = transforms.Compose([transforms.ToTensor()])

        self.mnist_data = MNIST(
            root=root,
            train=to_train,
            transform=transform,
            download=True)

    def __getitem__(self, index_to_fetch):
        image, target = self.mnist_data.__getitem__(index_to_fetch)
        random_num = torch.randint(0, 9, (1,), dtype=torch.float32, device=device)

        label_plus_random_num_encoded = encoding(target + random_num)

        return image, target, random_num, label_plus_random_num_encoded

    def __len__(self):
        return len(self.mnist_data)

In [12]:
class Network(nn.Module):
    @staticmethod
    def _get_conv_layer(in_features, out_features, kernel_size=3):
        return nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, bias=False)

    def __init__(self):
        super().__init__()
        self.conv1 = Network._get_conv_layer(in_features=1, out_features=32)
        self.conv2 = Network._get_conv_layer(in_features=32, out_features=32)
        self.conv3 = Network._get_conv_layer(in_features=32, out_features=64)
        self.conv4 = Network._get_conv_layer(in_features=64, out_features=64)

        self.conv_fc1 = nn.Linear(20 * 20 * 64, 256, bias=False)
        # self.conv_fc2 = nn.Linear(256, 256, bias=False)

        self.fc1 = nn.Linear(in_features=1, out_features=16, bias=False)
        self.fc2 = nn.Linear(16, 32, bias=False)
        self.fc3 = nn.Linear(32, 64, bias=False)

        self.fc_combined_1 = nn.Linear(256 + 64, 256)
        self.fc_combined_2 = nn.Linear(256, 256)
        self.output1_fc = nn.Linear(256, 10)
        self.output2_fc = nn.Linear(256, 5)

    def forward(self, x, y):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = self.conv_fc1(F.relu(x))
        # x = self.conv_fc2(F.relu(x))

        y = F.relu(self.fc1(y))
        y = F.relu(self.fc2(y))
        y = F.relu(self.fc3(y))

        z = torch.concat([x, y], dim=1)
        z = F.relu(self.fc_combined_1(z))
        z = F.relu(self.fc_combined_2(z))

        out1 = self.output1_fc(z)
        out2 = self.output2_fc(z)

        return F.softmax(out1, dim=1), out2

In [13]:
def train(train_set, network):
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
    optimizer = optim.Adam(network.parameters(), lr=0.01)
    criterion1 = F.cross_entropy
    criterion2 = nn.BCEWithLogitsLoss()
    # criterion2 = nn.MultiLabelSoftMarginLoss()

    for epoch in range(30):

        total_loss = 0
        total_correct = 0

        for batch in train_loader:  # Get Batch
            images, labels, random_nums, labels_plus_random_nums_encoded = batch

            images = images.to(device)
            labels = labels.to(device)
            labels_plus_random_nums_encoded.to(device)

            mnist_preds, number_preds = network(images, random_nums)  # Pass Batch

            # labels_plus_rand_nums = labels + random_nums.squeeze(dim=1)
            # labels_plus_rand_nums = encode(labels_plus_rand_nums)

            labels_plus_random_nums_encoded = labels_plus_random_nums_encoded.squeeze(dim=1)

            loss1 = criterion1(mnist_preds, labels)  # Calculate loss 1
            loss2 = criterion2(number_preds, labels_plus_random_nums_encoded)  # Calculate loss 2

            loss = loss1 + loss2  # combine both the loss values

            optimizer.zero_grad()  # clear previous gradients

            loss.backward()  # Calculate Gradients

            optimizer.step()  # Update Weights

            total_loss += loss.item()

            # check correct predictions for MNIST
            mnist_correct = mnist_preds.argmax(dim=1).eq(labels).sum().item()
            total_correct += mnist_correct

            # check correct predictions for number outputs
            numbers_preds_correct = number_preds.eq(labels_plus_random_nums_encoded).sum().item()
            total_correct += numbers_preds_correct

        print(
            "epoch", epoch,
            "total_correct:", total_correct,
            "total_loss:", total_loss,
            "mnist_preds_correct:", mnist_correct,
            "numbers_preds_correct:", numbers_preds_correct
        )


In [14]:
model = Network()
model = model.to(device)
train(MyDataset(), model)

epoch 0 total_correct: 30182 total_loss: 1482.0938935279846 mnist_preds_correct: 68 numbers_preds_correct: 0
epoch 1 total_correct: 47383 total_loss: 1221.2797288894653 mnist_preds_correct: 84 numbers_preds_correct: 0
epoch 2 total_correct: 50975 total_loss: 1128.7114737033844 mnist_preds_correct: 82 numbers_preds_correct: 0
epoch 3 total_correct: 55423 total_loss: 1056.1598411798477 mnist_preds_correct: 97 numbers_preds_correct: 0
epoch 4 total_correct: 56521 total_loss: 1022.8202037811279 mnist_preds_correct: 92 numbers_preds_correct: 0
epoch 5 total_correct: 56772 total_loss: 1006.7159051895142 mnist_preds_correct: 95 numbers_preds_correct: 0
epoch 6 total_correct: 56993 total_loss: 995.5969506502151 mnist_preds_correct: 97 numbers_preds_correct: 0
epoch 7 total_correct: 57130 total_loss: 976.068575501442 mnist_preds_correct: 98 numbers_preds_correct: 0
epoch 8 total_correct: 54523 total_loss: 1045.0489032268524 mnist_preds_correct: 95 numbers_preds_correct: 0
epoch 9 total_correct: