In [1]:
# step 0.1 Face Swap Classifier Training Notebook with Frame Extraction

import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt





In [2]:
# step 0.2


import os
import cv2
import mediapipe as mp

def extract_frames_from_videos(video_folder, output_folder, split, class_name, frames_per_video=100):
    """
    Extracts at least `frames_per_video` face-cropped frames from each video in `video_folder` 
    into `output_folder/split/class_name` using Mediapipe for face detection.
    Crops the face region, draws a bounding box, resizes to 112x112, and saves it.
    """
    mp_face_detection = mp.solutions.face_detection
    face_detection = mp_face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)

    # Set the correct path based on split and class_name
    output_class_dir = os.path.join(output_folder, split, class_name)
    os.makedirs(output_class_dir, exist_ok=True)

    for video_name in os.listdir(video_folder):
        if not video_name.lower().endswith(('.mp4', '.avi', '.mov')):
            continue

        video_path = os.path.join(video_folder, video_name)
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        interval = max(1, total_frames // frames_per_video)

        count = 0
        saved = 0

        while cap.isOpened() and saved < frames_per_video:
            ret, frame = cap.read()
            if not ret:
                break

            if count % interval == 0:
                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = face_detection.process(rgb_frame)

                if results.detections:
                    detection = results.detections[0]
                    bboxC = detection.location_data.relative_bounding_box
                    h, w, _ = frame.shape
                    x1 = int(bboxC.xmin * w)
                    y1 = int(bboxC.ymin * h)
                    x2 = int((bboxC.xmin + bboxC.width) * w)
                    y2 = int((bboxC.ymin + bboxC.height) * h)

                    x1, y1 = max(0, x1), max(0, y1)
                    x2, y2 = min(w, x2), min(h, y2)

                    face_crop = frame[y1:y2, x1:x2]

                    if face_crop.size == 0:
                        count += 1
                        continue

                    cv2.rectangle(face_crop, (0, 0), (x2 - x1, y2 - y1), (0, 255, 0), 2)
                    face_resized = cv2.resize(face_crop, (112, 112))
                    frame_filename = f"{os.path.splitext(video_name)[0]}_face{count}.jpg"
                    frame_path = os.path.join(output_class_dir, frame_filename)
                    
                    if os.path.exists(frame_path):
                        count += 1
                        continue

                    cv2.imwrite(frame_path, face_resized)
                    saved += 1

            count += 1

        cap.release()
        print(f"Saved {saved} face-cropped frames from {video_name} to {output_class_dir}")

    face_detection.close()


# --------- CALLING CODE ---------
# Map the video folders to the correct output folders

'''
extract_frames_from_videos("data/train/fake", "data/data_faces", "train_combined", "fake", frames_per_video=100)
extract_frames_from_videos("data/train/real", "data/data_faces", "train_combined", "real", frames_per_video=100)
extract_frames_from_videos("data/val/fake", "data/data_faces", "val_combined", "fake", frames_per_video=100)
extract_frames_from_videos("data/val/real", "data/data_faces", "val_combined", "real", frames_per_video=100)
'''
# now extracting from test set

extract_frames_from_videos("data/test/fake", "data/data_faces", "test_combined", "fake", frames_per_video=100)
extract_frames_from_videos("data/test/real", "data/data_faces", "test_combined", "real", frames_per_video=100)




Saved 100 face-cropped frames from 917_924.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 918_934.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 919_015.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 920_811.mp4 to data/data_faces\test_combined\fake
Saved 28 face-cropped frames from 923_023.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 924_917.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 925_933.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 927_912.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 928_160.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 929_962.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 932_384.mp4 to data/data_faces\test_combined\fake
Saved 100 face-cropped frames from 933_925.mp4 to data/data_faces\

In [None]:

# ---------- STEP 1: Data transforms ----------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
# ---------- STEP 2: Load datasets ----------
data_faces_dir = "./data/data_faces"

train_dataset = datasets.ImageFolder(
    root=os.path.join(data_faces_dir, "train_combined"),
    transform=transform
)

val_dataset = datasets.ImageFolder(
    root=os.path.join(data_faces_dir, "val_combined"),
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [None]:
# After defining train_loader:
for i, (inputs, labels) in enumerate(train_loader):
    print("Batch index:", i)
    print("Input batch shape:", inputs.shape)
    print("Label batch shape:", labels.shape)
    print("Sample labels:", labels)

    # You can even visualize a few of the images in the batch
    import matplotlib.pyplot as plt
    import numpy as np

    def imshow(img):
        img = img / 2 + 0.5     # unnormalize (if you normalized)
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    if i == 0: # Only inspect the first batch
        imshow(inputs[0].cpu()) # Show the first image in the batch
        print("Label of the first image:", labels[0])
        break # Stop after inspecting one batch

In [None]:
# ---------- STEP 3: Model Definition ----------
class FaceSwapDetector(nn.Module):
    def __init__(self):
        super(FaceSwapDetector, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 1)

    def forward(self, x):
        return torch.sigmoid(self.model(x))

model = FaceSwapDetector()

In [None]:

# ---------- STEP 4: Loss, optimizer, and device ----------
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.summary()


In [None]:

# ---------- STEP 5: Training loop ----------

'''
train_losses = []
val_accuracies = []

for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(inputs).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

    # Validation step
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()
            outputs = model(inputs).squeeze()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    val_accuracies.append(val_acc)
    print(f"Validation Accuracy: {val_acc:.2f}%")


    '''



    # calcualte the training accuracy also along with training loss and validation accuracy

train_losses = []
train_accuracies = []   # <-- Add this list
val_accuracies = []

for epoch in range(10):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(inputs).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # --- Calculate training accuracy ---
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_train_acc = 100 * correct / total   # Training accuracy for this epoch

    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_train_acc)   # Store it!

    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_train_acc:.2f}%")

    # Validation step
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()
            outputs = model(inputs).squeeze()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    val_accuracies.append(val_acc)
    print(f"Validation Accuracy: {val_acc:.2f}%")






In [None]:
train_losses 


In [None]:

# ---------- STEP 6: Plot validation accuracy ----------
#plt.plot(train_losses, label='Train Loss')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.title("Validation Accuracy")
plt.show()




In [None]:
plt.plot(train_accuracies, label='Train Accuracy', color='green')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Training Accuracy Over Epochs")
plt.legend()
plt.show()


In [None]:
# ---------- STEP 7:only train loss Plot ----------
plt.plot(train_losses, label='Train Loss', color='blue')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()
plt.show()


In [None]:
#import matplotlib.pyplot as plt

# ---------- Combined Train and Validation Accuracy Plot ----------

plt.plot(train_accuracies, label='Train Accuracy', color='blue')
plt.plot(val_accuracies, label='Validation Accuracy', color='orange')

plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy Over Epochs")
plt.legend()
plt.show()


In [None]:
torch.save(model.state_dict(), "face_swap_model.pth")


In [None]:
torch.save(model, "full_face_swap_model.pth") # save full model + architecture


In [None]:
# If you saved only state_dict
model = YourModelClass(*args, **kwargs)
model.load_state_dict(torch.load("face_swap_model.pth"))
model.eval()

# OR, if you saved the full model
model = torch.load("full_face_swap_model.pth")
model.eval()
