# Swin-Tiny Training on Kaggle (Extracted Frames)

> This notebook trains on **extracted frame datasets** (no video decoding during training).

## Required Kaggle Datasets

1. **FF++ frames**:
   `/kaggle/input/datasets/muhammadqaiser1921/faceforenscis/ffpp_binary_frames`
   - splits: `train`, `val`, `test`
   - labels: `0=real`, `1=fake`

2. **Deepfake frames**:
   `/kaggle/input/datasets/aryansingh16/deepfake-dataset/real_vs_fake/real-vs-fake`
   - splits: `train`, `valid`, `test`
   - labels: `real`, `fake`

## Run Order

- **Cell 1**: Clone repo + verify datasets + install requirements
- **Cell 2**: Load data (RUN ONCE - no need to re-run when updating model)
- **Cell 3**: Build and train model (RE-RUN as needed after editing model code)
- **Cell 4**: View results + download outputs

## Iterative Development Workflow

After running cells 1-3 once, you can:
1. Edit `swin_transformer.py` in your repo
2. Commit and push changes
3. Re-run Cell 1 to pull updates
4. Re-run Cell 3 only (skips data loading!)

In [None]:
import os
import sys
import subprocess
import shutil

# ========== CONFIGURATION ==========
GITHUB_USERNAME = "MuhammadQaiser1921"
REPO_NAME = "swin-model"
REPO_BRANCH = "main"
GITHUB_URL = "https://github.com/MuhammadQaiser1921/swin-model.git"

FFPP_FRAMES_ROOT = "/kaggle/input/datasets/muhammadqaiser1921/faceforenscis/ffpp_binary_frames"
DEEPFAKE_FRAMES_ROOT = (
    "/kaggle/input/datasets/aryansingh16/deepfake-dataset/real_vs_fake/real-vs-fake"
 )

# ========== CLONE / UPDATE REPO ==========
os.chdir('/kaggle/working')
repo_path = os.path.join('/kaggle/working', REPO_NAME)

print(f"üìå Repository: {GITHUB_URL}")
print(f"üåø Branch: {REPO_BRANCH}")
print(f"üìÅ Path: {repo_path}\n")

if os.path.exists(repo_path):
    try:
        os.chdir(repo_path)
        result = subprocess.run(["git", "status"], capture_output=True, text=True)
        if result.returncode == 0:
            print("‚úì Using existing repo, fetching updates...")
            subprocess.run(["git", "fetch", "--all"], check=True)
            subprocess.run(["git", "checkout", REPO_BRANCH], check=True)
            subprocess.run(["git", "pull", "origin", REPO_BRANCH], check=True)
        else:
            os.chdir('/kaggle/working')
            print("‚ö†Ô∏è Invalid repo directory, removing and re-cloning...")
            shutil.rmtree(repo_path, ignore_errors=True)
            subprocess.run(["git", "clone", "-b", REPO_BRANCH, GITHUB_URL], check=True)
    except Exception as e:
        os.chdir('/kaggle/working')
        print(f"‚ö†Ô∏è Error: {e}. Removing and re-cloning...")
        shutil.rmtree(repo_path, ignore_errors=True)
        subprocess.run(["git", "clone", "-b", REPO_BRANCH, GITHUB_URL], check=True)
else:
    print(f"Cloning {GITHUB_URL}...")
    subprocess.run(["git", "clone", "-b", REPO_BRANCH, GITHUB_URL], check=True)

print("‚úÖ Repository ready!\n")

# ========== VERIFY DATASETS ==========
def _check_path(path, label):
    if os.path.exists(path):
        print(f"‚úÖ {label} found:")
        print(f"   {path}")
        print(f"   Top-level folders: {os.listdir(path)[:6]}")
    else:
        print(f"‚ùå {label} not found:")
        print(f"   {path}")

_check_path(FFPP_FRAMES_ROOT, "FF++ frames")
_check_path(DEEPFAKE_FRAMES_ROOT, "Deepfake frames")

# ========== OUTPUT DIRECTORIES ==========
os.makedirs(os.path.join(repo_path, 'models', 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(repo_path, 'models', 'weights'), exist_ok=True)
os.makedirs(os.path.join(repo_path, 'results', 'logs'), exist_ok=True)

# ========== INSTALL REQUIREMENTS ==========
sys.path.insert(0, os.path.join(repo_path, 'src'))
req_file = os.path.join(repo_path, 'requirements.txt')
if os.path.exists(req_file):
    print("Installing requirements...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", req_file], check=True)
    print("‚úÖ Requirements installed\n")
else:
    print("‚ö†Ô∏è requirements.txt not found\n")

In [None]:
import tensorflow as tf
from video_train_config import Config
from data_loader import load_data, prepare_datasets

print("=" * 70)
print("üöÄ STEP 1: LOAD DATA (RUN ONCE)")
print("=" * 70)

# GPU check
gpus = tf.config.list_physical_devices('GPU')
print(f"\nüíª GPU available: {len(gpus)} GPU(s)")
for gpu in gpus:
    print(f"   ‚úì {gpu}")

# Configure dataset roots
Config.FFPP_FRAMES_ROOT = "/kaggle/input/datasets/muhammadqaiser1921/faceforenscis/ffpp_binary_frames"
Config.DEEPFAKE_FRAMES_ROOT = (
    "/kaggle/input/datasets/aryansingh16/deepfake-dataset/real_vs_fake/real-vs-fake"
)
Config.KAGGLE_ENV = True

# Training parameters
Config.BATCH_SIZE = 12
Config.EPOCHS = 20
Config.MAX_IMAGES_PER_CLASS = None  # Set to 500 for quick test

