In [None]:
import cv2
import torch
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

# Ensure CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Ensure you have a compatible PyTorch version installed!")

# Force PyTorch to use GPU
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(0)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class Meso4(nn.Module):
    def __init__(self):
        super(Meso4, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(nn.Conv2d(8, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.conv3 = nn.Sequential(nn.Conv2d(8, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2))
        self.conv4 = nn.Sequential(nn.Conv2d(16, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 8 * 8, 16)
        self.fc2 = nn.Linear(16, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

class DeepFakeVideoDataset(Dataset):
    def __init__(self, real_videos, fake_videos, transform=None, num_frames=5):
        self.real_videos = real_videos
        self.fake_videos = fake_videos
        self.transform = transform
        self.num_frames = num_frames
        self.data = [(vid, 0) for vid in self.real_videos] + [(vid, 1) for vid in self.fake_videos]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        video_path, label = self.data[idx]
        cap = cv2.VideoCapture(video_path)

        frames = []
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        selected_indices = np.linspace(0, frame_count - 1, self.num_frames, dtype=int)

        for i in range(frame_count):
            ret, frame = cap.read()
            if not ret:
                break
            if i in selected_indices:
                image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image = Image.fromarray(image)
                if self.transform:
                    image = self.transform(image)
                frames.append(image)

        cap.release()

        if len(frames) == 0:
            print(f"Skipping video {video_path} due to no valid frames.")
            return self.__getitem__(np.random.randint(0, len(self)))

        return torch.stack(frames).to(device), torch.tensor(label, dtype=torch.long, device=device)

def resume_training(model, train_loader, epochs=10, lr=0.001, checkpoint_path="Meso2V1_1.13.pth"):
    model.to(device)  
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print("Checkpoint loaded successfully!")

    for epoch in range(epochs):
        model.train()
        total_loss, correct_train, total_train = 0, 0, 0

        for frames, labels in train_loader:
            frames, labels = frames.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            frames = frames[:, 0, :, :, :]  

            outputs = model(frames)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

        train_acc = 100 * correct_train / total_train
        print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(train_loader):.4f}, Acc: {train_acc:.2f}%")

        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

# Dataset Paths
real_videos_dir = "E:\\dataset\\Dataset Videos\\DFD_original sequences"
fake_videos_dir = "E:\\dataset\\Dataset Videos\\DFD_manipulated_sequences\\DFD_manipulated_sequences"

real_videos = [os.path.join(real_videos_dir, f) for f in os.listdir(real_videos_dir) if f.endswith(('.mp4', '.avi', '.mov'))]
fake_videos = [os.path.join(fake_videos_dir, f) for f in os.listdir(fake_videos_dir) if f.endswith(('.mp4', '.avi', '.mov'))]

transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
dataset = DeepFakeVideoDataset(real_videos, fake_videos, transform=transform, num_frames=5)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

model = Meso4().to(device)
resume_training(model, train_loader, epochs=10, lr=0.001, checkpoint_path="Meso2V1_1.13.pth")


Using device: cuda
Loading checkpoint: Meso2V1_1.13.pth
Checkpoint loaded successfully!
