In [None]:
# 复现VPN在CIFAR10数据集上
# 数据及说明 https://blog.csdn.net/DaVinciL/article/details/78793067
# Variational Positive-incentive Noise: How Noise Benefits Models

In [11]:
#生成器为resnet18

# 读取CIFAR10数据集并加载至内存中
import numpy as np
import pickle
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time
import warnings
from torch.nn.functional import normalize
from math import sqrt 




# 忽略警告
warnings.filterwarnings("ignore", category=RuntimeWarning)

# 检查并下载CIFAR10数据集
def download_cifar10():
    base_path = "/gemini/data-3"
    dataset_path = os.path.join(base_path, "cifar-10-batches-py")
    
    if not os.path.exists(dataset_path):
        print("Downloading CIFAR-10 dataset...")
        os.makedirs(base_path, exist_ok=True)
        
        import urllib.request
        import tarfile
        
        url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        filename = os.path.join(base_path, "cifar-10-python.tar.gz")
        
        # 下载数据集
        urllib.request.urlretrieve(url, filename)
        print("Download complete!")
        
        # 解压文件
        with tarfile.open(filename, 'r:gz') as tar:
            tar.extractall(path=base_path)
        print(f"Dataset extracted to: {dataset_path}")
    
    return dataset_path

# 读取文件
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# 提取每一个通道的数据，进行重新排列，最后返回一张32x32的3通道的图片：
def GetPhoto(pixel):
    assert len(pixel) == 3072
    r = pixel[0:1024]; r = np.reshape(r, [32, 32, 1])
    g = pixel[1024:2048]; g = np.reshape(g, [32, 32, 1])
    b = pixel[2048:3072]; b = np.reshape(b, [32, 32, 1])
    photo = np.concatenate([r, g, b], -1)
    return photo

# 按照给出的关键字提取数据 
def getTrainDataByKeyword(keyword, size=(32, 32), normalized=False, filelist=[1,2,3,4,5]):
    base_path = download_cifar10()
    
    # 使用字符串类型处理关键字
    if isinstance(keyword, bytes):
        keyword = keyword.decode('utf-8')
    
    keyword_bytes = str.encode(keyword)
    
    assert keyword_bytes in [b'data', b'labels', b'batch_label', b'filenames']
    assert type(filelist) is list and len(filelist) != 0
    assert type(normalized) is bool
    assert type(size) is tuple

    files = []
    for i in filelist:
        if 1 <= i <= 5 and i not in files:
            files.append(i)

    if len(files) == 0:
        raise ValueError("No valid input files!")

    if keyword_bytes == b'data':
        data = []
        for i in files:
            data.append(unpickle(os.path.join(base_path, f"data_batch_{i}"))[b'data'])
        data = np.concatenate(data, 0)
        array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
        for i in range(len(data)):
            img = cv2.resize(GetPhoto(data[i]), size)
            if normalized:
                img = img / 255.0
            array[i] = img
        return array
    
    if keyword_bytes == b'labels':
        labels = []
        for i in files:
            labels += unpickle(os.path.join(base_path, f"data_batch_{i}"))[b'labels']
        return labels
    
    elif keyword_bytes == b'batch_label':
        batch_label = []
        for i in files:
            batch_label.append(unpickle(os.path.join(base_path, f"data_batch_{i}"))[b'batch_label'])
        return batch_label
    
    elif keyword_bytes == b'filenames':
        filenames = []
        for i in files:
            filenames += unpickle(os.path.join(base_path, f"data_batch_{i}"))[b'filenames']
        return filenames

# 提取测试集中的数据
def getTestDataByKeyword(keyword, size=(32, 32), normalized=False):
    base_path = download_cifar10()
    
    # 使用字符串类型处理关键字
    if isinstance(keyword, bytes):
        keyword = keyword.decode('utf-8')
    
    keyword_bytes = str.encode(keyword)

    assert keyword_bytes in [b'data', b'labels', b'batch_label', b'filenames']
    assert type(size) is tuple
    assert type(normalized) is bool

    test_file = os.path.join(base_path, "test_batch")
    test_data = unpickle(test_file)

    if keyword_bytes == b'data':
        data = test_data[b'data']
        array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
        for i in range(len(data)):
            img = cv2.resize(GetPhoto(data[i]), size)
            if normalized:
                img = img / 255.0
            array[i] = img
        return array
    
    elif keyword_bytes == b'labels':
        return test_data[b'labels']
    
    elif keyword_bytes == b'batch_label':
        return test_data[b'batch_label']
    
    elif keyword_bytes == b'filenames':
        return test_data[b'filenames']
    
    else:
        raise NameError(f"Invalid keyword: {keyword}")

