In [2]:
pip install timm

Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install natsort

Note: you may need to restart the kernel to use updated packages.


In [13]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import timm
from PIL import Image

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [5]:
# Directories for real and manipulated videos
real_videos_dir = "/home/ws5/Music/DeepFake/DeepFakeDataset/DFD_original sequences"
manipulated_videos_dir = "/home/ws5/Music/DeepFake/DeepFakeDataset/DFD_manipulated_sequences/DFD_manipulated_sequences"

# Output directories for extracted frames
output_real_dir = "/home/ws5/Music/DeepFake/Output/real"
output_manipulated_dir = "/home/ws5/Music/DeepFake/Output/manipulated"

In [6]:
# Ensure output directories exist
os.makedirs(output_real_dir, exist_ok=True)
os.makedirs(output_manipulated_dir, exist_ok=True)

In [90]:
def extract_frames_from_videos(videos_dir, output_dir, label, max_videos=50):
    video_files = [f for f in os.listdir(videos_dir) if f.endswith(('.mp4', '.avi', '.mov', '.mkv'))]
    video_files = video_files[:max_videos]  # Limit to max_videos

    for video_file in video_files:
        video_path = os.path.join(videos_dir, video_file)
        cap = cv2.VideoCapture(video_path)
        frame_count = 0
        success, image = cap.read()

        while success:
            if frame_count % int(cap.get(cv2.CAP_PROP_FPS)) == 0:
                frame_filename = f"{label}_{video_file}_frame{frame_count // int(cap.get(cv2.CAP_PROP_FPS))}.jpg"
                frame_path = os.path.join(output_dir, frame_filename)
                cv2.imwrite(frame_path, image)
            success, image = cap.read()
            frame_count += 1

        cap.release()

In [92]:
# Extract frames from 100 real and 100 manipulated videos
extract_frames_from_videos(real_videos_dir, output_real_dir, "real", max_videos=100)
extract_frames_from_videos(manipulated_videos_dir, output_manipulated_dir, "manipulated", max_videos=100)
print("Frame extraction completed.")

Frame extraction completed.


In [7]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define image transformations with advanced augmentations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
# Load the dataset
dataset_dir = "/home/ws5/Music/DeepFake/Output"  # Directory where frames are saved
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [11]:
# Dataset class with padding for frames
class DeepFakeDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_frames=20):
        self.root_dir = root_dir
        self.transform = transform
        self.num_frames = num_frames
        self.video_folders = self._get_valid_video_folders()

    def _get_valid_video_folders(self):
        valid_folders = []
        for category in ["real", "manipulated"]:
            category_path = os.path.join(self.root_dir, category)
            if not os.path.exists(category_path):
                print(f"WARNING: {category_path} does not exist.")
                continue
            for video_folder in os.listdir(category_path):
                video_path = os.path.join(category_path, video_folder)
                if not os.path.isdir(video_path):
                    continue
                frames = natsort.natsorted(glob.glob(os.path.join(video_path, "*.jpg")))
                if len(frames) >= self.num_frames:
                    valid_folders.append(video_path)
        print(f"Found {len(valid_folders)} valid video folders.")
        return valid_folders

    def __len__(self):
        return len(self.video_folders)

    def __getitem__(self, idx):
        video_folder = self.video_folders[idx]
        label = 1 if "manipulated" in video_folder else 0  # 1 for deepfake, 0 for real
        frames = natsort.natsorted(glob.glob(os.path.join(video_folder, "*.jpg")))
        frames = frames[:self.num_frames]  # Take only required frames

        # Pad if fewer frames are available
        if len(frames) < self.num_frames:
            last_frame = frames[-1]
            frames.extend([last_frame] * (self.num_frames - len(frames)))

        images = [Image.open(frame).convert("RGB") for frame in frames]
        if self.transform:
            images = [self.transform(img) for img in images]
        images_tensor = torch.stack(images)
        return images_tensor, torch.tensor(label, dtype=torch.long)

