<h1 style="text-align:center;font-weight: bold">GAN</h1>
    <h3 style="text-align:left;font-weight: bold">A generative adversarial network is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. Two neural networks contest with each other in a game. Given a training set, this technique learns to generate new data with the same statistics as the training set.</h3>


In [None]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import datetime
import matplotlib.pyplot as plt

In [None]:
!pip install torchsummary
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.manual_seed(1)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16
learning_rate = 0.0002
num_epochs = 1500

In [None]:
import shutil
from PIL import Image, ImageOps


def process_images_in_directory(source_dir, target_dir):
    os.makedirs(target_dir, exist_ok=True)
    # 遍历源目录中的文件
    for file_name in os.listdir(source_dir):
        file_path = os.path.join(source_dir, file_name)

        # 检查文件是否是图片
        if file_name.lower().endswith(
            (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif")
        ):
            # 打开图片
            img = Image.open(file_path)

            # 保存原始图片到目标目录
            original_save_path = os.path.join(target_dir, f"original_{file_name}")
            img.save(original_save_path)

            # 上下翻转
            flipped_up_down = ImageOps.flip(img)
            flipped_up_down_save_path = os.path.join(
                target_dir, f"flipped_ud_{file_name}"
            )
            flipped_up_down.save(flipped_up_down_save_path)

            # 左右翻转
            flipped_left_right = ImageOps.mirror(img)
            flipped_left_right_save_path = os.path.join(
                target_dir, f"flipped_lr_{file_name}"
            )
            flipped_left_right.save(flipped_left_right_save_path)

            # 90度旋转
            rotated_90 = img.rotate(90)  # 逆时针旋转90度
            rotated_90_save_path = os.path.join(target_dir, f"rotated_90_{file_name}")
            rotated_90.save(rotated_90_save_path)

            # 270度旋转
            rotated_270 = img.rotate(270)  # 逆时针旋转270度
            rotated_270_save_path = os.path.join(target_dir, f"rotated_270_{file_name}")
            rotated_270.save(rotated_270_save_path)

In [None]:
import os
import shutil
from PIL import Image, ImageOps

# 原始目录和目标目录
source_dir = "/kaggle/input/efficientnet-data/my_label_data/0"
target_dir = "./train/data"

# 创建目标目录
os.makedirs(target_dir, exist_ok=True)
process_images_in_directory(source_dir, target_dir)
process_images_in_directory(
    f"/kaggle/input/efficientnet-data/efficient_net_data_me/cropped_objects/0",
    f"{target_dir}",
)
process_images_in_directory(
    f"/kaggle/input/efficientnet-data/efficient2/cropped_objects/0", f"{target_dir}"
)

# shutil.copytree('/kaggle/input/efficientnet-data/efficient_net_data_me/cropped_objects/0', f'{target_dir}', dirs_exist_ok=True)
# shutil.copytree('/kaggle/input/efficientnet-data/efficient2/cropped_objects/0', f'{target_dir}', dirs_exist_ok=True)

image_count = sum(
    1
    for file_name in os.listdir(target_dir)
    if file_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"))
)

# 打印总数
print(f"Total images in the directory: {image_count}")

In [None]:
train_transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)
train_dataset = datasets.ImageFolder(
    root="./train", transform=train_transform
)  # 原始为../input/efficientnet-data/test/
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)

# train_transform = transforms.Compose([
#     transforms.Grayscale(num_output_channels=1),  # 转换为灰度图像
#     transforms.Resize((16, 16)),  # 调整为目标分辨率
#     transforms.ToTensor(),
#     transforms.Normalize([0.5], [0.5])  # 对灰度图像进行归一化
# ])
# train_dataset = datasets.ImageFolder(root='./train', transform=train_transform) #原始为../input/efficientnet-data/test/
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

<h1 style="text-align:center;font-weight: bold;">Exploratory Data Analysis</h1>


In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid


# 显示图像的函数
def show_images(images, nrow=8, figsize=(10, 10)):
    # 创建图像网格
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xticks([])  # 隐藏x轴刻度
    ax.set_yticks([])  # 隐藏y轴刻度

    # 将图片网格的大小调整并转置为 (height, width, channels)
    grid_img = make_grid(images, nrow=nrow).permute(1, 2, 0).cpu().numpy()

    # 显示图片
    ax.imshow(grid_img)