# 定义一个数据集类，用于加载CIFAR10数据集
class CIFAR10Dataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label


import cv2
import os
import numpy as np
import matplotlib.pyplot as plt

def visualize_directly(originals, noises, noise_maps, noisy_images, labels, class_names, epoch):
    """
    保存可视化结果到文件
    每个样本保存: 原始图像、噪声图、噪声图像、方差热图
    同时保存四图拼接的对比图
    
    :param originals: 原始图像列表 (在[0,1]范围)
    :param noises: 噪声列表
    :param noise_maps: 噪声图列表 (单个通道的方差图)
    :param noisy_images: 添加噪声后的图像列表 (在[0,1]范围)
    :param labels: 图像标签列表
    :param class_names: 类别名称列表
    :param epoch: 当前训练轮次
    """
    # 创建主保存目录
    main_dir = f"visualizations/epoch_{epoch}"
    os.makedirs(main_dir, exist_ok=True)
    
    # 为每个类别创建子目录
    for class_name in class_names:
        class_dir = os.path.join(main_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)
    
    # 处理每个样本
    for i, (orig, noise, noise_map, noisy_img, label) in enumerate(zip(
        originals, noises, noise_maps, noisy_images, labels)):
        
        # 获取类别名称
        class_idx = label.item()
        class_name = class_names[class_idx]
        
        # 为当前样本创建目录
        sample_dir = os.path.join(main_dir, class_name, f"sample_{i}")
        os.makedirs(sample_dir, exist_ok=True)
        
        # 1. 保存原始图像
        orig_np = orig.permute(1, 2, 0).cpu().numpy() * 255
        orig_np = np.clip(orig_np, 0, 255).astype(np.uint8)
        orig_bgr = cv2.cvtColor(orig_np, cv2.COLOR_RGB2BGR)
        orig_path = os.path.join(sample_dir, f"{class_name}_original.png")
        cv2.imwrite(orig_path, orig_bgr)
        
        # 2. 保存噪声图 (三通道彩色)
        noise_np = noise.permute(1, 2, 0).cpu().numpy()
        max_val = np.max(np.abs(noise_np))
        scaled_noise = noise_np / (2 * max_val) + 0.5 if max_val > 0 else 0.5
        scaled_noise = np.clip(scaled_noise, 0, 1)
        noise_rgb = (scaled_noise * 255).astype(np.uint8)
        noise_bgr = cv2.cvtColor(noise_rgb, cv2.COLOR_RGB2BGR)
        noise_path = os.path.join(sample_dir, f"{class_name}_noise.png")
        cv2.imwrite(noise_path, noise_bgr)
        
        # 3. 保存噪声图像 (添加噪声后的图像)
        noisy_np = noisy_img.permute(1, 2, 0).cpu().numpy() * 255
        noisy_np = np.clip(noisy_np, 0, 255).astype(np.uint8)
        noisy_bgr = cv2.cvtColor(noisy_np, cv2.COLOR_RGB2BGR)
        noisy_path = os.path.join(sample_dir, f"{class_name}_noisy_image.png")
        cv2.imwrite(noisy_path, noisy_bgr)
        
        # 4. 保存方差热图
        nv_map = noise_map.squeeze().cpu().numpy()
        nv_map = np.nan_to_num(nv_map)
        if np.max(nv_map) - np.min(nv_map) == 0:
            nv_map_normalized = np.zeros_like(nv_map)
        else:
            nv_map_normalized = (nv_map - np.min(nv_map)) / (np.max(nv_map) - np.min(nv_map))
        
        heatmap = cv2.applyColorMap((nv_map_normalized * 255).astype(np.uint8), cv2.COLORMAP_JET)
        heatmap_path = os.path.join(sample_dir, f"{class_name}_variance_heatmap.png")
        cv2.imwrite(heatmap_path, heatmap)
        
        # 5. 创建并保存四图拼接图
        
        create_grid_figure(class_name, epoch, sample_dir, 
                          [orig_bgr, noise_bgr, heatmap, noisy_bgr], 
                          ["Original", "Generated Noise", "Variance Heatmap", "π-Noise Image"])
    
    print(f"Visualization saved for epoch {epoch} at: {main_dir}")





