In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.io import decode_image
from sklearn.model_selection import KFold
import numpy as np
import os
import cv2
from PIL import Image

In [2]:
# Prepare the data - create a KFold split
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# "ours" dataset
authentic_path = "dane/CelebA/authentic"
spoof_path = "dane/CelebA/spoof"

# Load images
authentic_images = [os.path.join(authentic_path, img) for img in os.listdir(authentic_path)]
spoof_images = [os.path.join(spoof_path, img) for img in os.listdir(spoof_path)]

# Create X and y from all images
X = authentic_images + spoof_images
y = np.concatenate((np.zeros(len(authentic_images)), np.ones(len(spoof_images)))).astype(np.int64)

# KFold
kf = KFold(n_splits=5, random_state=random_seed, shuffle=True)

In [3]:
# Custom dataset to load images and labels
class OurDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, image_path  # Return image, label, and file path

In [4]:
# Create a text file to store incorrect predictions
with open("incorrect_predictions_swintransformer.txt", "w") as f:
    f.write("name\tpredicted\tactual\n")
    f.write("(0 - authentic, 1 - spoof)\n")

In [5]:
# Create a transform to resize images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [6]:
# Create, test and evaluate the model for each fold
for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
    print(f"Fold {fold}")
    X_train = [X[i] for i in train_idx]
    y_train = y[train_idx]
    X_test = [X[i] for i in test_idx]
    y_test = y[test_idx]
    
    # Create custom datasets
    train_dataset = OurDataset(X_train, y_train, transform=transform)
    test_dataset = OurDataset(X_test, y_test, transform=transform)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # Device
    device = torch.device("mps" if torch.cuda.is_available() else "cpu")
    
    # Create a model
    model = models.swin_v2_s(weights=None)   # s or v2_s
    num_ftrs = model.head.in_features
    model.head = torch.nn.Linear(num_ftrs, 2) # Two output neurons
    model = model.to(device)

    # Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 10
    
    model.train()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels, _ in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
        
    # Evaluate the model
    correct = 0
    total = 0
    incorrect_files = []
    
    accuracy = 0

    model.eval()

    with torch.no_grad():
        for images, labels, file_paths in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Collect file names of incorrectly predicted images
            with open("incorrect_predictions_swintransformer.txt", "a") as f:
                for i in range(len(predicted)):
                    if predicted[i].item() != labels[i].item():
                        f.write(file_paths[i] + "\t" + str(predicted[i].item()) + "\t" + str(labels[i].item()) + "\n")

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("-" * 50)

Fold 0
Epoch 1/10, Loss: 0.7311827356521964
Epoch 2/10, Loss: 0.6867667886991848
Epoch 3/10, Loss: 0.6800270661117712
Epoch 4/10, Loss: 0.6797727919740528
Epoch 5/10, Loss: 0.6787192307063984
Epoch 6/10, Loss: 0.6787282568329955
Epoch 7/10, Loss: 0.6789998850128266
Epoch 8/10, Loss: 0.678522928748635
Epoch 9/10, Loss: 0.6787379395403936
Epoch 10/10, Loss: 0.6791095795623144
Test Accuracy: 58.60%
--------------------------------------------------
Fold 1
Epoch 1/10, Loss: 0.7428869377802315
Epoch 2/10, Loss: 0.6900480958783234
Epoch 3/10, Loss: 0.6804863088465646
Epoch 4/10, Loss: 0.6792258324821445
Epoch 5/10, Loss: 0.6794366477060566
Epoch 6/10, Loss: 0.6795762480854781
Epoch 7/10, Loss: 0.6793073624424133
Epoch 8/10, Loss: 0.6788225710082095
Epoch 9/10, Loss: 0.6791249136577653
Epoch 10/10, Loss: 0.6787023382129042
Test Accuracy: 59.01%
--------------------------------------------------
Fold 2
Epoch 1/10, Loss: 0.7415686093748546
Epoch 2/10, Loss: 0.6836627777569009
Epoch 3/10, Loss: 