In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Model Training**

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

class DeepFakeDetector:
    def __init__(self, data_dir, batch_size=32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.data_dir = data_dir
        self.batch_size = batch_size
        self.setup_data_transforms()
        self.setup_datasets()
        self.setup_model()

    def setup_data_transforms(self):
        self.data_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.val_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def setup_datasets(self):
        # Load datasets
        dataset = datasets.ImageFolder(self.data_dir, transform=self.data_transforms)

        # Calculate splits
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size

        # Create splits
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size]
        )

        # Update validation transform
        self.val_dataset.dataset.transform = self.val_transforms

        # Create dataloaders
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        self.classes = dataset.classes
        print(f"Classes: {self.classes}")
        print(f"Total images: {len(dataset)}")
        print(f"Training images: {len(self.train_dataset)}")
        print(f"Validation images: {len(self.val_dataset)}")

    def setup_model(self):
        # Load model with latest weights
        weights = models.ResNet50_Weights.DEFAULT
        self.model = models.resnet50(weights=weights)

        # Freeze early layers
        for param in list(self.model.parameters())[:-4]:
            param.requires_grad = False

        # Modify final layer for binary classification
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )

        self.model = self.model.to(self.device)

        # Setup loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=0.01)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=3, factor=0.1, verbose=True
        )

    def train_one_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total})

        return running_loss / len(self.train_loader), 100. * correct / total

    def validate(self):
        self.model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for inputs, labels in tqdm(self.val_loader, desc='Validation'):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        return (val_loss / len(self.val_loader),
                100. * correct / total,
                all_preds,
                all_labels)

    def train(self, num_epochs=10):
        best_val_acc = 0.0
        train_losses, train_accs = [], []
        val_losses, val_accs = [], []

        for epoch in range(num_epochs):
            # Training phase
            train_loss, train_acc = self.train_one_epoch(epoch)
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation phase
            val_loss, val_acc, all_preds, all_labels = self.validate()
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            print(f'\nEpoch {epoch+1}/{num_epochs}:')
            print(f'Training Loss: {train_loss:.4f} Acc: {train_acc:.2f}%')
            print(f'Validation Loss: {val_loss:.4f} Acc: {val_acc:.2f}%')

            # Update learning rate
            self.scheduler.step(val_loss)

            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_acc': val_acc,
                }, 'best_deepfake_detector.pth')

            # Print classification report
            if epoch == num_epochs - 1:  # On last epoch
                print("\nClassification Report:")
                print(classification_report(all_labels, all_preds, target_names=self.classes))

                # Plot confusion matrix
                cm = confusion_matrix(all_labels, all_preds)
                plt.figure(figsize=(8, 6))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
                plt.title('Confusion Matrix')
                plt.xlabel('Predicted')
                plt.ylabel('True')
                plt.show()

        # Plot training history
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train')
        plt.plot(val_losses, label='Validation')
        plt.title('Loss over time')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(train_accs, label='Train')
        plt.plot(val_accs, label='Validation')
        plt.title('Accuracy over time')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()

        plt.tight_layout()
        plt.show()

    def predict_image(self, image_path):
        """Predict a single image"""
        self.model.eval()
        image = datasets.folder.default_loader(image_path)
        image = self.val_transforms(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(image)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)

        return {
            'class': self.classes[predicted[0]],
            'confidence': probabilities[0][predicted[0]].item(),
            'probabilities': {
                cls: prob.item()
                for cls, prob in zip(self.classes, probabilities[0])
            }
        }

# Usage example:
if __name__ == "__main__":
    # Initialize detector
    data_dir = "/content/drive/MyDrive/deepfake-and-real-images/Dataset/Test"
    detector = DeepFakeDetector(data_dir)

    # Train the model
    detector.train(num_epochs=10)

    # Example of prediction
    result = detector.predict_image('/content/drive/MyDrive/deepfake-and-real-images/Dataset/Test/Fake/fake_0.jpg')
    print(f"Prediction: {result['class']}")
    print(f"Confidence: {result['confidence']:.2f}")
    print("Class probabilities:", result['probabilities'])

Writing deepfake_detector.py


# **Gradio**

In [26]:
import gradio as gr
import torch
from PIL import Image
import os
from deepfake_detector import DeepFakeDetector

