In [None]:
import os
import random
import numpy as np
from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset, Subset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50, ResNet50_Weights
import warnings

warnings.filterwarnings("ignore") 

# ----------------------------
# Setup data roots and randomness
# ----------------------------
train_dir = "root1"
test_dir  = "root2"

device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# ----------------------------
# 1) Apply downsampling to balance data
# ----------------------------
def balance_dataset(folder):
    class_folders = [f for f in os.listdir(folder) if os.path.isdir(os.path.join(folder, f))]
    
    class_files = {}
    for cls in class_folders:
        path = os.path.join(folder, cls)
        files = [os.path.join(path, f) for f in os.listdir(path)
                 if f.lower().endswith((".jpg", ".png", ".jpeg", ".bmp", ".webp"))]
        class_files[cls] = files

    min_count = min(len(files) for files in class_files.values())

    balanced_files = []
    balanced_labels = []
    for idx, cls in enumerate(class_folders):
        sampled = random.sample(class_files[cls], min_count)
        balanced_files.extend(sampled)
        balanced_labels.extend([idx] * len(sampled))
    
    return balanced_files, balanced_labels, class_folders


# ----------------------------
# 2) Balance train / test data
# ----------------------------
train_files, train_labels, class_names = balance_dataset(train_dir)
test_files, test_labels, _ = balance_dataset(test_dir)

print("Number of sample for each class of train dataset:", {class_names[i]: train_labels.count(i) for i in range(len(class_names))})
print("Number of sample for each class of test dataset:", {class_names[i]: test_labels.count(i) for i in range(len(class_names))})

# ----------------------------
# 3) Dataset class
# ----------------------------
class CustomDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx])

        if img.mode in ("P", "RGBA"):
            img = img.convert("RGB")
        else:
            img = img.convert("RGB")

        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


# ----------------------------
# 4) Transform
# ----------------------------
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ----------------------------
# 5) Dataset & Dataloader
# ----------------------------
train_dataset = CustomDataset(train_files, train_labels, transform=train_transform)
test_dataset  = CustomDataset(test_files, test_labels, transform=test_transform)

# ---- train / val (80:20) ---- #
total_train = len(train_dataset)
val_size = int(total_train * 0.2)
train_size = total_train - val_size

train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_subset,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("Balanced Train size:", len(train_subset))
print("Balanced Val size:", len(val_subset))
print("Balanced Test size:", len(test_dataset))


# ----------------------------
# 6) Model
# ----------------------------
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

# ----------------------------
# 7) Optimizer & Scheduler
# ----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# ----------------------------
# 8) Training / Validation function
# ----------------------------
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = torch.argmax(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# ----------------------------
# 9) Training Loop
# ----------------------------
for epoch in range(20):
    loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_acc = evaluate(model, val_loader)
    scheduler.step()

    print(f"Epoch {epoch+1}/20, Loss: {loss:.4f}, Val Acc: {val_acc:.2f}%")

# ----------------------------
# 10) Test Accuracy
# ----------------------------
test_acc = evaluate(model, test_loader)
print("Test Accuracy:", test_acc)


Number of sample for each class of train dataset: {'fake': 153, 'real': 153}
Number of sample for each class of test dataset: {'fake': 110, 'real': 110}
Balanced Train size: 245
Balanced Val size: 61
Balanced Test size: 220
Epoch 1/20, Loss: 0.6318, Val Acc: 78.69%
Epoch 2/20, Loss: 0.3797, Val Acc: 80.33%
Epoch 3/20, Loss: 0.2615, Val Acc: 83.61%
Epoch 4/20, Loss: 0.1706, Val Acc: 81.97%
Epoch 5/20, Loss: 0.1499, Val Acc: 83.61%
Epoch 6/20, Loss: 0.1267, Val Acc: 86.89%
Epoch 7/20, Loss: 0.0938, Val Acc: 85.25%
Epoch 8/20, Loss: 0.0974, Val Acc: 77.05%
Epoch 9/20, Loss: 0.1292, Val Acc: 83.61%
Epoch 10/20, Loss: 0.1036, Val Acc: 78.69%
Epoch 11/20, Loss: 0.0768, Val Acc: 83.61%
Epoch 12/20, Loss: 0.0785, Val Acc: 83.61%
Epoch 13/20, Loss: 0.0913, Val Acc: 75.41%
Epoch 14/20, Loss: 0.1129, Val Acc: 80.33%
Epoch 15/20, Loss: 0.1092, Val Acc: 86.89%
Epoch 16/20, Loss: 0.0976, Val Acc: 88.52%
Epoch 17/20, Loss: 0.0855, Val Acc: 86.89%
Epoch 18/20, Loss: 0.0641, Val Acc: 83.61%
Epoch 19/20