# Project Overview

### <span style="color:red;">This notebook provides an experimental overview of a face anti-spoofing project based on the Swin Transformer, which was introduced in the paper *Swin Transformer: Hierarchical Vision Transformer using Shifted Windows* by Liu et al.</span>


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, SwinForImageClassification
from PIL import Image
import requests
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from sklearn.metrics import roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve


 Import Visual transformer model from https://huggingface.co/microsoft/swin-tiny-patch4-window7-224

In [None]:

image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")


In [None]:

# Set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)                  # Set the seed for Python's built-in random module
    np.random.seed(seed)               # Set the seed for NumPy
    torch.manual_seed(seed)            # Set the seed for PyTorch CPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # Set the seed for all GPUs (if using multiple GPUs)
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False      # Disable benchmark mode

# Call the function with your chosen seed
set_seed(42)  # You can choose any seed value

# Define a custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label

# Define data transforms with normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),             # Resize images to a consistent size
    transforms.ToTensor(),                     # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize with ImageNet mean and std
                         std=[0.229, 0.224, 0.225])
])

# Define paths to your data folders
train_data_dir = '/u/45/demo_file/long/train/'
val_data_dir = '/u/45/demo_file/long/validations/'
test_data_dir = '/u/45/demo_file/long/test/'

# Load your training, validation, and test datasets using ImageFolder
train_dataset = ImageFolder(root=train_data_dir, transform=transform)
val_dataset = ImageFolder(root=val_data_dir, transform=transform)
test_dataset = ImageFolder(root=test_data_dir, transform=transform)

# If you want to access the labels for train, validation, and test datasets:
train_labels = train_dataset.targets
val_labels = val_dataset.targets
test_labels = test_dataset.targets

# Create data loaders
batch_size = 8  # Adjust the batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Print number of images and classes
num_train_images = len(train_dataset)
num_val_images = len(val_dataset)
num_test_images = len(test_dataset)
num_classes = len(train_dataset.classes)

print(f"Number of training images: {num_train_images}")
print(f"Number of validation images: {num_val_images}")
print(f"Number of test images: {num_test_images}")
print(f"Number of classes: {num_classes}")
print(f"Class names: {train_dataset.classes}")


In [None]:
# Define a binary classification head
class BinaryClassificationHead(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BinaryClassificationHead, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Modify the Swin Transformer model for binary classification
classifier_head = BinaryClassificationHead(768, 32)  # Adjust input size as needed
model.classifier = classifier_head

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)


In [None]:


# Assuming model, train_loader, val_loader, test_loader, criterion, and optimizer are defined

num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Initialize early stopping parameters
patience = 2
verbose = True
delta = 0.001  # For validation loss improvements
best_val_loss = float('inf')
counter = 0
best_eer_threshold = 0.5  # Placeholder for the EER threshold from the validation set

# Initialize lists to store training and validation losses
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # Training loop
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data).logits  # Get logits from model
        loss = criterion(outputs, labels.unsqueeze(1).float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    val_labels = []  # Store true labels
    val_scores = []  # To store probabilities for EER calculation
    with torch.no_grad():
        for val_data, val_labels_batch in val_loader:
            val_data, val_labels_batch = val_data.to(device), val_labels_batch.to(device)
            val_outputs = model(val_data).logits
            val_loss += criterion(val_outputs, val_labels_batch.unsqueeze(1).float()).item()
            
            # Get probabilities (sigmoid for binary classification)
            val_probs = torch.sigmoid(val_outputs).cpu().numpy()  # Convert logits to probabilities
            val_scores.extend(val_probs)
            val_labels.extend(val_labels_batch.cpu().numpy())

    # Calculate average training and validation loss
    avg_train_loss = running_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    # EER Calculation on validation set
    fpr, tpr, thresholds = roc_curve(val_labels, val_scores, pos_label=1)  # Ensure pos_label is set
    fnr = 1 - tpr  # Calculate False Negative Rate
    eer_threshold_idx = np.nanargmin(np.abs(fnr - fpr))  # Find index where FNR == FPR
    eer_threshold = thresholds[eer_threshold_idx]  # EER threshold
    eer_fpr = fpr[eer_threshold_idx]
    eer_fnr = fnr[eer_threshold_idx]

    print(f"Epoch {epoch+1}, EER: {eer_fpr:.4f}, EER Threshold: {eer_threshold:.4f}")

    # Early stopping based on validation loss
    if avg_val_loss < best_val_loss - delta:
        best_val_loss = avg_val_loss
        best_eer_threshold = eer_threshold  # Save the EER threshold for the test set
        counter = 0
        torch.save(model.state_dict(), 'best_model.pth')  # Save the best model
    else:
        counter += 1
        if counter >= patience:
            if verbose:
                print(f"Early stopping at epoch {epoch+1}")
            break

    # Append losses for plotting
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='orange')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()

# Plot ROC Curve
plt.subplot(1, 2, 2)
plt.plot(fpr, tpr, label='ROC Curve', color='blue')
plt.plot([0, 1], [0, 1], 'k--')  # Diagonal line
plt.title('ROC Curve')
plt.xlabel('False Positive Rate (FAR)')
plt.ylabel('True Positive Rate (TPR)')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()


## Training completed and best EER threshold saved from validation

In [None]:


# Assuming best_eer_threshold is defined based on validation results

# --- Now calculate HTER on the test set using the EER threshold from validation ---
test_labels = []
test_scores = []

# Step 1: Gather predictions on the test set
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Get model predictions
        outputs = model(inputs)  # Model outputs
        scores = torch.sigmoid(outputs.logits).cpu().numpy()  # Compute probabilities

        test_labels.extend(labels.cpu().numpy())  # Collect true labels
        test_scores.extend(scores)  # Collect predicted scores (probabilities)

# Step 2: Initialize counters for FAR and FRR
false_accepts = 0  # Counter for false accepts (impostor samples classified as genuine)
false_rejects = 0  # Counter for false rejects (genuine samples classified as impostors)
total_genuine = 0  # Counter for total genuine samples
total_impostors = 0  # Counter for total impostor samples

# Step 3: Classify test set based on the EER threshold
for label, score in zip(test_labels, test_scores):
    if label == 0:  # Genuine sample
        total_genuine += 1
        if score >= best_eer_threshold:  # False rejection (genuine classified as impostor)
            false_rejects += 1
    else:  # Impostor sample
        total_impostors += 1
        if score < best_eer_threshold:  # False acceptance (impostor classified as genuine)
            false_accepts += 1

# Step 4: Calculate FAR and FRR while avoiding division by zero
far = false_accepts / total_impostors if total_impostors > 0 else 0  # False Acceptance Rate
frr = false_rejects / total_genuine if total_genuine > 0 else 0  # False Rejection Rate

# Step 5: Calculate HTER on the test set
hter = (far + frr) / 2  # Half Total Error Rate

# Output the results
print(f"Half Total Error Rate (HTER) on test set using EER threshold: {hter * 100:.2f}%")

# Step 6: Calculate AUC on test set
auc = roc_auc_score(test_labels, test_scores)
print(f"Area Under the ROC Curve (AUC) on test set: {auc * 100:.2f}%")
