DeepFake Video Detection using EfficientNetB2 + Transformer

In [None]:
import os
import cv2
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.utils import shuffle, resample, class_weight
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB2
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, GlobalAveragePooling2D, GlobalAveragePooling1D,
    LayerNormalization, MultiHeadAttention
)
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

Constants

In [None]:
FRAME_COUNT = 20
IMAGE_SIZE = 224
SEED = 42
BATCH_SIZE = 8
EPOCHS = 20


Feature Extractor : CNN base model using EfficientNetB2 pretrained on ImageNet

In [None]:
cnn_base = EfficientNetB2(weights="imagenet", include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
cnn_model = Sequential([cnn_base, GlobalAveragePooling2D()])


Frame Extraction

In [None]:
def extract_frames(video_path, max_frames=FRAME_COUNT):
    cap = cv2.VideoCapture(video_path)
    frames = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_indices = np.linspace(0, total_frames-1, max_frames, dtype=int)
    for i in range(total_frames):
        success, frame = cap.read()
        if not success:
            break
        if i in frame_indices:
            frame = cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE))
            frame = preprocess_input(frame)
            frames.append(frame)
    cap.release()
    return np.array(frames)

Dataset Loader

In [None]:
def load_dataset(folder_path):
    X, y = [], []
    for label, class_dir in enumerate(["real", "fake"]):
        class_path = os.path.join(folder_path, class_dir)
        video_files = os.listdir(class_path)
        random.shuffle(video_files)
        for video_file in tqdm(video_files, desc=f"Loading {class_dir}"):
            video_path = os.path.join(class_path, video_file)
            try:
                frames = extract_frames(video_path)
                if frames.shape[0] == FRAME_COUNT:
                    features = cnn_model.predict(frames, verbose=0)
                    X.append(features)
                    y.append(label)
            except Exception as e:
                print(f"Error processing {video_path}: {e}")
    return np.array(X), np.array(y)


Transformer Encoder

In [None]:
def transformer_encoder(inputs, head_size=64, num_heads=4, ff_dim=256, dropout=0.3):
    x = LayerNormalization(epsilon=1e-6)(inputs)
    x = MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout)(x, x)
    x = Dropout(dropout)(x)
    res = x + inputs

    x = LayerNormalization(epsilon=1e-6)(res)
    x = Dense(ff_dim, activation="relu")(x)
    x = Dropout(dropout)(x)
    x = Dense(inputs.shape[-1])(x)
    return x + res

Model Building

In [None]:
def build_transformer_model(input_shape, num_classes=2):
    inputs = Input(shape=input_shape)
    x = transformer_encoder(inputs)
    x = transformer_encoder(x)
    x = GlobalAveragePooling1D()(x)
    x = Dropout(0.4)(x)
    x = Dense(32, activation="relu")(x)
    x = Dropout(0.2)(x)
    outputs = Dense(1, activation="sigmoid")(x)
    return Model(inputs, outputs)

Loading Data

In [None]:
X_train, y_train = load_dataset("D:/DeepFake2.0/split_dataset_part2/train")
X_test, y_test = load_dataset("D:/DeepFake2.0/split_dataset_part2/test")
X_train, y_train = shuffle(X_train, y_train, random_state=SEED)

Handling Imbalance dataset

In [None]:
X_real = X_train[y_train == 0]
X_fake = X_train[y_train == 1]
X_real_upsampled, y_real_upsampled = resample(X_real, np.zeros(len(X_real)), replace=True, n_samples=len(X_fake), random_state=SEED)
X_train = np.concatenate([X_fake, X_real_upsampled])
y_train = np.concatenate([np.ones(len(X_fake)), y_real_upsampled])
X_train, y_train = shuffle(X_train, y_train, random_state=SEED)

Focal Loss 

In [None]:
def binary_focal_loss(gamma=2.0, alpha=0.25):
    def focal_loss(y_true, y_pred):
        y_true = K.cast(y_true, K.floatx())
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        modulating_factor = K.pow((1 - p_t), gamma)
        return K.mean(-alpha_factor * modulating_factor * K.log(p_t), axis=-1)
    return focal_loss

Training

In [None]:
class_weights = class_weight.compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
class_weights = {i: class_weights[i] for i in range(2)}

model = build_transformer_model(input_shape=(FRAME_COUNT, cnn_model.output_shape[1]))
model.compile(optimizer=Adam(learning_rate=1e-4), loss=binary_focal_loss(), metrics=["accuracy"])

lr_schedule = ReduceLROnPlateau(monitor="val_accuracy", patience=2, factor=0.5, min_lr=1e-6)

history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weights,
    callbacks=[lr_schedule],
)

Saving Model

In [None]:
model.save("deepfake_efficientnetb2_transformer_balanced_focal2.keras")

Evaluation 

In [None]:
y_pred_probs = model.predict(X_test)
y_pred = (y_pred_probs > 0.77).astype(int)

cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["REAL", "FAKE"])
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

Accuracy and Loss Graphs

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
plt.title("Accuracy Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Validation Loss")
plt.title("Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.tight_layout()
plt.show()

Inference

In [None]:
def predict_video(video_path):
    frames = extract_frames(video_path)
    if frames.shape[0] != FRAME_COUNT:
        print("Insufficient frames for inference.")
        return
    features = cnn_model.predict(frames, verbose=0)
    features = np.expand_dims(features, axis=0)
    model = load_model("deepfake_efficientnetb2_transformer_balanced_focal2.keras", compile=False)
    prediction = model.predict(features)[0][0]
    label = "REAL" if prediction < 0.5 else "FAKE"
    confidence = (1 - prediction) * 100 if label == "REAL" else prediction * 100
    print(f"Prediction: {label} ({confidence:.2f}% confidence)")