# Deepfake Detection Training Pipeline
## Command Center for Google Colab Pro

This notebook orchestrates the complete training pipeline:
1. Environment setup (gcsfuse, dependencies)
2. Data loading and splitting
3. Model training with best checkpoint saving
4. Evaluation and persistence

## 1. Install Dependencies

In [None]:
# Install requirements
!pip install -q -r ../requirements.txt

## 2. Mount Google Cloud Storage with gcsfuse

In [None]:
# Install gcsfuse
!echo "deb https://packages.cloud.google.com/apt gcsfuse-focal main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!sudo apt-get update
!sudo apt-get install -y gcsfuse

In [None]:
# Authenticate with Google Cloud
from google.colab import auth
auth.authenticate_user()

# Configure project
PROJECT_ID = "your-project-id"  # REPLACE WITH YOUR PROJECT ID
BUCKET_NAME = "your-bucket-name"  # REPLACE WITH YOUR BUCKET NAME

!gcloud config set project {PROJECT_ID}

In [None]:
# Create mount point and mount bucket
!mkdir -p /content/gcs_data
!gcsfuse --implicit-dirs {BUCKET_NAME} /content/gcs_data

# Verify mount
!ls -lh /content/gcs_data

## 3. Import Modules

In [None]:
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath('..'))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Import our custom modules
from src.data_loader import get_data_mixed_structure, LocalImageDataset, save_splits, load_splits
from src.preprocessing import get_transforms
from src.models import get_model
from src.trainer import main_training_loop, test_model
from configs.config import Config

print("All modules imported successfully!")

## 4. Configuration

In [None]:
# Create necessary directories
Config.create_directories()

# Verify GCS mount
if Config.validate_paths():
    print("âœ“ GCS mount verified")
else:
    print("âš  Warning: GCS mount point not found. Check your gcsfuse setup.")

# Set random seed for reproducibility
torch.manual_seed(Config.SEED)

# Get device
device = Config.get_device()
print(f"Using device: {device}")

# Display configuration
print("\nConfiguration:")
for key, value in Config.get_config_dict().items():
    print(f"  {key}: {value}")

## 5. Model Selection

In [None]:
# Choose model architecture
MODEL_NAME = 'resnet34'  # Options: resnet34, resnet50, efficientnet_b0, efficientnet_b4, vit_b_16, vit_b_32

# Determine model type for preprocessing
MODEL_TYPE = 'vit' if 'vit' in MODEL_NAME else 'cnn'

print(f"Selected model: {MODEL_NAME}")
print(f"Model type: {MODEL_TYPE}")

## 6. Data Loading and Splitting

In [None]:
# Option 1: Create new splits
CREATE_NEW_SPLITS = True

if CREATE_NEW_SPLITS:
    print("Creating new data splits...")
    
    train_data, val_data, test_data = get_data_mixed_structure(
        celeb_real_path=Config.PATHS['celeb_real'],
        youtube_real_path=Config.PATHS['youtube_real'],
        celeb_synthesis_path=Config.PATHS['celeb_synthesis'],
        ffhq_real_path=Config.PATHS['ffhq_real'],
        stylegan_fake_path=Config.PATHS['stylegan_fake'],
        stablediffusion_fake_path=Config.PATHS['stablediffusion_fake'],
        train_ratio=Config.TRAIN_RATIO,
        val_ratio=Config.VAL_RATIO,
        test_ratio=Config.TEST_RATIO,
        seed=Config.SEED,
        # max_samples_per_category=1000  # Uncomment for quick debugging
    )
    
    # Save splits for reproducibility
    splits_path = os.path.join(Config.SPLITS_DIR, f'{MODEL_NAME}_splits.pkl')
    save_splits(train_data, val_data, test_data, splits_path)
    
else:
    # Option 2: Load existing splits
    print("Loading existing splits...")
    splits_path = os.path.join(Config.SPLITS_DIR, f'{MODEL_NAME}_splits.pkl')
    train_data, val_data, test_data = load_splits(splits_path)

## 7. Create Datasets and DataLoaders

In [None]:
# Get transforms
train_transform = get_transforms(split='train', model_type=MODEL_TYPE, img_size=Config.IMG_SIZE)
val_transform = get_transforms(split='val', model_type=MODEL_TYPE, img_size=Config.IMG_SIZE)

# Create datasets
train_dataset = LocalImageDataset(train_data[0], train_data[1], transform=train_transform)
val_dataset = LocalImageDataset(val_data[0], val_data[1], transform=val_transform)
test_dataset = LocalImageDataset(test_data[0], test_data[1], transform=val_transform)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=Config.NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=Config.NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=Config.NUM_WORKERS,
    pin_memory=True
)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

## 8. Initialize Model, Loss, and Optimizer

In [None]:
# Create model
model = get_model(
    model_name=MODEL_NAME,
    num_classes=Config.NUM_CLASSES,
    pretrained=Config.PRETRAINED,
    device=device
)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)

# Learning rate scheduler (optional)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

print("\nModel initialized and ready for training!")

## 9. Train Model

In [None]:
# Run training loop
history = main_training_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=Config.NUM_EPOCHS,
    device=device,
    checkpoint_dir=Config.CHECKPOINT_DIR,
    model_name=MODEL_NAME,
    patience=Config.PATIENCE,
    min_delta=Config.MIN_DELTA,
    scheduler=scheduler
)

## 10. Plot Training History

In [None]:
import matplotlib.pyplot as plt

# Plot loss and accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(f'{Config.LOGS_DIR}/{MODEL_NAME}_training_history.png', dpi=150)
plt.show()

print(f"Training plots saved to {Config.LOGS_DIR}/{MODEL_NAME}_training_history.png")

## 11. Load Best Model and Evaluate on Test Set

In [None]:
from src.models import load_checkpoint

# Load best model
best_model_path = os.path.join(Config.CHECKPOINT_DIR, f'{MODEL_NAME}_best.pth')
model = load_checkpoint(model, best_model_path, device=device)

# Evaluate on test set
test_loss, test_acc = test_model(model, test_loader, criterion, device)

print(f"\nFinal Test Results:")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Test Accuracy: {test_acc:.4f}")

## 12. Save Model for Streamlit App

In [None]:
# Save model in simple format for inference
inference_model_path = f'{MODEL_NAME}.pth'
torch.save(model.state_dict(), inference_model_path)

print(f"Model saved for inference: {inference_model_path}")
print(f"Download this file to use with the Streamlit app!")

# Download to local machine (in Colab)
from google.colab import files
files.download(inference_model_path)

## 13. Cleanup (Optional)

In [None]:
# Unmount GCS bucket
!fusermount -u /content/gcs_data

print("GCS bucket unmounted successfully")

---
## Training Complete! ðŸŽ‰

Next steps:
1. Download the best model checkpoint
2. Use it with the Streamlit app for inference
3. Explore Grad-CAM visualizations in the app