条件GAN（Conditional Generative Adversarial Networks，cGAN）是GAN的一种变体，它在生成器和判别器的输入中增加了额外的条件信息，以指导生成器生成特定类别的图像。在具体实现时，可以将条件信息作为输入数据的一部分，与随机噪声级联起来作为生成器和判别器的输入。

以下是实现条件GAN的一般步骤：

定义生成器（Generator）和判别器（Discriminator）的网络结构。在这里，需要修改网络结构，使其能够接受条件信息作为输入，并将其与随机噪声级联起来。生成器的输出将是生成的图像，而判别器的输出将是对图像的真实性评估。

定义生成器和判别器的损失函数。对于条件GAN，生成器和判别器的损失函数通常包括GAN损失和条件信息损失。GAN损失用于衡量生成的图像与真实图像之间的差异，而条件信息损失用于确保生成的图像满足给定的条件信息。

编写训练循环。在每个训练迭代中，首先从数据集中随机采样真实图像和相应的条件信息，然后使用生成器生成假图像。接下来，将真实图像、假图像和相应的条件信息输入到判别器中，计算并更新生成器和判别器的损失函数。

训练生成器和判别器。通过反复迭代训练循环，不断更新生成器和判别器的参数，直到达到预定的训练轮数或损失收敛。

评估模型性能。在训练完成后，可以使用生成器生成图像，并通过人工评估或使用一些指标来评估生成的图像质量和生成的图像是否满足条件信息。

以下是一个简单的示例代码，演示了如何使用PyTorch实现条件GAN：

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.datasets import MNIST


def to_img(x):
    out = 0.5 * (x + 0.5)
    out = out.clamp(0, 1)  # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内：
    out = out.view(-1, 1, 28, 28)  # view()函数作用是将一个多行的Tensor,拼接成一行
    return out

# 定义生成器网络结构
class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.fc = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z, y):
        x = torch.cat((z, y), dim=1)
        x = self.fc(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.tanh(x)
        return x

# 定义判别器网络结构
class Discriminator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        x = self.fc(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

# 定义条件GAN模型
class cGAN(nn.Module):
    def __init__(self, generator, discriminator):
        super(cGAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, z, y):
        fake_images = self.generator(z, y)
        fake_outputs = self.discriminator(fake_images, y)
        return fake_images, fake_outputs

# 参数设置
input_size = 100  # 随机噪声z的维度
output_size = 784  # 图像的维度（28x28）
num_classes = 10  # 类别数量
batch_size = 64
lr = 0.0002
num_epochs = 200
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = MNIST(root='/data/mwj/data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化生成器和判别器
generator = Generator(input_size + num_classes, output_size)
discriminator = Discriminator(output_size + num_classes, 1)

# 初始化条件GAN模型
cgan = cGAN(generator, discriminator).to(device)

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)

# 训练循环
for epoch in range(num_epochs):
    for batch_idx, (real_images, labels) in enumerate(train_loader):
        real_images = real_images.to(device)
        # 生成随机噪声和条件信息
        z = torch.randn(real_images.size(0), input_size).to(device)
        y = torch.eye(num_classes)[labels].to(device)  # 将类别转换为one-hot编码
        
        # 训练判别器
        optimizer_d.zero_grad()
        fake_images, fake_outputs = cgan(z, y)
        real_outputs = discriminator(real_images.view(-1, output_size), y)
        loss_d_real = criterion(real_outputs, torch.ones_like(real_outputs))  # 真实图像的判别器损失
        loss_d_fake = criterion(fake_outputs, torch.zeros_like(fake_outputs))  # 生成图像的判别器损失
        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        optimizer_d.step()
        
        # 训练生成器
        optimizer_g.zero_grad()
        fake_images, fake_outputs = cgan(z, y)
        loss_g = criterion(fake_outputs, torch.ones_like(fake_outputs))  # 生成器损失
        loss_g.backward()
        optimizer_g.step()
        
        # 打印训练信息
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], "
                  f"Loss D: {loss_d:.4f}, Loss G: {loss_g:.4f}")
        if epoch == 0 and batch_idx==len(train_loader)-1:
            real_images = to_img(real_images.cuda().data)
            save_image(real_images, str(epoch)+'_real_images.png')
        if batch_idx==len(train_loader)-1:
            fake_images = to_img(fake_images.cuda().data)
            save_image(fake_images, str(epoch)+'_fake_images.png')


