In [33]:
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import pickle
from PIL import Image
from conf import settings
from tqdm import tqdm

def split_and_transform_dataset(dataset="TINYIMAGENET", 
                                data_dir="/home/featurize/data",
                                intersect_proportion=0.0, 
                                transformations=None):
    """
    分割数据集并对部分样本进行随机变换。

    参数：
    - dataset: 要使用的数据集名称 ("TinyImagenet200", "CIFAR10", "CIFAR100")
    - intersect_proportion: 攻击者子集与受害者子集之间的交集比例 (0.0 ~ 1.0)
    - transformations: 每种变换的比例和范围，格式为字典
    """

    np.random.seed(0)  # 设置随机种子以保证结果可重复

    # 加载数据集
    full_data_path = os.path.join(settings.DATA_PATH, dataset)
    if dataset == "CIFAR10":
        trainset = torchvision.datasets.CIFAR10(root=full_data_path, train=True, download=True, transform=None)
    elif dataset == "CIFAR100":
        trainset = torchvision.datasets.CIFAR100(root=full_data_path, train=True, download=True, transform=None)
    else:
         trainset = torchvision.datasets.ImageFolder(os.path.join(data_dir, "tiny-imagenet-200", "train"))

    # 将图像和标签加载为列表，添加进度条
    X_list, y_list = [], []
    for data in tqdm(trainset, desc="Loading dataset"):
        x, y = data
        X_list.append(np.array(x))
        y_list.append(y)

    # 转换为 NumPy 数组
    X_np = np.array(X_list)
    y_np = np.array(y_list)

    # 随机选择受害者子集和攻击者子集的索引
    vic_num = len(trainset) // 2
    train_idx_array = np.arange(len(trainset))
    np.random.shuffle(train_idx_array)
    vic_idx = train_idx_array[:vic_num]  # 受害者子集的索引
    shift = int(intersect_proportion * vic_num)
    start_att_idx = vic_num - shift
    att_idx = train_idx_array[start_att_idx: start_att_idx + vic_num]  # 攻击者子集的索引

    # 构建受害者和攻击者子集
    X_set1, y_set1 = X_np[vic_idx], y_np[vic_idx]
    X_set2, y_set2 = X_np[att_idx], y_np[att_idx]

    # 确保每种变换的存储文件夹存在
    base_path = "/home/featurize/work/RAI2project/RAI2data"
    for transform_name in transformations.keys():
        os.makedirs(os.path.join(base_path, transform_name), exist_ok=True)

    # 当 intersect_proportion == 0.0 时，保存未变换的数据集到每个变换对应的文件夹
    if intersect_proportion == 0.0:
        print("No transformations applied as intersect_proportion is set to 0.0")
        
        output_filename = f"{dataset}_intersect_{intersect_proportion}.pkl"
        for transform_name in transformations.keys():
            output_path = os.path.join(base_path, transform_name, output_filename)
            with open(output_path, "wb") as f:
                pickle.dump(((X_set1, y_set1), (X_set2, y_set2)), f)
            print(f"Dataset saved without transformations to {output_path}")
        
        return

    total_samples = len(att_idx[:shift])  # 交集中样本的数量
    print(f"Total selected samples from attackers: {total_samples}")

    # 针对每种变换创建一个新的攻击者子集副本
    for transform_name, (proportion, transform_params) in transformations.items():
        num_to_transform = int(total_samples * proportion)
        transform_indices = np.random.choice(att_idx[:shift], size=num_to_transform, replace=False)

        # 创建仅应用当前变换的攻击者子集副本
        transformed_X_set2 = X_set2.copy()
        transformed_y_set2 = y_set2.copy()

        # 输出变换操作的样本数量
        print(f"{transform_name.capitalize()} will transform {num_to_transform} samples.")

        # 定义每种变换操作
        if transform_name == "gaussian_noise":
            noise_std = transform_params  # 高斯噪声强度
            transform = lambda img: img + torch.randn_like(img) * noise_std
        elif transform_name == "brightness":
            brightness_factor = transform_params  # 亮度调整因子
            transform = transforms.ColorJitter(brightness=brightness_factor)
        elif transform_name == "shear":
            shear_range = transform_params  # 剪切范围
            transform = transforms.RandomAffine(degrees=0, shear=shear_range)
        elif transform_name == "translate":
            translate_range = transform_params  # 平移范围
            transform = transforms.RandomAffine(degrees=0, translate=translate_range)
        elif transform_name == "rotation":
            rotation_range = transform_params  # 旋转角度范围
            transform = transforms.RandomRotation(degrees=rotation_range)

        # 对选择的样本进行变换，并添加进度条
        for idx in tqdm(transform_indices, desc=f"Processing {transform_name}"):
            original_index = np.where(att_idx == idx)[0][0]
            img = Image.fromarray(X_set2[original_index])  # 从攻击者子集中提取图像
            if transform_name == "gaussian_noise":
                img = transform(torch.tensor(np.array(img)).float() / 255.0)  # 将图像标准化到[0, 1]
                img = (img * 255).clamp(0, 255).byte().numpy()  # 反归一化
            else:
                img = transform(img)
            transformed_X_set2[original_index] = np.array(img)  # 用变换后的图像替换原始图像

        # 保存当前变换后的受害者子集和攻击者子集到对应的文件夹
        output_filename = f"{dataset}_intersect_{intersect_proportion}.pkl"
        output_path = os.path.join(base_path, transform_name, output_filename)
        with open(output_path, "wb") as f:
            pickle.dump(((X_set1, y_set1), (transformed_X_set2, transformed_y_set2)), f)

        print(f"Transformed dataset saved to {output_path} for transformation: {transform_name}")

