In [1]:
import os
import cv2
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import timm
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ⚙️ Configuration
DATASET_PATH = "C:/Academics/Project/PBL/temp_dataset"
IMG_SIZE = 224
TARGET_FRAME_COUNT = 32
BATCH_SIZE = 4
FREEZE_BASE = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ✅ Dataset Check
if os.path.exists(DATASET_PATH):
    print("✅ Dataset Found!")
else:
    print("❌ Dataset Not Found! Check the path.")


✅ Dataset Found!


In [3]:
# 📁 Get video paths
real_videos = [os.path.join(DATASET_PATH, "real", f) for f in os.listdir(os.path.join(DATASET_PATH, "real")) if f.endswith(".mp4")]
fake_videos = [os.path.join(DATASET_PATH, "fake", f) for f in os.listdir(os.path.join(DATASET_PATH, "fake")) if f.endswith(".mp4")]

print(f"Found {len(real_videos)} real videos and {len(fake_videos)} fake videos.")


Found 10 real videos and 10 fake videos.


In [4]:
# 🧼 Transform
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [5]:
# 🎞️ Frame Sampler
def sample_frames(frames, target=TARGET_FRAME_COUNT):
    total = len(frames)
    if total >= target:
        indices = sorted(random.sample(range(total), target))
    else:
        indices = list(range(total)) + [total - 1] * (target - total)
    return [frames[i] for i in indices]


In [6]:
# 📹 Video Frame Extractor
def extract_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()
    return sample_frames(frames)


In [7]:
# 📦 Custom Dataset
class DeepfakeDataset(Dataset):
    def __init__(self, video_paths, labels, transform):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        frames = extract_frames(self.video_paths[idx])
        frames = [self.transform(frame) for frame in frames]
        video_tensor = torch.stack(frames)  # Shape: [32, 3, 224, 224]
        return video_tensor, torch.tensor(self.labels[idx], dtype=torch.long)


In [8]:
# 🏷️ Labels
all_videos = real_videos + fake_videos
all_labels = [0] * len(real_videos) + [1] * len(fake_videos)

# 🧪 Create Dataset & DataLoader
dataset = DeepfakeDataset(all_videos, all_labels, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# 🔍 Peek at one batch
for batch in dataloader:
    video_batch, label_batch = batch
    print("Video batch shape:", video_batch.shape)  # Expected: [B, 32, 3, 224, 224]
    print("Label batch:", label_batch)
    break


Video batch shape: torch.Size([4, 32, 3, 224, 224])
Label batch: tensor([1, 1, 0, 1])


In [9]:
# 🧠 Load Pretrained Swin Transformer (Tiny version)
model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True)

# 🛠️ Modify the classification head
in_features = model.head.in_features
model.head = nn.Linear(in_features, 2)  # 2 classes: real & fake

# 🧊 Freeze base model if desired
if FREEZE_BASE:
    for name, param in model.named_parameters():
        if not name.startswith("head"):
            param.requires_grad = False

# 🚀 Move to device
model = model.to(DEVICE)

# 🔍 Print to verify
print(model)


SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (drop1): 

In [10]:
# 🧮 Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# 🏋️ Training Loop
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for video_batch, labels in tqdm(dataloader):
        video_batch = video_batch.to(device)  # [B, 32, 3, 224, 224]
        labels = labels.to(device)

        # Flatten batch: [B * 32, 3, 224, 224]
        B, F, C, H, W = video_batch.shape
        frames = video_batch.view(B * F, C, H, W)

        # Forward pass on all frames
        frame_preds = model(frames)  # [B*F, 2]
        frame_preds = frame_preds.view(B, F, -1)  # [B, 32, 2]

        # Average predictions across frames
        video_preds = frame_preds.mean(dim=1)  # [B, 2]

        # Loss + Backward
        loss = criterion(video_preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stats
        total_loss += loss.item()
        _, predicted = torch.max(video_preds, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy


In [11]:
# 🏁 Run Training
epochs = 1  # Set to higher (e.g., 5–10) later
for epoch in range(epochs):
    print(f"\n🔁 Epoch {epoch+1}/{epochs}")
    train_loss, train_acc = train_one_epoch(model, dataloader, optimizer, criterion, DEVICE)
    print(f"✅ Train Loss: {train_loss:.4f} | Accuracy: {train_acc:.4f}")



🔁 Epoch 1/1


100%|██████████| 5/5 [04:58<00:00, 59.72s/it]

✅ Train Loss: 4.2256 | Accuracy: 0.0000