Epoch [0/200], Batch [0/938], Loss D: 0.6079, Loss G: 0.6753
Epoch [0/200], Batch [100/938], Loss D: 0.3967, Loss G: 0.7398
Epoch [0/200], Batch [200/938], Loss D: 0.1798, Loss G: 1.5792
Epoch [0/200], Batch [300/938], Loss D: 0.2201, Loss G: 1.3184
Epoch [0/200], Batch [400/938], Loss D: 0.3957, Loss G: 0.8844
Epoch [0/200], Batch [500/938], Loss D: 0.3860, Loss G: 0.9290
Epoch [0/200], Batch [600/938], Loss D: 0.4031, Loss G: 1.0419
Epoch [0/200], Batch [700/938], Loss D: 0.5204, Loss G: 0.8032
Epoch [0/200], Batch [800/938], Loss D: 0.4311, Loss G: 1.0606
Epoch [0/200], Batch [900/938], Loss D: 0.2132, Loss G: 1.8064
Epoch [1/200], Batch [0/938], Loss D: 0.4192, Loss G: 1.1212
Epoch [1/200], Batch [100/938], Loss D: 0.1690, Loss G: 1.7858
Epoch [1/200], Batch [200/938], Loss D: 0.4165, Loss G: 1.1407
Epoch [1/200], Batch [300/938], Loss D: 0.4849, Loss G: 0.9997
Epoch [1/200], Batch [400/938], Loss D: 0.4784, Loss G: 1.0441
Epoch [1/200], Batch [500/938], Loss D: 0.3432, Loss G: 1.3

In [17]:
generator.eval()
batch_s = 32
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
generator.to(device)
z = torch.randn(batch_s, input_size).to(device)
# labels = torch.ones(batch_s, 1, dtype=int)
# 创建维度为[100, 1]的张量，所有元素的值都为3，意味着生成类别index为3的样本
labels = torch.full((batch_s, 1), 6)
y = torch.eye(num_classes)[labels].squeeze().to(device)  # 将类别转换为one-hot编码
print(y)
fake_images = generator(z, y)
fake_images = to_img(fake_images)
save_image(fake_images, 'example_images.png')

tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],


In [11]:
# 训练循环
for epoch in range(num_epochs, 2*num_epochs):
    for batch_idx, (real_images, labels) in enumerate(train_loader):
        # 生成随机噪声和条件信息
        z = torch.randn(real_images.size(0), input_size)
        y = torch.eye(num_classes)[labels]  # 将类别转换为one-hot编码
        
        # 训练判别器
        optimizer_d.zero_grad()
        fake_images, fake_outputs = cgan(z, y)
        real_outputs = discriminator(real_images.view(-1, output_size), y)
        loss_d_real = criterion(real_outputs, torch.ones_like(real_outputs))  # 真实图像的判别器损失
        loss_d_fake = criterion(fake_outputs, torch.zeros_like(fake_outputs))  # 生成图像的判别器损失
        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        optimizer_d.step()
        
        # 训练生成器
        optimizer_g.zero_grad()
        fake_images, fake_outputs = cgan(z, y)
        loss_g = criterion(fake_outputs, torch.ones_like(fake_outputs))  # 生成器损失
        loss_g.backward()
        optimizer_g.step()
        
        # 打印训练信息
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{2*num_epochs}], Batch [{batch_idx}/{len(train_loader)}], "
                  f"Loss D: {loss_d:.4f}, Loss G: {loss_g:.4f}")
        if epoch == 0 and batch_idx==len(train_loader)-1:
            real_images = to_img(real_images.cuda().data)
            save_image(real_images, str(epoch)+'_real_images.png')
        if batch_idx==len(train_loader)-1:
            fake_images = to_img(fake_images.cuda().data)
            save_image(fake_images, str(epoch)+'_fake_images.png')

Epoch [200/200], Batch [0/938], Loss D: 0.7021, Loss G: 0.7621
Epoch [200/200], Batch [100/938], Loss D: 0.6511, Loss G: 0.8138
Epoch [200/200], Batch [200/938], Loss D: 0.6636, Loss G: 0.8971
Epoch [200/200], Batch [300/938], Loss D: 0.6571, Loss G: 0.9199
Epoch [200/200], Batch [400/938], Loss D: 0.6867, Loss G: 0.7655
Epoch [200/200], Batch [500/938], Loss D: 0.7100, Loss G: 0.6794
Epoch [200/200], Batch [600/938], Loss D: 0.6513, Loss G: 0.8016
Epoch [200/200], Batch [700/938], Loss D: 0.6737, Loss G: 0.8111
Epoch [200/200], Batch [800/938], Loss D: 0.7063, Loss G: 0.7915
Epoch [200/200], Batch [900/938], Loss D: 0.6177, Loss G: 0.8832
Epoch [201/200], Batch [0/938], Loss D: 0.6869, Loss G: 0.7334
Epoch [201/200], Batch [100/938], Loss D: 0.7220, Loss G: 0.7356
Epoch [201/200], Batch [200/938], Loss D: 0.6520, Loss G: 0.7828
Epoch [201/200], Batch [300/938], Loss D: 0.6572, Loss G: 0.7868
Epoch [201/200], Batch [400/938], Loss D: 0.6880, Loss G: 0.7401
Epoch [201/200], Batch [500/9

In [12]:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

In [14]:
import cv2
import os

def images_to_video(image_folder, video_name, fps):
    images = [img for img in sorted(os.listdir(image_folder)) if img.endswith(".png") and "fake_images" in img]
    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, _ = frame.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))

    for image in images:
        video.write(cv2.imread(os.path.join(image_folder, image)))

    cv2.destroyAllWindows()
    video.release()

# 调用函数
images_to_video('.', 'output_video.mp4', 30)  # 假设每秒30帧