# 显示批次图像的函数
def show_batch(dl, n_images=64, nrow=8):
    # 只显示部分图像
    for images, _ in dl:
        # 只取前 n_images 张图像
        images = images[:n_images]
        show_images(images, nrow=nrow)
        break  # 只显示一个批次的图像


# 使用 train_loader 来展示图像
show_batch(train_loader, n_images=64, nrow=8)

In [None]:
image_shape = (3, 64, 64)
image_dim = int(np.prod(image_shape))
latent_dim = 128

<h2 style="text-align:center;font-weight: bold;">Initializing Weights</h2>


In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

<h2 style="text-align:center;font-weight: bold;">Generator Network</h2>


In [None]:
import torchvision.models as models
import torch.nn.functional as F


# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.main = nn.Sequential(
#             nn.ConvTranspose2d(latent_dim, 64 * 8, 4, 1, 0, bias=False),
#             nn.BatchNorm2d(64 * 8),
#             nn.ReLU(True),
#             nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64 * 4),
#             nn.ReLU(True),
#             nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64 * 2),
#             nn.ReLU(True),
#             nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(Tru|e),
#             nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
#             nn.Tanh()
#         )


#     def forward(self, input):
#         output = self.main(input)
#         return output
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        # 主路径：上采样 + 卷积块
        self.block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),  # 上采样
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
        )
        # 跳跃连接：直接上采样 + 1x1 卷积
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        shortcut = self.shortcut(x)  # 跳跃连接
        block_output = self.block(x)  # 主路径
        return F.relu(block_output + shortcut, inplace=True)  # 输出


class Generator(nn.Module):
    """
    基于 ResNet 的生成器，生成分辨率为 64x64 的图像。
    """

    def __init__(self, latent_dim=128):
        super(Generator, self).__init__()
        # 初始映射层
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(
                latent_dim, 512, kernel_size=4, stride=1, padding=0, bias=False
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        # 残差块
        self.res_blocks = nn.Sequential(
            ResidualBlock(512, 256),  # 4x4 -> 8x8
            ResidualBlock(256, 128),  # 8x8 -> 16x16
            ResidualBlock(128, 64),  # 16x16 -> 32x32
        )
        # 输出层
        self.final_conv = nn.ConvTranspose2d(
            64, 3, kernel_size=4, stride=2, padding=1, bias=False
        )  # 32x32 -> 64x64
        self.final_activation = nn.Tanh()

    def forward(self, z):
        x = self.initial(z)
        x = self.res_blocks(x)
        x = self.final_conv(x)
        x = self.final_activation(x)
        return x


# class Generator(nn.Module):
#     def __init__(self, latent_dim):
#         super(Generator, self).__init__()
#         self.main = nn.Sequential(
#             # 输入: latent_dim x 1 x 1 -> 64 * 8 x 4 x 4
#             nn.ConvTranspose2d(latent_dim, 64 * 8, 4, 1, 0, bias=False),
#             nn.BatchNorm2d(64 * 8),
#             nn.ReLU(True),
#             # 64 * 8 x 4 x 4 -> 64 * 4 x 8 x 8
#             nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64 * 4),
#             nn.ReLU(True),
#             # 64 * 4 x 8 x 8 -> 64 * 2 x 16 x 16
#             nn.ConvTranspose2d(64 * 4, 1, 4, 2, 1, bias=False),
#             nn.Tanh()  # 将输出归一化到 [-1, 1]
#         )

#     def forward(self, input):
#         return self.main(input)

In [None]:
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

In [None]:
summary(generator, (latent_dim, 1, 1))

<h2 style="text-align:center;font-weight: bold;">Descriminator Network</h2>


In [None]:
from torch.nn.utils import spectral_norm


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 64, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
        )

    def forward(self, input):
        output = self.main(input)
        return output


# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.main = nn.Sequential(
#             # 输入: 16x16 -> 8x8
#             nn.Conv2d(1, 64, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64),
#             nn.LeakyReLU(0.2, inplace=True),

#             # 8x8 -> 4x4
#             nn.Conv2d(64, 128, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(128),
#             nn.LeakyReLU(0.2, inplace=True),

