In [None]:
from google.colab import drive

drive.mount('/content/drive')

dataset_path = '/content/drive/My Drive/Tongji_Palmprint'


Mounted at /content/drive


In [None]:
!pip install gdown




In [None]:
import gdown

# The shared link to the dataset file
url = 'https://drive.google.com/uc?id=15hEsOm0fZKUHpFNChPSjwiRfMczxcnVQ'
output = '/content/drive/MyDrive/Tongji_Contactless_Palmprint_Dataset.zip'  # Path in your Google Drive

# Download the file
gdown.download(url, output, quiet=False)


Downloading...
From (original): https://drive.google.com/uc?id=15hEsOm0fZKUHpFNChPSjwiRfMczxcnVQ
From (redirected): https://drive.google.com/uc?id=15hEsOm0fZKUHpFNChPSjwiRfMczxcnVQ&confirm=t&uuid=86e174bd-f919-4595-927e-ecc4f73d3e7e
To: /content/drive/MyDrive/Tongji_Contactless_Palmprint_Dataset.zip
100%|██████████| 4.95G/4.95G [01:02<00:00, 79.6MB/s]


'/content/drive/MyDrive/Tongji_Contactless_Palmprint_Dataset.zip'

In [None]:
import zipfile
import os

# Define the path to the zip file
zip_file_path = '/content/Tongji_Contactless_Palmprint_Dataset.zip'
extract_dir = '/content/extracted_files'

# Function to extract files
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
  # List all files in the zip file
  file_list = zip_ref.namelist()
  print("Files in the zip file:", file_list)

  # Try to extract all files
  zip_ref.extractall(extract_dir)
  print(f"Files have been extracted to {extract_dir}")



# Extract the zip file
extract_zip(zip_file_path, extract_dir)


FileNotFoundError: [Errno 2] No such file or directory: '/content/Tongji_Contactless_Palmprint_Dataset.zip'

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import matplotlib.pyplot as plt
from PIL import Image
import os
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:

# Define the Siamese Network
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn = models.vgg16(pretrained=True).features
        self.fc1 = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 256)
        )

    def forward_once(self, x):
        x = self.cnn(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

# Define the Contrastive Loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

# Custom Dataset for Siamese Network
class SiameseDataset(Dataset):
    def __init__(self, imageFolderDataset, transform=None):
        self.imageFolderDataset = imageFolderDataset
        self.transform = transform

    def __getitem__(self, index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        # Ensure the second image is of the same class or different class based on label
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            while True:
                img1_tuple = random.choice(self.imageFolderDataset.imgs)
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            while True:
                img1_tuple = random.choice(self.imageFolderDataset.imgs)
                if img0_tuple[1] != img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        img0 = img0.convert("RGB")
        img1 = img1.convert("RGB")

        if self.transform:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        return img0, img1, torch.from_numpy(np.array([int(img0_tuple[1] != img1_tuple[1])], dtype=np.float32))

    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [None]:
# Training process
def train_siamese_network(train_dataloader, num_epochs=20, learning_rate=0.0001):
    model = SiameseNetwork().to(device)
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data
            img0, img1, label = img0.to(device), img1.to(device), label.to(device)

            optimizer.zero_grad()
            output1, output2 = model(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()

            if i % 10 == 0:
                print(f"Epoch number {epoch+1}/{num_epochs}, Current loss {loss_contrastive.item():.4f}")

    return model

In [None]:
# Define transformations and dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
folder_dataset = torchvision.datasets.ImageFolder(root="data/training")
siamese_dataset = SiameseDataset(imageFolderDataset=folder_dataset, transform=transform)
train_dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=8, batch_size=32)

# Train the Siamese Network
siamese_model = train_siamese_network(train_dataloader)

# Save the trained model
torch.save(siamese_model.state_dict(), 'siamese_network.pt')

# Function to visualize some sample pairs
def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()

siamese_model = train_siamese_network(train_dataloader)
torch.save(siamese_model.state_dict(), 'siamese_network.pt')