# ðŸŽ­ Fine-tune Deepfake Detection Model
## Using WildDeepfake Dataset from Hugging Face

This notebook will train your model on the WildDeepfake dataset for better accuracy on YouTube videos.

## Step 1: Install Required Packages

In [None]:
!pip install torch torchvision
!pip install efficientnet-pytorch
!pip install datasets huggingface-hub
!pip install opencv-python-headless
!pip install albumentations
!pip install tqdm
!pip install Pillow

## Step 2.5: Upload Your Existing Model (For Colab Only)

In [None]:
# If running on Google Colab, upload your existing model
# Uncomment the lines below:

# from google.colab import files
# import os
# 
# os.makedirs('weights', exist_ok=True)
# print("Please upload your model file (finetuned_model.pth or best_model.pth):")
# uploaded = files.upload()
# 
# # Move uploaded file to weights folder (keep original name)
# for filename in uploaded.keys():
#     target_path = f'weights/{filename}'
#     os.rename(filename, target_path)
#     print(f"âœ“ Model uploaded to {target_path}")

print("âœ“ Skip this cell if running locally")
print("âœ“ The notebook will automatically find and load:")
print("  - weights/finetuned_advanced.pth (first priority)")
print("  - weights/finetuned_model.pth (second priority)")
print("  - weights/best_model.pth (third priority)")

## Step 2: Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from efficientnet_pytorch import EfficientNet
from datasets import load_dataset
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

# Check device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 3: Define Model Architecture

In [None]:
class DeepfakeEfficientNet(nn.Module):
    """EfficientNet-B0 backbone with binary classification head"""
    def __init__(self, pretrained=True, dropout=0.5):
        super(DeepfakeEfficientNet, self).__init__()
        if pretrained:
            self.efficientnet = EfficientNet.from_pretrained('efficientnet-b0')
        else:
            self.efficientnet = EfficientNet.from_name('efficientnet-b0')
        
        num_features = self.efficientnet._fc.in_features
        
        self.efficientnet._fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout * 0.7),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        return self.efficientnet(x)

print("âœ“ Model architecture defined")

## Step 4: Load WildDeepfake Dataset

In [None]:
print("Loading WildDeepfake dataset from Hugging Face...")
print("This may take a few minutes...")

# Load dataset
dataset = load_dataset("xingjunm/WildDeepfake", split="train")

print(f"\nâœ“ Dataset loaded!")
print(f"Total samples: {len(dataset)}")
print(f"\nDataset features: {dataset.features}")

# Show sample
print(f"\nSample data:")
print(dataset[0])

## Step 5: Create Custom Dataset Class

In [None]:
class WildDeepfakeDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Get image and label
        image = item['image']  # PIL Image
        label = item['label']  # 0 = real, 1 = fake
        
        # Convert PIL to numpy
        image = np.array(image)
        
        # Ensure RGB
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
        
        # Resize to 224x224
        image = cv2.resize(image, (224, 224))
        
        # Apply augmentation
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, torch.tensor(label, dtype=torch.float32)

print("âœ“ Dataset class defined")

## Step 6: Define Data Augmentation

In [None]:
# Training augmentation (heavy)
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), p=0.3),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
    A.Rotate(limit=15, p=0.3),
    A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Validation augmentation (minimal)
val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

print("âœ“ Augmentation pipelines defined")

## Step 7: Split Dataset and Create DataLoaders

In [None]:
# Split dataset (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset_hf = dataset.select(range(train_size))
val_dataset_hf = dataset.select(range(train_size, len(dataset)))

# Create custom datasets
train_dataset = WildDeepfakeDataset(train_dataset_hf, transform=train_transform)
val_dataset = WildDeepfakeDataset(val_dataset_hf, transform=val_transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create dataloaders
BATCH_SIZE = 32 if DEVICE == "cuda" else 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True if DEVICE == "cuda" else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True if DEVICE == "cuda" else False
)

print(f"\nâœ“ DataLoaders created")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## Step 8: Initialize Model and Training Components

In [None]:
# Initialize model
print("Initializing model...")
model = DeepfakeEfficientNet(pretrained=False, dropout=0.5)

# Try to load your existing model (in order of preference)
MODEL_PATHS = [
    "weights/finetuned_advanced.pth",
    "weights/finetuned_model.pth",
    "weights/best_model.pth"
]

model_loaded = False
for model_path in MODEL_PATHS:
    if os.path.exists(model_path):
        print(f"Loading existing model from {model_path}...")
        try:
            checkpoint = torch.load(model_path, map_location=DEVICE)
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
                print(f"âœ“ Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
            else:
                state_dict = checkpoint
            
            # Fix key mismatch (net. -> efficientnet.)
            new_state_dict = {}
            for key, value in state_dict.items():
                if key.startswith('net.'):
                    new_key = key.replace('net.', 'efficientnet.')
                    new_state_dict[new_key] = value
                else:
                    new_state_dict[key] = value
            
            model.load_state_dict(new_state_dict, strict=False)
            print(f"âœ“ Successfully loaded {model_path}")
            print("âœ“ Continuing training from your existing model")
            model_loaded = True
            break
        except Exception as e:
            print(f"âš  Could not load {model_path}: {e}")
            continue

if not model_loaded:
    print("âš  No existing model found")
    print("âœ“ Starting training from pretrained EfficientNet-B0")
    model = DeepfakeEfficientNet(pretrained=True, dropout=0.5)

model = model.to(DEVICE)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizer - using lower learning rate for fine-tuning existing model
lr = 0.00005 if model_loaded else 0.0001  # Lower LR if continuing training
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True
)

print("\nâœ“ Model initialized")
print(f"Learning rate: {lr}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Step 9: Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device).unsqueeze(1)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        predictions = (torch.sigmoid(outputs) > 0.5).float()
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device).unsqueeze(1)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

print("âœ“ Training functions defined")

## Step 10: Train the Model

In [None]:
# Training configuration
NUM_EPOCHS = 10
SAVE_PATH = "weights/wilddeepfake_model.pth"

# Create weights directory if it doesn't exist
os.makedirs("weights", exist_ok=True)

print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Device: {DEVICE}")
print("=" * 60)

best_val_loss = float('inf')
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, DEVICE)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print epoch summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, SAVE_PATH)
        print(f"âœ“ Model saved to {SAVE_PATH}")

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved to: {SAVE_PATH}")

## Step 11: Plot Training History

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='o')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Plot accuracy
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='o')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("âœ“ Training history plotted and saved to 'training_history.png'")

## Step 12: Download Trained Model (For Colab)

In [None]:
# Uncomment this if running on Google Colab
# from google.colab import files
# files.download('weights/wilddeepfake_model.pth')
# print("âœ“ Model downloaded! Place it in your project's weights/ folder")

print("\n" + "=" * 60)
print("NEXT STEPS:")
print("=" * 60)
print("1. Copy 'wilddeepfake_model.pth' to your project's weights/ folder")
print("2. Rename it to 'finetuned_model.pth' or 'best_model.pth'")
print("3. Restart your backend server: python backend_server.py")
print("4. Test the extension - accuracy should be much better!")
print("=" * 60)