#### Note : This notebook was used in Google Colab thus paths will have to be changed accordingly (if running elsewhere or in a different path)

#### Link to colab notebook: https://colab.research.google.com/drive/1dUSQPJc7kcRZZE7oIdfXMFR9HE4xKGdT?usp=sharing

#### Link to Dataset : https://drive.google.com/file/d/1LnvVO5eJiSwjOeu7SmNVDDkD9qp6BB9e/view?usp=sharing

In [None]:
!pip install facenet-pytorch albumentations torch torchvision facenet-pytorch albumentations opencv-python pandas numpy scikit-learn matplotlib

In [None]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from albumentations import Compose, RandomBrightnessContrast, HueSaturationValue
import pandas as pd
import numpy as np
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from facenet_pytorch import MTCNN, InceptionResnetV1
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

In [None]:
# Ensure GPU is being used
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print("GPU detected:", torch.cuda.get_device_name(0))

# Path to your dataset
BASE_PATH = '/content/drive/MyDrive/Hackathon/ZenTej/Sentinel_FaceV1'
# (or Kaggle use '/kaggle/input/sentinel-facev1/Sentinel_FaceV1' accordingly)

In [None]:
labels_df = pd.read_csv(os.path.join(BASE_PATH, 'Forgery_Dataset', 'train_labels.csv'))

# Filter valid entries only
valid_df = labels_df[labels_df.apply(lambda row: os.path.exists(os.path.join(BASE_PATH, 'Forgery_Dataset', row['image_path'])), axis=1)]
print(f"Using {len(valid_df)} valid entries out of {len(labels_df)}")

label_map = {'real': 0, 'fake': 1}

# Dataset class
class ForgeryDataset(Dataset):
    def __init__(self, df, base_path, transform=None):
        self.df = df
        self.base_path = base_path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            img_path = os.path.join(self.base_path, 'Forgery_Dataset', row['image_path'])
            image = cv2.imread(img_path)
            if image is None:
                raise FileNotFoundError(f"Image not found or corrupt: {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            label = label_map[row['label']]

            if self.transform:
                augmented = self.transform(image=image)
                image = augmented['image']

            image = image.transpose(2, 0, 1)
            image = torch.from_numpy(image).float() / 255.0
            return image, label
        except Exception as e:
            print(f"Error loading idx {idx}: {e} - Skipping")
            return None, None

# Augmentation transforms
aug_transform = Compose([
    RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
])

# Collate to filter skipped data
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    if len(batch) == 0:
        return None, None
    return torch.utils.data.dataloader.default_collate(batch)


In [None]:
train_dataset = ForgeryDataset(valid_df, BASE_PATH, transform=aug_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, collate_fn=collate_fn)

mtcnn = MTCNN(keep_all=True, device=device)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

deepfake_model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1).to(device)
in_features = deepfake_model.classifier[1].in_features
deepfake_model.classifier[1] = nn.Linear(in_features, 2).to(device)

optimizer = Adam(deepfake_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 2
for epoch in range(num_epochs):
    deepfake_model.train()
    running_loss = 0.0
    for images, labels in train_dataloader:
        if images is None:
            continue
        images = images.to(device)
        labels = labels.to(device)

        dct_imgs = torch.fft.fft2(images).real
        outputs = deepfake_model(dct_imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_dataloader)}")

torch.save(deepfake_model.state_dict(), '/content/deepfake_model.pth')
print("Model saved.")

In [None]:
def process_input(input1, input2=None, is_video=False):
    if is_video:
        cap = cv2.VideoCapture(input1)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            face = mtcnn(frame)
            if face is not None:
                frames.append(face)
        cap.release()
        if not frames:
            return None, None, None
        face1 = frames[0]
        liveness_score = np.var([f.mean().item() for f in frames]) > 0.1
    else:
        face1 = mtcnn(input1)
        liveness_score = 1.0

    if input2:
        face2 = mtcnn(input2)
        if face1 is None or face2 is None:
            return None, None, None
        emb1 = resnet(face1.unsqueeze(0).to(device)).detach().cpu().numpy()
        emb2 = resnet(face2.unsqueeze(0).to(device)).detach().cpu().numpy()
        match_score = cosine_similarity(emb1, emb2)[0][0]
    else:
        match_score = None

    if face1 is not None:
        face1 = face1.to(device)
        dct1 = torch.fft.fft2(face1.unsqueeze(0)).real
        auth_pred = deepfake_model(dct1)
        authenticity_prob = F.softmax(auth_pred, dim=1)[0][0].item()
        authenticity = 'Authentic' if authenticity_prob > 0.5 else 'Forged'
    else:
        authenticity = None

    return match_score, liveness_score, authenticity

In [None]:
def grad_cam(model, input_tensor, target_class=0):
    model.eval()
    grads = []
    activations = []

    def backward_hook(module, grad_in, grad_out):
        grads.append(grad_out[0])

    def forward_hook(module, input, output):
        activations.append(output)

    b_handle = model.features[-1].register_backward_hook(backward_hook)
    f_handle = model.features[-1].register_forward_hook(forward_hook)

    output = model(input_tensor)
    model.zero_grad()
    target_class = min(target_class, output.size(1) - 1)
    output[0][target_class].backward()

    pooled_grads = torch.mean(grads[0], dim=[0, 2, 3])
    heatmap = torch.sum(pooled_grads.unsqueeze(-1).unsqueeze(-1) * activations[0], dim=1)
    heatmap = F.relu(heatmap)

    b_handle.remove()
    f_handle.remove()

    return heatmap.detach().cpu().numpy()

In [None]:
sample_img_path = os.path.join(BASE_PATH, 'Forgery_Dataset/real/10564.jpg')
sample_img = cv2.imread(sample_img_path)
sample_img = cv2.cvtColor(sample_img, cv2.COLOR_BGR2RGB)
sample_face_tensor = mtcnn(sample_img)

if sample_face_tensor is not None and len(sample_face_tensor) > 0:
    sample_face = sample_face_tensor[0].unsqueeze(0).to(device)
    sample_dct = torch.fft.fft2(sample_face).real
    heatmap = grad_cam(deepfake_model, sample_dct, target_class=0)
    plt.imshow(heatmap[0], cmap='hot')
    plt.show()
else:
    print("No face detected in the sample image for Grad-CAM.")

In [None]:
torch.save(deepfake_model.state_dict(), '/content/deepfake_model.pth')
print("Model saved at /content/deepfake_model.pth")