## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Training will be very slow.")
    print("Go to Runtime → Change runtime type → GPU")

In [None]:
# Install additional dependencies
!pip install -q opencv-python-headless
!pip install -q seaborn

print("✓ Dependencies installed")

## 2. Mount Google Drive

Upload your dataset to Google Drive first, then mount it here.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Change these paths to match your Google Drive structure
DATASET_PATH = '/content/drive/MyDrive/FaceForensics++_C23'  # Update this path
OUTPUT_PATH = '/content/frames_dataset'
MODEL_SAVE_PATH = '/content/drive/MyDrive/deepfake_model'  # Save model to Drive

!mkdir -p {OUTPUT_PATH}
!mkdir -p {MODEL_SAVE_PATH}

print(f"✓ Google Drive mounted")
print(f"Dataset path: {DATASET_PATH}")
print(f"Output path: {OUTPUT_PATH}")

## 3. Video Preprocessing

Extract frames from videos. This will take 1-2 hours.

In [None]:
import os
import cv2
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class VideoPreprocessor:
    def __init__(self, dataset_path, output_path, frames_per_video=10, target_size=(224, 224)):
        self.dataset_path = Path(dataset_path)
        self.output_path = Path(output_path)
        self.frames_per_video = frames_per_video
        self.target_size = target_size
        
        # Create output directories
        (self.output_path / 'REAL').mkdir(parents=True, exist_ok=True)
        (self.output_path / 'FAKE').mkdir(parents=True, exist_ok=True)
        
    def extract_frames(self, video_path, label):
        try:
            cap = cv2.VideoCapture(str(video_path))
            if not cap.isOpened():
                return 0
            
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames == 0:
                cap.release()
                return 0
            
            frame_indices = np.linspace(0, total_frames - 1, self.frames_per_video, dtype=int)
            extracted_count = 0
            video_name = Path(video_path).stem
            
            for frame_idx in frame_indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if not ret:
                    continue
                
                resized_frame = cv2.resize(frame, self.target_size)
                output_file = self.output_path / label / f"{video_name}_frame_{extracted_count:03d}.jpg"
                cv2.imwrite(str(output_file), resized_frame)
                extracted_count += 1
            
            cap.release()
            return extracted_count
        except Exception as e:
            logger.error(f"Error processing {video_path}: {e}")
            return 0
    
    def process_fake_sequences(self):
        fake_types = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures']
        total_extracted = 0
        
        for fake_type in fake_types:
            video_dir = self.dataset_path / 'manipulated_sequences' / fake_type / 'c23' / 'videos'
            if not video_dir.exists():
                logger.warning(f"Directory not found: {video_dir}")
                continue
            
            video_files = list(video_dir.glob('*.mp4'))
            logger.info(f"Processing {len(video_files)} {fake_type} videos")
            
            for video_path in tqdm(video_files, desc=fake_type):
                count = self.extract_frames(video_path, 'FAKE')
                total_extracted += count
        
        return total_extracted
    
    def process_real_sequences(self):
        video_dir = self.dataset_path / 'original_sequences' / 'youtube' / 'c23' / 'videos'
        if not video_dir.exists():
            logger.error(f"Directory not found: {video_dir}")
            return 0
        
        video_files = list(video_dir.glob('*.mp4'))
        logger.info(f"Processing {len(video_files)} REAL videos")
        
        total_extracted = 0
        for video_path in tqdm(video_files, desc="Original sequences"):
            count = self.extract_frames(video_path, 'REAL')
            total_extracted += count
        
        return total_extracted

# Run preprocessing
preprocessor = VideoPreprocessor(DATASET_PATH, OUTPUT_PATH, frames_per_video=10)
print("Starting preprocessing...")
fake_count = preprocessor.process_fake_sequences()
real_count = preprocessor.process_real_sequences()

print(f"\n{'='*50}")
print(f"Preprocessing completed!")
print(f"FAKE frames: {fake_count}")
print(f"REAL frames: {real_count}")
print(f"Total frames: {fake_count + real_count}")
print(f"{'='*50}")

## 4. Dataset Analysis (Optional)

