In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# SELECT ONE: "VAE", "VQGAN", "WGAN"
selected_method = "VAE"

# Paths
real_dir = "../celebA/celeba/real_10000/img_align_celeba"
fake_dirs = {
    "VAE": "../vae_outputs/generated"
    # "VQGAN": "../vqgan_outputs/generated",
    # "WGAN": "../gan_outputs/generated"
}

class RealFakeDataset(Dataset):
    def __init__(self, real_dir, fake_dir, image_size=(128, 128)):
        self.real_imgs = sorted(os.listdir(real_dir))[:10000]
        self.fake_imgs = sorted(os.listdir(fake_dir))[:10000]
        self.real_dir = real_dir
        self.fake_dir = fake_dir
        self.image_size = image_size

    def __len__(self):
        return len(self.real_imgs) + len(self.fake_imgs)

    def __getitem__(self, idx):
        if idx < len(self.real_imgs):
            img_path = os.path.join(self.real_dir, self.real_imgs[idx])
            label = 1
        else:
            img_path = os.path.join(self.fake_dir, self.fake_imgs[idx - len(self.real_imgs)])
            label = 0
        img = Image.open(img_path).convert("RGB")
        img = img.resize(self.image_size)
        img = transforms.ToTensor()(img)
        return img, label

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x)

def train_classifier(dataloader, device):
    model = SimpleCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(5):
        total, correct = 0, 0
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
            logits = model(imgs)
            loss = loss_fn(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = (torch.sigmoid(logits) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1} | Accuracy: {correct / total:.4f}")

if __name__ == "__main__":
    dataset = RealFakeDataset(real_dir, fake_dirs[selected_method], image_size=(128, 128))
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nTraining classifier for: {selected_method}")
    train_classifier(loader, device)



Training classifier for: VAE