#             # 4x4 -> 2x2
#             nn.Conv2d(128, 256, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.LeakyReLU(0.2, inplace=True),

#             # 2x2 -> 1x1
#             nn.Conv2d(256, 512, 2, 1, 0, bias=False),  # 修改 kernel_size 为 2
#             nn.BatchNorm2d(512),
#             nn.LeakyReLU(0.2, inplace=True),

#             # 输出标量
#             nn.Conv2d(512, 1, 1, 1, 0, bias=False),
#             nn.Sigmoid()
#         )

#     def forward(self, input):
#         return self.main(input).view(-1, 1)

In [None]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

In [None]:
summary(discriminator, (3, 64, 64))

In [None]:
adversarial_loss = nn.BCELoss()

In [None]:
def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

## The generator_loss function is fed two parameters:

- fake_output: Output predictions from the discriminator, when fed generator-produced images.
- label: Ground truth labels (1), for you would like the generator to fool the discriminator and produce real images. Hence, the labels would be one.


In [None]:
    def discriminator_loss(output, label):
        disc_loss = adversarial_loss(output, label)
        return disc_loss

In [None]:
import torch.nn.functional as F  # 导入 F 模块


def discriminator_loss(real_output, fake_output):
    return torch.mean(F.relu(1.0 - real_output)) + torch.mean(F.relu(1.0 + fake_output))


def generator_loss(fake_output, label=None):
    return -torch.mean(fake_output)

In [None]:
def gradient_penalty(critic, real_data, fake_data, device):
    """
    计算梯度惩罚项，确保判别器满足 Lipschitz 连续性。
    :param critic: 判别器模型
    :param real_data: 真实样本
    :param fake_data: 生成样本
    :param device: 当前设备
    :return: 梯度惩罚值
    """
    # 生成插值样本
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)  # 随机权重
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)

    # 判别器对插值样本的输出
    critic_output = critic(interpolates)

    # 计算插值样本的梯度
    gradients = torch.autograd.grad(
        outputs=critic_output,
        inputs=interpolates,
        grad_outputs=torch.ones_like(critic_output, device=device),
        create_graph=True,
        retain_graph=True,
    )[0]

    # 计算梯度范数和惩罚
    gradients = gradients.view(batch_size, -1)  # 展平
    gradient_norm = gradients.norm(2, dim=1)  # 计算 L2 范数
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty

## The discriminator loss has:

- real (original images) output predictions, ground truth label as 1
- fake (generated images) output predictions, ground truth label as 0.


In [None]:
fixed_noise = torch.randn(128, latent_dim, 1, 1, device=device)
real_label = 1
fake_label = 0

In [None]:
G_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(
    discriminator.parameters(), lr=learning_rate * 0.08, betas=(0.5, 0.999)
)  # 单个粒子使用0.3就可以了，分离使用0.04 单个粒子0.03

<h2 style="text-align:center;font-weight: bold;">Training our network</h2>


In [None]:
import torch
import os
from torchvision.utils import save_image
import torch.nn.functional as F  # 用于 ReLU

# 初始化
D_loss_plot, G_loss_plot = [], []
best_loss_diff = float("inf")  # 初始化为无穷大

# 定义保存路径
os.makedirs("./t_weights", exist_ok=True)
os.makedirs("./images", exist_ok=True)

