In [4]:
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
from PIL import Image
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import warnings
warnings.filterwarnings("ignore")

In [5]:
# Device configuration
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [6]:
# Initialize MTCNN and model
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1)
checkpoint = torch.load("C:\\Users\\mussa\\OneDrive\\Documents\\DeepFake-Detection\\20180402-114759-vggface2.pt", map_location=DEVICE)
#model.load_state_dict(checkpoint)
model.to(DEVICE)
model.eval()

InceptionResnetV1(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_4a): 

In [7]:
# Define the predict function for video
def predict_video(input_video):
    cap = cv2.VideoCapture(input_video)
    if not cap.isOpened():
        return "Error: Could not open video file.", None

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2

    out = cv2.VideoWriter('output.avi', cv2.VideoWriter_fourcc(*'XVID'), 20.0, (frame_width, frame_height))

    frame_skip = 10  # Process every 10th frame

    predictions = []
    for i in range(0, frame_count, frame_skip):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if not ret:
            continue

        frame = cv2.resize(frame, (frame_width, frame_height))
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)

        try:
            face = mtcnn(pil_image)
            if face is None:
                out.write(frame)
                continue

            face = face.unsqueeze(0)
            face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
            face = face.to(DEVICE).to(torch.float32) / 255.0

            target_layers = [model.block8.branch1[-1]]
            cam = GradCAM(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
            targets = [ClassifierOutputTarget(0)]
            grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)[0, :]

            face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
            visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)

            with torch.no_grad():
                output = model(face)
                output = torch.sigmoid(output).squeeze().cpu().numpy()
                prediction = "real" if output < 0.5 else "fake"
                predictions.append(prediction)

            overlay = cv2.addWeighted(frame, 1, visualization, 0.5, 0)
            out.write(overlay)
        except Exception as e:
            print(f"Error processing frame {i}: {e}")
            out.write(frame)

    cap.release()
    out.release()

    return predictions[-1] if predictions else "Unknown", 'output.avi'

In [8]:
# Setup Gradio interface
interface = gr.Interface(
    fn=predict_video,
    inputs=[
        gr.Video(label="Input Video")
    ],
    outputs=[
        gr.Label(label="Class"),
        gr.Video(label="Output Video with Explainability")
    ],
)

interface.launch()

Running on local URL:  http://127.0.0.1:7862

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB

To create a public link, set `share=True` in `launch()`.




IMPORTANT: You are using gradio version 3.46.1, however version 4.29.0 is available, please upgrade.
--------
