In [1]:
# Import library
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x20a3f88e830>

In [3]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform)

# Make sure to have every number

# Select 10% of the data for this experiment
train_data = torch.utils.data.Subset(train_dataset, range(int(len(train_dataset)*0.1)))
test_data = torch.utils.data.Subset(test_dataset, range(int(len(test_dataset)*0.1)))

# Create data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)

In [4]:
import numpy as np
# check if the train_dataset has each number
# Get all images of each digits
for i in range(10):
    digit_indices = np.where(train_dataset.targets == i)[0]
    print(f"Image of number {i} size is {digit_indices.size}")

Image of number 0 size is 5923
Image of number 1 size is 6742
Image of number 2 size is 5958
Image of number 3 size is 6131
Image of number 4 size is 5842
Image of number 5 size is 5421
Image of number 6 size is 5918
Image of number 7 size is 6265
Image of number 8 size is 5851
Image of number 9 size is 5949


In [5]:
# Define the CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(64*4*4, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward_one(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        output = self.fc2(torch.abs(output1 - output2))
        return output

In [6]:
# Initialize the network and define the loss and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Training loop
for epoch in range(10):
    epoch_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        # Get the size of dimension 1 (channels)
        channels = inputs.size(1)
        # Split the inputs into two tensors along dimension 1
        # Get the size of dimension 0 (batch size)
        batch_size = inputs.size(0)
        # Split the inputs into two tensors along dimension 0
        inputs1, inputs2 = torch.split(inputs, batch_size // 2, dim=0)
        batch_size = labels.size(0)

        # Split the labels into two tensors along dimension 0
        labels1, labels2 = torch.split(labels, batch_size // 2, dim=0)

        # Compare the two sets of labels
        labels = (labels1 == labels2).long()
        optimizer.zero_grad()
        
        outputs = model(inputs1, inputs2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss = loss.item()
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader)}")

Epoch 1, Loss: 0.002386371981590352
Epoch 2, Loss: 0.0012819631461133349
Epoch 3, Loss: 0.0007066637911695115
Epoch 4, Loss: 0.0014740645568421546
Epoch 5, Loss: 0.0020091896678539035
Epoch 6, Loss: 0.001324441283941269
Epoch 7, Loss: 0.00030311329805470525
Epoch 8, Loss: 0.0004126962195051477
Epoch 9, Loss: 6.153716388693515e-05
Epoch 10, Loss: 0.00030202781187093004


In [7]:
# Evaluation mode
model.eval()

correct = 0
total = 0

# No gradient calculation
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        # Get the size of dimension 0 (batch size)
        batch_size = inputs.size(0)
        # Split the inputs into two tensors along dimension 0
        inputs1 = inputs[:batch_size // 2]
        inputs2 = inputs[batch_size // 2:]
        labels = (labels[:batch_size // 2] == labels[batch_size // 2:]).long()
        outputs = model(inputs1, inputs2)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = correct/total*100
print(f'Accuracy of the network on the test images: {accuracy}%')

Accuracy of the network on the test images: 98.0%
