<a href="https://colab.research.google.com/github/arunbalu2002/Deep_fake_detection/blob/main/deepfake_mamba_server_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mamba-ssm
!pip install gradio

Collecting mamba-ssm
  Downloading mamba_ssm-2.2.4.tar.gz (91 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/91.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ninja (from mamba-ssm)
  Using cached ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->mamba-ssm)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->mamba-ssm)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (fr

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive

/content/drive/MyDrive


In [None]:
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import timm
import torch.nn as nn
from mamba_ssm.models.mixer_seq_simple import Mamba


# Add the MetricsTracker class definition to this file
class MetricsTracker:
    def __init__(self):
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []
        self.best_val_acc = 0
        self.epochs_without_improvement = 0

    def update(self, train_loss, train_acc, val_loss, val_acc):
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)

        if val_acc > self.best_val_acc:
            self.best_val_acc = val_acc
            self.epochs_without_improvement = 0
            return True
        else:
            self.epochs_without_improvement += 1
            return False

class VisionMambaClassifier(nn.Module):
    def __init__(self, image_size=224, patch_size=16, dim=768, depth=24):
        super().__init__()

        # Initialize backbone
        self.backbone = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)
        backbone_channels = self.backbone.feature_info[-1]['num_chs']

        # Calculate the output size from backbone
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, image_size, image_size)
            features = self.backbone(dummy_input)[-1]
            feature_size = features.shape[-1]  # Should be 7 for 224x224 input

        self.patch_size = patch_size
        self.num_patches = (feature_size * feature_size)

        # Fixed patch embedding to handle 7x7 feature maps
        self.patch_embed = nn.Sequential(
            nn.Conv2d(backbone_channels, dim, kernel_size=1),
            nn.LayerNorm([dim, feature_size, feature_size]),
            nn.GELU()
        )

        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        self.dropout = nn.Dropout(0.2)

        # Deeper Mamba blocks with skip connections
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(dim),
                Mamba(
                    d_model=dim,
                    d_state=32,
                    d_conv=8,
                    expand=4
                ),
                nn.Dropout(0.2)
            )
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(dim)
        self.head = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(dim, dim // 2),
            nn.LayerNorm(dim // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(dim // 2, 2)
        )

    def forward(self, x):
        # Extract features using EfficientNet backbone
        features = self.backbone(x)[-1]

        # Process through patch embedding
        x = self.patch_embed(features)
        b, c, h, w = x.shape
        x = x.flatten(2).transpose(1, 2)

        # Ensure pos_embed matches the sequence length
        if x.size(1) != self.pos_embed.size(1):
            x = x + self.pos_embed[:, :x.size(1), :]
        else:
            x = x + self.pos_embed

        x = self.dropout(x)

        # Process through Mamba blocks with residual connections
        for block in self.blocks:
            x = x + block(x)

        # Global average pooling and classification
        x = x.mean(dim=1)
        x = self.norm(x)
        x = self.head(x)

        return x

# Global variables
model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_model(model_path='/content/drive/MyDrive/Mamba/best_model.pth'):
    """Load the model from checkpoint"""
    global model
    if model is None:
        model = VisionMambaClassifier()
        try:
            checkpoint = torch.load(model_path, map_location=device, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.to(device)
            model.eval()
            return "Model loaded successfully!"
        except Exception as e:
            return f"Error loading model: {str(e)}"
    return "Model already loaded"

def predict_image(image, model_path='/content/drive/MyDrive/vision_mamba/best_model.pth'):
    """Process the image and make a prediction"""
    if image is None:
        return "No image provided", None, None, None

    # Ensure model is loaded
    global model
    if model is None:
        load_status = load_model(model_path)
        if "Error" in load_status:
            return load_status, None, None, None

    try:
        # Convert to PIL Image if needed
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image).convert('RGB')
        else:
            image = image.convert('RGB')

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0).to(device)

        # Make prediction
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            prediction = torch.argmax(outputs, dim=1)

        # Get results
        pred_class = 'Real' if prediction.item() == 1 else 'Fake'
        probs = probabilities[0].cpu().numpy()

        fake_prob = float(probs[0])
        real_prob = float(probs[1])

        # Create results message
        results_msg = f"Prediction: {pred_class} (Confidence: {max(probs) * 100:.2f}%)\n"
        results_msg += f"Fake probability: {fake_prob * 100:.2f}%\n"
        results_msg += f"Real probability: {real_prob * 100:.2f}%"

        return results_msg, pred_class, fake_prob, real_prob

    except Exception as e:
        return f"Error during prediction: {str(e)}", None, None, None

