In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import os
import random
from PIL import Image

# 设置随机种子以确保结果可重现
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

使用设备: cuda


In [2]:
# 创建不平衡的MNIST数据集
class ImbalancedMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, download=True, imbalance_ratio=0.005, num_classes = 2):
        """
        创建一个只包含数字3和4的不平衡MNIST数据集
        数字3映射为标签1,数字4映射为标签0
        imbalance_ratio: 少数类相对于多数类的样本比例
        """
        self.mnist = datasets.MNIST(root=root, train=train, transform=transform, download=download)
        self.num_classes = 2  # 只有两个类别: 0(数字4)和 1 (数字3)
        
        # 创建不平衡数据集
        self.indices = self._create_imbalanced_indices(imbalance_ratio)
        
    def _create_imbalanced_indices(self, imbalance_ratio):
        # 获取数字3和4的索引
        class_3_indices = []
        class_4_indices = []
        
        for idx, (_, label) in enumerate(self.mnist):
            if label == 3:
                class_3_indices.append(idx)
            elif label == 4:
                class_4_indices.append(idx)
        
        # 创建不平衡数据集索引
        selected_indices = []
        
        # 多数类(数字3-> 标签1)保持原样
        selected_indices.extend(class_3_indices)
        
        # 少数类(数字4 -> 标签0)减少样本
        n_samples = int(len(class_4_indices) * imbalance_ratio)
        selected_indices.extend(class_4_indices[:n_samples])
        
        return selected_indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        img, label = self.mnist[self.indices[idx]]
        
        # 将原始标签映射为新标签: 3 -> 1, 4 -> 0
        if label == 3:
            new_label = 1
        elif label == 4:
            new_label = 0
        else:
            raise ValueError(f"意外的标签: {label}")
        
        return img, new_label

In [3]:
# 定义生成器网络 - 适用于MNIST
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes=10):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # 嵌入层处理类别标签
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        
        # 初始线性层
        self.linear = nn.Sequential(
            nn.Linear(latent_dim * 2, 128 * 7 * 7)
        )
        
        # 卷积层
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # 嵌入标签
        label_embedding = self.label_emb(labels)
        # 将噪声和标签嵌入连接起来
        x = torch.cat([noise, label_embedding], dim=1)
        # 线性层
        x = self.linear(x)
        # 重塑为卷积特征图
        x = x.view(x.shape[0], 128, 7, 7)
        # 卷积层
        img = self.conv_blocks(x)
        return img

# 定义判别器网络 - 适用于MNIST
class Discriminator(nn.Module):
    def __init__(self, num_classes=10):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes
        
        # 特征提取器
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
        )
        
        # 正确计算展平后的尺寸
        self.flatten_size = 64 * 4 * 4  # 修改前为 64*3*3
        
        # 真假判别器
        self.adv_layer = nn.Sequential(
            nn.Linear(self.flatten_size, 1),
            nn.Sigmoid()
        )
        
        # 类别分类器
        self.aux_layer = nn.Sequential(
            nn.Linear(self.flatten_size, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, img):
        features = self.features(img)
        features = features.view(features.shape[0], -1)
        validity = self.adv_layer(features)
        label = self.aux_layer(features)
        return validity, label


In [4]:
# 定义带有注意力机制的自动编码器网络 - 适用于MNIST
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        b, c, _, _ = x.size()
        
        # 平均池化特征
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        # 最大池化特征
        max_out = self.fc(self.max_pool(x).view(b, c))
        
        out = avg_out + max_out
        return self.sigmoid(out).view(b, c, 1, 1) * x

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), '空间注意力核大小必须为3或7'
        padding = 3 if kernel_size == 7 else 1
        
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        b, c, h, w = x.size()
        
        # 沿着通道维度计算平均值和最大值
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        
        # 拼接特征
        x_cat = torch.cat([avg_out, max_out], dim=1)
        
        # 应用卷积和激活函数
        out = self.conv(x_cat)
        
        return self.sigmoid(out) * x