In [None]:
import matplotlib.pyplot as plt

# Count frames
real_frames = len(list(Path(OUTPUT_PATH, 'REAL').glob('*.jpg')))
fake_frames = len(list(Path(OUTPUT_PATH, 'FAKE').glob('*.jpg')))

print(f"REAL frames: {real_frames:,}")
print(f"FAKE frames: {fake_frames:,}")
print(f"Total: {real_frames + fake_frames:,}")
print(f"Class ratio (FAKE/REAL): {fake_frames/real_frames:.2f}:1")

# Visualize distribution
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.bar(['REAL', 'FAKE'], [real_frames, fake_frames], color=['green', 'red'])
plt.ylabel('Number of Frames')
plt.title('Class Distribution')

plt.subplot(1, 2, 2)
plt.pie([real_frames, fake_frames], labels=['REAL', 'FAKE'], autopct='%1.1f%%', colors=['green', 'red'])
plt.title('Class Percentage')
plt.tight_layout()
plt.show()

## 5. Model Training

Train the deepfake detection model using ResNet50.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

class FaceForensicsDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = Path(data_path)
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Load REAL images (label 0)
        real_dir = self.data_path / 'REAL'
        if real_dir.exists():
            real_images = list(real_dir.glob('*.jpg'))
            self.images.extend(real_images)
            self.labels.extend([0] * len(real_images))
        
        # Load FAKE images (label 1)
        fake_dir = self.data_path / 'FAKE'
        if fake_dir.exists():
            fake_images = list(fake_dir.glob('*.jpg'))
            self.images.extend(fake_images)
            self.labels.extend([1] * len(fake_images))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

class DeepfakeDetector(nn.Module):
    def __init__(self, pretrained=True):
        super(DeepfakeDetector, self).__init__()
        self.backbone = models.resnet50(pretrained=pretrained)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 2)
        )
    
    def forward(self, x):
        return self.backbone(x)

print("✓ Model classes defined")

In [None]:
# Data transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
print("Loading dataset...")
dataset = FaceForensicsDataset(OUTPUT_PATH, transform=train_transform)
print(f"Total samples: {len(dataset)}")

# Split into train and validation
val_split = 0.2
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Apply different transforms to validation
val_dataset.dataset.transform = val_transform

print(f"Train set: {train_size:,}")
print(f"Val set: {val_size:,}")

# Data loaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"✓ Data loaders created (batch size: {BATCH_SIZE})")

In [None]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = DeepfakeDetector(pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Training function
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return total_loss / len(loader), correct / total, all_preds, all_labels

# Training loop
EPOCHS = 20
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

print(f"\nStarting training for {EPOCHS} epochs...\n")

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}\n")
    
    scheduler.step(val_loss)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc
        }, f'{MODEL_SAVE_PATH}/best_model.pt')
        print(f"✓ Model saved (Val Loss: {val_loss:.4f})\n")

print("\n" + "="*50)
print("Training completed!")
print("="*50)

