In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 定义生成对抗网络（GAN）的生成器
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=128, img_size=28, num_channels=1, embed_dim=128):
        super(ConditionalGenerator, self).__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(10, embed_dim)
        # self.embedding = nn.Linear(1, embed_dim)

        self.main = nn.Sequential(
            # 输入: 噪声 Z，维度为 [B, latent_dim, 1, 1] (经过全连接层后)
            # 经过全连接层将噪声向量转换为适合反卷积的维度
            nn.Linear(latent_dim + embed_dim, 256 * (img_size // 4) * (img_size // 4)),
            nn.BatchNorm1d(256 * (img_size // 4) * (img_size // 4)),
            nn.ReLU(True),

            # 将一维向量 reshape 成特征图
            nn.Unflatten(1, (256, img_size // 4, img_size // 4)),  # 256 是通道数，img_size//4 是 H 和 W

            # 反卷积层 1: 256x7x7 -> 128x14x14
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 反卷积层 2: 128x14x14 -> 64x28x28
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 反卷积层 3: 64x28x28 -> 1x28x28 (输出图像)
            nn.ConvTranspose2d(64, num_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh()  # 将像素值缩放到 [-1, 1]
        )

    def forward(self, input, origin_y_label):
        # input 是噪声向量，origin_y_label 是标签
        # y_embed = self.y_embed(origin_y_label.unsqueeze(1)).view(input.size(0), -1)
        # 将标签嵌入到噪声向量中
        y = self.embedding(origin_y_label).squeeze(1)
        input = torch.cat([input, y], dim=1)
        # input = input + y

        # 首先通过全连接层将噪声维度扩展
        x = self.main[0](input)
        x = self.main[1](x)
        x = self.main[2](x)
        # 然后进行 Unflatten
        x = self.main[3](x)
        # 接着进行反卷积
        x = self.main[4:](x)
        return x, origin_y_label


class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_maps_d=64, embed_dim=128):
        super(Discriminator, self).__init__()
        # self.embedding = nn.Linear(1, embed_dim)
        self.embedding = nn.Embedding(10, embed_dim)  # 假设标签是0-9的整数
        self.back_ffn = nn.Linear(embed_dim, 28 * 28)

        self.main = nn.Sequential(
            # 输入图像 (batch_size, img_channels, 28, 28)
            nn.Conv2d(img_channels, feature_maps_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (batch_size, feature_maps_d, 14, 14)

            nn.Conv2d(feature_maps_d, feature_maps_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (batch_size, feature_maps_d*2, 7, 7)

            nn.Conv2d(feature_maps_d * 2, feature_maps_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (batch_size, feature_maps_d*4, 3, 3) (或者 4x4 取决于具体计算)

            # 最终的卷积层输出到1
            nn.Conv2d(feature_maps_d * 4, feature_maps_d, 3, 1, 0, bias=False),  # 调整核大小以适应最终尺寸
        )
        self.fc = nn.Linear(feature_maps_d + embed_dim, 1)  # 假设最后的特征图大小为 (feature_maps_d, 3, 3)
        self.sigmoid = nn.Sigmoid()  # Sigmoid 激活函数用于输出概率

    def forward(self, input, y):
        # y 是 正确的label标签，可以用多种方法来将 y 填入图像中，这里采用embed思想，将y扔进去
        B = input.size(0)
        y = self.embedding(y).squeeze(1)  # 将标签嵌入为 (batch_size, embed_dim)
        input = input.reshape(B, 1, 28, 28)
        input = self.main(input)
        input = input.view(B, -1)
        input = torch.cat([input, y], dim=1)
        input = self.fc(input)
        input = self.sigmoid(input)
        return input


def train_gan(G, D, dataloader, epochs, device,
              lr_D=2e-4, lr_G=2e-4,
              beta1=0.5,
              noise_dim=128):
    criterion = nn.BCELoss()  # 二元交叉熵损失

    optimizer_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999))
    optimizer_G = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999))

    print("开始GAN训练...")
    for epoch in range(epochs):
        D_loss_total = 0
        G_loss_total = 0
        loader_len = len(dataloader)

        for i, data in tqdm(enumerate(dataloader), desc=f"Epoch {epoch + 1}/{epochs}", total=loader_len):
            real_images, y_label = data
            B = real_images.size(0)
            real_images = real_images.view(-1, 784).to(device)  # 假设28x28图像展平为784
            y_label = y_label.view(B, -1).to(device)

            # --- 训练判别器 (D) ---
            optimizer_D.zero_grad()

            # 使用真实图像进行训练
            real_outputs = D(real_images, y_label)
            # 对真实图像使用软标签，防止D过于自信
            real_labels = torch.full((B,), 0.9, device=device)  # 真实图像的软标签
            errD_real = criterion(real_outputs.view(-1), real_labels)
            errD_real.backward()

            # 使用假图像进行训练
            noise = torch.randn(B, noise_dim, device=device)
            # 生成假图像的标签
            fake_image_label = torch.randint(0, 10, (B,), device=device)  # 随机生成标签
            fake_image_label = fake_image_label.view(B, -1)
            fake_images, fake_image_label = G(noise, fake_image_label)
            # 在D的训练中，G的输出需要 .detach()，这样G的梯度不会在这里计算
            fake_output = D(fake_images.detach(), fake_image_label.detach())
            # 对D的假图像也使用软标签，但值更低
            fake_labels = torch.full((B,), 0.1, device=device)  # 假图像的软标签
            errD_fake = criterion(fake_output.view(-1), fake_labels)
            errD_fake.backward()

            errD = errD_real + errD_fake
            optimizer_D.step()
            D_loss_total += errD.item()

            # --- 训练生成器 (G) ---
            optimizer_G.zero_grad()
            # 生成新的假图像，避免使用D训练时旧的图像
            noise = torch.randn(B, noise_dim, device=device)
            fake_image_label = torch.randint(0, 10, (B,), device=device)  # 随机生成标签
            fake_image_label = fake_image_label.view(B, -1)
            fake_images, fake_image_label = G(noise, fake_image_label)
            # D对G生成的图像的输出。这里至关重要，不要使用 detach()。
            output = D(fake_images, fake_image_label)
            # G希望D将假图像分类为真实（目标标签为1.0或软标签）
            # 这里我们使用完整的1.0作为G的目标，以强烈鼓励它欺骗D
            errG = criterion(output.view(-1), torch.full((B,), 1.0, device=device))
            errG.backward()
            optimizer_G.step()
            G_loss_total += errG.item()

        # --- Epoch 结束总结 ---
        d_loss_avg = D_loss_total / loader_len
        # 根据G的更新次数调整平均损失
        g_loss_avg = G_loss_total / loader_len

        print(f"Epoch [{epoch + 1}/{epochs}] "
              f"D Loss: {d_loss_avg:.4f}, G Loss: {g_loss_avg:.4f}")

In [None]:
from torchvision import transforms

batch_size = 256
epochs = 50

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化
])

# 下载 MNIST 数据集
mnist_train = MNIST(root='../../dataset_file/mnist_raw', train=True, download=False, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

# 初始化生成器和判别器
G = ConditionalGenerator().to(device)
D = Discriminator().to(device)

# 训练GAN
train_gan(G, D, dataloader, epochs=epochs, device=device)

In [None]:
#检测生成质量
import matplotlib.pyplot as plt


def generate_and_plot(G, num_images=16):
    # 生成噪声向量
    noise = torch.randn(num_images, 128, device=device)
    # 随机生成标签            
    labels = torch.randint(0, 10, (num_images,), device=device)
    labels = labels.view(num_images, -1)  # 将标签调整为 (num_images, 1) 的形状
    # 生成图像
    with torch.no_grad():
        generated_images, _ = G(noise, labels)
    generated_images = generated_images.view(num_images, 28, 28).cpu().numpy()
    # 绘制生成的图像
    # 图像下方显示标签
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_images[i], cmap='gray')
        ax.axis('off')
        ax.set_title(f'Label: {int(labels[i].item())}')
    plt.tight_layout()
    plt.show()


generate_and_plot(G, num_images=16)

In [None]:
#保存两个模型
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')