#Resnet50 Feature Extraction


In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from google.colab import files

# This will open a file picker to upload files
uploaded = files.upload()


Saving resnet50_deepfake_frame.pth to resnet50_deepfake_frame.pth


In [15]:
from torch.utils.data import Dataset
from PIL import Image
import os
import torch

class VideoFramesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_dirs = []

        for label_name in os.listdir(root_dir):  # fake / real
            class_dir = os.path.join(root_dir, label_name)
            if not os.path.isdir(class_dir):
                continue
            for video_folder in os.listdir(class_dir):  # 000_003, etc
                video_path = os.path.join(class_dir, video_folder)
                if os.path.isdir(video_path):
                    self.video_dirs.append(video_path)

    def __len__(self):
        return len(self.video_dirs)

    def __getitem__(self, idx):
        video_path = self.video_dirs[idx]
        label = 1 if "fake" in video_path.lower() else 0

        frames = []
        for file in sorted(os.listdir(video_path)):
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                img_path = os.path.join(video_path, file)
                try:
                    img = Image.open(img_path).convert("RGB")
                    if self.transform:
                        img = self.transform(img)
                    frames.append(img)
                except Exception as e:
                    print(f"❌ Could not load {img_path}: {e}")

        if len(frames) == 0:
            raise ValueError(f"No valid frames found in {video_path}")

        frames = torch.stack(frames, dim=0)  # (T, C, H, W)
        return frames, label


In [16]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = VideoFramesDataset("dataset", transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=False)


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
!cp /content/drive/MyDrive/dataset.zip /content/

In [9]:
!unzip -q dataset.zip

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = resnet50()
model.fc = nn.Linear(2048, 2)  # Match your training setting
model.load_state_dict(torch.load("resnet50_deepfake_frame.pth", map_location=device))
model.fc = nn.Identity()  # Remove classification head
model = model.to(device)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [17]:
all_feats = []
all_labels = []

num_videos = len(loader)
pbar = tqdm(enumerate(loader), total=num_videos, desc="Extracting features")

for i, (frames, label) in pbar:
    try:
        # Skip if no valid frames were returned
        if frames is None or frames.numel() == 0:
            print(f"Skipping video {i} due to no valid frames.")
            pbar.set_postfix({"% done": f"{(i+1)/num_videos*100:.1f}%", "skipped": i})
            continue

        T, C, H, W = frames.shape[1:]
        x = frames.view(T, C, H, W).to(device)  # (T, C, H, W)

        with torch.no_grad():
            feats = model(x)                 # (T, 2048)
            feats = feats.view(1, T, -1)     # (1, T, 2048)
            feats = F.normalize(feats, p=2, dim=-1)

        all_feats.append(feats.cpu())
        # Ensure label is a tensor before appending
        all_labels.append(torch.tensor([label]))

        pbar.set_postfix({"% done": f"{(i+1)/num_videos*100:.1f}%"})

    except Exception as e:
        print(f"⚠️ Skipping video {i}: {e}")
        continue

# Only concatenate if there are features
if all_feats:
    features_tensor = torch.cat(all_feats, dim=0)  # (N, T, 2048)
    labels_tensor   = torch.cat(all_labels, dim=0) # (N,)

    os.makedirs("features", exist_ok=True)
    torch.save(features_tensor, "features/all_features.pt")
    torch.save(labels_tensor,   "features/all_labels.pt")

    print(f"✅ Saved {features_tensor.shape[0]} sequences to features/")
else:
    print("No features extracted from any videos.")

Extracting features: 100%|██████████| 2000/2000 [04:44<00:00,  7.02it/s, % done=100.0%]


✅ Saved 2000 sequences to features/


In [18]:
from google.colab import files

files.download("features/all_features.pt")
files.download("features/all_labels.pt")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>