print("\n‚öôÔ∏è Configuration:")
print(f"   FF++ root: {Config.FFPP_FRAMES_ROOT}")
print(f"   Deepfake root: {Config.DEEPFAKE_FRAMES_ROOT}")
print(f"   Batch size: {Config.BATCH_SIZE}")
print(f"   Epochs: {Config.EPOCHS}")
print(f"   Max images/class: {Config.MAX_IMAGES_PER_CLASS}")

# Load data
print("\nüìÇ Loading data...")
data = load_data()

# Prepare datasets
train_ds, val_ds, test_ds = prepare_datasets(data)

print("\n‚úÖ Data loading complete. You can now re-run Cell 3 to train different models without reloading data!")

In [None]:
from train_video import build_and_compile_model, train_model

print("\n" + "=" * 70)
print("üöÄ STEP 2: BUILD AND TRAIN MODEL")
print("=" * 70)

# Build and compile the model
model = build_and_compile_model()

# Train the model
print("\nüéØ Starting training...\n")
history = train_model(model, train_ds, val_ds)

# Evaluate on validation set
print(f"\nüìä Computing metrics on validation set...")
from train_video import compute_auc_metrics
import numpy as np

y_val_pred_probs = model.predict(val_ds)
y_val_pred_probs = y_val_pred_probs[:, 1]  # Get probabilities for class 1 (fake)

auc_metrics = compute_auc_metrics(data['val_labels'], y_val_pred_probs)

print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED")
print("=" * 70)

print(f"\nüìà Training Performance:")
print(f"   Final training accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"   Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
print(f"   Best validation accuracy: {max(history.history['val_accuracy']):.4f}")

if auc_metrics:
    print("\nüìä Validation AUC Metrics:")
    print(f"   AUC-ROC: {auc_metrics.get('auc_roc', 'N/A'):.4f}")
    print(f"   PR-AUC: {auc_metrics.get('pr_auc', 'N/A'):.4f}")
    print(f"   Optimal Threshold: {auc_metrics.get('optimal_threshold', 'N/A'):.4f}")
    print(f"   Sensitivity: {auc_metrics.get('optimal_sensitivity', 'N/A'):.4f}")
    print(f"   Specificity: {auc_metrics.get('optimal_specificity', 'N/A'):.4f}")
    print(f"   F1-Score: {auc_metrics.get('optimal_f1', 'N/A'):.4f}")

# Test set evaluation (if available)
if test_ds is not None and len(data['test_paths']) > 0:
    print("\nüìä Computing test set metrics...")
    y_test_pred_probs = model.predict(test_ds)
    y_test_pred_probs = y_test_pred_probs[:, 1]
    test_auc_metrics = compute_auc_metrics(data['test_labels'], y_test_pred_probs)
    print(f"   Test AUC-ROC: {test_auc_metrics.get('auc_roc', 'N/A'):.4f}")
    print(f"   Test PR-AUC: {test_auc_metrics.get('pr_auc', 'N/A'):.4f}")

# Save model
from datetime import datetime
import os

final_model_path = os.path.join(
    Config.CHECKPOINT_DIR,
    f"swin_tiny_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.h5"
)
model.save(final_model_path)
print(f"\nüíæ Model saved to: {final_model_path}")

print("\nüíæ Output directories:")
print(f"   ‚Ä¢ {Config.CHECKPOINT_DIR}")
print(f"   ‚Ä¢ {Config.WEIGHTS_DIR}")
print(f"   ‚Ä¢ {Config.LOG_DIR}")

In [None]:
import glob
import json
import os

print("\n" + "=" * 70)
print("üì• RESULTS & DOWNLOAD")
print("=" * 70)

results_dir = Config.LOG_DIR
weights_dir = Config.WEIGHTS_DIR
checkpoints_dir = Config.CHECKPOINT_DIR

print(f"\nLogs: {results_dir}")
print(f"Weights: {weights_dir}")
print(f"Checkpoints: {checkpoints_dir}")

# AUC metrics
auc_files = sorted(glob.glob(os.path.join(results_dir, 'auc_metrics_*.json')))
if auc_files:
    latest = auc_files[-1]
    print(f"\n‚úÖ AUC metrics: {os.path.basename(latest)}")
    with open(latest, 'r') as f:
        metrics = json.load(f)
        print(f"   AUC-ROC: {metrics.get('auc_roc', 'N/A'):.4f}")
        print(f"   PR-AUC: {metrics.get('pr_auc', 'N/A'):.4f}")
else:
    print("\n‚ö†Ô∏è No AUC metrics found")

# Weights
weight_files = sorted(glob.glob(os.path.join(weights_dir, '*weights*.h5')))
if weight_files:
    print("\n‚úÖ Weight files:")
    for path in weight_files[-3:]:
        size_mb = os.path.getsize(path) / (1024**2)
        print(f"   {os.path.basename(path)} ({size_mb:.1f} MB)")
else:
    print("\n‚ö†Ô∏è No weights found")

# Checkpoints
checkpoint_files = sorted(glob.glob(os.path.join(checkpoints_dir, '*.h5')))
if checkpoint_files:
    print("\n‚úÖ Checkpoints:")
    for path in checkpoint_files[-2:]:
        size_mb = os.path.getsize(path) / (1024**2)
        print(f"   {os.path.basename(path)} ({size_mb:.1f} MB)")
else:
    print("\n‚ö†Ô∏è No checkpoints found")

print("\nüì• Download: Kaggle ‚Üí Output ‚Üí Download all")
print("=" * 70)