In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from torchvision.models import inception_v3, Inception_V3_Weights
import torch.nn.functional as F
from scipy import linalg
import matplotlib.pyplot as plt
from torch.nn.utils import spectral_norm

In [2]:
# 1. 数据加载和预处理
def load_data():
    
    # 数据预处理：归一化到[-1, 1]区间
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # 加载MNIST数据集
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    
    return train_loader, val_loader

In [3]:
# 2. 定义生成器网络
class Generator(nn.Module):
    
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        
        #定义模型各连接层结构
        self.model = nn.Sequential(
            
            # 输入层：潜在空间维度
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(256),
            
            # 隐藏层1
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            # 隐藏层2
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            # 输出层：生成784维(28x28)的图像
            nn.Linear(1024, 784),
            nn.Tanh()  # 输出范围归一化到[-1, 1]
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [4]:
# 3. 定义判别器网络
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # 输入层：784维(28x28)的图像
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            
            # 隐藏层1
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            
            # 输出层：二分类(真/假)
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出概率值
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [5]:
# 4. 计算FID (Fréchet Inception Distance)指标
def calculate_fid(real_images, generated_images, inception_model, device, batch_size=64):
    
    # 转换为3通道图像以适应Inception模型
    real_images = convert_to_3_channels(real_images)
    generated_images = convert_to_3_channels(generated_images)
    
    # 将图像移动到与模型相同的设备
    real_images = real_images.to(device)
    generated_images = generated_images.to(device)
    
    # 提取真实图像和生成图像的特征
    real_features = extract_features(real_images, inception_model, device, batch_size)
    gen_features = extract_features(generated_images, inception_model, device, batch_size)
    
    # 计算均值和协方差
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
    
    # 计算FID
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid

def convert_to_3_channels(images):
    """将1通道图像转换为3通道图像"""
    # 检查图像是否已经是3通道
    if images.shape[1] == 3:
        return images
    
    # 将1通道复制为3通道
    return images.repeat(1, 3, 1, 1)

def extract_features(images, inception_model, device, batch_size):
    # 提取图像特征的实现
    features = []
    inception_model.eval()
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            
            # 调整图像大小为Inception模型的输入尺寸
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            
            # 将批次数据移动到与模型相同的设备
            batch = batch.to(device)
            
            batch_features = inception_model(batch)
            features.append(batch_features.cpu().numpy()) # 转回CPU以存储为numpy数组
    return np.concatenate(features)

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    # 计算Fréchet距离
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    
    diff = mu1 - mu2
    
    # 计算协方差矩阵的平方根
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    # 检查数值问题
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
    return fid

In [6]:
# 5. 训练GAN模型
def train_gan(epochs = 100):
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 加载数据
    train_loader, val_loader = load_data()
    
    # 初始化模型
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    
    # 定义损失函数和优化器
    adversarial_loss = nn.BCELoss().to(device)
    
    lr = 0.0002
    betas = (0.5, 0.95)
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=betas)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)
    
    # 加载Inception模型用于FID计算
    inception_model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
    inception_model.fc = nn.Identity()  # 移除最后的全连接层
    inception_model = inception_model.to(device)
    inception_model.eval()
    
    # 训练参数
    latent_dim = 100
    sample_interval = 10
    
    # 训练损失记录
    g_losses = []
    d_losses = []
    fid_scores = []
    
    # 训练循环
    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(train_loader):
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            
            # 真实样本标签为1，生成样本标签为0
            valid = torch.ones(batch_size, 1).to(device)
            fake = torch.zeros(batch_size, 1).to(device)
            
            # ---------------------
            #  训练生成器
            # ---------------------
            optimizer_G.zero_grad()
            
            # 生成随机噪声
            z = torch.randn(batch_size, latent_dim).to(device)
            
            # 生成假图像
            gen_imgs = generator(z)
            
            # 生成器损失：希望判别器将生成的图像判断为真实的
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            
            g_loss.backward()
            optimizer_G.step()
            
            # ---------------------
            #  训练判别器
            # ---------------------
            optimizer_D.zero_grad()
            
            # 判别器损失：正确分类真实图像和生成图像
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            
            d_loss.backward()
            optimizer_D.step()
            
            # 记录损失
            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())
            
            # 打印训练进度
            if i % 100 == 0:
                print(
                    f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
                    f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                )
        
        # 每个epoch结束后计算FID
        if (epoch + 1) % sample_interval == 0:
            # 生成5k个样本
            generated_samples = generate_samples(generator, 5000, latent_dim, device)
            
            # 获取验证集样本
            val_samples = get_val_samples(val_loader, 5000, device)
            
            # 计算FID
            fid = calculate_fid(val_samples, generated_samples, inception_model, device)
            fid_scores.append(fid)
            print(f"[Epoch {epoch}/{epochs}] [FID score: {fid:.4f}]")
            
            # 保存生成的样本
            save_samples(generated_samples, f"./结果汇总/samples_epoch_{epoch}.png")
    
    # 训练结束后生成5k个样本并保存
    final_samples = generate_samples(generator, 5000, latent_dim, device)
    save_samples(final_samples, "./结果汇总/final_generated_samples.png")
    
    # 计算最终FID分数
    val_samples = get_val_samples(val_loader, 5000, device)
    final_fid = calculate_fid(val_samples, final_samples, inception_model, device)
    print(f"Final FID score: {final_fid:.4f}")
    
    # 绘制损失曲线
    plot_losses(g_losses, d_losses, fid_scores, final_fid)
    
    # 保存模型
    torch.save(generator.state_dict(), "./结果汇总/generator_model.pth")
    torch.save(discriminator.state_dict(), "./结果汇总/discriminator_model.pth")