## 6. Training Visualization

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Accuracy', marker='o')
ax2.plot(history['val_acc'], label='Val Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(f'{MODEL_SAVE_PATH}/training_history.png', dpi=150)
plt.show()

print(f"✓ Training history saved to {MODEL_SAVE_PATH}/training_history.png")

## 7. Final Evaluation

In [None]:
# Load best model
checkpoint = torch.load(f'{MODEL_SAVE_PATH}/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Loaded best model from epoch {checkpoint['epoch']+1}")

# Final evaluation
val_loss, val_acc, all_preds, all_labels = validate(model, val_loader, criterion, device)

precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)

print(f"\n{'='*50}")
print("FINAL EVALUATION RESULTS")
print(f"{'='*50}")
print(f"Validation Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"\nConfusion Matrix:")
print(f"  True Negatives (REAL as REAL): {cm[0,0]}")
print(f"  False Positives (REAL as FAKE): {cm[0,1]}")
print(f"  False Negatives (FAKE as REAL): {cm[1,0]}")
print(f"  True Positives (FAKE as FAKE): {cm[1,1]}")
print(f"{'='*50}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['REAL', 'FAKE'], 
            yticklabels=['REAL', 'FAKE'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.savefig(f'{MODEL_SAVE_PATH}/confusion_matrix.png', dpi=150)
plt.show()

## 8. Inference on Test Images

In [None]:
def predict_image(model, image_path, transform, device):
    """Predict on a single image"""
    model.eval()
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)
        prediction = output.argmax(dim=1).item()
    
    result = {
        'prediction': 'REAL' if prediction == 0 else 'FAKE',
        'confidence': float(probabilities[0, prediction].item()),
        'real_prob': float(probabilities[0, 0].item()),
        'fake_prob': float(probabilities[0, 1].item())
    }
    
    return result, image

# Test on a few validation samples
test_indices = [0, 100, 200, 300, 400]
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, idx in enumerate(test_indices):
    if idx < len(val_dataset):
        img_path = val_dataset.dataset.images[val_dataset.indices[idx]]
        true_label = val_dataset.dataset.labels[val_dataset.indices[idx]]
        
        result, image = predict_image(model, img_path, val_transform, device)
        
        axes[i].imshow(image)
        axes[i].axis('off')
        axes[i].set_title(
            f"Pred: {result['prediction']}\n"
            f"True: {'REAL' if true_label == 0 else 'FAKE'}\n"
            f"Conf: {result['confidence']:.2%}",
            fontsize=10
        )

plt.tight_layout()
plt.savefig(f'{MODEL_SAVE_PATH}/sample_predictions.png', dpi=150)
plt.show()

print("✓ Sample predictions saved")

## 9. Download Results

All results are automatically saved to your Google Drive at the path specified in `MODEL_SAVE_PATH`.

**Saved files:**
- `best_model.pt` - Trained model checkpoint
- `training_history.png` - Training curves
- `confusion_matrix.png` - Evaluation metrics
- `sample_predictions.png` - Sample predictions

You can also download the model directly:

In [None]:
from google.colab import files

# Download model (optional, already saved to Drive)
# files.download(f'{MODEL_SAVE_PATH}/best_model.pt')

print(f"\n{'='*50}")
print("ALL RESULTS SAVED TO GOOGLE DRIVE")
print(f"{'='*50}")
print(f"Location: {MODEL_SAVE_PATH}")
print(f"\nFiles saved:")
print(f"  - best_model.pt")
print(f"  - training_history.png")
print(f"  - confusion_matrix.png")
print(f"  - sample_predictions.png")
print(f"{'='*50}")

## 10. Using the Model Locally

To use this model on your local machine:

```python
import torch
from PIL import Image
import torchvision.transforms as transforms

# Load model
model = DeepfakeDetector(pretrained=False)
checkpoint = torch.load('best_model.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Prepare image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Predict
image = Image.open('test.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.softmax(output, dim=1)
    prediction = output.argmax(dim=1).item()

print(f"Prediction: {'REAL' if prediction == 0 else 'FAKE'}")
print(f"Confidence: {probabilities[0, prediction].item():.2%}")
```

## Tips for Better Results

1. **Use GPU Runtime**: Runtime → Change runtime type → GPU (T4)
2. **Monitor Training**: Watch for overfitting (train acc >> val acc)
3. **Adjust Learning Rate**: If loss plateaus early, reduce LR to 0.0001
4. **Increase Epochs**: For better accuracy, train for 30+ epochs
5. **Save Checkpoints**: Model is auto-saved to Google Drive
6. **Check GPU Memory**: Reduce batch size if OOM errors occur
7. **Class Imbalance**: Dataset has 4x more FAKE than REAL - this is expected

## Troubleshooting

- **Out of Memory**: Reduce BATCH_SIZE to 16 or 8
- **Slow Training**: Make sure GPU is enabled
- **Low Accuracy**: Train for more epochs or adjust learning rate
- **Files Not Found**: Verify DATASET_PATH points to your Google Drive folder

## Next Steps

1. Download the trained model from Google Drive
2. Use the inference script locally (see cell above)
3. Fine-tune on specific deepfake types
4. Deploy as a web API
5. Optimize model for mobile/edge devices