class Autoencoder(nn.Module):
    def __init__(self, latent_dim=100):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        
        # 编码器
        self.encoder_block1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.ca1 = ChannelAttention(16)
        self.sa1 = SpatialAttention()
        
        self.encoder_block2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 14x14 -> 7x7
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.ca2 = ChannelAttention(32)
        self.sa2 = SpatialAttention()
        
        self.encoder_block3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=1, padding=1),  # 7x7 -> 7x7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.ca3 = ChannelAttention(64)
        self.sa3 = SpatialAttention()
        
        # 将特征图展平并映射到潜在空间
        self.fc = nn.Linear(64 * 7 * 7, latent_dim)
        
        # 解码器输入层
        self.decoder_input = nn.Linear(latent_dim, 64 * 7 * 7)
        
        # 解码器
        self.decoder_block1 = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.ca4 = ChannelAttention(32)
        self.sa4 = SpatialAttention()
        
        self.upsample1 = nn.Upsample(scale_factor=2)  # 7x7 -> 14x14
        
        self.decoder_block2 = nn.Sequential(
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.ca5 = ChannelAttention(16)
        self.sa5 = SpatialAttention()
        
        self.upsample2 = nn.Upsample(scale_factor=2)  # 14x14 -> 28x28
        
        self.output_layer = nn.Sequential(
            nn.Conv2d(16, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def encode(self, img):
        # 编码器前向传播，应用注意力机制
        x = self.encoder_block1(img)
        x = self.ca1(x)
        x = self.sa1(x)
        
        x = self.encoder_block2(x)
        x = self.ca2(x)
        x = self.sa2(x)
        
        x = self.encoder_block3(x)
        x = self.ca3(x)
        x = self.sa3(x)
        
        x = x.view(x.shape[0], -1)
        z = self.fc(x)
        return z
    
    def decode(self, z):
        # 解码器前向传播，应用注意力机制
        x = self.decoder_input(z)
        x = x.view(x.shape[0], 64, 7, 7)
        
        x = self.decoder_block1(x)
        x = self.ca4(x)
        x = self.sa4(x)
        
        x = self.upsample1(x)
        
        x = self.decoder_block2(x)
        x = self.ca5(x)
        x = self.sa5(x)
        
        x = self.upsample2(x)
        img = self.output_layer(x)
        
        return img
    
    def forward(self, img):
        z = self.encode(img)
        reconstructed = self.decode(z)
        return reconstructed

In [5]:
def create_balanced_mini_batches(dataset, batch_size, num_classes):
    samples_by_class = [[] for _ in range(num_classes)]
    for idx, (img, label) in enumerate(dataset):
        samples_by_class[label].append((img, label, idx))
    
    # Count samples per class
    samples_count = [len(samples) for samples in samples_by_class]
    
    # Determine maximum class size
    max_class_size = max(samples_count)
    
    # Calculate samples per class per batch (M/N)
    samples_per_class_per_batch = batch_size // num_classes
    
    # Calculate number of mini-batches (K)
    num_batches = max_class_size // samples_per_class_per_batch
    
    # Initialize empty mini-batches
    balanced_batches = [[] for _ in range(num_batches)]
    
    # For each mini-batch
    for i in range(num_batches):
        # For each class
        for j in range(num_classes):
            # Randomly choose M/N samples from class j
            class_samples = samples_by_class[j]
            if len(class_samples) < samples_per_class_per_batch:
                # For smaller classes, sample with replacement
                selected_samples = random.choices(class_samples, k=samples_per_class_per_batch)
            else:
                # For larger classes, sample without replacement
                selected_indices = random.sample(range(len(class_samples)), samples_per_class_per_batch)
                selected_samples = [class_samples[idx] for idx in selected_indices]
                
                # If this class is the largest, remove the selected samples
                if samples_count[j] == max_class_size:
                    # Remove in reverse order to avoid index shifting issues
                    for idx in sorted(selected_indices, reverse=True):
                        class_samples.pop(idx)
            
            # Add selected samples to current mini-batch
            balanced_batches[i].extend(selected_samples)
    return balanced_batches
# 定义BAGAN类
class BAGAN:
    def __init__(self, latent_dim=100, batch_size=64, root='./data', imbalance_ratio=0.005, num_classes=2):
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.imbalance_ratio = imbalance_ratio
        self.num_classes = 2  # 修改为2个类别：0(原数字4)和1(原数字3)
        
        # 数据转换
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # MNIST是单通道，所以只需一个值
        ])
        
        # 创建不平衡的MNIST数据集（只包含数字3和4）
        self.dataset = ImbalancedMNIST(
            root=root,
            train=True,
            transform=self.transform,
            download=True,
            imbalance_ratio=imbalance_ratio,
            num_classes= 2,
        )
        
        # 初始化网络
        self.autoencoder = Autoencoder(latent_dim).to(device)
        self.generator = Generator(latent_dim, self.num_classes).to(device)
        self.discriminator = Discriminator(self.num_classes).to(device)
        
        # 分析类别分布
        self.class_counts = self._get_class_distribution()
        print(f"Class Distribution: {self.class_counts}")
        
        # 计算类别权重以进行平衡采样
        self.weights = self._compute_weights()
        
    def _get_class_distribution(self):
        counts = Counter()
        for _, label in self.dataset:
            # 检查label是张量还是整数
            if hasattr(label, 'item'):
                counts[label.item()] += 1
            else:
                counts[label] += 1
        return counts
    
    def _compute_weights(self):
        max_count = max(self.class_counts.values())
        weights = []
        for _, label in self.dataset:
            # 检查label是张量还是整数
            label_idx = label.item() if torch.is_tensor(label) else label
            count = self.class_counts[label_idx]
            weight = max_count / count if count > 0 else 0
            weights.append(weight)
        return weights
    
    def _create_dataloaders(self):
        # 为整个数据集创建加载器
        dataloader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )
        
        # 为每个类别创建单独的加载器
        class_loaders = {}
        for class_idx in range(self.num_classes):
            # 筛选该类别的样本
            indices = [i for i, (_, y) in enumerate(self.dataset) if 
                      (y.item() if torch.is_tensor(y) else y) == class_idx]
            if indices:  # 确保该类别有样本
                class_subset = Subset(self.dataset, indices)
                class_loaders[class_idx] = DataLoader(
                    class_subset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    num_workers=4
                )
        
        return dataloader, class_loaders
    
    def pretrain_autoencoder(self, epochs=50, lr=0.0002):
        """预训练自动编码器"""
        print("预训练自动编码器...")
        
        # 创建数据加载器
        _, class_loaders = self._create_dataloaders()
        
        # 为自动编码器设置优化器
        optimizer = optim.Adam(self.autoencoder.parameters(), lr=lr, betas=(0.5, 0.999))
        criterion = nn.MSELoss()
        
        # 为每个类别存储潜在表示的均值和方差
        self.latent_means = torch.zeros(self.num_classes, self.latent_dim).to(device)
        self.latent_vars = torch.ones(self.num_classes, self.latent_dim).to(device)
        
        self.autoencoder.train()
        for epoch in range(epochs):
            total_loss = 0
            samples_count = 0
            
            # 每个类别的数据加载器
            for class_idx, loader in class_loaders.items():
                class_latent_vectors = []
                
                for i, (imgs, _) in enumerate(loader):
                    imgs = imgs.to(device)
                    
                    # 重置梯度
                    optimizer.zero_grad()
                    
                    # 自动编码器前向传播
                    latent = self.autoencoder.encode(imgs)
                    reconstructed = self.autoencoder.decode(latent)
                    
                    # 记录潜在向量
                    class_latent_vectors.append(latent.detach())
                    
                    # 计算损失
                    loss = criterion(reconstructed, imgs)
                    
                    # 反向传播和优化
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item() * imgs.size(0)
                    samples_count += imgs.size(0)
                
                # 计算该类别的潜在向量的均值和方差
                if class_latent_vectors:
                    class_latent = torch.cat(class_latent_vectors, dim=0)
                    self.latent_means[class_idx] = class_latent.mean(dim=0)
                    self.latent_vars[class_idx] = class_latent.var(dim=0)
            
            avg_loss = total_loss / samples_count if samples_count > 0 else 0
            print(f"Epoch [{epoch+1}/{epochs}] Autoencoder Loss: {avg_loss:.4f}")
        
        # 将预训练的解码器权重初始化生成器对应层
        print("将自动编码器知识转移到生成器...")
        self._init_generator_from_autoencoder()
    
    def _init_generator_from_autoencoder(self):
        """将自动编码器知识转移到生成器"""
        # 设置嵌入层来表示潜在空间中的类别均值
        with torch.no_grad():
            for class_idx in range(self.num_classes):
                self.generator.label_emb.weight.data[class_idx] = self.latent_means[class_idx]
    ########################################

    # Incorporating the balanced mini-batch approach into training
    def train_gan_with_balanced_batches(self, epochs = 100, sample_interval=200):
        # Define losses
        adversarial_loss = torch.nn.BCELoss()
        auxiliary_loss = torch.nn.CrossEntropyLoss()
        
        # Setup optimizers
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Parameters for balanced batching
        batch_size = 64  # Total batch size (M)
        samples_per_class = batch_size // self.num_classes  # Samples per class per batch (M/N)
        
        for epoch in range(epochs):
            # Create balanced mini-batches for this epoch
            balanced_batches = create_balanced_mini_batches(self.dataset, batch_size, self.num_classes)
            
            for i, batch in enumerate(balanced_batches):
                # Extract images and labels from the batch
                batch_imgs = []
                batch_labels = []
                for img, label, _ in batch:
                    batch_imgs.append(img)
                    batch_labels.append(label)
                
                real_imgs = torch.stack(batch_imgs).to(device)
                labels = torch.tensor(batch_labels).to(device)
                
                batch_size = real_imgs.size(0)
                
                # Create labels
                valid = torch.ones(batch_size, 1).to(device)
                fake = torch.zeros(batch_size, 1).to(device)
                
                # Add label smoothing for stability
                valid = valid - 0.1 * torch.rand(valid.shape).to(device)
                fake = fake + 0.1 * torch.rand(fake.shape).to(device)
                
                # ---------------------
                #  Train Discriminator
                # ---------------------
                optimizer_D.zero_grad()
                
                # Real images loss
                real_pred, real_aux = self.discriminator(real_imgs)
                d_real_loss = 0.5 * (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels))
                
                # Generate fake images
                z = torch.randn(batch_size, self.latent_dim).to(device)
                gen_labels = torch.randint(0, self.num_classes, (batch_size,)).to(device)
                
                # Add class-specific statistics to noise
                for idx in range(batch_size):
                    class_idx = gen_labels[idx].item()
                    z[idx] = z[idx] * torch.sqrt(self.latent_vars[class_idx]) + self.latent_means[class_idx]
                
                gen_imgs = self.generator(z, gen_labels)
                
                # Fake images loss
                fake_pred, fake_aux = self.discriminator(gen_imgs.detach())
                d_fake_loss = 0.5 * (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels))
                
                # Total discriminator loss
                d_loss = 0.5 * (d_real_loss + d_fake_loss)
                d_loss.backward()
                optimizer_D.step()
                
                # -----------------
                #  Train Generator
                # -----------------
                optimizer_G.zero_grad()
                
                # Generate new batch of images
                z = torch.randn(batch_size, self.latent_dim).to(device)
                gen_labels = torch.randint(0, self.num_classes, (batch_size,)).to(device)
                
                # Add class-specific statistics
                for idx in range(batch_size):
                    class_idx = gen_labels[idx].item()
                    z[idx] = z[idx] * torch.sqrt(self.latent_vars[class_idx]) + self.latent_means[class_idx]
                
                gen_imgs = self.generator(z, gen_labels)
                
                # Generator loss
                validity, pred_label = self.discriminator(gen_imgs)
                g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
                
                g_loss.backward()
                optimizer_G.step()
                
                # Print progress
                if i % 20 == 0:
                    print(
                        f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(balanced_batches)}] "
                        f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                    )
                
                batches_done = epoch * len(balanced_batches) + i
                if batches_done % sample_interval == 0:
                    self.sample_images(batches_done)
    #########################################
    def train(self, epochs=200, lr=0.0002, b1=0.5, b2=0.999, sample_interval=200):
        """训练BAGAN"""
        print("开始训练BAGAN...")
        
        # 创建数据加载器
        dataloader, _ = self._create_dataloaders()
        
        # 损失函数
        adversarial_loss = nn.BCELoss()
        auxiliary_loss = nn.CrossEntropyLoss()
        
        # 优化器
        optimizer_G = optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        
        for epoch in range(epochs):
            for i, (real_imgs, labels) in enumerate(dataloader):
                batch_size = real_imgs.size(0)
                
                # 配置输入
                real_imgs = real_imgs.to(device)
                labels = labels.to(device)
                
                # 创建标签
                valid = torch.ones(batch_size, 1).to(device)
                fake = torch.zeros(batch_size, 1).to(device)
                
                # -----------------
                #  训练生成器
                # -----------------
                
                optimizer_G.zero_grad()
                
                # 采样噪声和标签作为生成器输入
                z = torch.randn(batch_size, self.latent_dim).to(device)
                gen_labels = torch.randint(0, self.num_classes, (batch_size,)).to(device)
                
                # 为生成的噪声添加类别特定的统计信息
                for idx in range(batch_size):
                    class_idx = gen_labels[idx].item()
                    z[idx] = z[idx] * torch.sqrt(self.latent_vars[class_idx]) + self.latent_means[class_idx]
                
                # 生成一批假图像
                gen_imgs = self.generator(z, gen_labels)
                
                # 计算生成器的损失
                validity, pred_label = self.discriminator(gen_imgs)
                g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
                
                g_loss.backward()
                optimizer_G.step()
                
                # ---------------------
                #  训练判别器
                # ---------------------
                
                optimizer_D.zero_grad()
                
                # 真实图像的损失
                real_pred, real_aux = self.discriminator(real_imgs)
                d_real_loss = 0.5 * (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels))
                
                # 生成图像的损失
                fake_pred, fake_aux = self.discriminator(gen_imgs.detach())
                d_fake_loss = 0.5 * (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels))
                
                # 总判别器损失
                d_loss = 0.5 * (d_real_loss + d_fake_loss)
                
                d_loss.backward()
                optimizer_D.step()
                
                # 打印训练进度
                if i % 50 == 0:
                    print(
                        f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                        f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                    )
                
                batches_done = epoch * len(dataloader) + i
                if batches_done % sample_interval == 0:
                    self.sample_images(batches_done)
    
    def sample_images(self, batches_done):
        """保存采样的图像"""
        # 为每个类别生成样本
        n_row, n_col = 1, 2  # 1行，2个类别
        fig, axs = plt.subplots(n_row, n_col, figsize=(n_col * 2, n_row * 2))
        
        # 生成每个类别的样本
        with torch.no_grad():
            for i, class_idx in enumerate(range(self.num_classes)):
                # 生成该类别的噪声和标签
                z = torch.randn(1, self.latent_dim).to(device)
                label = torch.tensor([class_idx], device=device)
                
                # 为噪声添加类别特定的统计信息
                z = z * torch.sqrt(self.latent_vars[class_idx]) + self.latent_means[class_idx]
                
                # 生成图像
                gen_img = self.generator(z, label)
                
                # 显示图像
                img = gen_img[0].cpu().detach().numpy()
                img = (img + 1) / 2  # 从[-1, 1]转换到[0, 1]
                img = img.reshape(28, 28)
                
                # 原始类别映射
                if class_idx == 0:
                    original_digit = "4"
                else:
                    original_digit = "3"
                #original_digit = "4" if class_idx == 0 else "3"
                axs[i].imshow(img, cmap='gray')
                axs[i].set_title(f"Class {class_idx} (Original Digit {original_digit})")
                axs[i].axis('off')
        
        plt.tight_layout()
        
        # 创建保存目录
        save_dir = "bagan_mnist_binary_samples"
        os.makedirs(save_dir, exist_ok=True)
        
        # 保存图像
        plt.savefig(f"{save_dir}/sample_{batches_done}.png")
        plt.close()
    
    def generate_balanced_dataset(self, samples_per_class=1000, output_dir="./augmented_mnist_binary"):
        """生成平衡数据集 - 针对二分类场景(数字3和4)"""
        print(f"为每个类别生成 {samples_per_class} 个样本...")
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        for class_idx in range(self.num_classes):  # 应该只有2个类别
            os.makedirs(os.path.join(output_dir, str(class_idx)), exist_ok=True)
        
        # 类别名称映射（用于更清晰的输出信息）
        class_name = {0: "Number 4(Class 0)", 1: "Number 3(Class 1)"}
        
        # 确认是二分类模式
        if self.num_classes != 2:
            print(f"警告: 当前设置为{self.num_classes}个类别,而不是预期的2个类别")
        
        self.generator.eval()
        with torch.no_grad():
            for class_idx in range(self.num_classes):
                if class_idx >= 2:  # 确保只处理0和1两个类别
                    print(f"跳过类别 {class_idx}，因为当前是二分类模式")
                    continue
                    
                # 计算需要生成的额外样本数
                real_samples = self.class_counts.get(class_idx, 0)
                if real_samples >= samples_per_class:
                    print(f"{class_name[class_idx]} 已经有 {real_samples} 个样本，不需要增强")
                    continue
                
                to_generate = samples_per_class - real_samples
                print(f"为{class_name[class_idx]}生成 {to_generate} 个额外样本")
                
                # 批次生成
                batch_size = min(self.batch_size, to_generate)
                num_batches = to_generate // batch_size + (1 if to_generate % batch_size != 0 else 0)
                
                for batch in range(num_batches):
                    current_batch_size = min(batch_size, to_generate - batch * batch_size)
                    
                    # 生成噪声和标签
                    z = torch.randn(current_batch_size, self.latent_dim).to(device)
                    labels = torch.full((current_batch_size,), class_idx, dtype=torch.long).to(device)
                    
                    # 为噪声添加类别特定的统计信息
                    for idx in range(current_batch_size):
                        z[idx] = z[idx] * torch.sqrt(self.latent_vars[class_idx]) + self.latent_means[class_idx]
                    
                    # 生成图像
                    gen_imgs = self.generator(z, labels)
                    
                    # 保存生成的图像
                    for idx, img in enumerate(gen_imgs):
                        img_idx = batch * batch_size + idx
                        img = img.cpu().detach().numpy()
                        img = (img + 1) / 2  # 从[-1, 1]转换到[0, 1]
                        img = img.reshape(28, 28) * 255
                        img = img.astype(np.uint8)
                        img = Image.fromarray(img, mode='L')  # 灰度图像
                        img.save(os.path.join(output_dir, str(class_idx), f"gen_{img_idx}.png"))
        
        # 统计生成后的数据集大小
        total_samples = {0: 0, 1: 0}
        for class_idx in range(2):  # 只计算二分类
            class_dir = os.path.join(output_dir, str(class_idx))
            if os.path.exists(class_dir):
                files = [f for f in os.listdir(class_dir) if f.endswith('.png')]
                total_samples[class_idx] = len(files)
        
        print(f"数据增强完成！增强后的数据集保存在 {output_dir}")
        print(f"最终数据集统计：")
        print(f"- {class_name[0]}: {total_samples[0]} 个样本")
        print(f"- {class_name[1]}: {total_samples[1]} 个样本")

    def evaluate_model(self, test_loader):
        """评估模型在测试集上的性能"""
        self.discriminator.eval()
        correct = 0
        total = 0
        
        class_name = {0: "Number 4(Class 0)", 1: "Number 3(Class 1)"}
        class_correct = {0: 0, 1: 0}
        class_total = {0: 0, 1: 0}
        
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                
                validity, pred_labels = self.discriminator(imgs)
                _, predicted = torch.max(pred_labels.data, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                # 计算每个类别的准确率
                for i in range(len(labels)):
                    label = labels[i].item()
                    class_total[label] += 1
                    if predicted[i] == label:
                        class_correct[label] += 1
        
        accuracy = 100 * correct / total
        print(f"在测试集上的总体准确率: {accuracy:.2f}%")
        
        # 打印每个类别的准确率
        for class_idx in range(self.num_classes):
            if class_total[class_idx] > 0:
                class_acc = 100 * class_correct[class_idx] / class_total[class_idx]
                print(f"{class_name[class_idx]}准确率: {class_acc:.2f}% ({class_correct[class_idx]}/{class_total[class_idx]})")

In [6]:
# 示例用法
def main():
    # 设置参数
    latent_dim = 100
    batch_size = 64
    imbalance_ratio = 0.005  # 少数类样本数量为多数类的10%
    
    # 创建BAGAN实例
    bagan = BAGAN(
        latent_dim=latent_dim,
        batch_size=batch_size,
        root='./data',
        imbalance_ratio=imbalance_ratio,
        num_classes=2  # 明确指定二分类
    )
    
    # 预训练自动编码器
    bagan.pretrain_autoencoder(epochs=30)
    
    # 训练BAGAN
    #bagan.train(epochs=50, sample_interval=500)
    bagan.train_gan_with_balanced_batches(epochs=50, sample_interval=500)
    # 生成平衡数据集
    bagan.generate_balanced_dataset(samples_per_class=1000)
    
    # 创建只包含数字3和4的MNIST测试集
    mnist_test = datasets.MNIST(
        root='./data',
        train=False,
        transform=bagan.transform,
        download=True
    )
    
    # 筛选出数字3和4，并重新映射标签：数字4->类别0，数字3->类别1
    idx = (mnist_test.targets == 3) | (mnist_test.targets == 4)
    mnist_test.data = mnist_test.data[idx]
    mnist_test.targets = mnist_test.targets[idx]
    # 重新映射标签
    mnist_test.targets = (mnist_test.targets == 3).long()
    
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
    
    # 评估模型
    bagan.evaluate_model(test_loader)
    
    # 可视化二分类的生成结果
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    
    class_name = {0: "Number 4(Class 0)", 1: "Number 3(Class 1)"}
    
    with torch.no_grad():
        for i in range(2):  # 只有2个类别
            # 生成噪声和标签
            z = torch.randn(1, latent_dim).to(device)
            label = torch.tensor([i], device=device)
            
            # 使用类别特定的统计信息
            z = z * torch.sqrt(bagan.latent_vars[i]) + bagan.latent_means[i]
            
            # 生成图像
            gen_img = bagan.generator(z, label)
            
            # 显示图像
            img = gen_img[0].cpu().detach().numpy()
            img = (img + 1) / 2  # 从[-1, 1]转换到[0, 1]
            img = img.reshape(28, 28)
            axes[i].imshow(img, cmap='gray')
            axes[i].set_title(f"{class_name[i]}")
            axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig("mnist_binary_bagan_samples.png")
    plt.close()
    
    # 额外：为每个类别生成多个样本并展示
    n_samples = 5
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples*2, 4))
    
    with torch.no_grad():
        for class_idx in range(2):
            for j in range(n_samples):
                # 生成噪声和标签
                z = torch.randn(1, latent_dim).to(device)
                label = torch.tensor([class_idx], device=device)
                
                # 使用类别特定的统计信息
                z = z * torch.sqrt(bagan.latent_vars[class_idx]) + bagan.latent_means[class_idx]
                
                # 生成图像
                gen_img = bagan.generator(z, label)
                
                # 显示图像
                img = gen_img[0].cpu().detach().numpy()
                img = (img + 1) / 2  # 从[-1, 1]转换到[0, 1]
                img = img.reshape(28, 28)
                axes[class_idx, j].imshow(img, cmap='gray')
                axes[class_idx, j].axis('off')
                
            # 为每一行添加类别标签
            axes[class_idx, 0].set_ylabel(class_name[class_idx])
    
    plt.tight_layout()
    plt.savefig("mnist_binary_bagan_multiple_samples.png")
    plt.close()