In [7]:
# 辅助函数
def generate_samples(generator, num_samples, latent_dim, device):
    """生成指定数量的样本"""
    z = torch.randn(num_samples, latent_dim).to(device)
    with torch.no_grad():
        samples = generator(z)
    return samples.cpu() # 转回CPU以便保存和处理

def get_val_samples(val_loader, num_samples, device):
    """从验证集中获取指定数量的样本"""
    samples = []
    for batch, _ in val_loader:
        samples.append(batch)
        if len(samples) * batch.size(0) >= num_samples:
            break
    samples = torch.cat(samples)[:num_samples].to(device)
    return samples

def save_samples(samples, filename):
    """保存生成的样本图像"""
    # 实现样本图像保存功能
    from torchvision.utils import save_image
    # 确保像素值在[0,1]范围内
    samples = (samples + 1) / 2
    save_image(samples, filename, nrow=25, normalize=True)

def plot_losses(g_losses, d_losses, fid_scores, final_fid):
    """绘制损失曲线和FID分数"""
    plt.figure(figsize=(9, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(g_losses, label='Generator Loss')
    plt.plot(d_losses, label='Discriminator Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(range(0, len(fid_scores) * 10, 10), fid_scores)
    plt.axhline(y=final_fid, color='r', linestyle='--', linewidth=1.5, alpha=0.7)
    plt.axhline(y=min(fid_scores), color='r', linestyle='--', linewidth=1.5, alpha=0.7)
    
    # 标注水平虚线的值
    plt.text(x=0, y=final_fid, s=f'{round(final_fid,2)}', fontsize=8, color='red')
    plt.text(x=0, y=min(fid_scores), s=f'{round(min(fid_scores),2)}', fontsize=8, color='red')
    
    plt.xlabel('Epochs')
    plt.ylabel('FID Score')
    
    plt.tight_layout()
    plt.savefig('./结果汇总/training_curves.png')
    plt.show()

In [None]:
# 主函数
if __name__ == "__main__":
    train_gan(epochs=500)

使用设备: cuda
[Epoch 0/500] [Batch 0/469] [D loss: 0.6924] [G loss: 0.7131]
[Epoch 0/500] [Batch 100/469] [D loss: 0.0857] [G loss: 2.1900]
[Epoch 0/500] [Batch 200/469] [D loss: 0.1326] [G loss: 2.5792]
[Epoch 0/500] [Batch 300/469] [D loss: 0.2026] [G loss: 1.3799]
[Epoch 0/500] [Batch 400/469] [D loss: 0.5106] [G loss: 1.7811]
[Epoch 1/500] [Batch 0/469] [D loss: 0.7467] [G loss: 0.3512]
[Epoch 1/500] [Batch 100/469] [D loss: 0.9000] [G loss: 0.2552]
[Epoch 1/500] [Batch 200/469] [D loss: 0.8391] [G loss: 0.3193]
[Epoch 1/500] [Batch 300/469] [D loss: 0.8416] [G loss: 0.2853]
[Epoch 1/500] [Batch 400/469] [D loss: 0.8382] [G loss: 0.3125]
[Epoch 2/500] [Batch 0/469] [D loss: 0.8025] [G loss: 0.3364]
[Epoch 2/500] [Batch 100/469] [D loss: 0.7783] [G loss: 0.3759]
[Epoch 2/500] [Batch 200/469] [D loss: 0.6800] [G loss: 0.5777]
[Epoch 2/500] [Batch 300/469] [D loss: 0.7054] [G loss: 0.8172]
[Epoch 2/500] [Batch 400/469] [D loss: 0.6854] [G loss: 0.6591]
[Epoch 3/500] [Batch 0/469] [D loss