# DeepFake Detection

# 1: ResNet50 - Best Model

## Requirements

**To train the model**, first set the training and validation data paths in the constants section. Uncomment the training model part and run it to obtain the 'best_model' and reproduce the best model. The model weights will be saved in the current directory as 'best_model.pth.'

**To test the model:**

**Method 1:** set the testing data path in the constants section with 'test_data_path=...'. And go to model testing section. The model will load weights from 'best_model.pth' and run the test.

**Method 2:** set the test txt path in the constants section with 'test_txt_path=...'. And go to model testing section. The model will load weights from 'best_model.pth' and run the test.

**INSTALL**: pytorch, torchvision, sklearn

## Imports

In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import (
    Compose,
    ToTensor,
    Normalize,
    CenterCrop,
    Resize,
    RandomHorizontalFlip,
    RandomAffine,
)

from torchvision.models import resnet50
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score

## Constants

In [None]:
# Data paths
train_data_path = "./data/train/"
val_data_path = "./data/val/"

# Set on of the following
test_data_path = "./data/test/"
test_txt_path = "./test.txt"


# Set random seed and device
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Batch size
batch_size = 32

# Number of classes
num_classes = 2

# Input size
input_size = (317, 317)

# Number of epoch
num_epochs = 100

## Dataset loader

We used os.walk() method to find our images as the data images are contanined in many directories.

In [None]:
class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.imgs = []
        self._load_data()

    def _load_data(self):
        for root, dirs, _ in os.walk(self.data_path):
            for folder in dirs:
                folder_path = os.path.join(root, folder)

                for file in os.listdir(folder_path):
                    file_path = os.path.join(folder_path, file)

                    if file_path.endswith('jpg'):
                        label = 0 if "Fake" in folder_path else 1
                        self.imgs.append((file_path, label))

    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.imgs)


In [None]:
class MyDatasetTest(Dataset):
    def __init__(self, txt_path, transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs
        self.transform = transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.imgs)

## ResNet50

In [None]:
# Model
class ResNet(nn.Module):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        self.features = resnet50(pretrained=True)
        
        # Freeze all layers except the last three
        for param in self.features.parameters():
            param.requires_grad = False
            
        # Modify the last three layers
        self.features.layer3.requires_grad_(True)
        self.features.layer4.requires_grad_(True)
        self.features.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.features(x)
        return x

# Initialize model, criterion, and optimizer
model = ResNet(num_classes=num_classes)
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

## Training Functions

In [None]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0

    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)

        # Calculate loss
        loss = criterion(outputs, labels)
        train_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    train_loss /= len(data_loader)
    return train_loss

In [None]:
def validate(model, data_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate loss
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate predictions
            predictions = torch.argmax(outputs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    val_loss /= len(data_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    return val_loss, accuracy

In [None]:
def train_model(model, train_path, val_path, transform=None, target_transform=None, num_epochs=100, batch_size=32):
    # Datasets
    train_dataset = MyDataset(train_path, transform=transform)
    val_dataset = MyDataset(val_path, transform=target_transform)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    best_val_accuracy = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)

        # Save best model based on validation accuracy
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')

        # Print training and validation metrics for each epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}]\tTrain Loss: {train_loss:.4f}\tVal Loss: {val_loss:.4f}\tVal Accuracy: {val_accuracy:.4f}")

        with open('train_losses.txt', 'a') as f:
            f.write(f"{train_loss}\n")

        with open('val_losses.txt', 'a') as f:
            f.write(f"{val_loss}\n")

        with open('val_accuracies.txt', 'a') as f:
            f.write(f"{val_accuracy}\n")
    
    return best_val_accuracy, model

## Test Functions

In [None]:
def load_test_data(data_path, transform, isTxt = 0):
    if isTxt:
        test_dataset = MyDatasetTest(data_path, transform=transform)
    else:
        test_dataset = MyDataset(data_path, transform=transform)
        
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    return test_loader

In [None]:
def test(model, data_path, transform, isTxt = 0):
    test_loader = load_test_data(data_path, transform, isTxt)
    
    model.eval()
    
    accuracy, recall, precision, auc = 0, 0, 0, 0
    
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate predictions
            predictions = get_predictions(outputs)
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            
    # Calculate final metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions)
    auc = roc_auc_score(all_labels, all_predictions)

    return accuracy, recall, precision, auc

## Training & Testing 

### Transformations

**Augmentations**, we normalize all the images according to the the default values of mean and standard deviation of pythorch. The training images are further processed by appliying random horizontal flips, random shear and scale. The images are also resized to (317,317). Exact details can be seen below:

In [None]:
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transform = Compose([
    Resize((317, 317)),
    CenterCrop((317, 317)),
    RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize
])

val_transform = Compose([
    Resize((317, 317)),
    ToTensor(),
    normalize
])

test_transform = Compose([
    Resize((317, 317)),
    ToTensor(),
    normalize
])

### Training model

In [None]:
best_accuracy, best_model = train_model(model, train_data_path, val_data_path, transform=train_transform, target_transform=val_transform, num_epochs=100, batch_size=32)
print("Best validation accuracy: " + best_accuracy)

### Testing model

In [None]:
best_model = ResNet(num_classes=2)
best_model.to(device)
best_model.load_state_dict(torch.load('/kaggle/input/best-model/best_model.pth'))

### Method 1: Data in directory

In [None]:
accuracy, recall, precision, auc = test(best_model, test_data_path, test_transform)

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test AUC: {auc:.4f}")

### Method 2: Img paths in .txt file

In [None]:
accuracy, recall, precision, auc = test(best_model, test_txt_path, test_transform, isTxt = 1)

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test AUC: {auc:.4f}")