In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [None]:
!unzip -qq "/content/gdrive/MyDrive/Teaching/CC/ИИКб-21/13/Dataset/data.zip"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from torchvision.datasets import ImageFolder
from PIL import Image
import os
import matplotlib.pyplot as plt

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# 数据集加载和预处理
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('.jpg', '.png'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")  # 打开图像并转换为 RGB 格式
        if self.transform:
            image = self.transform(image)
        return image

# 数据路径
dataset_path = "/content/img_align_celeba"

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到 [-1, 1]
])

# 数据加载
dataset = CustomDataset(root_dir=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 8 * 8 * 128),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 8, 8)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, z):
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(8 * 8 * 128, 1),
            nn.Sigmoid()  # 输出范围 [0, 1]
        )

    def forward(self, x):
        return self.model(x)


In [None]:
# 初始化模型
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()  # 二元交叉熵
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:

# 训练循环
epochs = 50
fixed_noise = torch.randn(16, latent_dim).to(device)  # 用于保存固定噪声生成的图像

if not os.path.exists('generated_images'):
    os.makedirs('generated_images')

for epoch in range(epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # 训练判别器
        optimizer_D.zero_grad()
        real_labels = torch.ones((batch_size, 1)).to(device)  # 真实标签
        fake_labels = torch.zeros((batch_size, 1)).to(device)  # 假标签

        # 判别器对真实图像的损失
        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)

        # 判别器对假图像的损失
        z = torch.randn(batch_size, latent_dim).to(device)  # 随机噪声
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())  # 假图像不回传生成器梯度
        d_loss_fake = criterion(outputs, fake_labels)

        # 总判别器损失
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        fake_labels = torch.ones((batch_size, 1)).to(device)  # 生成器希望让判别器认为假的图像为真
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, fake_labels)
        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Batch {i}/{len(dataloader)} "
                  f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # 保存生成图像
    with torch.no_grad():
        fake_images = generator(fixed_noise).detach().cpu()
        fake_images = (fake_images + 1) / 2.0  # 将范围从 [-1, 1] 转换到 [0, 1]
        grid = utils.make_grid(fake_images, nrow=4)
        utils.save_image(grid, f"generated_images/epoch_{epoch+1}.png")

print("Training finished. Generated images are saved in 'generated_images' folder.")


Epoch [1/50] Batch 0/1583 Loss D: 1.3897, Loss G: 0.9324
Epoch [1/50] Batch 100/1583 Loss D: 0.0368, Loss G: 5.5936
Epoch [1/50] Batch 200/1583 Loss D: 0.1660, Loss G: 4.9337
Epoch [1/50] Batch 300/1583 Loss D: 0.5893, Loss G: 1.3278
Epoch [1/50] Batch 400/1583 Loss D: 0.2461, Loss G: 3.4950
Epoch [1/50] Batch 500/1583 Loss D: 0.2640, Loss G: 3.0103
Epoch [1/50] Batch 600/1583 Loss D: 0.2076, Loss G: 2.9829
Epoch [1/50] Batch 700/1583 Loss D: 0.3458, Loss G: 2.8806
Epoch [1/50] Batch 800/1583 Loss D: 0.3739, Loss G: 2.2323
Epoch [1/50] Batch 900/1583 Loss D: 0.3107, Loss G: 3.1705
Epoch [1/50] Batch 1000/1583 Loss D: 0.2178, Loss G: 3.4153
Epoch [1/50] Batch 1100/1583 Loss D: 0.2402, Loss G: 3.3543
Epoch [1/50] Batch 1200/1583 Loss D: 0.8169, Loss G: 1.3823
Epoch [1/50] Batch 1300/1583 Loss D: 0.2914, Loss G: 3.0399
Epoch [1/50] Batch 1400/1583 Loss D: 0.3812, Loss G: 1.9732
Epoch [1/50] Batch 1500/1583 Loss D: 0.2559, Loss G: 2.8487
Epoch [2/50] Batch 0/1583 Loss D: 0.3616, Loss G: 2.