In [1]:
import os
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

In [2]:
# %% --------------------------------------- 固定随机种子 -----------------------------------------------------------------
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Using: cpu


In [3]:
# %% ---------------------------------- 数据准备 ---------------------------------------------------------------
def change_image_shape(images):
    """调整图像形状以确保正确的格式"""
    shape_tuple = images.shape  # 获取图像数组的形状
    if len(shape_tuple) == 3:  # 如果形状的维度为3（例如：样本数，高度，宽度）
        images = images.reshape(-1, 1, shape_tuple[-1], shape_tuple[-1])  # 重塑为四维：样本数，通道数(1)，高度，宽度
    elif shape_tuple == 4 and shape_tuple[-1] > 3:  # 如果是四维且最后一个维度大于3（不是RGB通道）
        images = images.reshape(-1, shape_tuple[1], shape_tuple[-1], shape_tuple[-1])  # 重塑为四维：样本数，通道数，高度，宽度
    return images  # 返回调整后的图像数组

In [5]:
# 加载MNIST数据集
from torchvision.datasets import FashionMNIST
from torchvision.datasets import MNIST
from torchvision import transforms

# 创建数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 加载训练集
fashion_mnist = MNIST(root = '/Users/max/MasterThesisData/MNIST/', train=True, download=True, transform=transform)
test_fashion_mnist = MNIST(root = '/Users/max/MasterThesisData/MNIST/', train=False, download=True, transform=transform)
images = fashion_mnist.data.numpy()
labels = fashion_mnist.targets.numpy()
test_images = test_fashion_mnist.data.numpy()
test_labels = test_fashion_mnist.targets.numpy()
# 转换图像形状
images = images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)
print("Train dataset")
print(images.shape)
print(labels.shape)
print("Test dataset")
print(test_images.shape)
print(test_labels.shape)

Train dataset
(60000, 28, 28, 1)
(60000,)
Test dataset
(10000, 28, 28, 1)
(10000,)


In [6]:
# 选择类别为3和4的样本
images_3 = images[labels == 3]
images_4 = images[labels == 4]
# 将类别4的样本数减少到类别3样本数的0.5%
num_3 = images_3.shape[0]
num_4 = int(num_3 * 0.005)  # 0.5%
images_4 = images_4[:num_4]
# 构造新的标签数组
labels_3 = np.full((images_3.shape[0],), 3)
labels_4 = np.full((images_4.shape[0],), 4)
# 合并样本和标签
images_new = np.vstack([images_3, images_4])
labels_new = np.concatenate([labels_3, labels_4])
imbalance_images = images_new
imbalance_labels = labels_new

images = imbalance_images # 3和4类别的样本
labels = imbalance_labels # 3和4类别的样本
#  3 -> 1 4 -> 0
labels[labels == 3] = 1
labels[labels == 4] = 0
print("Train - Imbalance data shape:", images.shape, labels.shape)
print("Train - Imbalance data distribution:", np.unique(labels, return_counts=True))
# 选择类别为3和4的样本
test_images_3 = test_images[test_labels == 3]
test_images_4 = test_images[test_labels == 4]
# 构造新的标签数组
test_labels_3 = np.full((test_images_3.shape[0],), 3)
test_labels_4 = np.full((test_images_4.shape[0],), 4)
# 合并样本和标签
test_images_new = np.vstack([test_images_3, test_images_4])
test_labels_new = np.concatenate([test_labels_3, test_labels_4])
test_imbalance_images = test_images_new
test_imbalance_labels = test_labels_new

test_images = test_imbalance_images # 3和4类别的样本
test_labels = test_imbalance_labels # 3和4类别的样本

#  3 -> 1 4 -> 0
test_labels[test_labels == 3] = 1
test_labels[test_labels == 4] = 0

print("Test - Imbalance data shape:", test_images.shape, test_labels.shape)
print("Test - Imbalance data distribution:", np.unique(test_labels, return_counts=True))

Train - Imbalance data shape: (6161, 28, 28, 1) (6161,)
Train - Imbalance data distribution: (array([0, 1]), array([  30, 6131]))
Test - Imbalance data shape: (1992, 28, 28, 1) (1992,)
Test - Imbalance data distribution: (array([0, 1]), array([ 982, 1010]))


In [10]:
# 提高分辨率能够获得更好的信息
# 设置通道数
channel = images.shape[-1]

# 将图像调整为 64 x 64 x channel
real = np.ndarray(shape=(images.shape[0], 64, 64, channel))
for i in range(images.shape[0]):
    real[i] = cv2.resize(images[i], (64, 64)).reshape((64, 64, channel))