def update_plot(fake_prob, real_prob):
    import matplotlib.pyplot as plt
    import numpy as np

    if fake_prob is None or real_prob is None:
        return None

    # Create confidence visualization
    fig, ax = plt.subplots(figsize=(6, 3))
    categories = ['Fake', 'Real']
    values = [fake_prob * 100, real_prob * 100]
    colors = ['#FF5733' if values[0] > values[1] else '#33FF57',
             '#33FF57' if values[1] > values[0] else '#FF5733']

    ax.barh(categories, values, color=colors)
    ax.set_xlim(0, 100)
    ax.set_xlabel('Confidence (%)')

    for i, v in enumerate(values):
        ax.text(v + 1, i, f"{v:.1f}%", va='center')

    plt.tight_layout()
    return fig

def create_interface():
    """Create and launch the Gradio interface"""
    with gr.Blocks(title="Deepfake Detection") as interface:
        gr.Markdown("# Deepfake Detection")
        gr.Markdown("Upload an image or take a picture with your camera to check if it's real or AI-generated")

        with gr.Row():
            with gr.Column(scale=1):
                # Input methods
                with gr.Tab("Upload Image"):
                    input_image = gr.Image(type="pil", label="Upload Image")

                with gr.Tab("Camera"):
                    # Fixed: Use webcam() method instead of source parameter
                    camera_input = gr.Image(label="Take Photo", type="pil")
                    camera_button = gr.Button("Capture from Webcam")

                # Model path input
                model_path = gr.Textbox(label="Model Path", value="/content/drive/MyDrive/vision_mamba/best_model.pth")

                # Buttons
                with gr.Row():
                    load_button = gr.Button("Load Model")
                    analyze_button = gr.Button("Analyze Image", variant="primary")

            with gr.Column(scale=1):
                # Results display
                result_text = gr.Textbox(label="Results", lines=5)

                with gr.Row():
                    prediction_label = gr.Label(label="Prediction")

                with gr.Row():
                    fake_score = gr.Number(label="Fake Score", value=0, interactive=False)
                    real_score = gr.Number(label="Real Score", value=0, interactive=False)

                # Visual feedback with progress bar
                with gr.Row():
                    gr.Markdown("### Confidence Visualization")
                    fake_bar = gr.Plot(label="Confidence Scores")

        # Function to handle input from either upload or camera
        def process_input(image, camera_img, model_path):
            # Use whichever input is not None
            img_to_process = image if image is not None else camera_img
            if img_to_process is None:
                return "Please provide an image through upload or camera", None, None, None, None

            result, pred_class, fake_prob, real_prob = predict_image(img_to_process, model_path)

            # Create confidence plot
            plot = update_plot(fake_prob, real_prob)

            # Create Label output
            label_output = {pred_class: max(fake_prob, real_prob)} if pred_class else None

            return result, label_output, fake_prob, real_prob, plot

        # Connect functions to events
        load_button.click(fn=load_model, inputs=[model_path], outputs=[result_text])

        # Make the camera button use the webcam
        if hasattr(camera_input, 'webcam'):
            # For newer versions of Gradio
            camera_button.click(fn=lambda: None, inputs=None, outputs=None, _js="() => {document.querySelector('button.webcam-trigger').click(); return null;}")

        analyze_button.click(
            fn=process_input,
            inputs=[input_image, camera_input, model_path],
            outputs=[result_text, prediction_label, fake_score, real_score, fake_bar]
        )

    return interface

# Launch the interface if run directly
if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://ddac3403b8d8a65d88.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