for epoch in range(1, num_epochs + 1):
    D_loss_list, G_loss_list = [], []

    for index, (real_images, _) in enumerate(train_loader):
        # ==================== 判别器训练 ====================
        D_optimizer.zero_grad()
        real_images = real_images.to(device)

        # 生成假样本
        noise_vector = torch.randn(real_images.size(0), latent_dim, 1, 1, device=device)
        generated_image = generator(noise_vector)

        # 判别器输出
        real_output = discriminator(real_images)
        fake_output = discriminator(generated_image.detach())

        # 梯度惩罚
        lambda_gp = 10  # 梯度惩罚权重
        gp = gradient_penalty(
            discriminator, real_images, generated_image.detach(), device
        )

        # 判别器损失（WGAN-GP）
        D_loss = -torch.mean(real_output) + torch.mean(fake_output) + lambda_gp * gp
        D_loss.backward()
        D_loss_list.append(D_loss.item())
        D_optimizer.step()

        # ==================== 生成器训练 ====================
        G_optimizer.zero_grad()

        # 生成器的假样本输出
        fake_output = discriminator(generated_image)

        # 生成器损失
        G_loss = -torch.mean(fake_output)
        G_loss.backward()
        G_loss_list.append(G_loss.item())
        G_optimizer.step()

    # 计算每个 epoch 的平均损失
    avg_D_loss = torch.mean(torch.FloatTensor(D_loss_list))
    avg_G_loss = torch.mean(torch.FloatTensor(G_loss_list))
    loss_diff = abs(avg_D_loss - avg_G_loss)

    print(
        f"Epoch: [{epoch}/{num_epochs}]: D_loss: {avg_D_loss:.3f}, G_loss: {avg_G_loss:.3f}, Loss_Diff: {loss_diff:.3f}"
    )

    # 保存损失
    D_loss_plot.append(avg_D_loss)
    G_loss_plot.append(avg_G_loss)

    # 保存生成的图像
    if epoch % 10 == 0:  # 每 10 个 epoch 保存一次样本
        save_image(
            generated_image.data[:10],
            f"./images/sample_{epoch}.png",
            nrow=5,
            normalize=True,
        )

    # 保存最佳模型（基于损失差异）
    if loss_diff < 0.7:
        best_loss_diff = loss_diff
        save_image(
            generated_image.data[:3],
            f"./images/sample_{epoch}.png",
            nrow=5,
            normalize=True,
        )
        torch.save(generator.state_dict(), "./t_weights/best_generator.pth")
        torch.save(discriminator.state_dict(), "./t_weights/best_discriminator.pth")
        print(f"Best model saved at epoch {epoch} with Loss_Diff: {best_loss_diff:.3f}")

# 最后保存损失曲线
torch.save(D_loss_plot, "./t_weights/D_loss_plot.pth")
torch.save(G_loss_plot, "./t_weights/G_loss_plot.pth")

In [None]:
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(D_loss_plot, label="Discriminator Loss (D)", color="red", linewidth=2)
plt.plot(G_loss_plot, label="Generator Loss (G)", color="blue", linewidth=2)

# 添加标题和标签
plt.title("GAN Training Loss", fontsize=16)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)

# 添加图例
plt.legend(fontsize=12)
plt.grid(True)

loss_diff = [abs(d - g) for d, g in zip(D_loss_plot, G_loss_plot)]
plt.plot(loss_diff, label="Loss Difference", color="green", linewidth=2)

In [None]:
# D_loss_plot, G_loss_plot = [], []
# for epoch in range(1, num_epochs+1):

#     D_loss_list, G_loss_list = [], []

#     for index, (real_images, _) in enumerate(train_loader):
#         D_optimizer.zero_grad()
#         real_images = real_images.to(device)

#         real_target = Variable(torch.ones(real_images.size(0)).to(device))
#         fake_target = Variable(torch.zeros(real_images.size(0)).to(device))

#         real_target = real_target.unsqueeze(1)
#         fake_target = fake_target.unsqueeze(1)

#         D_real_loss = discriminator_loss(discriminator(real_images), real_target)
#         D_real_loss.backward()

#         noise_vector = torch.randn(real_images.size(0), latent_dim, 1, 1, device=device)
#         noise_vector = noise_vector.to(device)

#         generated_image = generator(noise_vector)
#         output = discriminator(generated_image.detach())
#         D_fake_loss = discriminator_loss(output,  fake_target)

#         D_fake_loss.backward()

#         D_total_loss = D_real_loss + D_fake_loss
#         D_loss_list.append(D_total_loss)

#         D_optimizer.step()


#         G_optimizer.zero_grad()
#         G_loss = generator_loss(discriminator(generated_image), real_target)
#         G_loss_list.append(G_loss)

#         G_loss.backward()
#         G_optimizer.step()


#     print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
#             (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
#              torch.mean(torch.FloatTensor(G_loss_list))))

#     D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
#     G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
#     if epoch > num_epochs*0.9:
#         save_image(generated_image.data[:10], './images/sample_%d'%epoch + '.png', nrow=5, normalize=True)
#         torch.save(generator.state_dict(), './t_weights/generator_epoch_%d.pth' % (epoch))
#         torch.save(discriminator.state_dict(), './t_weights/discriminator_epoch_%d.pth' % (epoch))

