In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import os

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 设置参数
latent_dim = 128
n_classes = 2  # 您的代码中有两个类别：0 (少数类) 和 1 (多数类)
channel = 1  # MNIST图像是单通道的

# 初始化模型组件
decoder = Decoder(latent_dim, channel).to(device)
embedding = LabelEmbedding(n_classes, latent_dim).to(device)
generator = Generator(embedding, decoder).to(device)

# 加载训练好的权重
# 使用保存的最新模型或指定的模型
model_path = 'bagan_gp_results/generator_49.pt'  # 假设您保存了50个epoch，使用最后一个
generator.load_state_dict(torch.load(model_path))
generator.eval()  # 设置为评估模式




def generate_minority_samples(generator, class_label=0, num_samples=100):
    """
    生成少数类样本
    
    参数:
        generator: 训练好的生成器模型
        class_label: 要生成的类别标签 (0表示少数类)
        num_samples: 要生成的样本数量
    
    返回:
        生成的图像数组
    """
    # 设置随机种子以便结果可重现
    torch.manual_seed(42)
    
    # 创建噪声向量
    noise = torch.randn(num_samples, latent_dim, device=device)
    
    # 创建类别标签
    labels = torch.full((num_samples,), class_label, dtype=torch.long, device=device)
    
    # 生成图像
    with torch.no_grad():
        generated_images = generator(noise, labels)
    
    # 转换为NumPy数组并调整范围到[0,1]用于显示
    generated_images_np = generated_images.cpu().numpy() * 0.5 + 0.5
    
    return generated_images_np

# 生成100个少数类样本
minority_samples = generate_minority_samples(generator, class_label=0, num_samples=100)



def plot_generated_samples(images, rows=10, cols=10):
    """
    显示生成的样本图像
    
    参数:
        images: 生成的图像数组
        rows: 行数
        cols: 列数
    """
    fig, axes = plt.subplots(rows, cols, figsize=(cols*1.5, rows*1.5))
    
    for i, ax in enumerate(axes.flatten()):
        if i < len(images):
            if images.shape[1] == 1:  # 单通道图像
                ax.imshow(images[i].reshape(64, 64), cmap='gray')
            else:  # 三通道图像
                ax.imshow(np.transpose(images[i], (1, 2, 0)))
            
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('bagan_gp_results/generated_minority_samples.png')
    plt.show()

# 显示生成的少数类样本
plot_generated_samples(minority_samples)



