In [1]:
import torch
import torch.nn as nn
from torchvision import models

class DeepFakeDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove final FC layer
        self.rnn = nn.LSTM(input_size=2048, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(128, 1)

    def forward(self, x):
        # x: (batch_size, seq_len, C, H, W)
        batch_size, seq_len = x.size(0), x.size(1)
        cnn_features = []
        for i in range(seq_len):
            features = self.cnn(x[:, i, :, :, :])  # (batch_size, 2048)
            cnn_features.append(features)
        cnn_features = torch.stack(cnn_features, dim=1)  # (batch_size, seq_len, 2048)
        _, (hidden, _) = self.rnn(cnn_features)
        output = self.fc(hidden[-1])
        return torch.sigmoid(output)


In [34]:
import os
import pandas as pd
import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image


class SingleVideoDataset(Dataset):
    def __init__(self, video_folder, label, transform=None):
        """
        Args:
            video_folder (str): Path to the folder containing frames of a single video.
            label (int/str): Label for the video.
            transform (callable, optional): Optional transform to be applied on each frame.
        """
        self.video_folder = video_folder
        self.label = label
        self.frame_paths = sorted([os.path.join(video_folder, f) for f in os.listdir(video_folder)])
        self.transform = transform

    def __len__(self):
        # Return 1 because each instance is a whole video (sequence of frames)
        return 1

    def __getitem__(self, idx):
        # Load all frames in the video folder
        frames = []
        for img_path in self.frame_paths:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            frames.append(image)
        # Stack frames into a single tensor (T, C, H, W)
        frames = torch.stack(frames, dim=0)
        return frames, self.label
    

class AllVideosDataset(Dataset):
    def __init__(self, root_dir, metadata_path, transform=None):
        """
        Args:
            root_dir (str): Directory with all video folders.
            metadata_path (str): Path to the CSV file with video_id and label.
            transform (callable, optional): Optional transform to be applied on each frame.
        """
        self.root_dir = root_dir
        self.metadata = pd.read_csv(metadata_path)
        self.transform = transform
        self.video_datasets = []
        no_data_videos = []
        for _, row in self.metadata.iterrows():
            video_id = row['video_id']
            label = row['Label']
            video_folder = os.path.join(root_dir, "images", f"{video_id}")
            if os.path.exists(video_folder) and os.listdir(video_folder):
                self.video_datasets.append(SingleVideoDataset(video_folder, label, transform))
            else:
                no_data_videos.append(video_id)
        if no_data_videos:
            print(f"Warning: No data found for {len(no_data_videos)} videos.")

    def __len__(self):
        # Each item is a whole video
        return len(self.video_datasets)

    def __getitem__(self, idx):
        # Return the entire sequence for the idx-th video
        return self.video_datasets[idx][0]  # [0] because each SingleVideoDataset has only one item

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = AllVideosDataset('data', 'data/metadata.csv', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

print(f'Total samples in dataset: {len(dataset)}')
# Access the first batch of data using an iterator
first_batch = next(iter(dataloader))
print(first_batch[0].shape)


Total samples in dataset: 176


In [None]:
# training setup
model = DeepFakeDetector()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCELoss()

num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images) 
        loss = criterion(outputs.squeeze(), labels.float())
        loss.backward()
        optimizer.step()




KeyboardInterrupt: 