In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  _torch_pytree._register_pytree_node(


# 生成器

In [2]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()  # Tanh激活函数使得输出范围在-1到1之间
        )

    def forward(self, x):
        return self.net(x)


# 判别器

In [3]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()  # Sigmoid激活函数将输出压缩到0到1之间
        )

    def forward(self, x):
        return self.net(x)


In [4]:
# 超参数设置
latent_size = 64  # 生成器输入的维度
hidden_size = 256  # 隐藏层维度
image_size = 784  # 28x28 图像展平后的维度
output_size = 1  # 判别器输出的维度

# 创建模型
G = Generator(latent_size, hidden_size, image_size).to(device)
D = Discriminator(image_size, hidden_size, output_size).to(device)

# 优化器
lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

# 损失函数
criterion = nn.BCELoss()


In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_data, batch_size=64, shuffle=True)


In [6]:
num_epochs = 50
real_label = 1
fake_label = 0



for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        # 训练判别器
        D.zero_grad()
        # 真实图像
        real_images = images.view(-1, image_size).to(device)
        real_labels = torch.full((images.size(0),), real_label, dtype=torch.float, device=device)
        real_output = D(real_images)
        # print(real_labels.shape)
        # print(real_output.shape)
        D_loss_real = criterion(real_output, real_labels.view(-1, 1))
        D_loss_real.backward()
        D_x = real_output.mean().item()

        # 生成假图像
        noise = torch.randn(images.size(0), latent_size, device=device)
        fake_images = G(noise)
        fake_labels = torch.full((images.size(0),), fake_label, dtype=torch.float, device=device)
        fake_output = D(fake_images.detach())
        D_loss_fake = criterion(fake_output, fake_labels.view(-1, 1))
        D_loss_fake.backward()
        D_G_z1 = fake_output.mean().item()

        # 更新判别器D
        D_optimizer.step()

        # 训练生成器
        G.zero_grad()
        label = torch.full((images.size(0),), real_label, dtype=torch.float, device=device)
        output = D(fake_images)
        G_loss = criterion(output, label.view(-1, 1))
        G_loss.backward()
        D_G_z2 = output.mean().item()

        # 更新生成器G
        G_optimizer.step()
        
        if i % 1 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], D_loss: {D_loss_real.item()+D_loss_fake.item():.4f}, G_loss: {G_loss.item():.4f}, D(x): {D_x:.4f}, D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')


Epoch [1/50], Step [1/938], D_loss: 1.4805, G_loss: 0.7207, D(x): 0.4437, D(G(z)): 0.4867/0.4864
Epoch [1/50], Step [2/938], D_loss: 1.3308, G_loss: 0.7146, D(x): 0.5184, D(G(z)): 0.4900/0.4894
Epoch [1/50], Step [3/938], D_loss: 1.2111, G_loss: 0.7087, D(x): 0.5876, D(G(z)): 0.4930/0.4923
Epoch [1/50], Step [4/938], D_loss: 1.1161, G_loss: 0.7018, D(x): 0.6505, D(G(z)): 0.4963/0.4957
Epoch [1/50], Step [5/938], D_loss: 1.0337, G_loss: 0.6946, D(x): 0.7113, D(G(z)): 0.4998/0.4993
Epoch [1/50], Step [6/938], D_loss: 0.9640, G_loss: 0.6866, D(x): 0.7686, D(G(z)): 0.5037/0.5033
Epoch [1/50], Step [7/938], D_loss: 0.9144, G_loss: 0.6790, D(x): 0.8139, D(G(z)): 0.5075/0.5071
Epoch [1/50], Step [8/938], D_loss: 0.8747, G_loss: 0.6675, D(x): 0.8570, D(G(z)): 0.5133/0.5130
Epoch [1/50], Step [9/938], D_loss: 0.8396, G_loss: 0.6595, D(x): 0.8955, D(G(z)): 0.5176/0.5171
Epoch [1/50], Step [10/938], D_loss: 0.8262, G_loss: 0.6486, D(x): 0.9187, D(G(z)): 0.5234/0.5228
Epoch [1/50], Step [11/938], 

KeyboardInterrupt: 