# Deepfake Detection Training - DFD Dataset
## With Advanced Generalization & Early Stopping

In [None]:
# Check GPU
import torch
print("="*70)
if torch.cuda.is_available():
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
else:
    print("❌ NO GPU! Runtime → Change runtime type → GPU")
    raise Exception("GPU required")
print("="*70)

In [None]:
# Install dependencies
!pip install -q efficientnet-pytorch kagglehub albumentations

In [None]:
# Download DFD dataset
import kagglehub

print("Downloading DFD dataset...")
path = kagglehub.dataset_download("sanikatiwarekar/deep-fake-detection-dfd-entire-original-dataset")
print(f"\n✅ Dataset downloaded to: {path}")

In [None]:
# Explore dataset structure
import os
from pathlib import Path

dataset_path = Path(path)
print("\nDataset structure:")
for root, dirs, files in os.walk(dataset_path):
    level = root.replace(str(dataset_path), '').count(os.sep)
    if level < 3:
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        if level < 2:
            for d in dirs[:5]:
                print(f"{indent}  {d}/")
            if len(dirs) > 5:
                print(f"{indent}  ... and {len(dirs)-5} more")

In [None]:
# Organize dataset
import shutil
from pathlib import Path

print("Organizing dataset...")

organized = Path('dataset_organized')
organized.mkdir(exist_ok=True)
(organized / 'Real').mkdir(exist_ok=True)
(organized / 'Fake').mkdir(exist_ok=True)

real_count = 0
fake_count = 0

# Process all images
for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
            filepath = os.path.join(root, file)
            root_lower = root.lower()
            
            # Determine real or fake
            if 'real' in root_lower or 'original' in root_lower:
                dest = organized / 'Real' / f'{real_count:06d}_{file}'
                if not dest.exists():
                    shutil.copy(filepath, dest)
                    real_count += 1
                    if real_count % 1000 == 0:
                        print(f"  Real: {real_count:,}")
            elif 'fake' in root_lower or 'manipulated' in root_lower or 'deepfake' in root_lower:
                dest = organized / 'Fake' / f'{fake_count:06d}_{file}'
                if not dest.exists():
                    shutil.copy(filepath, dest)
                    fake_count += 1
                    if fake_count % 1000 == 0:
                        print(f"  Fake: {fake_count:,}")

print(f"\n✅ Organized: {real_count:,} real, {fake_count:,} fake")
print(f"Total: {real_count + fake_count:,} images")

In [None]:
# Advanced training script with generalization techniques
training_script = '''
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from pathlib import Path
import random
from tqdm import tqdm

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Model with dropout for generalization
class DeepfakeDetector(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b0')
        in_features = self.efficientnet._fc.in_features
        self.efficientnet._fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(512, 1)
        )
    
    def forward(self, x):
        return self.efficientnet(x)

# Heavy augmentation for generalization
train_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0)),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.MotionBlur(blur_limit=7),
    ], p=0.5),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
        A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20),
    ], p=0.5),
    A.OneOf([
        A.CLAHE(clip_limit=2),
        A.Sharpen(),
        A.Emboss(),
    ], p=0.3),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Dataset
class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.samples = []
        
        for label, folder in [(0, 'Real'), (1, 'Fake')]:
            folder_path = self.root_dir / folder
            if folder_path.exists():
                for img_path in folder_path.glob('*'):
                    if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
                        self.samples.append((str(img_path), label))
        
        random.shuffle(self.samples)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            image = self.transform(image=image)['image']
        
        return image, torch.tensor(label, dtype=torch.float32)

# Early stopping
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Focal Loss for hard examples
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

# Training
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load dataset
    full_dataset = DeepfakeDataset('./dataset_organized', transform=None)
    
    # Split train/val (80/20)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset = DeepfakeDataset('./dataset_organized', transform=train_transform)
    train_dataset.samples = full_dataset.samples[:train_size]
    
    val_dataset = DeepfakeDataset('./dataset_organized', transform=val_transform)
    val_dataset.samples = full_dataset.samples[train_size:]
    
    print(f"\nTrain: {len(train_dataset):,} | Val: {len(val_dataset):,}")
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
    
    # Model
    model = DeepfakeDetector(dropout=0.5).to(device)
    
    # Loss and optimizer
    criterion = FocalLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=5, min_delta=0.001)
    
    # Training loop
    best_val_acc = 0
    Path('weights').mkdir(exist_ok=True)
    
    for epoch in range(30):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch+1}/30")
        print(f"{'='*70}")
        
        # Train
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for images, labels in tqdm(train_loader, desc='Training'):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            preds = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
        
        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Validation'):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).squeeze()
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                preds = (torch.sigmoid(outputs) > 0.5).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss
            }, 'weights/best_model.pth')
            print(f"✅ Saved best model (Val Acc: {val_acc:.2f}%)")
        
        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"\n⚠️ Early stopping triggered at epoch {epoch+1}")
            print(f"Best validation accuracy: {best_val_acc:.2f}%")
            break
    
    print(f"\n{'='*70}")
    print(f"✅ Training complete! Best Val Acc: {best_val_acc:.2f}%")
    print(f"{'='*70}")

if __name__ == '__main__':
    train()
'''

with open('train_advanced.py', 'w') as f:
    f.write(training_script)

print("✅ Training script created")

In [None]:
# Train model
!python train_advanced.py

In [None]:
# Download trained model
from google.colab import files
from pathlib import Path

model_path = Path('weights/best_model.pth')
if model_path.exists():
    print(f"Downloading model ({model_path.stat().st_size/1e6:.1f} MB)...")
    files.download(str(model_path))
    print("\n✅ Done! Place in weights/ folder and run video_detection.py")
else:
    print("Model not found")
    !ls -lh weights/