def create_grid_figure(class_name, epoch, sample_dir, images, titles):
    """
    创建并保存带标题的四图网格
    
    :param class_name: 类别名称
    :param epoch: 训练轮次
    :param sample_dir: 保存目录
    :param images: 图像列表 (BGR格式)
    :param titles: 标题列表
    """
    # 创建新图
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f"Epoch {epoch} - {class_name} Visualization", fontsize=16)
    
    # 设置每个子图
    for i, (img, title) in enumerate(zip(images, titles)):
        # 转换BGR为RGB
        rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        axs[i].imshow(rgb_img)
        axs[i].set_title(title, fontsize=12)
        axs[i].axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # 保存图像
    quad_path = os.path.join(sample_dir, f"{class_name}_grid_{epoch}.png")
    plt.savefig(quad_path, bbox_inches='tight', dpi=120)
    plt.close(fig)

   

# 类别名称 (英文避免中文字体问题)
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# 读取训练集和测试集数据
print("Loading training data...")
train_data = getTrainDataByKeyword('data', normalized=True, filelist=[1,2,3,4,5])  # 使用全部5个训练批次
train_labels = getTrainDataByKeyword('labels', filelist=[1,2,3,4,5])

print("Loading test data...")
test_data = getTestDataByKeyword('data', normalized=True)
test_labels = getTestDataByKeyword('labels')

print(f"Training data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")

# 定义数据预处理转换
transform = transforms.Compose([
    transforms.ToTensor()  # 将图像转换为张量并自动缩放到[0,1]
])

# 创建数据集对象
train_dataset = CIFAR10Dataset(train_data, train_labels, transform=transform)
test_dataset = CIFAR10Dataset(test_data, test_labels, transform=transform)

# 创建数据加载器 (批量大小256，符合论文)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

# 定义基模型
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

# 根据论文精确实现的VPN生成器
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class ResnetVPNGenerator(nn.Module):
    def __init__(self, num_classes=10):
        super(ResnetVPNGenerator, self).__init__()
        self.num_classes = num_classes
        self.gamma = 0.01 * (1.0 / num_classes)  # γ = 0.01 × 1/|Y|
        
        # 使用ResNet18作为特征提取器
        self.resnet = models.resnet18(pretrained=False)
        # 调整ResNet的第一层卷积以适应CIFAR-10 (32x32图像)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()  # 移除maxpool层以适应小尺寸图像
        
        # 替换最后一层全连接层以输出方差 (3*32*32)
        self.resnet.fc = nn.Linear(512, 3 * 32 * 32)
        
        # 零均值假设 (μ = 0)
        self.zero_mean = True

    def sample(self, mu, variance, num=1):
        # 扩展方差和均值的维度以进行采样
        var = variance.expand(num, *variance.size()).transpose(0, 1)
        m = mu.expand(num, *mu.size()).transpose(0, 1)
        # 生成标准正态分布的随机噪声
        epsilon = torch.randn_like(var).to(var.device)
        # 重参数化公式：noise = μ + ε * σ
        noise = var * epsilon + m
        return noise

    def forward(self, x, y):
        batch_size = x.size(0)
        
        # 2. 准备标签特征：将标签索引复制为与图像相同维度的向量
        y_replicated = y.unsqueeze(1).repeat(1, 3 * 32 * 32).float()
        label_vec = self.gamma * y_replicated  # [0~0.01]
        label_vec = label_vec.view(batch_size, 3, 32, 32)
        
        # 3. 组合特征：将图像和标签向量结合
        combined = torch.clamp(x + label_vec, 0.0, 1.0)
        
        # 1. 准备图像特征：直接使用ResNet处理图像
        img_features = self.resnet(combined)
        
        # 4. 通过ResNet生成方差
        log_var = self.resnet.fc(img_features)
        variance = torch.sigmoid(log_var)
        
        # 使用截断，限制方差在[C1,C2]范围
        variance = torch.clamp(variance, 0.01, 0.1)
        variance = (variance-C1)
        
        # 均值设为0（Zero Mean assumption）
        mu = torch.zeros_like(variance)
        
        # 5. 重参数化采样 (噪声大小m=1)
        noise = self.sample(mu, variance)
        
        # 6. 重塑噪声为图像形状
        noise = noise.view(batch_size, 3, 32, 32)
        
        # 7. 应用噪声到原始图像
        noisy_img = x + noise
        noisy_img = torch.clamp(noisy_img, 0.0, 1.0)  # 保持[0,1]范围
        
        # 8. 计算方差热图用于可视化
        noise_map = noise.var(dim=1, keepdim=True)  # 在通道维度计算方差
        noise_map = torch.clamp(noise_map, min=1e-5)
        
        return noisy_img, noise, noise_map

    

