In [None]:
# final_Gradio_Interface_Fused.ipynb

import gradio as gr
import torch
import os
import csv
from datetime import datetime
from transformers import ViTForImageClassification, ViTImageProcessor
from transformers import BertForSequenceClassification, BertTokenizer
from PIL import Image
import numpy as np
import timm
import torch.nn as nn
import cv2
from torchvision import transforms

# Load models and tokenizers
local_model_root = "./models/"

vit_model = ViTForImageClassification.from_pretrained(os.path.join(local_model_root, "vit_deepfake_model")).to('cpu')
vit_processor = ViTImageProcessor.from_pretrained(os.path.join(local_model_root, "vit_deepfake_model"))

bert_original_model = BertForSequenceClassification.from_pretrained(os.path.join(local_model_root, "bert-original-caption-model")).to('cpu')
bert_original_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

bert_generated_model = BertForSequenceClassification.from_pretrained(os.path.join(local_model_root, "bert-generated-caption-model")).to('cpu')
bert_generated_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load video model
video_model = timm.create_model('xception', pretrained=False, num_classes=2)
video_model.load_state_dict(torch.load(os.path.join(local_model_root, "video_model.pth"), map_location=torch.device('cpu')))
video_model.eval()

vit_model.eval()
bert_original_model.eval()
bert_generated_model.eval()

# Logging function
def log_prediction(input_type, label, confidence):
    log_file = "predictions.csv"
    header = ['timestamp', 'input_type', 'label', 'confidence']

    file_exists = os.path.isfile(log_file)

    with open(log_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            writer.writerow(header)
        writer.writerow([datetime.now().isoformat(), input_type, label, f"{confidence:.2f}"])

# Image prediction function
def predict_image(image):
    inputs = vit_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = vit_model(**inputs)
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)
    confidence, prediction = torch.max(probs, dim=1)

    label_map = {0: "Fake", 1: "Real"}
    label = label_map[prediction.item()]
    confidence_percent = confidence.item() * 100

    log_prediction("image", label, confidence_percent)

    return f"Image Prediction: {label} ({confidence_percent:.2f}%)"

# Text prediction function (both original and generated)
def predict_both_texts(text):
    inputs_original = bert_original_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs_original = bert_original_model(**inputs_original)
    logits_original = outputs_original.logits
    probs_original = torch.nn.functional.softmax(logits_original, dim=-1)
    confidence_original, prediction_original = torch.max(probs_original, dim=1)

    label_map = {0: "Fake News", 1: "Real News"}
    label_original = label_map[prediction_original.item()]
    confidence_percent_original = confidence_original.item() * 100

    inputs_generated = bert_generated_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs_generated = bert_generated_model(**inputs_generated)
    logits_generated = outputs_generated.logits
    probs_generated = torch.nn.functional.softmax(logits_generated, dim=-1)
    confidence_generated, prediction_generated = torch.max(probs_generated, dim=1)

    label_generated = label_map[prediction_generated.item()]
    confidence_percent_generated = confidence_generated.item() * 100

    log_prediction("text", f"Original: {label_original}, Generated: {label_generated}", (confidence_percent_original + confidence_percent_generated) / 2)

    return (
        f"Original Text Model: {label_original} ({confidence_percent_original:.2f}%)",
        f"Generated Text Model: {label_generated} ({confidence_percent_generated:.2f}%)"
    )

# Video prediction function
def extract_frames(video_file, num_frames=16):
    frames = []
    cap = cv2.VideoCapture(video_file)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    for i in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if not ret:
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()
    return frames

def preprocess_frames(frames):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    frames_tensor = [transform(frame) for frame in frames]
    return torch.stack(frames_tensor)

def predict_video(video):
    frames = extract_frames(video)
    frames_tensor = preprocess_frames(frames)

    with torch.no_grad():
        outputs = video_model(frames_tensor)
        _, predicted = torch.max(outputs.data, 1)
        prediction = torch.round(torch.mean(predicted.float()))

    label = "Fake" if prediction.item() else "Real"
    confidence = 100.0  # Placeholder since we don't have softmax here

    log_prediction("video", label, confidence)

    return f"Video Prediction: {label}"

# Gradio interfaces
image_interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Textbox(label="Prediction Result"),
    title="Deepfake Image Detector"
)

both_text_interface = gr.Interface(
    fn=predict_both_texts,
    inputs=gr.Textbox(lines=5, placeholder="Paste any caption text here..."),
    outputs=[
        gr.Textbox(label="Original Text Model Prediction"),
        gr.Textbox(label="Generated Text Model Prediction")
    ],
    title="Fake News Detection (Original and Generated Captions)"
)

video_interface = gr.Interface(
    fn=predict_video,
    inputs=gr.Video(label="Upload a Video"),
    outputs=gr.Textbox(label="Prediction Result"),
    title="Deepfake Video Detector"
)

# Launch with tabs
gr.TabbedInterface(
    [image_interface, both_text_interface, video_interface],
    tab_names=["Detect Deepfake Image", "Detect Fake News (Text Captions)", "Detect Deepfake Video"]
).launch(server_name="0.0.0.0")