class GradioDeepFakeDetector:
    def __init__(self, model_path, data_dir):
        self.detector = DeepFakeDetector(data_dir)

        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, map_location=self.detector.device, weights_only=True)
            self.detector.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded model from {model_path}")
        else:
            print("No pre-trained model found. Please train the model first.")

        self.detector.model.eval()

    def predict(self, image):
        if image is None:
            return "Please upload an image to analyze.", 0.0, None, None, None

        temp_path = "temp_image.jpg"
        image.save(temp_path)

        try:
            result = self.detector.predict_image(temp_path)

            prediction = result['class']
            confidence = result['confidence'] * 100

            # Format detailed results
            fake_prob = result['probabilities'].get('Fake', 0) * 100
            real_prob = result['probabilities'].get('Real', 0) * 100

            status = "⚠️ LIKELY FAKE" if prediction == "Fake" else "✅ LIKELY REAL"

            response = f"### {status}\n\n"
            response += f"**Overall Confidence:** {confidence:.1f}%\n\n"

            return (
                response,
                fake_prob/100,  # For the fake probability bar
                real_prob/100,  # For the real probability bar
                image,  # Return the analyzed image
                f"Analysis complete - {prediction} ({confidence:.1f}% confidence)"  # Status message
            )

        except Exception as e:
            return f"Error processing image: {str(e)}", 0.0, 0.0, None, "Error occurred during analysis"
        finally:
            if os.path.exists(temp_path):
                os.remove(temp_path)

def create_gradio_interface():
    detector = GradioDeepFakeDetector(
        model_path="best_deepfake_detector.pth",
        data_dir="/content/drive/MyDrive/deepfake-and-real-images/Dataset/Test"
    )

    with gr.Blocks(
        title="DeepFake Image Detector",
        theme=gr.themes.Soft(
            primary_hue="blue",
            secondary_hue="gray",
        ),
    ) as interface:
        gr.Markdown("""
        # 🔍 DeepFake Image Detector

        Upload an image to check if it's authentic or artificially generated.
        Our AI model will analyze the image and provide detailed authenticity scores.

        ---
        """)

        with gr.Row():
            # Left column - Input
            with gr.Column(scale=1):
                input_image = gr.Image(
                    label="📤 Upload Image for Analysis",
                    type="pil",
                    height=400,
                    show_label=True,
                    container=True,
                )

                with gr.Row():
                    submit_btn = gr.Button(
                        "🔍 Analyze Image",
                        variant="primary",
                        size="lg"
                    )
                    clear_btn = gr.Button(
                        "🗑️ Clear",
                        variant="secondary",
                        size="lg"
                    )

                status_text = gr.Markdown(
                    "### Status: Waiting for image...",
                    elem_id="status_display"
                )

            # Right column - Results
            with gr.Column(scale=1):
                output_image = gr.Image(
                    label="Analyzed Image",
                    type="pil",
                    height=300,
                    visible=True
                )

                results_text = gr.Markdown(
                    label="Analysis Results",
                )

                gr.Markdown("### Probability Distribution")

                with gr.Row():
                    with gr.Column():
                        fake_prob = gr.Slider(
                            label="Fake Probability",
                            minimum=0,
                            maximum=1,
                            value=0,
                            interactive=False,
                            info="Likelihood of being artificially generated"
                        )

                    with gr.Column():
                        real_prob = gr.Slider(
                            label="Real Probability",
                            minimum=0,
                            maximum=1,
                            value=0,
                            interactive=False,
                            info="Likelihood of being authentic"
                        )

        # Add example images section
        gr.Markdown("### Try with Example Images")
        if os.path.exists("examples"):
            gr.Examples(
                examples=["/content/drive/MyDrive/deepfake-and-real-images/Dataset/Test/Fake/fake_7.jpg", "/content/drive/MyDrive/deepfake-and-real-images/Dataset/Test/Fake/fake_189.jpg"],
                inputs=input_image,
                label="Example Images",
                examples_per_page=4
            )

        # Event handlers
        submit_btn.click(
            fn=detector.predict,
            inputs=[input_image],
            outputs=[results_text, fake_prob, real_prob, output_image, status_text]
        )

        clear_btn.click(
            lambda: [None, 0.0, 0.0, None, "### Status: Waiting for image..."],
            inputs=[],
            outputs=[input_image, fake_prob, real_prob, output_image, status_text]
        )

    return interface

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(
        share=True,
        show_error=True,
    )

Using device: cuda
Classes: ['Fake', 'Real']
Total images: 10905
Training images: 8724
Validation images: 2181
Loaded model from best_deepfake_detector.pth
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://57fb481ab93597df32.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
