In [29]:
import os
import re
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from torch.utils.data import SubsetRandomSampler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print(f'Using device: {device} - {torch.cuda.get_device_name(0)}')
else:
    raise Exception("GPU is not available. Please run this notebook on a system with a GPU.")

Using device: cuda:0 - NVIDIA GeForce RTX 3080


In [30]:
class FingerprintDataset(Dataset):
    def __init__(self, real_dir, altered_dir, transform=None, limit=None):
        self.real_dir = real_dir
        self.altered_dir = altered_dir
        self.transform = transform
        self.limit = limit

        real_images = os.listdir(real_dir)[:limit] if limit else os.listdir(real_dir)
        altered_images = os.listdir(altered_dir)

        self.image_dict = {real_img: self.find_altered(real_img, altered_images) for real_img in real_images}
        self.image_dict = {k: v for k, v in self.image_dict.items() if v is not None}

        print(f'Length of image_dict: {len(self.image_dict)}')

    def find_altered(self, real_img, altered_images):
        real_img_prefix = re.match(r'\d+', real_img).group() if re.match(r'\d+', real_img) else None
        for file in altered_images:
            if file.startswith(real_img_prefix):
                return file
        return None

    def __len__(self):
        return len(self.image_dict)

    def __getitem__(self, idx):
        if idx >= len(self.image_dict):
            raise IndexError('Index out of range')

        real_img, altered_img = list(self.image_dict.items())[idx]

        real_image = Image.open(os.path.join(self.real_dir, real_img))
        altered_image = Image.open(os.path.join(self.altered_dir, altered_img))

        if self.transform:
            real_image = self.transform(real_image)
            altered_image = self.transform(altered_image)

        return real_image, altered_image, 1, real_img, altered_img

In [31]:
transform = transforms.Compose([
    transforms.Grayscale(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to [-1, 1]
])

In [32]:
real_dir = 'dataset/Real'
altered_dir = 'dataset/Altered/Altered-Easy'

dataset = FingerprintDataset(real_dir, altered_dir, transform=transform)

n_splits = 5

kfold = KFold(n_splits=n_splits, shuffle=True)

n_samples = len(dataset)

for fold, (train_indices, val_indices) in enumerate(kfold.split(range(n_samples))):
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=6)
    val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=6)


Length of image_dict: 5956


In [33]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        
        self.feature_extractor = models.resnet18(pretrained=True)
        
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        self.feature_extractor.fc = nn.Identity()
        
        # Define the final layer that computes the absolute difference between the two outputs
        self.final_layer = nn.Linear(512, 1)

    def forward(self, input1, input2):
        # Ensure the input has at least 4 dimensions
        if len(input1.shape) < 4:
            input1 = input1.unsqueeze(1)
        if len(input2.shape) < 4:
            input2 = input2.unsqueeze(1)

        # Ensure the input has the correct dimensions
        input1 = input1.view(input1.shape[0], -1, input1.shape[2], input1.shape[3])
        input2 = input2.view(input2.shape[0], -1, input2.shape[2], input2.shape[3])

        # Pass both inputs through the feature extractor
        output1 = self.feature_extractor(input1)
        output2 = self.feature_extractor(input2)
        
        # Compute the absolute difference between the two outputs
        diff = torch.abs(output1 - output2)
        
        # Pass the difference through the final layer to get the output
        output = self.final_layer(diff)
        
        return output

In [34]:
model = SiameseNetwork().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [35]:
num_epochs = 1

