In [1]:
root_dir ="/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/"

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import vit_b_16
import torch.optim as optim
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import random

In [3]:
seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
dataset_root_dir = root_dir + 'DeepfakeEmpiricalStudy/dataset/'
train_dir = dataset_root_dir + 'CELEB/test'
val_dir = dataset_root_dir + 'CELEB/val'
test_dirs = [dataset_root_dir + 'CELEB-M/test', dataset_root_dir + 'DF/test', dataset_root_dir + 'DFD/test', \
             dataset_root_dir + 'F2F/test', dataset_root_dir + 'FS-I/test', dataset_root_dir + 'NT-I/test' ]

models_root_dir = root_dir + 'DeepfakeEmpiricalStudy_Models/'

In [5]:
batch_size = 64
num_epochs = 5
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# class TransformerBasedModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super(TransformerBasedModel, self).__init__()
#         self.vit = vit_b_16(pretrained=True)
#         #self.vit.heads = nn.Linear(self.vit.heads.in_features, num_classes)

#     def forward(self, x):
#         return self.vit(x)

In [6]:
class TransformerBasedModel(nn.Module):
    def __init__(self, num_classes=2):
        super(TransformerBasedModel, self).__init__()
        self.vit = vit_b_16(pretrained=False)

        vit_weights = torch.load(models_root_dir + 'vit_b_16-c867db91.pth')
        self.vit.load_state_dict(vit_weights)

        for param in self.vit.parameters():
            param.requires_grad = True

        self.classifier = nn.Linear(self.vit.heads.head.in_features, num_classes)
        self.vit.heads.head = self.classifier

    def forward(self, x):
        x = self.vit(x)
        return x


model = TransformerBasedModel(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  vit_weights = torch.load(models_root_dir + 'vit_b_16-c867db91.pth')


In [7]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    model.train()
    best_acc = 0.0

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, criterion)[0]

        print("Epoch "+str(epoch+1)+", Loss: "+str(running_loss/total)+", Train Accuracy: "+str(train_acc)+", Val Accuracy: "+str(val_acc))

        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/total:.4f}, Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), models_root_dir + 'best_vit_model.pth')
            print('Model saved!')

    print(f"Training complete. Best validation accuracy: {best_acc:.4f}")

def evaluate_model(model, loader, criterion):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return correct / total, np.array(all_labels), np.array(all_preds)

In [8]:
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

100%|██████████| 32/32 [12:05<00:00, 22.68s/it]


In [None]:
model.load_state_dict(torch.load(models_root_dir + 'best_vit_model.pth'))

all_labels_combined = []
all_preds_combined = []

for test_dir in test_dirs:
    test_dataset = datasets.ImageFolder(test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    test_acc, all_labels, all_preds = evaluate_model(model, test_loader, criterion)
    print(f"Test Accuracy for {test_dir}: {test_acc:.4f}")

    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1])
    plot_confusion_matrix(cm, classes=['real', 'fake'], title=f'Confusion Matrix for {test_dir}')

    all_labels_combined.extend(all_labels)
    all_preds_combined.extend(all_preds)

cm_combined = confusion_matrix(all_labels_combined, all_preds_combined, labels=[0, 1])
print(f"Average Accuracy: {np.mean([evaluate_model(model, DataLoader(datasets.ImageFolder(test_dir, transform=transform), batch_size=batch_size, shuffle=False), criterion)[0] for test_dir in test_dirs]):.4f}")
plot_confusion_matrix(cm_combined, classes=['real', 'fake'], title='Combined Confusion Matrix')