In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
import timm
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
import numpy as np
from tqdm import tqdm
import os

In [None]:
import random
from pathlib import Path
import shutil

def split_train_to_val(source_dir, val_ratio=0.2, seed=42):
    source_dir = Path(source_dir)
    val_dir = source_dir.parent / "val"

    random.seed(seed)

    # Define class subdirectories
    classes = ["0_real", "1_fake"]

    for class_name in classes:
        train_class_dir = source_dir / class_name
        val_class_dir = val_dir / class_name

        if not train_class_dir.exists():
            print(f"Warning: {train_class_dir} does not exist. Skipping.")
            continue

        # Create validation directory
        val_class_dir.mkdir(parents=True, exist_ok=True)

        # Get all image files
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
        files = [f for f in train_class_dir.iterdir()
                if f.suffix.lower() in image_extensions and f.is_file()]

        if len(files) == 0:
            print(f"No images found in {train_class_dir}")
            continue

        # Calculate number to move
        num_to_move = max(1, int(len(files) * val_ratio))  # at least 1 image
        print(f"Moving {num_to_move}/{len(files)} images from {class_name} to validation")

        # Randomly select files
        files_to_move = random.sample(files, num_to_move)

        # Move them
        for file_path in files_to_move:
            dest_path = val_class_dir / file_path.name
            shutil.move(str(file_path), str(dest_path))
            # print(f"Moved: {file_path.name} → {dest_path}")

    print(f"\nDone! Validation set created at: {val_dir}")

In [None]:
split_train_to_val('data/train', val_ratio=0.25, seed=42)

In [None]:
from torchvision.transforms import v2 as T  # ← v2 namespace

train_transform = T.Compose([
    T.Resize((384, 384)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.ToImage(),                   # ← replaces old ToTensor() + handles PIL→Tensor
    T.ToDtype(torch.float32, scale=True),  # ← replaces ToTensor()'s /255
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
    T.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),  # ← now works!
])

val_transform = T.Compose([
    T.Resize((384, 384)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
])

In [None]:
train_dataset = ImageFolder('data/train', transform=train_transform)
val_dataset   = ImageFolder('data/val',   transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, 
                          num_workers=8, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False, 
                          num_workers=8, pin_memory=True)

num_classes = 2

In [None]:
import timm

# Option 1: Base fcmae pretrained (default, works on 224x224 input; fastest to load)
# model = timm.create_model('convnextv2_huge.fcmae', pretrained=True, num_classes=2)

# Option 2: fcmae pretrain + in22k fine-tune + in1k fine-tune at 512x512 (SOTA for high-res; matches your 384x384 setup well)
# Requires timm >=0.9.2; uses larger input for better artifact detection
model = timm.create_model('convnextv2_huge.fcmae_ft_in22k_in1k_512', pretrained=True, num_classes=2)

# Option 3: Simpler in1k-pretrained (if fcmae tags fail; still strong baseline)
# model = timm.create_model('convnextv2_huge', pretrained=True, num_classes=2)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.5]).to(device))  # slight weight on fakes if needed
# OR use Label Smoothing
# criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.05)

# Cosine annealing with warmup
num_epochs = 30
total_steps = len(train_loader) * num_epochs
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=1e-4,
                                          total_steps=total_steps,
                                          pct_start=0.1,
                                          anneal_strategy='cos')

In [None]:
best_auc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        running_loss += loss.item()
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = torch.argmax(outputs, dim=1)
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    auc = roc_auc_score(all_labels, all_probs)
    acc = accuracy_score(all_labels, all_preds)
    
    print(f"Epoch {epoch+1} - Val AUC: {auc:.5f} - Val Acc: {acc:.5f}")
    
    # Save best model
    if auc > best_auc:
        best_auc = auc
        torch.save(model.state_dict(), 'best_aigc_detector.pth')
        print(f"New best model saved! AUC: {auc:.5f}")

In [None]:
print(classification_report(all_labels, all_preds, target_names=['Real', 'Fake']))
print(f"Final Best Validation AUC: {best_auc:.5f}")