for fold, (train_indices, val_indices) in enumerate(kfold.split(range(n_samples))):
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=6)
    val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=6)

    # Reset the model and optimizer
    model = SiameseNetwork().to(device)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs1, inputs2, labels, real_img, altered_img) in enumerate(train_loader):
            inputs1 = inputs1.to(device)
            inputs2 = inputs2.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs1, inputs2)
            outputs = outputs.squeeze()  # Remove the extra dimension

            labels = labels.float().squeeze()
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / (i+1)))

        val_loss = 0.0
        model.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        val_preds = []
        val_labels_list = []
        with torch.no_grad():  # Disable gradient computation
            for i, (inputs1, inputs2, labels, real_img, altered_img) in enumerate(val_loader):
                inputs1 = inputs1.to(device)
                inputs2 = inputs2.to(device)
                labels = labels.to(device)

                outputs = model(inputs1, inputs2)
                outputs = outputs.squeeze()  # Remove the extra dimension

                labels = labels.float().squeeze()
                loss = criterion(outputs, labels)

                val_loss += loss.item()

                val_preds.extend(outputs.detach().cpu().numpy())
                val_labels_list.extend(labels.detach().cpu().numpy())

            # if i % 10 == 0:
            #     print('Validation loss: %.3f' % (val_loss / len(val_loader)))

        val_preds = np.array(val_preds) > 0.5  # Convert to binary predictions
        val_labels_list = np.array(val_labels_list)
        accuracy = accuracy_score(val_labels_list, val_preds)
        precision = precision_score(val_labels_list, val_preds)
        recall = recall_score(val_labels_list, val_preds)
        f1 = f1_score(val_labels_list, val_preds)

        print('Validation loss: %.3f' % (val_loss / len(val_loader)))
        print('Validation Accuracy: %.3f' % accuracy)
        print('Validation Precision: %.3f' % precision)
        print('Validation Recall: %.3f' % recall)
        print('Validation F1 Score: %.3f' % f1)

        model.train()  # Set the model back to training mode

        print('Finished Training for Fold %d' % (fold + 1))
    print('Finished Training for Fold %d' % (fold + 1))

[1,     1] loss: 0.937
[1,     2] loss: 0.732
[1,     3] loss: 0.569
[1,     4] loss: 0.465
[1,     5] loss: 0.398
[1,     6] loss: 0.336
[1,     7] loss: 0.295
[1,     8] loss: 0.270
[1,     9] loss: 0.243
[1,    10] loss: 0.220
[1,    11] loss: 0.204
[1,    12] loss: 0.188
[1,    13] loss: 0.176
[1,    14] loss: 0.164
[1,    15] loss: 0.154
[1,    16] loss: 0.145
[1,    17] loss: 0.138
[1,    18] loss: 0.131
[1,    19] loss: 0.126
[1,    20] loss: 0.120
[1,    21] loss: 0.115
[1,    22] loss: 0.111
[1,    23] loss: 0.106
[1,    24] loss: 0.102
[1,    25] loss: 0.099
[1,    26] loss: 0.095
[1,    27] loss: 0.092
[1,    28] loss: 0.089
[1,    29] loss: 0.086
[1,    30] loss: 0.083
[1,    31] loss: 0.081
[1,    32] loss: 0.079
[1,    33] loss: 0.077
[1,    34] loss: 0.075
[1,    35] loss: 0.073
[1,    36] loss: 0.071
[1,    37] loss: 0.069
[1,    38] loss: 0.067
[1,    39] loss: 0.066
[1,    40] loss: 0.064
[1,    41] loss: 0.063
[1,    42] loss: 0.061
[1,    43] loss: 0.060
[1,    44] 



[1,     1] loss: 0.831
[1,     2] loss: 0.677
[1,     3] loss: 0.561
[1,     4] loss: 0.472
[1,     5] loss: 0.392
[1,     6] loss: 0.338
[1,     7] loss: 0.295
[1,     8] loss: 0.263
[1,     9] loss: 0.238
[1,    10] loss: 0.218
[1,    11] loss: 0.203
[1,    12] loss: 0.189
[1,    13] loss: 0.174
[1,    14] loss: 0.163
[1,    15] loss: 0.155
[1,    16] loss: 0.146
[1,    17] loss: 0.139
[1,    18] loss: 0.132
[1,    19] loss: 0.126
[1,    20] loss: 0.120
[1,    21] loss: 0.114
[1,    22] loss: 0.110
[1,    23] loss: 0.106
[1,    24] loss: 0.103
[1,    25] loss: 0.099
[1,    26] loss: 0.095
[1,    27] loss: 0.092
[1,    28] loss: 0.090
[1,    29] loss: 0.087
[1,    30] loss: 0.084
[1,    31] loss: 0.081
[1,    32] loss: 0.079
[1,    33] loss: 0.077
[1,    34] loss: 0.075
[1,    35] loss: 0.073
[1,    36] loss: 0.071
[1,    37] loss: 0.069
[1,    38] loss: 0.067
[1,    39] loss: 0.066
[1,    40] loss: 0.065
[1,    41] loss: 0.063
[1,    42] loss: 0.062
[1,    43] loss: 0.060
[1,    44] 



