In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from transformers import ViTModel, ViTConfig

# Configurations
class Config:
    img_size = 224
    num_frames = 8
    batch_size = 16
    epochs = 10
    lr = 3e-5
    num_classes = 1
    dataset_path = "./"
    real_folder = "real"
    fake_folder = "fake"

# Simple face crop (Celeb-DF faces are centered)
def center_crop(image):
    h, w = image.shape[:2]
    size = min(h, w)
    y, x = (h - size) // 2, (w - size) // 2
    return image[y:y+size, x:x+size]

# Dataset
class CelebDFDataset(Dataset):
    def __init__(self, config, mode='train'):
        self.config = config
        self.mode = mode
        self.samples = []
        
        # Load real and fake video paths
        real_videos = [os.path.join(config.real_folder, f) 
                      for f in os.listdir(config.real_folder) if f.endswith('.mp4')]
        fake_videos = [os.path.join(config.fake_folder, f) 
                      for f in os.listdir(config.fake_folder) if f.endswith('.mp4')]
        
        # Split (80/20)
        np.random.seed(42)
        np.random.shuffle(real_videos)
        split = int(0.8 * len(real_videos))
        
        real_videos = real_videos[:split] if mode == 'train' else real_videos[split:]
        
        # Create samples (0=real, 1=fake)
        for vid in real_videos:
            self.samples.append((vid, 0))
        for vid in fake_videos:
            self.samples.append((vid, 1))
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_path, label = self.samples[idx]
        frames = self._sample_frames(vid_path)
        frames = torch.stack([transforms.ToTensor()(frame) for frame in frames])
        return frames, torch.tensor(label, dtype=torch.float32)
    
    def _sample_frames(self, vid_path):
        cap = cv2.VideoCapture(vid_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        indices = np.linspace(0, total_frames-1, self.config.num_frames, dtype=int)
        
        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = center_crop(frame)
                frame = Image.fromarray(frame).resize((self.config.img_size, self.config.img_size))
                frames.append(frame)
        
        cap.release()
        # Ensure we have exactly num_frames frames
        while len(frames) < self.config.num_frames:
            if len(frames) > 0:
                frames.append(frames[-1])  # Duplicate the last frame
            else:
                # If no frames were loaded, pad with black frames
                black_frame = Image.new("RGB", (self.config.img_size, self.config.img_size))
                frames.append(black_frame)
        return frames[:self.config.num_frames]  # Trim if we got extra frames

# Model (using Vision Transformer)
class DeepfakeDetector(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Use pretrained ViT
        vit_config = ViTConfig(
            image_size=config.img_size,
            num_hidden_layers=4,
            num_attention_heads=4,
            hidden_size=192
        )
        self.vit = ViTModel(vit_config)
        
        # Temporal attention
        self.temporal_attn = nn.MultiheadAttention(192, 4)
        self.classifier = nn.Linear(192, config.num_classes)
        
    def forward(self, x):
        # x: (batch, frames, C, H, W)
        batch, frames = x.shape[0], x.shape[1]
        
        # Process each frame with ViT
        frame_features = []
        for t in range(frames):
            out = self.vit(x[:, t]).last_hidden_state[:, 0]  # CLS token
            frame_features.append(out)
        
        # Temporal attention
        features = torch.stack(frame_features, dim=1)  # (batch, frames, dim)
        features = features.transpose(0, 1)  # (frames, batch, dim)
        attn_out, _ = self.temporal_attn(features, features, features)
        features = attn_out.mean(dim=0)  # (batch, dim)
        
        return torch.sigmoid(self.classifier(features)).squeeze()

# Training
def train(model, loader, criterion, optimizer, device):
    model.train()
    preds, truths = [], []
    
    for frames, labels in loader:
        frames, labels = frames.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(frames)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        preds.extend(outputs.detach().cpu().numpy())
        truths.extend(labels.cpu().numpy())
    
    acc = accuracy_score(truths, np.round(preds))
    auc = roc_auc_score(truths, preds)
    return acc, auc

def evaluate(model, loader, criterion, device):
    model.eval()
    preds, truths = [], []
    
    with torch.no_grad():
        for frames, labels in loader:
            frames, labels = frames.to(device), labels.to(device)
            outputs = model(frames)
            
            preds.extend(outputs.cpu().numpy())
            truths.extend(labels.cpu().numpy())
    
    acc = accuracy_score(truths, np.round(preds))
    auc = roc_auc_score(truths, preds)
    return acc, auc

def main():
    config = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Data
    train_set = CelebDFDataset(config, 'train')
    val_set = CelebDFDataset(config, 'val')
    
    train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=config.batch_size)
    
    # Model
    model = DeepfakeDetector(config).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    
    # Training loop
    best_auc = 0
    for epoch in range(config.epochs):
        train_acc, train_auc = train(model, train_loader, criterion, optimizer, device)
        val_acc, val_auc = evaluate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{config.epochs}")
        print(f"Train Acc: {train_acc:.4f} | AUC: {train_auc:.4f}")
        print(f"Val Acc: {val_acc:.4f} | AUC: {val_auc:.4f}")
        
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "best_model.pth")
    
    print(f"Best Val AUC: {best_auc:.4f}")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/10
Train Acc: 0.4667 | AUC: 0.4107
Val Acc: 0.8000 | AUC: 0.7500
Epoch 2/10
Train Acc: 0.5333 | AUC: 0.4821
Val Acc: 0.8000 | AUC: 0.6875
Epoch 2/10
Train Acc: 0.5333 | AUC: 0.4821
Val Acc: 0.8000 | AUC: 0.6875
Epoch 3/10
Train Acc: 0.5333 | AUC: 0.5714
Val Acc: 0.8000 | AUC: 0.5625
Epoch 3/10
Train Acc: 0.5333 | AUC: 0.5714
Val Acc: 0.8000 | AUC: 0.5625
Epoch 4/10
Train Acc: 0.5333 | AUC: 0.6250
Val Acc: 0.2000 | AUC: 0.4375
Epoch 4/10
Train Acc: 0.5333 | AUC: 0.6250
Val Acc: 0.2000 | AUC: 0.4375
Epoch 5/10
Train Acc: 0.4667 | AUC: 0.6607
Val Acc: 0.2000 | AUC: 0.3125
Epoch 5/10
Train Acc: 0.4667 | AUC: 0.6607
Val Acc: 0.2000 | AUC: 0.3125
Epoch 6/10
Train Acc: 0.4667 | AUC: 0.6964
Val Acc: 0.2000 | AUC: 0.3750
Epoch 6/10
Train Acc: 0.4667 | AUC: 0.6964
Val Acc: 0.2000 | AUC: 0.3750
Epoch 7/10
Train Acc: 0.4667 | AUC: 0.7321
Val Acc: 0.2000 | AUC: 0.3750
Epoch 7/10
Train Acc: 0.4667 | AUC: 0.7321
Val Acc: 0.2000 | AUC: 0.3750
Epoch 8/10
Train Acc: 0.4667 | AUC: 0.7143
Val Acc:

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda
