In [1]:
%pip install timm tqdm scikit-learn matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class DeepFakeDataset(Dataset):
    def __init__(self, dataset_dir, transform=None):
        self.dataset_dir = dataset_dir
        self.transform = transform
        self.images = []
        self.labels = []

        for label_dir in os.listdir(dataset_dir):
            label_path = os.path.join(dataset_dir, label_dir)
            if os.path.isdir(label_path):
                label = 0 if label_dir == "real" else 1
                for image_name in os.listdir(label_path):
                    image_path = os.path.join(label_path, image_name)
                    if image_path.lower().endswith(('.jpg', '.jpeg', '.png')):
                        self.images.append(image_path)
                        self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.images[idx])
            if image.mode != "RGB":
                image = image.convert("RGB")

            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.images[idx]}: {str(e)}")
            placeholder_image = torch.zeros((3, 224, 224)) if self.transform else Image.new('RGB', (224, 224), color='black')
            return placeholder_image, self.labels[idx]


In [3]:
from torchvision import transforms

def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return train_transform, valid_transform


In [4]:
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

def prepare_data(dataset_dir='data', batch_size=32):
    split_data_dir = os.path.join(dataset_dir, "split_data")

    if not os.path.exists(split_data_dir):
        print("Creating split data directories...")
        os.makedirs(split_data_dir, exist_ok=True)
        for split in ["train", "val", "test"]:
            for label in ["real", "fake"]:
                os.makedirs(os.path.join(split_data_dir, split, label), exist_ok=True)

        for label in ["real", "fake"]:
            source_dir = os.path.join(dataset_dir, label)
            images = [f for f in os.listdir(source_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            train, temp = train_test_split(images, test_size=0.2, random_state=42)
            val, test = train_test_split(temp, test_size=0.5, random_state=42)

            for split, split_imgs in zip(["train", "val", "test"], [train, val, test]):
                for img in tqdm(split_imgs, desc=f"Copying {label} {split} images"):
                    shutil.copy2(os.path.join(source_dir, img), os.path.join(split_data_dir, split, label, img))

    train_tf, valid_tf = get_transforms()
    train_ds = DeepFakeDataset(os.path.join(split_data_dir, "train"), transform=train_tf)
    val_ds = DeepFakeDataset(os.path.join(split_data_dir, "val"), transform=valid_tf)
    test_ds = DeepFakeDataset(os.path.join(split_data_dir, "test"), transform=valid_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader, test_loader


In [5]:
import torch
import torch.nn as nn
from timm import create_model

class EfficientNetV2(nn.Module):
    def __init__(self, num_classes=1, dropout_rate=0.3, pretrained=True):
        super().__init__()
        self.base_model = create_model('tf_efficientnetv2_l', pretrained=pretrained, num_classes=0)
        num_features = self.base_model.num_features
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(num_features, 512),
            nn.SiLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
        self.freeze_layers()

    def freeze_layers(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        for param in list(self.base_model.parameters())[-30:]:
            param.requires_grad = True

    def forward(self, x):
        features = self.base_model.forward_features(x)
        out = self.classifier(features)
        # Squeeze if single output node (binary classification), and shape is [batch_size, 1]
        if out.dim() == 2 and out.size(1) == 1:
            out = out.squeeze(1)
        return out


In [6]:
import torch.optim as optim

def train_model(model, train_loader, valid_loader, num_epochs=10, device='cuda'):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    os.makedirs("saved_models", exist_ok=True)

    best_val_loss = float('inf')
    patience, epochs_without_improve = 5, 0
    train_losses, valid_losses, train_accuracies, valid_accuracies = [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).long()
            correct += (preds == labels.view(-1)).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels.float())
                val_loss += loss.item() * images.size(0)
                preds = (torch.sigmoid(outputs) > 0.5).long()
                val_correct += (preds == labels.view(-1)).sum().item()
                val_total += labels.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        train_losses.append(train_loss)
        valid_losses.append(val_loss)
        train_accuracies.append(train_acc)
        valid_accuracies.append(val_acc)

        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improve = 0
            save_model(epoch, model, optimizer, train_loss, val_loss, device)
        else:
            epochs_without_improve += 1

        if epochs_without_improve >= patience:
            print("Early stopping")
            break

        scheduler.step()

    return train_losses, valid_losses, train_accuracies, valid_accuracies


In [7]:
def save_model(epoch, model, optimizer, train_loss, val_loss, device):
    torch.save(model.state_dict(), f'saved_models/model_epoch_{epoch+1}_weights.pth')
    torch.save(model, f'saved_models/model_epoch_{epoch+1}_full.pth')
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }, f'saved_models/checkpoint_epoch_{epoch+1}.pth')

    model.eval()
    example_input = torch.randn(1, 3, 224, 224).to(device)
    traced = torch.jit.trace(model, example_input)
    traced.save(f'saved_models/model_epoch_{epoch+1}_traced.pt')


In [8]:
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_model(model, test_loader, device='cuda'):
    model.eval()
    criterion = nn.BCEWithLogitsLoss()
    test_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels.float())
            test_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).long()
            correct += (preds == labels.view(-1)).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = correct / total
    test_loss /= total
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    print(f"Loss: {test_loss:.4f}, Acc: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")


In [9]:
import matplotlib.pyplot as plt

def plot_metrics(train_losses, valid_losses, train_accuracies, valid_accuracies):
    epochs = range(1, len(train_losses)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b', label='Train Loss')
    plt.plot(epochs, valid_losses, 'r', label='Val Loss')
    plt.title('Loss Over Epochs')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, 'b', label='Train Acc')
    plt.plot(epochs, valid_accuracies, 'r', label='Val Acc')
    plt.title('Accuracy Over Epochs')
    plt.legend()

    plt.tight_layout()
    plt.savefig("training_metrics.png")
    plt.show()


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader, val_loader, test_loader = prepare_data(dataset_dir='data', batch_size=32)

model = EfficientNetV2().to(device)
train_losses, val_losses, train_accuracies, val_accuracies = train_model(
    model, train_loader, val_loader, num_epochs=10, device=device
)

plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies)
evaluate_model(model, test_loader, device=device)


Training Epoch 1: 100%|██████████| 250/250 [02:48<00:00,  1.48it/s]


Epoch 1: Train Loss=0.0808, Val Loss=0.0124, Train Acc=0.9684, Val Acc=0.9990


  if out.dim() == 2 and out.size(1) == 1:
Training Epoch 2: 100%|██████████| 250/250 [02:10<00:00,  1.92it/s]


Epoch 2: Train Loss=0.0283, Val Loss=0.0094, Train Acc=0.9908, Val Acc=0.9960


Training Epoch 3: 100%|██████████| 250/250 [02:10<00:00,  1.92it/s]


Epoch 3: Train Loss=0.0170, Val Loss=0.0060, Train Acc=0.9948, Val Acc=0.9990


Training Epoch 4:  39%|███▉      | 97/250 [00:51<01:19,  1.93it/s]