test_channel = test_images.shape[-1]
test_real = np.ndarray(shape=(test_images.shape[0], 64, 64, test_channel))
for i in range(test_images.shape[0]):
    test_real[i] = cv2.resize(test_images[i], (64, 64)).reshape((64, 64, test_channel))


print("Train - Imbalance data shape:", real.shape, labels.shape)
print("Train - Imbalance data distribution:", np.unique(labels, return_counts=True))
print("Test - Imbalance data shape:", test_real.shape, test_labels.shape)
print("Test - Imbalance data distribution:", np.unique(test_labels, return_counts=True))

Train - Imbalance data shape: (6161, 64, 64, 1) (6161,)
Train - Imbalance data distribution: (array([0, 1]), array([  30, 6131]))
Test - Imbalance data shape: (1992, 64, 64, 1) (1992,)
Test - Imbalance data distribution: (array([0, 1]), array([ 982, 1010]))


In [13]:
# 加载训练集和测试集 同时输出形状
X_train = real
y_train = labels
X_test = test_real
y_test = test_labels
print("Train data shape:", X_train.shape, y_train.shape)
print("Test data shape:", X_test.shape, y_test.shape)

Train data shape: (6161, 64, 64, 1) (6161,)
Test data shape: (1992, 64, 64, 1) (1992,)


In [14]:
# 对GAN训练建议使用[-1, 1]范围的输入 标准化 归一化输入
X_train = (X_train.astype('float32') - 127.5) / 127.5
X_test = (X_test.astype('float32') - 127.5) / 127.5

# 转换为PyTorch张量并调整通道顺序 (N,H,W,C) -> (N,C,H,W)
X_train = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)
X_test = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 1, 2)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

print("Train data shape:", X_train.shape, y_train.shape)
print("Test data shape:", X_test.shape, y_test.shape)

Train data shape: torch.Size([6161, 1, 64, 64]) torch.Size([6161])
Test data shape: torch.Size([1992, 1, 64, 64]) torch.Size([1992])


In [18]:
# 获取图像大小
img_size = (channel, 64, 64)
# 获取类别数量
n_classes = len(torch.unique(y_train))

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # batch_size= 128
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False) #batch_size= 128

print("Train Loader:",len(train_loader))
print("Test Loader:",len(test_loader))

Train Loader: 49
Test Loader: 16


In [19]:
# %% ---------------------------------- 超参数设置 ----------------------------------------------------------------
# 潜在空间维度
latent_dim = 128
# 训练比率 === 训练判别器的次数 / 训练生成器的次数
train_ratio = 10
# 优化器参数
lr = 0.0002
beta1 = 0.5
beta2 = 0.9
# 梯度惩罚权重
gp_weight = 10.0 # GP的调整

