In [None]:
# Install necessary libraries
!pip install mtcnn
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install opencv-python

#packages
import zipfile, os, shutil, cv2
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.data import Data, Batch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from torch_geometric.nn import GCNConv, global_max_pool
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import pandas as pd
import os
import cv2
import torch
from PIL import Image



In [None]:
import zipfile

with zipfile.ZipFile("/content/split_frames_dataset.zip", "r") as zip_ref:
    zip_ref.extractall("/content/")

# check dataset folder exists
print(os.listdir("/content"))

['.config', 'split_frames_dataset.zip', 'split_frames_dataset', 'sample_data']


In [None]:
import glob

def load_image_paths(base_dir, label):
    # Recursively grab all jpg frames from person folders
    files = glob.glob(os.path.join(base_dir, "**", "*.jpg"), recursive=True)
    return [(f, label) for f in files]

# Train
real_train_files = load_image_paths("/content/split_frames_dataset/train/real", 0)
fake_train_files = load_image_paths("/content/split_frames_dataset/train/fake", 1)

# Test
real_val_files = load_image_paths("/content/split_frames_dataset/test/real", 0)
fake_val_files = load_image_paths("/content/split_frames_dataset/test/fake", 1)

print("✅ Train Real:", len(real_train_files))
print("✅ Train Fake:", len(fake_train_files))
print("✅ Test Real:", len(real_val_files))
print("✅ Test Fake:", len(fake_val_files))

✅ Train Real: 166
✅ Train Fake: 166
✅ Test Real: 42
✅ Test Fake: 42


In [None]:
def image_to_graph(image_tensor, k=9, patch_size=32, debug=True):
    """
    Converts an image tensor [3,H,W] into a graph where
    each node = flattened patch of size (3*patch_size*patch_size).
    """
    C, H, W = image_tensor.shape
    if H < patch_size or W < patch_size:
        if debug:
            print(f"⚠️ Skipping tiny frame: {image_tensor.shape}")
        return None

    # Extract patches: [C, num_patches_h, num_patches_w, ps, ps]
    patches = image_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    patches = patches.permute(1, 2, 0, 3, 4).contiguous()   # [num_patches_h, num_patches_w, C, ps, ps]
    patches = patches.view(-1, C * patch_size * patch_size)  # [num_patches, 3072]

    if patches.size(0) < 2:
        if debug:
            print(f"⚠️ Too few patches: {patches.shape}")
        return None

    #if debug:
       # print(f"✅ Graph created: nodes={patches.size(0)}, features={patches.size(1)}")

    # Build similarity graph
    similarity = cosine_similarity(patches.cpu().numpy())
    edge_index = []
    for i in range(len(patches)):
        indices = similarity[i].argsort()[-k-1:-1]
        edge_index += [(i, j) for j in indices]

    if len(edge_index) == 0:
        if debug:
            print("⚠️ No edges created")
        return None

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = patches.float()  # [num_patches, 3072]

    return Data(x=x, edge_index=edge_index)

In [None]:
class FaceFusionDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list
        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]

        # Try up to N attempts in case of bad frames
        attempts = 0
        while attempts < len(self.file_list):
            image = cv2.imread(img_path)
            if image is None:
                idx = (idx + 1) % len(self.file_list)
                attempts += 1
                continue

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_tensor = self.transform(Image.fromarray(image))

            graph = image_to_graph(image_tensor)
            if graph is not None:
                graph.y = torch.tensor(label, dtype=torch.long)
                return image_tensor, graph

            # If graph failed, move to next image
            idx = (idx + 1) % len(self.file_list)
            attempts += 1

        # If all fail
        raise RuntimeError("❌ All images failed to generate a valid graph.")


In [None]:
# Define Model Architectures
import torchvision.models as models

class CNNBranch(nn.Module):
  def __init__(self):
    super().__init__()
    base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    self.feature_extractor = nn.Sequential(*list(base.children())[:-1])
    self.fc = nn.Linear(512, 600)

  def forward(self,x):
    x = self.feature_extractor(x).view(x.size(0),-1)
    return self.fc(x)

