In [171]:
import os
import re
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print(f'Using device: {device} - {torch.cuda.get_device_name(0)}')
else:
    raise Exception("GPU is not available. Please run this notebook on a system with a GPU.")

Using device: cuda:0 - NVIDIA GeForce RTX 3080


In [172]:
class FingerprintDataset(Dataset):
    def __init__(self, real_dir, altered_dir, transform=None, limit=None):
        self.real_dir = real_dir
        self.altered_dir = altered_dir
        self.transform = transform
        self.limit = limit

        real_images = os.listdir(real_dir)[:limit] if limit else os.listdir(real_dir)
        altered_images = os.listdir(altered_dir)

        self.image_dict = {real_img: self.find_altered(real_img, altered_images) for real_img in real_images}
        self.image_dict = {k: v for k, v in self.image_dict.items() if v is not None}

        print(f'Length of image_dict: {len(self.image_dict)}')

    def find_altered(self, real_img, altered_images):
        real_img_prefix = re.match(r'\d+', real_img).group() if re.match(r'\d+', real_img) else None
        for file in altered_images:
            if file.startswith(real_img_prefix):
                return file
        return None

    def __len__(self):
        return len(self.image_dict)

    def __getitem__(self, idx):
        if idx >= len(self.image_dict):
            raise IndexError('Index out of range')

        real_img, altered_img = list(self.image_dict.items())[idx]

        real_image = Image.open(os.path.join(self.real_dir, real_img))
        altered_image = Image.open(os.path.join(self.altered_dir, altered_img))

        if self.transform:
            real_image = self.transform(real_image)
            altered_image = self.transform(altered_image)

        return real_image, altered_image, 1, real_img, altered_img

In [173]:
transform = transforms.Compose([
    transforms.Grayscale(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to [-1, 1]
])

In [174]:
real_dir = 'dataset/Real'
altered_dir = 'dataset/Altered/Altered-Easy'

dataset = FingerprintDataset(real_dir, altered_dir, transform=transform)

#TODO SET NUMBER OF THREADS ACCORDINGLY

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=6)

Length of image_dict: 5956


In [175]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        
        self.feature_extractor = models.resnet18(pretrained=True)
        
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        self.feature_extractor.fc = nn.Identity()
        
        # Define the final layer that computes the absolute difference between the two outputs
        self.final_layer = nn.Linear(512, 1)

    def forward(self, input1, input2):
        # Ensure the input has at least 4 dimensions
        if len(input1.shape) < 4:
            input1 = input1.unsqueeze(1)
        if len(input2.shape) < 4:
            input2 = input2.unsqueeze(1)

        # Ensure the input has the correct dimensions
        input1 = input1.view(input1.shape[0], -1, input1.shape[2], input1.shape[3])
        input2 = input2.view(input2.shape[0], -1, input2.shape[2], input2.shape[3])

        # Pass both inputs through the feature extractor
        output1 = self.feature_extractor(input1)
        output2 = self.feature_extractor(input2)
        
        # Compute the absolute difference between the two outputs
        diff = torch.abs(output1 - output2)
        
        # Pass the difference through the final layer to get the output
        output = self.final_layer(diff)
        
        return output

In [176]:
model = SiameseNetwork().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [177]:
num_epochs = 1
running_loss = 0.0

for epoch in range(num_epochs):
    for i, (inputs1, inputs2, labels, real_img, altered_img) in enumerate(dataloader):
        print(f'Training with real image: {real_img}, altered image: {altered_img}')

        inputs1 = inputs1.to(device)
        inputs2 = inputs2.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs1, inputs2)
        outputs = outputs.squeeze()  # Remove the extra dimension

        labels = labels.float().squeeze()
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / (i+1)))

    running_loss = 0.0

print('Finished Training')

Training with real image: ('125__M_Left_thumb_finger.BMP', '289__F_Right_thumb_finger.BMP', '238__M_Right_index_finger.BMP', '363__M_Left_ring_finger.BMP', '98__M_Left_thumb_finger.BMP', '534__F_Right_thumb_finger.BMP', '211__M_Left_thumb_finger.BMP', '202__M_Left_little_finger.BMP', '308__M_Right_little_finger.BMP', '45__M_Right_index_finger.BMP', '483__M_Left_ring_finger.BMP', '186__M_Right_thumb_finger.BMP', '478__M_Left_ring_finger.BMP', '404__M_Right_middle_finger.BMP', '165__M_Left_middle_finger.BMP', '293__F_Left_middle_finger.BMP'), altered image: ('125__M_Right_little_finger_Obl.BMP', '289__F_Right_middle_finger_Obl.BMP', '238__M_Left_little_finger_Obl.BMP', '363__M_Left_middle_finger_Zcut.BMP', '98__M_Right_middle_finger_Obl.BMP', '534__F_Left_ring_finger_CR.BMP', '211__M_Left_middle_finger_CR.BMP', '202__M_Right_ring_finger_Zcut.BMP', '308__M_Left_middle_finger_CR.BMP', '45__M_Right_little_finger_Obl.BMP', '483__M_Left_middle_finger_CR.BMP', '186__M_Left_middle_finger_CR.BMP