if __name__ == "__main__":
    main()

Class Distribution: Counter({1: 6131, 0: 29})
预训练自动编码器...
Epoch [1/30] Autoencoder Loss: 0.5766
Epoch [2/30] Autoencoder Loss: 0.2345
Epoch [3/30] Autoencoder Loss: 0.1143
Epoch [4/30] Autoencoder Loss: 0.0723
Epoch [5/30] Autoencoder Loss: 0.0527
Epoch [6/30] Autoencoder Loss: 0.0416
Epoch [7/30] Autoencoder Loss: 0.0346
Epoch [8/30] Autoencoder Loss: 0.0299
Epoch [9/30] Autoencoder Loss: 0.0262
Epoch [10/30] Autoencoder Loss: 0.0239
Epoch [11/30] Autoencoder Loss: 0.0218
Epoch [12/30] Autoencoder Loss: 0.0200
Epoch [13/30] Autoencoder Loss: 0.0191
Epoch [14/30] Autoencoder Loss: 0.0180
Epoch [15/30] Autoencoder Loss: 0.0168
Epoch [16/30] Autoencoder Loss: 0.0160
Epoch [17/30] Autoencoder Loss: 0.0155
Epoch [18/30] Autoencoder Loss: 0.0148
Epoch [19/30] Autoencoder Loss: 0.0145
Epoch [20/30] Autoencoder Loss: 0.0138
Epoch [21/30] Autoencoder Loss: 0.0137
Epoch [22/30] Autoencoder Loss: 0.0130
Epoch [23/30] Autoencoder Loss: 0.0129
Epoch [24/30] Autoencoder Loss: 0.0127
Epoch [25/30] A