<h1 style="text-align:center;font-weight: bold">Outputing Results</h1>


In [None]:
def getImagePaths(path):
    image_names = []
    for dirname, _, filenames in os.walk(path):
        for filename in filenames:
            fullpath = os.path.join(dirname, filename)
            image_names.append(fullpath)
    return image_names

In [None]:
import cv2
import math
import matplotlib.pyplot as plt


def display_multiple_img(images_paths):
    # 计算自适应的行列数
    num_images = len(images_paths)
    cols = int(math.ceil(math.sqrt(num_images)))  # 列数 = 根号下的图像数量，四舍五入
    rows = int(math.ceil(num_images / cols))  # 行数 = 图像数量 / 列数，四舍五入

    # 设置图形大小，调整到适合的比例，增加图像的显示大小
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * 3, rows * 3))

    # 遍历图像路径列表
    for ind, image_path in enumerate(images_paths):
        # 尝试读取并显示图像
        try:
            image = cv2.imread(image_path)  # 读取图像
            if image is None:
                raise ValueError(f"Image at {image_path} could not be loaded.")

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换为 RGB

            ax.ravel()[ind].imshow(image)  # 显示图像
            ax.ravel()[ind].set_axis_off()  # 隐藏轴
        except Exception as e:
            print(f"Error displaying image at {image_path}: {e}")

    # 隐藏未使用的子图（如果图像少于网格数）
    for i in range(num_images, rows * cols):
        ax.ravel()[i].set_visible(False)

    plt.tight_layout(pad=2.0)  # 增加子图间距
    plt.show()

In [None]:
display_multiple_img(getImagePaths("./images"))

In [None]:
!pip install pytorch-fid

In [None]:
import os
import torch
from torchvision.utils import save_image
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

real_images_dir = "./train"
generated_images_dir = "./generated_images"
os.makedirs(real_images_dir, exist_ok=True)
os.makedirs(generated_images_dir, exist_ok=True)

generator = Generator()  # 使用你定义的生成器类
# generator.load_state_dict(torch.load(f'./t_weights/generator_epoch_{num_epochs}.pth'))  # 加载最后一轮的生成器权重
generator.load_state_dict(torch.load(f"./t_weights/best_generator.pth"))
generator.eval()


def generate_images(generator, num_images, latent_dim, save_dir):
    generator.eval()
    noise_vector = torch.randn(num_images, latent_dim, 1, 1)
    with torch.no_grad():
        generated_images = generator(noise_vector)
    generated_images = torch.nn.functional.interpolate(generated_images, size=(64, 64))

    # 保存每张图像到指定目录
    for i in range(num_images):
        save_path = os.path.join(save_dir, f"generated_image_{i + 1:03d}.png")
        save_image(generated_images[i], save_path, normalize=True)


def save_real_images(dataset_path, save_dir, num_images=50):
    transform = transforms.Compose(
        [
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )
    real_dataset = ImageFolder(root=dataset_path, transform=transform)
    real_loader = DataLoader(real_dataset, batch_size=1, shuffle=False)

    # 保存前 num_images 的真实图像
    for i, (image, _) in enumerate(real_loader):
        if i >= num_images:
            break
        save_path = os.path.join(save_dir, f"real_image_{i + 1:03d}.png")
        save_image(image[0], save_path, normalize=True)


generate_images(
    generator, num_images=200, latent_dim=latent_dim, save_dir=generated_images_dir
)

path_to_real_dataset = "./real_images"
os.makedirs(path_to_real_dataset, exist_ok=True)
save_real_images(real_images_dir, path_to_real_dataset, num_images=200)

# 提示用户使用 pytorch-fid 工具
print(f"Run the following command to compute FID:")
print(f"pytorch-fid {path_to_real_dataset} {generated_images_dir}")

In [None]:
display_multiple_img(getImagePaths("./generated_images"))

In [None]:
!python -m pytorch_fid  ./real_images ./generated_images

In [None]:
generate_images(
    generator, num_images=500, latent_dim=latent_dim, save_dir=generated_images_dir
)

In [None]:
!zip -r generated_images.zip ./generated_images