In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
dataset_path = "datasets/FF++"
REAL_PATH = os.path.join(dataset_path, "real")
FAKE_PATH = os.path.join(dataset_path, "fake")
FRAME_COUNT = 10
FRAME_SIZE = (128, 128)
BATCH_SIZE = 8
MAX_VIDEOS = 700
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
class VideoDataset(Dataset):
    def __init__(self, real_path, fake_path, frame_count=10, frame_size=(128,128), max_videos=700, transform=None):
        self.real_videos = sorted(os.listdir(real_path))[:max_videos]
        self.fake_videos = sorted(os.listdir(fake_path))[:max_videos]
        self.real_paths = [os.path.join(real_path, v) for v in self.real_videos]
        self.fake_paths = [os.path.join(fake_path, v) for v in self.fake_videos]
        self.all_paths = self.real_paths + self.fake_paths
        self.labels = [0]*len(self.real_paths) + [1]*len(self.fake_paths)
        self.frame_count = frame_count
        self.frame_size = frame_size
        self.transform = transform

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, idx):
        video_path = self.all_paths[idx]
        label = self.labels[idx]
        
        # Frame extraction
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        step = max(total_frames // self.frame_count, 1)
        frames = []
        for i in range(self.frame_count):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
            ret, frame = cap.read()
            if not ret: break
            frame = cv2.resize(frame, self.frame_size)
            frames.append(frame)
        cap.release()
        
        # Convert to tensor
        video = torch.tensor(np.stack(frames), dtype=torch.float32).permute(0,3,1,2)/255.0
        
        # Apply transforms
        if self.transform:
            video = self.transform(video)
            
        return video, label

In [4]:
# Data augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
])

# Create datasets
full_dataset = VideoDataset(REAL_PATH, FAKE_PATH, 
                          frame_count=FRAME_COUNT, 
                          frame_size=FRAME_SIZE,
                          max_videos=MAX_VIDEOS,
                          transform=train_transform)

# Split data
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                         shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, 
                       shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, 
                        shuffle=False, num_workers=NUM_WORKERS)

In [5]:
class VideoClassifier(nn.Module):
    def __init__(self, frame_size=128, num_classes=2):
        super().__init__()
        # Feature extraction backbone
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor.fc = nn.Identity()  # Remove final layer
        
        # Temporal modeling
        self.lstm = nn.LSTM(input_size=512, hidden_size=256, 
                          num_layers=2, batch_first=True, dropout=0.5)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x shape: (batch, frames, channels, height, width)
        batch_size, frames, channels, h, w = x.size()
        x = x.view(batch_size*frames, channels, h, w)
        features = self.feature_extractor(x)
        features = features.view(batch_size, frames, -1)
        lstm_out, _ = self.lstm(features)
        out = self.classifier(lstm_out[:, -1, :])  # Take last timestep
        return out



In [6]:
# Initialize model
model = VideoClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)



In [7]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(dataloader), correct/total

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(dataloader), correct/total


In [8]:
# Training parameters
NUM_EPOCHS = 50
best_val_loss = float('inf')
patience = 5
trigger_times = 0

train_losses = []
val_losses = []
train_accs = []
val_accs = []

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, 
                                      criterion, optimizer, DEVICE)
    val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping!")
            break
    
    # Logging
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")


Epoch 1/50


Training: 100%|██████████| 35/35 [02:04<00:00,  3.55s/it]
Validation: 100%|██████████| 8/8 [00:27<00:00,  3.40s/it]


Train Loss: 0.6938 | Train Acc: 0.5107
Val Loss: 0.7009 | Val Acc: 0.4667

Epoch 2/50


Training: 100%|██████████| 35/35 [02:02<00:00,  3.50s/it]
Validation: 100%|██████████| 8/8 [00:24<00:00,  3.09s/it]


Train Loss: 0.6905 | Train Acc: 0.5179
Val Loss: 0.6957 | Val Acc: 0.5500

Epoch 3/50


Training: 100%|██████████| 35/35 [01:58<00:00,  3.38s/it]
Validation: 100%|██████████| 8/8 [00:24<00:00,  3.06s/it]


Train Loss: 0.6822 | Train Acc: 0.6071
Val Loss: 0.6863 | Val Acc: 0.5833

Epoch 4/50


Training:  46%|████▌     | 16/35 [01:03<01:15,  3.98s/it]


KeyboardInterrupt: 