# 定义模型
base_model = BaseModel()
vpn_generator = VPNGenerator(num_classes=10, noise_m=1)

# 定义损失函数和优化器 (使用论文指定的学习率0.001)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(base_model.parameters()) + list(vpn_generator.parameters()), 
                      lr=0.001, weight_decay=1e-4)

# 学习率调度器 (论文未指定，但StepLR有助于收敛)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

base_model.to(device)
vpn_generator.to(device)

# 训练轮数40 (论文指定浅层模型)
num_epochs = 100


# 根据论文精确实现的VPN损失函数
def vpn_loss_function(base_output, vpn_output, gt_labels, noise_cov, criterion, c1=0.01, c2=0.3):
    """
    计算VPN损失函数
    
    Args:
        base_output: 基模型的输出
        vpn_output: 加入VPN噪声后的基模型输出
        gt_labels: 真实标签
        noise_cov: 噪声的协方差矩阵 (方差)
        criterion: 基础损失函数 (如CrossEntropyLoss)
        c1, c2: 噪声强度约束参数
    
    Returns:
        VPN损失以及基模型损失
    """
    # 基模型损失
    base_loss = criterion(base_output, gt_labels)
    
    # VPN损失 (公式来自论文中的LVPN)
    vpn_loss = -torch.mean(torch.log_softmax(vpn_output, dim=1).gather(1, gt_labels.unsqueeze(1)).squeeze(1))
    
    # 噪声强度约束损失
    # 计算每个样本的协方差矩阵范数 (Frobenius范数)
    noise_norm = torch.norm(noise_cov.view(noise_cov.size(0), -1), p='fro', dim=1)
    # 应用约束 C1 ≤ ∥Σ∥ ≤ C2
    constraint_loss = torch.mean(torch.relu(c1 - noise_norm) + torch.relu(noise_norm - c2))
    
    # 总损失 (结合基模型损失、VPN损失和约束损失)
    total_loss = 0.7*base_loss + 0.3*vpn_loss + 0* constraint_loss
    
    return total_loss, base_loss, vpn_loss 