class HierarchicalGNNBranch(nn.Module):
  def __init__(self):
    super().__init__()
    self.fds = [80,160,400,600]
    self.mlps = nn.ModuleList([
        nn.Sequential(
            nn.Linear(3072 if i==0 else self.fds[i-1],fd*4),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(fd*4,fd),
            nn.ReLU(),
            nn.Dropout(0.6)
        ) for i,fd in enumerate(self.fds)
    ])
    self.convs = nn.ModuleList([GCNConv(fd,fd) for fd in self.fds])
    self.norms = nn.ModuleList([nn.BatchNorm1d(fd) for fd in self.fds])
    self.dropout = nn.Dropout(0.6)

  def forward(self,data):
    x,edge_index,batch = data.x, data.edge_index, data.batch
    for i in range(len(self.fds)):
      x_res = x # Store the output of the previous layer before applying current MLP
      x = self.mlps[i](x)
      x = self.convs[i](x,edge_index)
      x = self.norms[i](x)
      if i > 0 and x.size(1) == x_res.size(1): # Apply residual connection only after the first layer AND if feature sizes match
          x = F.relu(x+x_res)
      else:
          x = F.relu(x) # Just apply ReLU otherwise
      x = self.dropout(x)
    return global_max_pool(x,batch)

class FuNetA(nn.Module):
  def __init__(self):
    super().__init__()
    self.cnn = CNNBranch()
    self.gnn = HierarchicalGNNBranch()
    self.fc = nn.Linear(600,2)
  def forward(self,img,graph):
    return self.fc(self.cnn(img)+self.gnn(graph))

class FuNetM(nn.Module):
  def __init__(self):
    super().__init__()
    self.cnn = CNNBranch()
    self.gnn = HierarchicalGNNBranch()
    self.fc = nn.Linear(600,2)
  def forward(self,img,graph):
    return self.fc(self.cnn(img)*self.gnn(graph))


class FuNetC(nn.Module):
  def __init__(self):
    super().__init__()
    self.cnn = CNNBranch()
    self.gnn = HierarchicalGNNBranch()
    self.fc = nn.Linear(1200,2)
  def forward(self,img,graph):
    return self.fc(torch.cat([self.cnn(img),self.gnn(graph)],dim=1))

In [None]:
from torch_geometric.data import Batch
from torch.utils.data import DataLoader

# Build train/val datasets
train_set = FaceFusionDataset(real_train_files + fake_train_files)
val_set   = FaceFusionDataset(real_val_files + fake_val_files)

# Custom collate function
def collate_fn(batch):
    batch = [b for b in batch if b is not None]  # filter out None
    if len(batch) == 0:
        return None, None
    images = torch.stack([b[0] for b in batch])
    graphs = Batch.from_data_list([b[1] for b in batch])
    return images, graphs


# DataLoaders
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=8, shuffle=False, collate_fn=collate_fn)

print("✅ Train samples:", len(train_set))
print("✅ Val samples:", len(val_set))


✅ Train samples: 332
✅ Val samples: 84


In [None]:
def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for images, graphs in loader:
        if images is None:
            continue
        images, graphs = images.to(device), graphs.to(device)
        optimizer.zero_grad()
        out = model(images, graphs)
        loss = F.cross_entropy(out, graphs.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, graphs in loader:
            images, graphs = images.to(device), graphs.to(device)
            out = model(images, graphs)
            loss = F.cross_entropy(out, graphs.y)
            total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    y_true, y_pred, y_prob = [], [], []
    with torch.no_grad():
        for images, graphs in loader:
            images, graphs = images.to(device), graphs.to(device)
            logits = model(images, graphs)
            probs = F.softmax(logits, dim=1)[:, 1]
            preds = torch.argmax(logits, dim=1)
            y_true.extend(graphs.y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob)
    }

class EarlyStopping:
    def __init__(self, patience=5, delta=0.0001):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_model_state = model.state_dict()
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
#training loop + validation loop + results
models={
    'FuNet-A':FuNetA(),
    #'FuNet-M':FuNetM(),
    #'FuNet-C':FuNetC()

}
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_histories={}
results={}
for name,model in models.items():
  model=model.to(device)
  optimizer=torch.optim.AdamW(model.parameters(),lr=2e-4)
  early_stopper=EarlyStopping(patience=5)

  train_losses=[]
  val_losses=[]
  for epoch in range(1,3):
    train_loss=train(model,train_loader,optimizer,device)
    val_loss=validate(model,val_loader,device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    early_stopper(val_loss,model)
    print(f"Epoch:{epoch} : Train Loss={train_loss:.4f}, Validation Loss={val_loss:.4f}")
    if early_stopper.early_stop:
      print("Early stopping")
      break

model.load_state_dict(early_stopper.best_model_state)
model_histories[name]={
    'train_losses':train_losses,
    'val_losses':val_losses
}
results[name]=evaluate(model,val_loader,device)





Epoch:1 : Train Loss=0.7040, Validation Loss=0.4438
Epoch:2 : Train Loss=0.2771, Validation Loss=0.0771


In [None]:

torch.save(model.state_dict(), "funet_a_full.pth")

print("Funet A model saved successfully!")

Funet A model saved successfully!


In [None]:

from google.colab import files
files.download("funet_a_full.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>