def save_generated_samples(images, output_dir='bagan_gp_results/generated_samples'):
    """
    保存生成的样本图像
    
    参数:
        images: 生成的图像数组
        output_dir: 输出目录
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for i, img in enumerate(images):
        # 转换为范围[0,255]的uint8类型
        if img.shape[0] == 1:  # 单通道图像
            img = (img.reshape(64, 64) * 255).astype(np.uint8)
        else:  # 三通道图像
            img = (np.transpose(img, (1, 2, 0)) * 255).astype(np.uint8)
        
        # 保存图像
        plt.imsave(f'{output_dir}/minority_sample_{i}.png', img, cmap='gray' if channel == 1 else None)

# 保存生成的少数类样本
save_generated_samples(minority_samples)




def augment_dataset_with_generated_samples(X_train, y_train, generated_samples, class_label=0):
    """
    使用生成的样本增强数据集
    
    参数:
        X_train: 原始训练特征
        y_train: 原始训练标签
        generated_samples: 生成的样本
        class_label: 生成样本的类别
    
    返回:
        增强后的特征和标签
    """
    # 将生成的样本转换为与原始数据相同的格式
    gen_samples = torch.tensor(generated_samples, dtype=torch.float32)
    
    # 创建标签
    gen_labels = torch.full((len(generated_samples),), class_label, dtype=torch.long)
    
    # 合并原始数据和生成的数据
    X_augmented = torch.cat([X_train, gen_samples], dim=0)
    y_augmented = torch.cat([y_train, gen_labels], dim=0)
    
    return X_augmented, y_augmented

# 假设X_train和y_train是您的原始训练数据
# X_augmented, y_augmented = augment_dataset_with_generated_samples(X_train, y_train, minority_samples, class_label=0)
# 
# # 创建新的数据加载器
# augmented_dataset = TensorDataset(X_augmented, y_augmented)
# augmented_loader = DataLoader(augmented_dataset, batch_size=128, shuffle=True)

In [None]:
# 导入必要的库
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import os
from scipy.linalg import sqrtm
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.models import inception_v3
from torchvision import transforms
from PIL import Image

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 加载Inception模型用于FID和IS计算
def load_inception_model():
    """加载预训练的Inception-v3模型"""
    model = inception_v3(pretrained=True, transform_input=False)
    # 移除最后的全连接层
    model.fc = torch.nn.Identity()
    model.eval()
    return model.to(device)

# 预处理图像函数
# 修改后的预处理图像函数
def preprocess_images(images, target_size=(299, 299)):
    """将图像预处理为Inception模型需要的格式"""
    # 如果输入是numpy数组，转换为torch张量
    if isinstance(images, np.ndarray):
        # 从[0,1]范围转换到[0,255]范围
        if images.max() <= 1.0:
            images = images * 255.0
        
        if len(images.shape) == 3:  # 单图像
            images = np.expand_dims(images, axis=0)
            
        # 确保图像是NCHW格式
        if images.shape[1] != 3 and images.shape[1] != 1:
            images = np.transpose(images, (0, 3, 1, 2))
        
        # 转换为PyTorch张量
        images = torch.from_numpy(images).float() / 255.0
    
    # 处理单通道图像 - 转换为3通道
    if images.size(1) == 1:
        images = images.repeat(1, 3, 1, 1)
    
    # 定义预处理变换
    preprocess = transforms.Compose([
        transforms.Resize(target_size),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return preprocess(images)
# 计算FID分数
def calculate_fid(real_images, generated_images, model, batch_size=32):
    """
    计算FID分数
    
    参数:
        real_images: 真实图像数组 (NCHW格式)
        generated_images: 生成的图像数组 (NCHW格式)
        model: 预训练的Inception v3模型
        batch_size: 批处理大小
    
    返回:
        FID分数
    """
    model.eval()
    
    # 获取真实图像的特征
    real_features = []
    n_batches = int(np.ceil(len(real_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Processing real images"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(real_images))
            batch = real_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            features = model(batch).cpu().numpy()
            real_features.append(features)
    
    real_features = np.vstack(real_features)
    
    # 获取生成图像的特征
    gen_features = []
    n_batches = int(np.ceil(len(generated_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Processing generated images"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(generated_images))
            batch = generated_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            features = model(batch).cpu().numpy()
            gen_features.append(features)
    
    gen_features = np.vstack(gen_features)
    
    # 计算均值和协方差
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_gen = np.mean(gen_features, axis=0)
    sigma_gen = np.cov(gen_features, rowvar=False)
    
    # 计算FID
    ssdiff = np.sum((mu_real - mu_gen) ** 2)
    covmean = sqrtm(sigma_real.dot(sigma_gen))
    
    # 检查并处理复数结果
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma_real + sigma_gen - 2 * covmean)
    
    return fid

# 计算IS分数
def calculate_is(generated_images, model, batch_size=32, splits=10):
    """
    计算Inception Score
    
    参数:
        generated_images: 生成的图像数组 (NCHW格式)
        model: 预训练的Inception v3模型 (带有全连接层)
        batch_size: 批处理大小
        splits: 用于计算均值和标准差的拆分数
    
    返回:
        IS均值和标准差
    """
    # 对于IS，我们需要完整的Inception模型，包括最后的分类层
    if not hasattr(model, 'fc') or isinstance(model.fc, torch.nn.Identity):
        print("Loading full Inception model for IS calculation...")
        full_model = inception_v3(pretrained=True, transform_input=False)
        full_model.eval()
        full_model = full_model.to(device)
    else:
        full_model = model
    
    # 获取所有生成图像的预测
    preds = []
    n_batches = int(np.ceil(len(generated_images) / batch_size))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Calculating IS"):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(generated_images))
            batch = generated_images[start:end]
            batch = preprocess_images(batch)
            batch = batch.to(device)
            pred = F.softmax(full_model(batch), dim=1).cpu().numpy()
            preds.append(pred)
    
    preds = np.vstack(preds)
    
    # 计算分割的IS并返回均值和标准差
    scores = []
    n_images = len(generated_images)
    n_split = int(np.floor(n_images / splits))
    
    for i in range(splits):
        part = preds[i * n_split:(i + 1) * n_split]
        kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
        kl = np.mean(np.sum(kl, axis=1))
        scores.append(np.exp(kl))
        
    return float(np.mean(scores)), float(np.std(scores))

# 加载真实数据集进行FID计算
def load_real_dataset():
    """加载真实数据集"""
    # 加载MNIST数据集
    from torchvision.datasets import MNIST
    from torchvision import transforms
    
    transform = transforms.Compose([transforms.ToTensor()])
    
    # 加载训练集
    dataset = MNIST(root='/Users/max/MasterThesisData/MNIST/', train=True, download=True, transform=transform)
    
    # 筛选类别为4的样本（对应于我们模型中的少数类0）
    indices = np.where(np.array(dataset.targets) == 4)[0]
    images = dataset.data[indices].numpy().reshape(-1, 28, 28, 1)
    
    # 调整为64x64大小
    resized_images = np.zeros((len(images), 64, 64, 1), dtype=np.float32)
    for i in range(len(images)):
        resized_images[i] = cv2.resize(images[i], (64, 64)).reshape(64, 64, 1)
    
    # 标准化到[-1, 1]
    normalized_images = (resized_images.astype('float32') - 127.5) / 127.5
    
    # 转换为PyTorch张量并调整通道顺序 (N,H,W,C) -> (N,C,H,W)
    real_images = torch.tensor(normalized_images, dtype=torch.float32).permute(0, 3, 1, 2)
    
    # 转换回[0, 1]范围用于计算FID
    real_images = real_images * 0.5 + 0.5
    
    return real_images

# 示例：评估生成的图像
def evaluate_generated_images():
    """评估生成的图像质量"""
    # 准备模型
    # 设置参数
    latent_dim = 128
    n_classes = 2  
    channel = 1  
    
    # 初始化模型组件
    decoder = Decoder(latent_dim, channel).to(device)
    embedding = LabelEmbedding(n_classes, latent_dim).to(device)
    generator = Generator(embedding, decoder).to(device)
    
    # 加载训练好的权重
    model_path = 'bagan_gp_results/generator_49.pt'
    generator.load_state_dict(torch.load(model_path))
    generator.eval()
    
    # 生成少数类样本
    print("Generating minority class samples...")
    minority_samples = generate_minority_samples(generator, class_label=0, num_samples=1000)
    
    # 加载真实样本
    print("Loading real samples...")
    real_samples = load_real_dataset()
    
    # 选择一部分真实样本进行评估
    if len(real_samples) > 1000:
        indices = np.random.choice(len(real_samples), 1000, replace=False)
        real_samples = real_samples[indices]
    
    # 加载Inception模型
    print("Loading Inception model...")
    inception_model = load_inception_model()
    
    # 计算FID
    print("Calculating FID score...")
    fid_score = calculate_fid(real_samples, minority_samples, inception_model)
    print(f"FID Score: {fid_score:.4f}")
    
    # 计算IS
    print("Calculating Inception Score...")
    is_mean, is_std = calculate_is(minority_samples, inception_model)
    print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
    
    # 保存结果
    results = {
        "FID": float(fid_score),
        "IS_mean": float(is_mean),
        "IS_std": float(is_std)
    }
    
    # 将结果保存到JSON文件
    import json
    with open('bagan_gp_results/evaluation_metrics.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"Results saved to 'bagan_gp_results/evaluation_metrics.json'")
    
    return fid_score, is_mean, is_std

# 执行评估
if __name__ == "__main__":
    import cv2
    # 确保结果可重现
    torch.manual_seed(42)
    np.random.seed(42)
    
    fid_score, is_mean, is_std = evaluate_generated_images()