# 双重训练模型 (基模型和VPN生成器一起训练)
for epoch in range(num_epochs):
    base_model.train()
    vpn_generator.train()
    running_loss = 0.0
    total_samples = 0
    correct_predictions = 0
    
    # 存储样本用于可视化
    sample_originals, sample_noises, sample_noise_maps, sample_noisy_imgs, sample_labels = [], [], [], [], []
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # 生成VPN噪声 (噪声大小m=1)
        noisy_data, noise, noise_map = vpn_generator(data, target)


                # 保存一些样本用于可视化（每epoch的第一个batch）
        if batch_idx == 0 and epoch % 5 == 0:  # 每5个epoch可视化一次
            for i in range(5):
                # 使用.detach()将张量从计算图中分离
                sample_originals.append(data[i].clone().detach())
                sample_noises.append(noise[i].clone().detach())
                sample_noise_maps.append(noise_map[i].clone().detach())
                sample_noisy_imgs.append(noisy_data[i].clone().detach())
                sample_labels.append(target[i].clone().detach())
        
        
        # 计算基模型在干净数据和噪声数据上的输出
        base_output = base_model(data)
        vpn_output = base_model(noisy_data)
        
        # 计算联合损失
        loss, base_loss, vpn_loss = vpn_loss_function(base_output, vpn_output, target, noise, criterion)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计信息
        running_loss += loss.item() * data.size(0)
        _, predicted = torch.max(vpn_output, 1)
        correct_predictions += (predicted == target).sum().item()
        total_samples += data.size(0)
        
        if batch_idx % 50 == 0:
            print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Base Loss: {base_loss.item():.4f}, VPN Loss: {vpn_loss.item():.4f}")
    
    scheduler.step()

        # 可视化噪声（每5个epoch）
    if epoch % 5 == 0 and len(sample_originals) > 0:
        visualize_directly(
            sample_originals,
            sample_noises,
            sample_noise_maps,
            sample_noisy_imgs,
            sample_labels,
            class_names,
            epoch
        )
    
    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    print(f"Epoch {epoch+1}/{num_epochs} - Total Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}")
    

# 测试模型
print("\nTesting model...")
base_model.eval()
vpn_generator.eval()
correct = 0
total = 0
test_loss = 0.0

# 存储测试样本用于可视化
test_originals, test_noises, test_noise_maps, test_noisy_imgs, test_labels = [], [], [], [], []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        
        # 生成VPN噪声
        noisy_data, noise, noise_map = vpn_generator(data, target)
        
        # 保存一些样本用于可视化
        if batch_idx == 0:
            for i in range(5):
                # 使用.detach()将张量从计算图中分离
                test_originals.append(data[i].clone().detach())
                test_noises.append(noise[i].clone().detach())
                test_noise_maps.append(noise_map[i].clone().detach())
                test_noisy_imgs.append(noisy_data[i].clone().detach())
                test_labels.append(target[i].clone().detach())
        
        # 计算预测结果
        output = base_model(noisy_data)
        loss = criterion(output, target)
        
        test_loss += loss.item() * data.size(0)
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

# 最终测试可视化
if len(test_originals) > 0:
    visualize_directly(
        test_originals,
        test_noises,
        test_noise_maps,
        test_noisy_imgs,
        test_labels,
        class_names,
        "Final"
    )

test_loss = test_loss / total
accuracy = correct / total

print(f"\nTest Results:")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")

Loading training data...
Loading test data...
Training data shape: (50000, 32, 32, 3)
Test data shape: (10000, 32, 32, 3)
Using device: cuda
Epoch 1/100, Batch 0/196, Loss: 2.4464, Base Loss: 2.4452, VPN Loss: 2.4491
Epoch 1/100, Batch 50/196, Loss: 1.8232, Base Loss: 1.7518, VPN Loss: 1.9899
Epoch 1/100, Batch 100/196, Loss: 1.6892, Base Loss: 1.6264, VPN Loss: 1.8358
Epoch 1/100, Batch 150/196, Loss: 1.6439, Base Loss: 1.5855, VPN Loss: 1.7803
Visualization saved for epoch 0 at: visualizations/epoch_0
Epoch 1/100 - Total Loss: 2.3457, Acc: 0.3165
Epoch 2/100, Batch 0/196, Loss: 1.5784, Base Loss: 1.5239, VPN Loss: 1.7058
Epoch 2/100, Batch 50/196, Loss: 1.3313, Base Loss: 1.2584, VPN Loss: 1.5012
Epoch 2/100, Batch 100/196, Loss: 1.3743, Base Loss: 1.2528, VPN Loss: 1.6578
Epoch 2/100, Batch 150/196, Loss: 1.3447, Base Loss: 1.2445, VPN Loss: 1.5785
Epoch 2/100 - Total Loss: 1.3472, Acc: 0.4487
Epoch 3/100, Batch 0/196, Loss: 1.0485, Base Loss: 0.9291, VPN Loss: 1.3269
Epoch 3/100, B

KeyboardInterrupt: 