In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import TensorDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.classification import MultilabelAccuracy
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Data Preparation

In [3]:
X_train = np.load("data/X_train.npy")
X_val = np.load("data/X_val.npy")


y_train = np.load("data/Y_train.npy")
y_val = np.load("data/Y_val.npy")

In [4]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32) / 255
X_val_tensor = torch.tensor(X_val, dtype=torch.float32) / 255
y_train_tensor = torch.tensor(y_train, dtype=torch.int64)
y_val_tensor = torch.tensor(y_val, dtype=torch.int64)

In [5]:
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

## Model Building

This problem can be framed as a multilabel classification because one image has two labels.

So unlike one digit MNIST classification, output layer should have 20 units.

CNN-based architecture is used.

3 consecutive conv blocks, each with 1 convolutional layer and 1 max pooling layer, are used for feature extraction.

As image size is 64 $\times$ 64, 3 conv layers are enough.

Classifier consists of 3 dense layers.

In [6]:
class ClassificationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = self.conv_module(1, 32)  # 32
        self.conv2 = self.conv_module(32, 64)  # 16
        self.conv3 = self.conv_module(64, 128)  # 8

        self.classifier = self.classification_module(
            8 * 8 * 128, 10 * 2
        )  # Multi-label

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        y = self.classifier(x)

        return y

    @staticmethod
    def conv_module(in_channels, out_channels):
        module = nn.Sequential(
            # nn.Conv2d(in_channels, in_channels, 3, 1, 1),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        return module

    @staticmethod
    def classification_module(in_features, num_class):
        classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, num_class),
            nn.Softmax(dim=1),
        )

        return classifier

In [7]:
model = ClassificationNet().to(device)

In [8]:
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    pin_memory=True,
)
val_loader = DataLoader(val_dataset, batch_size=128, pin_memory=True)

In [9]:
epochs = 50
optimizer = Adam(model.parameters(), lr=1e-2)
criterion = CrossEntropyLoss()
lr_scheduler = OneCycleLR(optimizer, max_lr=1e-2, epochs=epochs, steps_per_epoch=15)
metric = MultilabelAccuracy(num_labels=10 * 2).to(device)

In [10]:
for epoch in range(epochs):
    model.train()
    train_loss = 0

    for data, labels in train_loader:
        data = data.to(device).unsqueeze(1)
        labels = labels.to(device)

        first_digit_onehot = F.one_hot(labels[:, 0], num_classes=10)
        second_digit_onehot = F.one_hot(labels[:, 1], num_classes=10)
        one_hot_labels = torch.cat(
            (first_digit_onehot, second_digit_onehot), dim=1
        ).to(torch.float32)

        optimizer.zero_grad()

        out = model(data)

        loss = criterion(out, one_hot_labels)

        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    if epoch > 0:
        print("-" * 20)
    print(f"Epoch {epoch+1}")
    print(f"Train Loss: {train_loss / len(train_loader)}")

    model.eval()
    val_loss = 0

    with torch.no_grad():

        for data, labels in val_loader:
            data = data.to(device).unsqueeze(1)
            labels = labels.to(device)

            first_digit_onehot = F.one_hot(labels[:, 0], num_classes=10)
            second_digit_onehot = F.one_hot(labels[:, 1], num_classes=10)
            one_hot_labels = torch.cat(
                (first_digit_onehot, second_digit_onehot), dim=1
            ).to(torch.float32)

            out = model(data)
            loss = criterion(out, one_hot_labels)

            metric.update(out, one_hot_labels)
            val_loss += loss.item()

        print(f"Val Loss: {val_loss / len(val_loader)}")

        print(f"Accuracy on all validation data: {metric.compute().item()}")

Epoch 1
Train Loss: 5.613500724585292
Val Loss: 5.490601503396336
Accuracy on all validation data: 0.9180700182914734
--------------------
Epoch 2
Train Loss: 5.301175937104149
Val Loss: 5.28937715216528
Accuracy on all validation data: 0.9276624917984009
--------------------
Epoch 3
Train Loss: 5.233660463327036
Val Loss: 5.292125732083864
Accuracy on all validation data: 0.9308183193206787
--------------------
Epoch 4
Train Loss: 5.212146262391307
Val Loss: 5.2530099410045
Accuracy on all validation data: 0.9332712888717651
--------------------
Epoch 5
Train Loss: 5.198953593500887
Val Loss: 5.246218639084056
Accuracy on all validation data: 0.9348629713058472
--------------------
Epoch 6
Train Loss: 5.192322243516818
Val Loss: 5.218808807904208
Accuracy on all validation data: 0.9363750219345093
--------------------
Epoch 7
Train Loss: 5.185159026624295
Val Loss: 5.213337186016614
Accuracy on all validation data: 0.9375128746032715
--------------------
Epoch 8
Train Loss: 5.17552690

In [11]:
torch.save(model.state_dict(), "best.pt")