In [1]:

import os
import glob
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pretrainedmodels
import pretrainedmodels.utils as utils

ModuleNotFoundError: No module named 'PIL'

In [1]:
# ====================
# Dataset Definition
# ====================
class GenericFrameDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform

        for label, class_name in enumerate(['real', 'fake']):
            class_path = os.path.join(root_dir, class_name)
            video_folders = glob.glob(os.path.join(class_path, '*'))

            for folder in video_folders:
                for img_path in glob.glob(os.path.join(folder, '*.png')):
                    self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

NameError: name 'Dataset' is not defined

In [4]:

# ====================
# Model Setup
# ====================
xception = pretrainedmodels.__dict__['xception'](pretrained='imagenet')
num_features = xception.last_linear.in_features
xception.last_linear = nn.Linear(num_features, 1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xception = xception.to(device)

In [5]:
# ====================
# Transforms & Loader
# ====================
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

# Define transform (adjust as needed)
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Required for Xception
    transforms.ToTensor()
])

# Correctly point to your 'real' and 'fake' folders
train_dataset = ImageFolder('D:/EE656 project/train_split/train', transform=transform)
val_dataset = ImageFolder('D:/EE656 project/train_split/val', transform=transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [6]:
# ====================
# Training
# ====================
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(xception.parameters(), lr=1e-4)
epochs = 5

for epoch in range(epochs):
    xception.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = xception(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}")

Epoch 1/5:   0%|          | 0/2347 [00:00<?, ?it/s]

Epoch 1/5:   1%|          | 14/2347 [03:21<9:18:30, 14.36s/it]


KeyboardInterrupt: 

In [8]:
# ====================
# Evaluation
# ====================
xception.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for imgs, labels in tqdm(val_loader, desc="Evaluating"):
        imgs = imgs.to(device)
        outputs = xception(imgs)
        probs = torch.sigmoid(outputs).cpu().numpy().flatten()

        y_pred.extend(probs)
        y_true.extend(labels.numpy())

auc = roc_auc_score(y_true, y_pred)
print(f"AUC: {auc:.4f}")

fpr, tpr, _ = roc_curve(y_true, y_pred)
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Xception')
plt.legend()
plt.grid()
plt.show()

Evaluating:   0%|          | 0/587 [00:07<?, ?it/s]


KeyboardInterrupt: 