# 示例调用，添加了旋转变换
transformations = {
    "gaussian_noise": (1.0, 0.1),      # 100%样本加高斯噪声，标准差0.1
    "brightness": (1.0, 0.2),          # 100%样本亮度调整，因子0.2
    "shear": (1.0, (15, 15)),          # 100%样本剪切，范围±15°
    "translate": (1.0, (0.25, 0.25)),    # 100%样本平移，水平和垂直方向最大平移20%
    "rotation": (1.0, (-25, 25)),        # 100%样本旋转，范围±25°
}


split_and_transform_dataset(dataset="TINYIMAGENET", 
                            intersect_proportion=1.0, 
                            transformations=transformations)


Loading dataset: 100%|██████████| 100000/100000 [00:11<00:00, 8631.94it/s]


Total selected samples from attackers: 50000
Gaussian_noise will transform 50000 samples.


Processing gaussian_noise: 100%|██████████| 50000/50000 [00:07<00:00, 6723.88it/s]


Transformed dataset saved to /home/featurize/work/RAI2project/RAI2data/gaussian_noise/TINYIMAGENET_intersect_1.0.pkl for transformation: gaussian_noise
Brightness will transform 50000 samples.


Processing brightness: 100%|██████████| 50000/50000 [00:06<00:00, 7636.22it/s]


Transformed dataset saved to /home/featurize/work/RAI2project/RAI2data/brightness/TINYIMAGENET_intersect_1.0.pkl for transformation: brightness
Shear will transform 50000 samples.


Processing shear: 100%|██████████| 50000/50000 [00:04<00:00, 10362.67it/s]


Transformed dataset saved to /home/featurize/work/RAI2project/RAI2data/shear/TINYIMAGENET_intersect_1.0.pkl for transformation: shear
Translate will transform 50000 samples.


Processing translate: 100%|██████████| 50000/50000 [00:04<00:00, 10333.62it/s]


Transformed dataset saved to /home/featurize/work/RAI2project/RAI2data/translate/TINYIMAGENET_intersect_1.0.pkl for transformation: translate
Rotation will transform 50000 samples.


Processing rotation: 100%|██████████| 50000/50000 [00:04<00:00, 12397.51it/s]


Transformed dataset saved to /home/featurize/work/RAI2project/RAI2data/rotation/TINYIMAGENET_intersect_1.0.pkl for transformation: rotation
