In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from timm.models.vision_transformer import vit_base_patch16_224


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Define directories
data_dir = 'Datasets'

In [5]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [6]:
# Load datasets
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Split dataset into train, validation, and test sets
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [7]:
# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [8]:
# Define Vision Transformer model
class FingerprintClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(FingerprintClassifier, self).__init__()
        self.model = vit_base_patch16_224(pretrained=True)
        self.model.head = nn.Linear(self.model.head.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

In [9]:
# Initialize model, loss function, and optimizer
model = FingerprintClassifier(num_classes=3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
# Training loop
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {running_loss/len(train_loader):.4f}, "
              f"Validation Loss: {val_loss/len(val_loader):.4f}")

In [11]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

Epoch [1/10], Train Loss: 1.6837, Validation Loss: 1.2407
Epoch [2/10], Train Loss: 0.6444, Validation Loss: 0.3714
Epoch [3/10], Train Loss: 0.2106, Validation Loss: 0.2003
Epoch [4/10], Train Loss: 0.0645, Validation Loss: 0.0424
Epoch [5/10], Train Loss: 0.0424, Validation Loss: 0.0377
Epoch [6/10], Train Loss: 0.0116, Validation Loss: 0.1488
Epoch [7/10], Train Loss: 0.0493, Validation Loss: 0.0321
Epoch [8/10], Train Loss: 0.0193, Validation Loss: 0.0337
Epoch [9/10], Train Loss: 0.0346, Validation Loss: 0.0464
Epoch [10/10], Train Loss: 0.0412, Validation Loss: 0.0539


In [12]:
# Evaluate the model on the test set
def evaluate_model(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {accuracy:.4f}")


In [13]:
# Evaluate the model
evaluate_model(model, test_loader, criterion)

Test Loss: 0.0663, Test Accuracy: 0.9757


In [14]:
# Save the model
torch.save(model.state_dict(), 'fingerprint_classifier.pth')