In [9]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from sklearn.metrics import f1_score
from PIL import Image
import pandas as pd

from tqdm import tqdm

In [4]:
class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Load real images
        real_dir = os.path.join(root_dir, 'train_images/train_images/real_train')
        for img_name in os.listdir(real_dir):
            self.images.append(os.path.join(real_dir, img_name))
            self.labels.append(0)  # 0 for real

        # Load fake images
        fake_dir = os.path.join(root_dir, 'train_images/train_images/fake_train')
        for img_name in os.listdir(fake_dir):
            self.images.append(os.path.join(fake_dir, img_name))
            self.labels.append(1)  # 1 for fake

            
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
# Define data augmentation and normalization
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
full_dataset = DeepfakeDataset(root_dir='/kaggle/input/wec-intelligence-sig-2024-recruitment-task-cv', transform=data_transforms)

# Split the dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [7]:
# Load pre-trained VGG model
model = models.vgg16(pretrained=True)
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, 2)  # Binary classification

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 50
best_val_f1 = 0.0

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 227MB/s]  


In [10]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    # Validation
    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_f1 = f1_score(val_labels, val_preds, average='weighted')
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_deepfake_detection_model.pth')

    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val F1: {val_f1:.4f}")


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 1/50, Loss: 0.1006, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 2/50, Loss: 0.0585, Val F1: 0.9937


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 3/50, Loss: 0.0171, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 4/50, Loss: 0.0488, Val F1: 0.9812


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 5/50, Loss: 0.0261, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 6/50, Loss: 0.0032, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 7/50, Loss: 0.0520, Val F1: 0.9844


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 8/50, Loss: 0.0259, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 9/50, Loss: 0.0070, Val F1: 0.9937


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 10/50, Loss: 0.0009, Val F1: 1.0000


100%|██████████| 40/40 [00:10<00:00,  3.65it/s]


Epoch 11/50, Loss: 0.1767, Val F1: 0.9656


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 12/50, Loss: 0.1715, Val F1: 0.9844


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 13/50, Loss: 0.0458, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 14/50, Loss: 0.0169, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 15/50, Loss: 0.0205, Val F1: 0.9937


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 16/50, Loss: 0.0176, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 17/50, Loss: 0.0011, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 18/50, Loss: 0.0019, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 19/50, Loss: 0.0560, Val F1: 0.9750


100%|██████████| 40/40 [00:11<00:00,  3.63it/s]


Epoch 20/50, Loss: 0.0076, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 21/50, Loss: 0.0673, Val F1: 0.9781


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 22/50, Loss: 0.0286, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 23/50, Loss: 0.0094, Val F1: 0.9937


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 24/50, Loss: 0.0012, Val F1: 1.0000


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 25/50, Loss: 0.0003, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 26/50, Loss: 0.0002, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 27/50, Loss: 0.0001, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 28/50, Loss: 0.0000, Val F1: 1.0000


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 29/50, Loss: 0.0000, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 30/50, Loss: 0.0000, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 31/50, Loss: 0.0000, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 32/50, Loss: 0.0000, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 33/50, Loss: 0.0002, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 34/50, Loss: 0.0016, Val F1: 1.0000


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 35/50, Loss: 0.0003, Val F1: 1.0000


100%|██████████| 40/40 [00:11<00:00,  3.63it/s]


Epoch 36/50, Loss: 0.4594, Val F1: 0.8263


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 37/50, Loss: 0.2459, Val F1: 0.8956


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 38/50, Loss: 0.1366, Val F1: 0.8563


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 39/50, Loss: 0.0541, Val F1: 0.9750


100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


Epoch 40/50, Loss: 0.0140, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 41/50, Loss: 0.0243, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 42/50, Loss: 0.0161, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 43/50, Loss: 0.0123, Val F1: 0.9875


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 44/50, Loss: 0.0229, Val F1: 0.9938


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 45/50, Loss: 0.0073, Val F1: 0.9937


100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


Epoch 46/50, Loss: 0.0083, Val F1: 0.9969


100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


Epoch 47/50, Loss: 0.0025, Val F1: 0.9906


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 48/50, Loss: 0.0361, Val F1: 0.9969


100%|██████████| 40/40 [00:11<00:00,  3.63it/s]


Epoch 49/50, Loss: 0.1109, Val F1: 0.9844


100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


Epoch 50/50, Loss: 0.0211, Val F1: 0.9906


## Generating outputs on test set

In [30]:
class DeepfakeTest(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []

        # Load real images
        self.images_dir = os.path.join(root_dir, 'test_images/test_images/')
        for img_name in os.listdir(self.images_dir):
            self.images.append(img_name)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]
        img_path = os.path.join(self.images_dir, image_name)
        image_id = image_name.split('.')[0].split('_')[2]
        image = Image.open(img_path).convert('RGB') 

        if self.transform:
            image = self.transform(image)

        return image, image_id


test_transforms = data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [36]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm

# Assuming the DeepfakeTest dataset and model are defined earlier in the script

test_dataset = DeepfakeTest("/kaggle/input/wec-intelligence-sig-2024-recruitment-task-cv/", test_transforms)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

# Load best model for testing
model.load_state_dict(torch.load('best_deepfake_detection_model.pth'))
model.eval()

results = {
    'ID': [],
    'TARGET': [],
}

with torch.no_grad():
    for inputs, image_ids in tqdm(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        results['TARGET'].extend(preds.cpu().numpy())
        results['ID'].extend(image_ids)


results_df = pd.DataFrame(results)
results_df['ID'] = results_df['ID'].astype(int)
results_df = results_df.sort_values('ID')

results_df.to_csv('test_predictions.csv', index=False)

  model.load_state_dict(torch.load('best_deepfake_detection_model.pth'))
100%|██████████| 400/400 [00:02<00:00, 151.51it/s]


In [39]:
output_df = pd.read_csv('test_predictions.csv')
output_df.head()

Unnamed: 0,ID,TARGET
0,1,0
1,2,0
2,3,1
3,4,0
4,5,1