[1,     1] loss: 1.418
[1,     2] loss: 1.101
[1,     3] loss: 0.903
[1,     4] loss: 0.739
[1,     5] loss: 0.613
[1,     6] loss: 0.526
[1,     7] loss: 0.459
[1,     8] loss: 0.412
[1,     9] loss: 0.374
[1,    10] loss: 0.340
[1,    11] loss: 0.311
[1,    12] loss: 0.286
[1,    13] loss: 0.266
[1,    14] loss: 0.248
[1,    15] loss: 0.233
[1,    16] loss: 0.220
[1,    17] loss: 0.208
[1,    18] loss: 0.198
[1,    19] loss: 0.189
[1,    20] loss: 0.180
[1,    21] loss: 0.172
[1,    22] loss: 0.165
[1,    23] loss: 0.158
[1,    24] loss: 0.152
[1,    25] loss: 0.146
[1,    26] loss: 0.141
[1,    27] loss: 0.136
[1,    28] loss: 0.132
[1,    29] loss: 0.128
[1,    30] loss: 0.124
[1,    31] loss: 0.120
[1,    32] loss: 0.116
[1,    33] loss: 0.113
[1,    34] loss: 0.110
[1,    35] loss: 0.108
[1,    36] loss: 0.105
[1,    37] loss: 0.103
[1,    38] loss: 0.100
[1,    39] loss: 0.098
[1,    40] loss: 0.096
[1,    41] loss: 0.093
[1,    42] loss: 0.091
[1,    43] loss: 0.089
[1,    44] 



[1,     1] loss: 0.725
[1,     2] loss: 0.607
[1,     3] loss: 0.520
[1,     4] loss: 0.432
[1,     5] loss: 0.368
[1,     6] loss: 0.315
[1,     7] loss: 0.278
[1,     8] loss: 0.250
[1,     9] loss: 0.225
[1,    10] loss: 0.204
[1,    11] loss: 0.186
[1,    12] loss: 0.173
[1,    13] loss: 0.160
[1,    14] loss: 0.150
[1,    15] loss: 0.141
[1,    16] loss: 0.134
[1,    17] loss: 0.127
[1,    18] loss: 0.120
[1,    19] loss: 0.115
[1,    20] loss: 0.110
[1,    21] loss: 0.105
[1,    22] loss: 0.101
[1,    23] loss: 0.097
[1,    24] loss: 0.093
[1,    25] loss: 0.090
[1,    26] loss: 0.087
[1,    27] loss: 0.084
[1,    28] loss: 0.081
[1,    29] loss: 0.078
[1,    30] loss: 0.076
[1,    31] loss: 0.073
[1,    32] loss: 0.071
[1,    33] loss: 0.069
[1,    34] loss: 0.068
[1,    35] loss: 0.066
[1,    36] loss: 0.064
[1,    37] loss: 0.063
[1,    38] loss: 0.061
[1,    39] loss: 0.060
[1,    40] loss: 0.058
[1,    41] loss: 0.057
[1,    42] loss: 0.056
[1,    43] loss: 0.055
[1,    44] 



[1,     1] loss: 0.379
[1,     2] loss: 0.345
[1,     3] loss: 0.281
[1,     4] loss: 0.242
[1,     5] loss: 0.202
[1,     6] loss: 0.177
[1,     7] loss: 0.157
[1,     8] loss: 0.142
[1,     9] loss: 0.129
[1,    10] loss: 0.117
[1,    11] loss: 0.108
[1,    12] loss: 0.101
[1,    13] loss: 0.094
[1,    14] loss: 0.089
[1,    15] loss: 0.083
[1,    16] loss: 0.078
[1,    17] loss: 0.074
[1,    18] loss: 0.071
[1,    19] loss: 0.067
[1,    20] loss: 0.065
[1,    21] loss: 0.062
[1,    22] loss: 0.060
[1,    23] loss: 0.058
[1,    24] loss: 0.055
[1,    25] loss: 0.053
[1,    26] loss: 0.051
[1,    27] loss: 0.050
[1,    28] loss: 0.048
[1,    29] loss: 0.047
[1,    30] loss: 0.045
[1,    31] loss: 0.044
[1,    32] loss: 0.043
[1,    33] loss: 0.041
[1,    34] loss: 0.041
[1,    35] loss: 0.040
[1,    36] loss: 0.038
[1,    37] loss: 0.038
[1,    38] loss: 0.037
[1,    39] loss: 0.036
[1,    40] loss: 0.035
[1,    41] loss: 0.034
[1,    42] loss: 0.034
[1,    43] loss: 0.033
[1,    44] 