# Real vs Fake Image Detector - Using ProGAN Dataset

This notebook trains a ResNet-50 model to detect fake images using the properly organized ProGAN dataset.

**Dataset Structure:**
```
/scratch/rzc7ew/progan/
├── airplane/0_real/  (real images)
├── airplane/1_fake/  (ProGAN-generated fakes)
├── bicycle/0_real/
├── bicycle/1_fake/
└── ... (all categories)
```

**Training:**
- **Real (Class 0)**: All images from `*/0_real/` folders
- **Fake (Class 1)**: All images from `*/1_fake/` folders

## Step 1: Configuration

In [None]:
# Configuration
DATASET_ROOT = '/scratch/rzc7ew/progan'  # Can change to biggan, stylegan, etc.

BATCH_SIZE = 32
NUM_WORKERS = 4
LEARNING_RATE = 0.0001
NUM_EPOCHS = 20
PATIENCE = 5
BLUR_PROB = 0.5
JPEG_PROB = 0.5

print("Configuration:")
print(f"  Dataset: {DATASET_ROOT}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")

## Step 2: Import Libraries

In [None]:
import os
import sys
import numpy as np
import random
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import (
    accuracy_score, precision_recall_curve, average_precision_score,
    roc_auc_score, roc_curve, classification_report, confusion_matrix
)
from tqdm.auto import tqdm
import seaborn as sns

# Set seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 3: Verify Dataset Structure

In [None]:
if not os.path.exists(DATASET_ROOT):
    print(f"❌ ERROR: Dataset not found at: {DATASET_ROOT}")
else:
    print(f"✓ Dataset directory found: {DATASET_ROOT}")
    print("\nCategories available:")
    categories = [d for d in os.listdir(DATASET_ROOT) 
                  if os.path.isdir(os.path.join(DATASET_ROOT, d))]
    for i, cat in enumerate(categories[:10], 1):
        print(f"  {i}. {cat}")
    if len(categories) > 10:
        print(f"  ... and {len(categories) - 10} more categories")
    print(f"\nTotal categories: {len(categories)}")

## Step 4: Data Augmentation

In [None]:
class DataAugmentation:
    """Data augmentation from CNNDetection paper"""
    
    @staticmethod
    def apply_gaussian_blur(img, sigma_range=(0.0, 3.0)):
        sigma = np.random.uniform(sigma_range[0], sigma_range[1])
        if sigma > 0:
            img_array = np.array(img)
            for i in range(3):
                img_array[:, :, i] = gaussian_filter(img_array[:, :, i], sigma=sigma)
            return Image.fromarray(img_array)
        return img
    
    @staticmethod
    def apply_jpeg_compression(img, quality_range=(30, 100)):
        quality = np.random.randint(quality_range[0], quality_range[1] + 1)
        buffer = BytesIO()
        img.save(buffer, format='JPEG', quality=quality)
        buffer.seek(0)
        return Image.open(buffer)
    
    @staticmethod
    def augment(img, blur_prob=0.5, jpeg_prob=0.5):
        if random.random() < blur_prob:
            img = DataAugmentation.apply_gaussian_blur(img)
        if random.random() < jpeg_prob:
            img = DataAugmentation.apply_jpeg_compression(img)
        return img

print("✓ Data augmentation defined")

## Step 5: Dataset Class

