In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image

# GAN 的生成器
class Generator(nn.Module):
    def __init__(self, latent_dim=128):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256 * 4 * 4),
            nn.ReLU(),
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# GAN 的判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型
latent_dim = 128
generator = Generator(latent_dim)
discriminator = Discriminator()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# 训练 GAN
epochs = 10
for epoch in range(epochs):
    for real_imgs, _ in train_loader:
        # 训练判别器
        valid = torch.ones(real_imgs.size(0), 1)
        fake = torch.zeros(real_imgs.size(0), 1)

        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), valid)
        z = torch.randn(real_imgs.size(0), latent_dim)
        fake_imgs = generator(z)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

# 保存生成图像
with torch.no_grad():
    z = torch.randn(64, latent_dim)
    samples = generator(z).cpu()
    save_image(samples, 'gan_samples.png', nrow=8)


In [None]:
# 在上述 GAN 的 Generator 和 Discriminator 模型中：

# 将类别标签（label）转换为嵌入向量，拼接到输入噪声向量（z）或图像中。
# 在判别器中同样处理标签，确保输入图像和标签共同参与判别。
# 在 Generator 的输入中添加：


self.label_embedding = nn.Embedding(10, latent_dim)
def forward(self, z, labels):
    label_embeds = self.label_embedding(labels)
    input = torch.cat([z, label_embeds], dim=1)
    return self.model(input)