In [40]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import f1_score
import numpy as np
import os


class PairedDataset(Dataset):
    def __init__(self, real_dir, fake_dir, transform=None):
        self.real_paths = sorted([os.path.join(real_dir, fname) for fname in os.listdir(real_dir) 
                                  if os.path.isfile(os.path.join(real_dir, fname))])
        self.fake_paths = sorted([os.path.join(fake_dir, fname) for fname in os.listdir(fake_dir) 
                                  if os.path.isfile(os.path.join(fake_dir, fname))])
        self.transform = transform

        assert len(self.real_paths) == len(self.fake_paths)

    def __len__(self):
        return len(self.real_paths) + len(self.fake_paths)

    def __getitem__(self, idx):
        if idx < len(self.real_paths):
            image = Image.open(self.real_paths[idx]).convert('RGB')
            label = torch.tensor(1.)
        else:
            image = Image.open(self.fake_paths[idx - len(self.real_paths)]).convert('RGB')
            label = torch.tensor(0.)

        if self.transform:
            image = self.transform(image)

        return image, label


class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
            nn.Dropout(0.2),  # 添加 Dropout

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(256),
            nn.Dropout(0.3),  # 添加 Dropout

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(512),

            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(z_dim, 512 * 7 * 7),
            nn.Unflatten(1, (512, 7, 7)),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(256),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

In [54]:
# 较为简单的模型结构，降低模型复杂度
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),  # 添加 Dropout

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(64),
            nn.Dropout(0.3),  # 添加 Dropout

            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(z_dim, 128 * 7 * 7),
            nn.Unflatten(1, (128, 7, 7)),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

In [68]:
from torchvision import transforms

# 增加数据增强
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomRotation(40),  # 在[-30, 30]范围内随机旋转
    transforms.RandomAffine(degrees=0, shear=10, scale=(0.8,1.2)),  # 随机仿射变换
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 色彩抖动
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 正规化
    transforms.RandomErasing(),  # 随机擦除
])

valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

batch_size = 256
train_dataset = PairedDataset("./data/dataset/train/real/", "./data/dataset/train/fake/", transform=train_transform)
valid_dataset = PairedDataset("./data/dataset/valid/real/", "./data/dataset/valid/fake/", transform=valid_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from tqdm import tqdm
# 初始化生成器和判别器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
z_dim = 100
generator = Generator(z_dim).to(device)
classifier = Classifier().to(device)

# 初始化优化器和损失函数
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, weight_decay=1e-5)
cls_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.0002, weight_decay=1e-5)
loss_fn = nn.BCEWithLogitsLoss()

# 用于保存每个批次的 loss 值
losses = []

# 训练循环
num_epochs = 5
all_gen_loss = []
all_cls_loss = []
for epoch in range(num_epochs):
    # 切换到训练模式
    generator.train()
    classifier.train()

    progress_bar = tqdm(train_dataloader, desc='Epoch {:03d}'.format(epoch + 1), leave=False, disable=False)
    for real_images, real_labels in progress_bar:
        real_labels = real_labels.view(-1, 1).to(device)
        real_images = real_images.to(device)

        # 训练判别器
        cls_optimizer.zero_grad()
        real_outputs = classifier(real_images)
        cls_loss = loss_fn(real_outputs, real_labels)
        all_cls_loss.append(cls_loss)   # <============= 添加cls_loss
        cls_loss.backward()
        cls_optimizer.step()

        # 训练生成器s
        gen_optimizer.zero_grad()
        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = generator(z)
        outputs = classifier(fake_images)
        gen_loss = loss_fn(outputs, real_labels)
        all_gen_loss.append(gen_loss)   # <============= 添加gen_loss
        gen_loss.backward()
        gen_optimizer.step()

        # 更新进度条
        progress_bar.set_postfix({'gen_loss': '{:.6f}'.format(gen_loss.item()), 'cls_loss': '{:.6f}'.format(cls_loss.item())})

        losses.append((cls_loss.item(), gen_loss.item()))

    progress_bar.close()

    # 切换到评估模式
    generator.eval()
    classifier.eval()

    with torch.no_grad():
        avg_f1 = 0.
        num_batches = 0
        for real_images, real_labels in valid_dataloader:
            real_labels = real_labels.view(-1, 1).to(device)
            real_images = real_images.to(device)

            with torch.no_grad():
                # 评估生成器
                z = torch.randn(batch_size, z_dim).to(device)
                fake_images = generator(z)
                outputs = classifier(fake_images)
                predicted = torch.round(torch.sigmoid(outputs)).cpu().numpy()
                actual = np.ones(batch_size)
                f1 = f1_score(actual, predicted)
                avg_f1 += f1
                num_batches += 1

        avg_f1 /= num_batches
        print(f'Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss.item()}, Classifier Loss: {cls_loss.item()}, Average F1 Score: {avg_f1}')
        # print(f'Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss.item()}, Classifier Loss: {cls_loss.item()}, F1 Score: {f1}')

                                                                                                

Epoch [1/5], Generator Loss: 0.7097705006599426, Classifier Loss: 0.5658202171325684, Average F1 Score: 0.5337256880828577


                                                                                                

Epoch [2/5], Generator Loss: 0.7038366794586182, Classifier Loss: 0.47440698742866516, Average F1 Score: 0.9035801475928896


                                                                                                

Epoch [3/5], Generator Loss: 0.6951031684875488, Classifier Loss: 0.5155269503593445, Average F1 Score: 0.9555080908636563


                                                                                                

Epoch [4/5], Generator Loss: 0.6942633390426636, Classifier Loss: 0.4731321930885315, Average F1 Score: 0.9767274571632736


                                                                                                

In [63]:
torch.save(classifier, 'classifier.pth')
torch.save(generator, 'generator.pth')


In [65]:
from datetime import datetime
import csv
import os

# 加载模型
model_path = "./classifier.pth"
classifier = torch.load(model_path)

# 确保模型在正确的设备上
classifier = classifier.to(device)

# 我们首先定义一个函数，该函数接受图片路径作为输入，并返回模型的预测结果
def predict(image_path, transform, classifier, device):
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)
    classifier.eval()
    with torch.no_grad():
        outputs = classifier(image)
        predicted = torch.round(torch.sigmoid(outputs))
    return predicted.item()

import os

def save_predictions_to_csv(folder_path, csv_file_path, transform, classifier, device):
    with open(csv_file_path, 'w', newline='') as csvfile:
        fieldnames = ['file_path', 'prediction', 'prediction_time']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        # 列出文件夹中的所有文件
        image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

        for image_path in image_paths:
            prediction = predict(image_path, transform, classifier, device)
            prediction_time = datetime.now().isoformat()
            writer.writerow({'file_path': image_path, 'prediction': prediction, 'prediction_time': prediction_time})

folder_path = "./data/dataset/test/real/"  # 图片文件夹路径
csv_file_path = "./predictions.csv"  # CSV 文件的路径
save_predictions_to_csv(folder_path, csv_file_path, transform, classifier, device)