# CNN Feature Extractor
class CNNFeatureExtractor(nn.Module):
    def __init__(self, model_name="efficientnet_b0"):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = models.efficientnet_b0(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self, x):
        x = self.cnn(x)  # Shape: (batch_size * seq_len, 1280, H, W)
        x = self.global_pool(x)  # Shape: (batch_size * seq_len, 1280, 1, 1)
        x = x.flatten(start_dim=1)  # Shape: (batch_size * seq_len, 1280)
        return x

# ViT for Temporal Modeling
class ViTForTemporalModeling(nn.Module):
    def __init__(self, vit_model_name="google/vit-base-patch16-224"):
        super(ViTForTemporalModeling, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.fc = nn.Linear(1280, 768)  # EfficientNet output -> ViT input
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        x = self.fc(x)  # Shape: (batch_size, seq_len, 768)
        x = x.view(batch_size * seq_len, -1)  # Shape: (batch_size * seq_len, 768)
        x = self.vit(x).last_hidden_state  # Shape: (batch_size * seq_len, 768)
        x = x.view(batch_size, seq_len, -1)  # Shape: (batch_size, seq_len, 768)
        return x

# CNN + ViT Hybrid Model
class CNNViTModel(nn.Module):
    def __init__(self, cnn_model_name="efficientnet_b0", vit_model_name="google/vit-base-patch16-224"):
        super(CNNViTModel, self).__init__()
        self.cnn_extractor = CNNFeatureExtractor(cnn_model_name)
        self.vit_model = ViTForTemporalModeling(vit_model_name)
        self.classifier = nn.Linear(768, 1)  # ViT output -> binary classification
    
    def forward(self, x):
        batch_size, seq_len, c, h, w = x.shape
        x = x.view(batch_size * seq_len, c, h, w)  # Shape: (batch_size * seq_len, 3, 224, 224)
        x = self.cnn_extractor(x)  # Shape: (batch_size * seq_len, 1280)
        x = x.view(batch_size, seq_len, -1)  # Shape: (batch_size, seq_len, 1280)
        x = self.vit_model(x)  # Shape: (batch_size, seq_len, 768)
        x = x.mean(dim=1)  # Shape: (batch_size, 768)
        x = self.classifier(x)  # Shape: (batch_size, 1)
        return torch.sigmoid(x)

# Training loop
def train_model(model, train_loader, val_loader, epochs=10):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Validation
        model.eval()
        all_labels = []
        all_predictions = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float()
                outputs = model(inputs).squeeze()
                predicted = (outputs > 0.5).float()
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
                
        print(f"Epoch {epoch+1}/{epochs}")

# Main execution
if __name__ == "__main__":
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Image transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Initialize model
    model = CNNViTModel().to(device)

    # Start training
    train_model(model, train_loader, val_loader, epochs=50)

Epoch [1/50]
Epoch [2/50]
Epoch [3/50]
Epoch [4/50]
Epoch [5/50]
Epoch [6/50]
Epoch [7/50]
Epoch [8/50]
Epoch [9/50]
Epoch [10/50]
Epoch [11/50]
Epoch [12/50]
Epoch [13/50]
Epoch [14/50]
Epoch [15/50]
Epoch [16/50]
Epoch [17/50]
Epoch [18/50]
Epoch [19/50]
Epoch [20/50]
Epoch [21/50]
Epoch [22/50]
Epoch [23/50]
Epoch [24/50]
Epoch [25/50]
Epoch [26/50]
Epoch [27/50]
Epoch [28/50]
Epoch [29/50]
Epoch [30/50]
Epoch [31/50]
Epoch [32/50]
Epoch [33/50]
Epoch [34/50]
Epoch [35/50]
Epoch [36/50]
Epoch [37/50]
Epoch [38/50]
Epoch [39/50]
Epoch [40/50]
Epoch [41/50]
Epoch [42/50]
Epoch [43/50]
Epoch [44/50]
Epoch [45/50]
Epoch [46/50]
Epoch [47/50]
Epoch [48/50]
Epoch [49/50]
Epoch [50/50]


In [15]:
acc = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, zero_division=1)
f1 = f1_score(all_labels, all_predictions, zero_division=1)

print(f"Accuracy: {acc:.2f}, Precision: {precision:.2f}, F1-score: {f1:.2f}")

Accuracy: 98.33, Precision: 98.97, F1-score: 99.70