In [23]:
# %% ---------------------------------- 模型设置 -------------------------------------------------------------------
# 构建解码器模型
class Decoder(nn.Module):
    def __init__(self, latent_dim, channel):
        """
        生成器/解码器模型
        参数:
            latent_dim: 潜在空间维度
            channel: 输出图像的通道数
        """
        super(Decoder, self).__init__()
        
        # 初始全连接层
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 4*4*256),
            nn.LeakyReLU(0.2)
        )
        
        # 转置卷积层
        self.deconv = nn.Sequential(
            # 尺寸: 4 x 4 x 256 -> 8 x 8 x 128
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 8 x 8 x 128 -> 16 x 16 x 128
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 16 x 16 x 128 -> 32 x 32 x 64
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 32 x 32 x 64 -> 64 x 64 x channel
            nn.ConvTranspose2d(64, channel, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        """前向传播"""
        x = self.fc(x)
        x = x.view(-1, 256, 4, 4)
        x = self.deconv(x)
        return x

In [None]:
# 注意力机制
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        """
        注意力机制
        参数:
            in_dim: 输入特征的维度
        """
        super(SelfAttention, self).__init__()
        
        # 计算Q、K、V的权重矩阵
        self.query = nn.Conv2d(in_dim, in_dim//8, kernel_size=1)
        self.key = nn.Conv2d(in_dim, in_dim//8, kernel_size=1)
        self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        
        # 输出映射
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        """前向传播"""
        batch_size, C, width, height = x.size()
        
        # 计算Q、K、V
        proj_query = self.query(x).view(batch_size, -1, width*height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        
        proj_value = self.value(x).view(batch_size, -1, width*height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        
        # 输出映射
        out = self.gamma * out + x
        return out

# 构建编码器模型
class Encoder(nn.Module):
    def __init__(self, img_size, latent_dim):
        """
        编码器模型
        参数:
            img_size: 输入图像的大小 (C, H, W)
            latent_dim: 潜在空间维度
        """
        super(Encoder, self).__init__()
        
        # 卷积层
        self.conv = nn.Sequential(
            # 尺寸: 64 x 64 x channel -> 32 x 32 x 64
            nn.Conv2d(img_size[0], 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            SelfAttention(64),
            
            # 尺寸: 32 x 32 x 64 -> 16 x 16 x 128
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            SelfAttention(128),
            
            # 尺寸: 16 x 16 x 128 -> 8 x 8 x 128
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            SelfAttention(128),
            
            # 尺寸: 8 x 8 x 128 -> 4 x 4 x 256
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            SelfAttention(256)
        )
        
        # 全连接层
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4*4*256, latent_dim),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        """前向传播"""
        x = self.conv(x)
        # 保存特征图用于后续的判别器
        self.features = x
        x = self.fc(x)
        return x

In [25]:
# 构建嵌入模型
class LabelEmbedding(nn.Module):
    def __init__(self, n_classes, latent_dim):
        """
        标签嵌入模型
        参数:
            n_classes: 类别数量
            latent_dim: 潜在空间维度
        """
        super(LabelEmbedding, self).__init__()
        
        self.embedding = nn.Embedding(n_classes, latent_dim)
        
    def forward(self, noise, label):
        """
        前向传播
        参数:
            noise: 噪声向量
            label: 类别标签
        """
        label_embedding = self.embedding(label).squeeze(1)
        # 元素乘法融合噪声和标签信息
        noise_le = noise * label_embedding
        return noise_le

In [26]:
# 构建自编码器
class Autoencoder(nn.Module):
    def __init__(self, encoder, decoder, embedding):
        """
        自编码器模型
        参数:
            encoder: 编码器模型
            decoder: 解码器模型
            embedding: 标签嵌入模型
        """
        super(Autoencoder, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.embedding = embedding
        
    def forward(self, img, label):
        """前向传播"""
        latent = self.encoder(img)
        labeled_latent = self.embedding(latent, label)
        rec_img = self.decoder(labeled_latent)
        return rec_img

In [27]:
# 构建判别器模型
class Discriminator(nn.Module):
    def __init__(self, img_size, n_classes):
        """
        判别器模型
        参数:
            img_size: 输入图像的大小 (C, H, W)
            n_classes: 类别数量
        """
        super(Discriminator, self).__init__()
        
        # 卷积层
        self.conv = nn.Sequential(
            # 尺寸: 64 x 64 x channel -> 32 x 32 x 64
            nn.Conv2d(img_size[0], 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 32 x 32 x 64 -> 16 x 16 x 128
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 16 x 16 x 128 -> 8 x 8 x 128
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            
            # 尺寸: 8 x 8 x 128 -> 4 x 4 x 256
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        
        # 标签嵌入层
        self.label_embedding = nn.Sequential(
            nn.Embedding(n_classes, 512),
            nn.Flatten(),
            nn.Linear(512, 4*4*256),
            nn.LeakyReLU(0.2)
        )
        
        # 最终判别层
        self.classifier = nn.Sequential(
            nn.Linear(4*4*256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1)
        )
    
    def forward(self, img, label):
        """前向传播"""
        img_features = self.conv(img)
        img_features = img_features.view(-1, 4*4*256)
        
        label_features = self.label_embedding(label)
        
        # 融合图像和标签特征
        features = img_features * label_features
        output = self.classifier(features)
        
        return output

In [28]:
# 构建生成器（继承预训练的解码器和嵌入层）
class Generator(nn.Module):
    def __init__(self, embedding, decoder):
        """
        生成器模型
        参数:
            embedding: 预训练的标签嵌入模型
            decoder: 预训练的解码器模型
        """
        super(Generator, self).__init__()
        
        self.embedding = embedding
        self.decoder = decoder
        
    def forward(self, noise, label):
        """前向传播"""
        labeled_latent = self.embedding(noise, label)
        gen_img = self.decoder(labeled_latent)
        return gen_img

In [29]:
# %% ---------------------------------- 损失函数和训练函数 ----------------------------------------------------------------
# 判别器损失函数
def discriminator_loss(real_logits, fake_logits, wrong_label_logits):
    """
    判别器的损失函数
    参数:
        real_logits: 真实图像的判别结果
        fake_logits: 生成图像的判别结果
        wrong_label_logits: 真实图像但标签错误的判别结果
    """
    real_loss = F.binary_cross_entropy_with_logits(real_logits, torch.ones_like(real_logits))
    fake_loss = F.binary_cross_entropy_with_logits(fake_logits, torch.zeros_like(fake_logits))
    wrong_label_loss = F.binary_cross_entropy_with_logits(wrong_label_logits, torch.zeros_like(wrong_label_logits))
    
    return wrong_label_loss + fake_loss + real_loss

In [30]:
# 生成器损失函数
def generator_loss(fake_logits):
    """
    生成器的损失函数
    参数:
        fake_logits: 生成图像的判别结果
    """
    return F.binary_cross_entropy_with_logits(fake_logits, torch.ones_like(fake_logits))

# 梯度惩罚函数
def gradient_penalty(discriminator, real_images, fake_images, labels):
    """
    计算梯度惩罚
    参数:
        discriminator: 判别器模型
        real_images: 真实图像
        fake_images: 生成图像
        labels: 类别标签
    """
    batch_size = real_images.size(0)
    
    # 创建插值图像
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = real_images + alpha * (fake_images - real_images)
    interpolated.requires_grad_(True)
    
    # 计算判别器对插值图像的输出
    disc_interpolates = discriminator(interpolated, labels)
    
    # 计算梯度
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolated,
                                   grad_outputs=torch.ones_like(disc_interpolates),
                                   create_graph=True, retain_graph=True)[0]
    
    # 计算梯度范数
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # 返回梯度惩罚值
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty

In [31]:
# 训练一个epoch的函数
def train_epoch(discriminator, generator, dataloader, d_optimizer, g_optimizer, epoch):
    """
    训练一个epoch
    参数:
        discriminator: 判别器模型
        generator: 生成器模型
        dataloader: 数据加载器
        d_optimizer: 判别器优化器
        g_optimizer: 生成器优化器
        epoch: 当前epoch
    """
    discriminator.train()
    generator.train()
    
    d_losses = []
    g_losses = []
    
    for batch_idx, (real_images, labels) in enumerate(dataloader):
        batch_size = real_images.size(0)
        real_images, labels = real_images.to(device), labels.to(device)
        
        # 每个生成器更新前，多次更新判别器
        for _ in range(train_ratio):
            # 生成随机噪声和标签
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_labels = torch.randint(0, n_classes, (batch_size,), device=device)
            wrong_labels = torch.randint(0, n_classes, (batch_size,), device=device)
            
            # 清除判别器梯度
            d_optimizer.zero_grad()
            
            # 生成假图像
            fake_images = generator(noise, fake_labels)
            
            # 计算判别器对真实图像、假图像和标签错误图像的输出
            real_logits = discriminator(real_images, labels)
            fake_logits = discriminator(fake_images.detach(), fake_labels)
            wrong_label_logits = discriminator(real_images, wrong_labels)
            
            # 计算判别器损失和梯度惩罚
            d_loss = discriminator_loss(real_logits, fake_logits, wrong_label_logits)
            gp = gradient_penalty(discriminator, real_images, fake_images.detach(), labels)
            d_total_loss = d_loss + gp_weight * gp
            
            # 反向传播和优化
            d_total_loss.backward()
            d_optimizer.step()
        
        # 训练生成器
        # 生成新的随机噪声和标签
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_labels = torch.randint(0, n_classes, (batch_size,), device=device)
        
        # 清除生成器梯度
        g_optimizer.zero_grad()
        
        # 生成假图像
        fake_images = generator(noise, fake_labels)
        
        # 计算判别器对假图像的输出
        fake_logits = discriminator(fake_images, fake_labels)
        
        # 计算生成器损失
        g_loss = generator_loss(fake_logits)
        
        # 反向传播和优化
        g_loss.backward()
        g_optimizer.step()
        
        # 记录损失
        d_losses.append(d_total_loss.item())
        g_losses.append(g_loss.item())
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} [{batch_idx}/{len(dataloader)}] - D Loss: {d_total_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # 返回平均损失
    return sum(d_losses) / len(d_losses), sum(g_losses) / len(g_losses)


In [None]:
# %% ---------------------------------- 自编码器训练 ----------------------------------------------------------------
# 初始化模型
encoder = Encoder(img_size, latent_dim).to(device)
decoder = Decoder(latent_dim, channel).to(device)
embedding = LabelEmbedding(n_classes, latent_dim).to(device)
autoencoder = Autoencoder(encoder, decoder, embedding).to(device)

# 优化器
ae_optimizer = optim.Adam(autoencoder.parameters(), lr=lr, betas=(beta1, beta2))

# 训练函数
def train_autoencoder(autoencoder, dataloader, optimizer, num_epochs):
    """
    训练自编码器
    参数:
        autoencoder: 自编码器模型
        dataloader: 数据加载器
        optimizer: 优化器
        num_epochs: 训练轮数
    """
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = []
        
        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            reconstructed = autoencoder(images, labels)
            
            # 计算损失（使用MAE损失）
            loss = F.l1_loss(reconstructed, images)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss.append(loss.item())
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} [{batch_idx}/{len(dataloader)}] - Loss: {loss.item():.4f}")
        
        losses.append(sum(epoch_loss) / len(epoch_loss))
        print(f"Epoch {epoch+1}/{num_epochs} completed - Avg Loss: {losses[-1]:.4f}")
    
    return losses

# 训练自编码器
print("开始训练自编码器...")
ae_losses = train_autoencoder(autoencoder, train_loader, ae_optimizer, num_epochs=30)

In [None]:
# %% ---------------------------------- 显示自编码器重建结果 ----------------------------------------------------------------
# 评估自编码器并显示结果
def show_reconstructed_images():
    """显示自编码器重建的图像"""
    autoencoder.eval()
    
    # 获取测试集的一批数据
    show_test_images = []
    show_test_labels = []
    
    for c in range(n_classes):
        # 为每个类别找到一个示例
        for images, labels in test_loader:
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if len(idx) > 0:
                show_test_images.append(images[idx[0]].unsqueeze(0))
                show_test_labels.append(labels[idx[0]].unsqueeze(0))
                break
    
    # 转换为批次
    show_test_images = torch.cat(show_test_images, dim=0).to(device)
    show_test_labels = torch.cat(show_test_labels, dim=0).to(device)
    
    # 重建图像
    with torch.no_grad():
        reconstructed = autoencoder(show_test_images, show_test_labels)
    
    # 转换为NumPy数组用于显示
    show_test_images = show_test_images.cpu().numpy()
    reconstructed = reconstructed.cpu().numpy()
    
    # 转换回[0, 1]范围
    show_test_images = show_test_images * 0.5 + 0.5
    reconstructed = reconstructed * 0.5 + 0.5
    
    # 显示结果
    plt.figure(figsize=(2*n_classes, 4))
    
    for i in range(n_classes):
        # 显示原始图像
        ax = plt.subplot(2, n_classes, i+1)
        if channel == 3:
            plt.imshow(np.transpose(show_test_images[i], (1, 2, 0)))
        else:
            plt.imshow(show_test_images[i].reshape(64, 64), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        # 显示重建图像
        ax = plt.subplot(2, n_classes, i + n_classes + 1)
        if channel == 3:
            plt.imshow(np.transpose(reconstructed[i], (1, 2, 0)))
        else:
            plt.imshow(reconstructed[i].reshape(64, 64), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
    plt.savefig('./bagan_gp_results/autoencoder_reconstruction.png')
    plt.show()

# 显示重建结果
show_reconstructed_images()

In [None]:
# %% ---------------------------------- BAGAN-GP训练 ----------------------------------------------------------------
# 初始化BAGAN-GP模型
discriminator = Discriminator(img_size, n_classes).to(device)
generator = Generator(embedding, decoder).to(device)

# 优化器
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))

# 创建目录保存结果
os.makedirs('bagan_gp_results', exist_ok=True)

# 生成并保存图像的函数
def generate_and_save_images(generator, epoch):
    """
    生成图像并保存
    参数:
        generator: 生成器模型
        epoch: 当前epoch
    """
    generator.eval()
    
    # 使用固定的噪声来跟踪训练进度
    np.random.seed(42)
    latent_gen = torch.tensor(np.random.normal(size=(n_classes, latent_dim)), 
                              dtype=torch.float32).to(device)
    
    # 获取一些测试图像
    test_images = []
    for c in range(n_classes):
        for images, labels in test_loader:
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if len(idx) > 0:
                test_images.append(images[idx[0]].unsqueeze(0))
                break
    
    test_images = torch.cat(test_images, dim=0)
    
    # 转换回[0, 1]范围用于显示
    test_images_np = test_images.cpu().numpy() * 0.5 + 0.5
    
    # 创建画布
    plt.figure(figsize=(2*n_classes, 2*(n_classes+1)))
    
    # 显示真实图像
    for i in range(n_classes):
        ax = plt.subplot(n_classes+1, n_classes, i+1)
        if channel == 3:
            plt.imshow(np.transpose(test_images_np[i], (1, 2, 0)))
        else:
            plt.imshow(test_images_np[i].reshape(64, 64), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
    # 对每个类别生成图像
    with torch.no_grad():
        for c in range(n_classes):
            # 创建类别标签
            class_labels = torch.ones(n_classes, dtype=torch.long, device=device) * c
            
            # 生成图像
            generated_images = generator(latent_gen, class_labels)
            
            # 转换为NumPy并调整范围
            generated_images_np = generated_images.cpu().numpy() * 0.5 + 0.5
            
            # 显示生成的图像
            for i in range(n_classes):
                ax = plt.subplot(n_classes+1, n_classes, (i+1)*n_classes+1+c)
                if channel == 3:
                    plt.imshow(np.transpose(generated_images_np[i], (1, 2, 0)))
                else:
                    plt.imshow(generated_images_np[i].reshape(64, 64), cmap='gray')
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
    
    plt.savefig(f'bagan_gp_results/generated_plot_{epoch}.png')
    plt.close()

# 训练BAGAN-GP
print("开始训练BAGAN-GP...")
d_loss_history = []
g_loss_history = []
learning_steps = 50

for learning_step in range(learning_steps):
    print(f'学习步骤 # {learning_step + 1} {"-" * 50}')
    
    # 训练一个epoch
    d_loss, g_loss = train_epoch(discriminator, generator, train_loader, d_optimizer, g_optimizer, learning_step)
    
    # 记录损失
    d_loss_history.append(d_loss)
    g_loss_history.append(g_loss)
    
    # 每一步显示并保存生成的图像
    generate_and_save_images(generator, learning_step)
    
    # 每10步保存模型
    if (learning_step + 1) % 10 == 0:
        torch.save(generator.state_dict(), f'bagan_gp_results/generator_{learning_step}.pt')
        torch.save(discriminator.state_dict(), f'bagan_gp_results/discriminator_{learning_step}.pt')

# 绘制损失历史
plt.figure(figsize=(10, 5))
plt.plot(d_loss_history, label='D Loss')
plt.plot(g_loss_history, label='C Loss')
plt.legend()
plt.title('Train Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig('bagan_gp_results/loss_history.png')
plt.show()

In [None]:
# Save gif from generated images
import imageio
import os

# Define the directory containing the generated images
dir = 'bagan_gp_results/'

# Create the directory if it doesn't exist
if not os.path.exists(dir):
    os.makedirs(dir)

# Collect all the images
ims = []
for i in range(learning_steps):
    fname = 'generated_plot_%d.png' % i
    if fname in os.listdir(dir):
        print('loading png...', i)
        im = imageio.imread(dir + fname)
        ims.append(im)

# Check if any images were found
if ims:
    print('saving as gif...')
    imageio.mimsave(dir + 'training_demo.gif', ims, fps=3)
    print(f'GIF saved to {dir}training_demo.gif')
else:
    print('No images found to create GIF')

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import os

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 设置参数
latent_dim = 128
n_classes = 2  # 您的代码中有两个类别：0 (少数类) 和 1 (多数类)
channel = 1  # MNIST图像是单通道的

# 初始化模型组件
decoder = Decoder(latent_dim, channel).to(device)
embedding = LabelEmbedding(n_classes, latent_dim).to(device)
generator = Generator(embedding, decoder).to(device)

# 加载训练好的权重
# 使用保存的最新模型或指定的模型
model_path = 'bagan_gp_results/generator_49.pt'  # 假设您保存了50个epoch，使用最后一个
generator.load_state_dict(torch.load(model_path))
generator.eval()  # 设置为评估模式




def generate_minority_samples(generator, class_label=0, num_samples=100):
    """
    生成少数类样本
    
    参数:
        generator: 训练好的生成器模型
        class_label: 要生成的类别标签 (0表示少数类)
        num_samples: 要生成的样本数量
    
    返回:
        生成的图像数组
    """
    # 设置随机种子以便结果可重现
    torch.manual_seed(42)
    
    # 创建噪声向量
    noise = torch.randn(num_samples, latent_dim, device=device)
    
    # 创建类别标签
    labels = torch.full((num_samples,), class_label, dtype=torch.long, device=device)
    
    # 生成图像
    with torch.no_grad():
        generated_images = generator(noise, labels)
    
    # 转换为NumPy数组并调整范围到[0,1]用于显示
    generated_images_np = generated_images.cpu().numpy() * 0.5 + 0.5
    
    return generated_images_np

# 生成100个少数类样本
minority_samples = generate_minority_samples(generator, class_label=0, num_samples=100)



def plot_generated_samples(images, rows=10, cols=10):
    """
    显示生成的样本图像
    
    参数:
        images: 生成的图像数组
        rows: 行数
        cols: 列数
    """
    fig, axes = plt.subplots(rows, cols, figsize=(cols*1.5, rows*1.5))
    
    for i, ax in enumerate(axes.flatten()):
        if i < len(images):
            if images.shape[1] == 1:  # 单通道图像
                ax.imshow(images[i].reshape(64, 64), cmap='gray')
            else:  # 三通道图像
                ax.imshow(np.transpose(images[i], (1, 2, 0)))
            
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('bagan_gp_results/generated_minority_samples.png')
    plt.show()

# 显示生成的少数类样本
plot_generated_samples(minority_samples)



def save_generated_samples(images, output_dir='bagan_gp_results/generated_samples'):
    """
    保存生成的样本图像
    
    参数:
        images: 生成的图像数组
        output_dir: 输出目录
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for i, img in enumerate(images):
        # 转换为范围[0,255]的uint8类型
        if img.shape[0] == 1:  # 单通道图像
            img = (img.reshape(64, 64) * 255).astype(np.uint8)
        else:  # 三通道图像
            img = (np.transpose(img, (1, 2, 0)) * 255).astype(np.uint8)
        
        # 保存图像
        plt.imsave(f'{output_dir}/minority_sample_{i}.png', img, cmap='gray' if channel == 1 else None)

# 保存生成的少数类样本
save_generated_samples(minority_samples)




def augment_dataset_with_generated_samples(X_train, y_train, generated_samples, class_label=0):
    """
    使用生成的样本增强数据集
    
    参数:
        X_train: 原始训练特征
        y_train: 原始训练标签
        generated_samples: 生成的样本
        class_label: 生成样本的类别
    
    返回:
        增强后的特征和标签
    """
    # 将生成的样本转换为与原始数据相同的格式
    gen_samples = torch.tensor(generated_samples, dtype=torch.float32)
    
    # 创建标签
    gen_labels = torch.full((len(generated_samples),), class_label, dtype=torch.long)
    
    # 合并原始数据和生成的数据
    X_augmented = torch.cat([X_train, gen_samples], dim=0)
    y_augmented = torch.cat([y_train, gen_labels], dim=0)
    
    return X_augmented, y_augmented

# 假设X_train和y_train是您的原始训练数据
# X_augmented, y_augmented = augment_dataset_with_generated_samples(X_train, y_train, minority_samples, class_label=0)
# 
# # 创建新的数据加载器
# augmented_dataset = TensorDataset(X_augmented, y_augmented)
# augmented_loader = DataLoader(augmented_dataset, batch_size=128, shuffle=True)

In [None]:
# 导入必要的库
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import os
from scipy.linalg import sqrtm
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.models import inception_v3
from torchvision import transforms
from PIL import Image

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 加载Inception模型用于FID和IS计算
def load_inception_model():
    """加载预训练的Inception-v3模型"""
    model = inception_v3(pretrained=True, transform_input=False)
    # 移除最后的全连接层
    model.fc = torch.nn.Identity()
    model.eval()
    return model.to(device)

# 预处理图像函数
# 修改后的预处理图像函数
def preprocess_images(images, target_size=(299, 299)):
    """将图像预处理为Inception模型需要的格式"""
    # 如果输入是numpy数组，转换为torch张量
    if isinstance(images, np.ndarray):
        # 从[0,1]范围转换到[0,255]范围
        if images.max() <= 1.0:
            images = images * 255.0
        
        if len(images.shape) == 3:  # 单图像
            images = np.expand_dims(images, axis=0)
            
        # 确保图像是NCHW格式
        if images.shape[1] != 3 and images.shape[1] != 1:
            images = np.transpose(images, (0, 3, 1, 2))
        
        # 转换为PyTorch张量
        images = torch.from_numpy(images).float() / 255.0
    
    # 处理单通道图像 - 转换为3通道
    if images.size(1) == 1:
        images = images.repeat(1, 3, 1, 1)
    
    # 定义预处理变换
    preprocess = transforms.Compose([
        transforms.Resize(target_size),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return preprocess(images)
# 计算FID分数
def calculate_fid(real_images, generated_images, model, batch_size=32):
    """
    计算FID分数
    
    参数:
        real_images: 真实图像数组 (NCHW格式)
        generated_images: 生成的图像数组 (NCHW格式)
        model: 预训练的Inception v3模型
        batch_size: 批处理大小
    
    返回:
        FID分数
    """
    model.eval()
    
    # 获取真实图像的特征
    real_features = []
    n_batches = int(np.ceil(len(real_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Processing real images"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(real_images))
            batch = real_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            features = model(batch).cpu().numpy()
            real_features.append(features)
    
    real_features = np.vstack(real_features)
    
    # 获取生成图像的特征
    gen_features = []
    n_batches = int(np.ceil(len(generated_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Processing generated images"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(generated_images))
            batch = generated_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            features = model(batch).cpu().numpy()
            gen_features.append(features)
    
    gen_features = np.vstack(gen_features)
    
    # 计算均值和协方差
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_gen = np.mean(gen_features, axis=0)
    sigma_gen = np.cov(gen_features, rowvar=False)
    
    # 计算FID
    ssdiff = np.sum((mu_real - mu_gen) ** 2)
    covmean = sqrtm(sigma_real.dot(sigma_gen))
    
    # 检查并处理复数结果
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma_real + sigma_gen - 2 * covmean)
    
    return fid

# 计算IS分数
def calculate_is(generated_images, model, batch_size=32, splits=10):
    """
    计算Inception Score
    
    参数:
        generated_images: 生成的图像数组 (NCHW格式)
        model: 预训练的Inception v3模型 (带有全连接层)
        batch_size: 批处理大小
        splits: 用于计算均值和标准差的拆分数
    
    返回:
        IS均值和标准差
    """
    # 对于IS，我们需要完整的Inception模型，包括最后的分类层
    if not hasattr(model, 'fc') or isinstance(model.fc, torch.nn.Identity):
        print("Loading full Inception model for IS calculation...")
        full_model = inception_v3(pretrained=True, transform_input=False)
        full_model.eval()
        full_model = full_model.to(device)
    else:
        full_model = model
    
    # 获取所有生成图像的预测
    preds = []
    n_batches = int(np.ceil(len(generated_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Calculating IS"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(generated_images))
            batch = generated_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            pred = F.softmax(full_model(batch), dim=1).cpu().numpy()
            preds.append(pred)
    
    preds = np.vstack(preds)
    
    # 计算分割的IS并返回均值和标准差
    scores = []
    n_images = len(generated_images)
    n_split = int(np.floor(n_images / splits))
    
    for i in range(splits):
        part = preds[i * n_split:(i + 1) * n_split]
        kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
        kl = np.mean(np.sum(kl, axis=1))
        scores.append(np.exp(kl))
        
    return float(np.mean(scores)), float(np.std(scores))

# 加载真实数据集进行FID计算
def load_real_dataset():
    """加载真实数据集"""
    # 加载MNIST数据集
    from torchvision.datasets import MNIST
    from torchvision import transforms
    
    transform = transforms.Compose([transforms.ToTensor()])
    
    # 加载训练集
    dataset = MNIST(root='/Users/max/MasterThesisData/MNIST/', train=True, download=True, transform=transform)
    
    # 筛选类别为4的样本（对应于我们模型中的少数类0）
    indices = np.where(np.array(dataset.targets) == 4)[0]
    images = dataset.data[indices].numpy().reshape(-1, 28, 28, 1)
    
    # 调整为64x64大小
    resized_images = np.zeros((len(images), 64, 64, 1), dtype=np.float32)
    for i in range(len(images)):
        resized_images[i] = cv2.resize(images[i], (64, 64)).reshape(64, 64, 1)
    
    # 标准化到[-1, 1]
    normalized_images = (resized_images.astype('float32') - 127.5) / 127.5
    
    # 转换为PyTorch张量并调整通道顺序 (N,H,W,C) -> (N,C,H,W)
    real_images = torch.tensor(normalized_images, dtype=torch.float32).permute(0, 3, 1, 2)
    
    # 转换回[0, 1]范围用于计算FID
    real_images = real_images * 0.5 + 0.5
    
    return real_images

# 示例：评估生成的图像
def evaluate_generated_images():
    """评估生成的图像质量"""
    # 准备模型
    # 设置参数
    latent_dim = 128
    n_classes = 2  
    channel = 1  
    
    # 初始化模型组件
    decoder = Decoder(latent_dim, channel).to(device)
    embedding = LabelEmbedding(n_classes, latent_dim).to(device)
    generator = Generator(embedding, decoder).to(device)
    
    # 加载训练好的权重
    model_path = 'bagan_gp_results/generator_49.pt'
    generator.load_state_dict(torch.load(model_path))
    generator.eval()
    
    # 生成少数类样本
    print("Generating minority class samples...")
    minority_samples = generate_minority_samples(generator, class_label=0, num_samples=1000)
    
    # 加载真实样本
    print("Loading real samples...")
    real_samples = load_real_dataset()
    
    # 选择一部分真实样本进行评估
    if len(real_samples) > 1000:
        indices = np.random.choice(len(real_samples), 1000, replace=False)
        real_samples = real_samples[indices]
    
    # 加载Inception模型
    print("Loading Inception model...")
    inception_model = load_inception_model()
    
    # 计算FID
    print("Calculating FID score...")
    fid_score = calculate_fid(real_samples, minority_samples, inception_model)
    print(f"FID Score: {fid_score:.4f}")
    
    # 计算IS
    print("Calculating Inception Score...")
    is_mean, is_std = calculate_is(minority_samples, inception_model)
    print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
    
    # 保存结果
    results = {
        "FID": float(fid_score),
        "IS_mean": float(is_mean),
        "IS_std": float(is_std)
    }
    
    # 将结果保存到JSON文件
    import json
    with open('bagan_gp_results/evaluation_metrics.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"Results saved to 'bagan_gp_results/evaluation_metrics.json'")
    
    return fid_score, is_mean, is_std

# 执行评估
if __name__ == "__main__":
    import cv2
    # 确保结果可重现
    torch.manual_seed(42)
    np.random.seed(42)
    
    fid_score, is_mean, is_std = evaluate_generated_images()