In [None]:
class RealFakeDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, augment=True,
                 blur_prob=0.5, jpeg_prob=0.5):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.augment = augment
        self.blur_prob = blur_prob
        self.jpeg_prob = jpeg_prob
        
        print(f"Dataset: {len(self.image_paths)} images")
        print(f"  - Real: {sum(1 for l in self.labels if l == 0)}")
        print(f"  - Fake: {sum(1 for l in self.labels if l == 1)}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            img = Image.new('RGB', (224, 224))
        
        if self.augment:
            img = DataAugmentation.augment(img, self.blur_prob, self.jpeg_prob)
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

print("✓ Dataset class defined")

## Step 6: Load Real and Fake Images

In [None]:
def load_real_fake_data(dataset_root):
    """Load real (0_real) and fake (1_fake) images from organized dataset"""
    print("="*70)
    print("Loading Real and Fake Images")
    print("="*70)
    
    real_images = []
    fake_images = []
    
    # Scan all category folders
    categories = [d for d in os.listdir(dataset_root) 
                  if os.path.isdir(os.path.join(dataset_root, d))]
    
    print(f"\nScanning {len(categories)} categories...")
    
    for category in tqdm(categories, desc="Loading categories"):
        category_path = os.path.join(dataset_root, category)
        
        # Load real images (0_real folder)
        real_folder = os.path.join(category_path, '0_real')
        if os.path.exists(real_folder):
            for file in os.listdir(real_folder):
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    real_images.append(os.path.join(real_folder, file))
        
        # Load fake images (1_fake folder)
        fake_folder = os.path.join(category_path, '1_fake')
        if os.path.exists(fake_folder):
            for file in os.listdir(fake_folder):
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    fake_images.append(os.path.join(fake_folder, file))
    
    print(f"\n✓ Found {len(real_images)} real images")
    print(f"✓ Found {len(fake_images)} fake images")
    
    if len(real_images) == 0 or len(fake_images) == 0:
        print("\n❌ ERROR: Not enough images!")
        return None, None, None, None
    
    # Balance dataset
    min_count = min(len(real_images), len(fake_images))
    print(f"\nBalancing to {min_count} images per class")
    
    real_images = random.sample(real_images, min_count)
    fake_images = random.sample(fake_images, min_count)
    
    # Create labels
    real_labels = [0] * len(real_images)
    fake_labels = [1] * len(fake_images)
    
    # Combine and shuffle
    all_images = real_images + fake_images
    all_labels = real_labels + fake_labels
    
    combined = list(zip(all_images, all_labels))
    random.shuffle(combined)
    all_images, all_labels = zip(*combined)
    all_images = list(all_images)
    all_labels = list(all_labels)
    
    # 80/20 split
    split_idx = int(0.8 * len(all_images))
    
    train_images = all_images[:split_idx]
    train_labels = all_labels[:split_idx]
    val_images = all_images[split_idx:]
    val_labels = all_labels[split_idx:]
    
    print(f"\nFinal split:")
    print(f"  Training: {len(train_images)} images")
    print(f"    - Real: {sum(1 for l in train_labels if l == 0)}")
    print(f"    - Fake: {sum(1 for l in train_labels if l == 1)}")
    print(f"  Validation: {len(val_images)} images")
    print(f"    - Real: {sum(1 for l in val_labels if l == 0)}")
    print(f"    - Fake: {sum(1 for l in val_labels if l == 1)}")
    
    return train_images, train_labels, val_images, val_labels

# Load data
train_images, train_labels, val_images, val_labels = load_real_fake_data(DATASET_ROOT)

## Step 7: Define Transforms

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✓ Transforms defined")

## Step 8: Create Datasets and DataLoaders

In [None]:
train_dataset = RealFakeDataset(
    train_images, train_labels,
    transform=train_transform,
    augment=True,
    blur_prob=BLUR_PROB,
    jpeg_prob=JPEG_PROB
)

val_dataset = RealFakeDataset(
    val_images, val_labels,
    transform=val_transform,
    augment=False
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"\n✓ DataLoaders:")
print(f"  - Training batches: {len(train_loader)}")
print(f"  - Validation batches: {len(val_loader)}")

## Step 9: Model Architecture

In [None]:
class FakeImageDetector(nn.Module):
    def __init__(self, pretrained=True):
        super(FakeImageDetector, self).__init__()
        self.resnet = models.resnet50(pretrained=pretrained)
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, 1)
    
    def forward(self, x):
        return self.resnet(x)

model = FakeImageDetector(pretrained=True).to(device)
print(f"✓ Model: {sum(p.numel() for p in model.parameters()):,} parameters")

## Step 10: Training Setup

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=3
)

history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'val_ap': []
}

print("✓ Training setup complete")

## Step 11: Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.float().to(device)
        
        optimizer.zero_grad()
        outputs = model(images).view(-1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': loss.item()})
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    return epoch_loss, epoch_acc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation'):
            images = images.to(device)
            labels = labels.float().to(device)
            
            outputs = model(images).view(-1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_ap = average_precision_score(all_labels, all_probs)
    
    return epoch_loss, epoch_acc, epoch_ap, all_probs, all_labels

print("✓ Training functions defined")

## Step 12: Train the Model

In [None]:
best_val_acc = 0.0
patience_counter = 0

print("="*70)
print("Starting Training: Real vs Fake Detection")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 70)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_ap, _, _ = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_acc)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['val_ap'].append(val_ap)
    
    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val AP: {val_ap:.4f}")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_ap': val_ap
        }, 'best_model_real_vs_fake.pth')
        print(f"  >>> Best model saved! (Val Acc: {val_acc:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping after {epoch + 1} epochs")
        break

print("\n" + "="*70)
print("Training Complete!")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print("="*70)

## Steps 13-16: Evaluation, Visualization, Confusion Matrix

*Use the same evaluation code from previous notebooks*

**Expected Results:**
- Accuracy: **80-95%** (real images vs GAN-generated)
- Much better than synthetic approach
- Proper real vs fake detection