In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import random
import os
import pickle
import numpy as np
from PIL import Image

In [31]:
base_folder = "cifar-10-batches-py"
data_batches = [
    "data_batch_1",
    "data_batch_2",
    "data_batch_3",
    "data_batch_4",
    "data_batch_5",
]

root = "./"
num_samples = 15000
per_class_count = num_samples // 10
data = []
labels = []

for batches in data_batches:
    file_path = os.path.join(root, base_folder, batches)
    with open(file_path, "rb") as infile:
        data_entry = pickle.load(infile, encoding="latin1")
        data.extend(data_entry["data"])
        labels.extend(data_entry["labels"])
        
data = np.array(data)
labels = np.array(labels)

victim_data = None
victim_label = None
shadow_data = None
shadow_label = None

In [32]:
#Create a victim split. Store the rest as the shadow model data
for i in range(10):
    indices = np.where(labels == i)[0]
    victim_indices = np.random.choice(indices, per_class_count, replace=False)
    shadow_indices = np.setdiff1d(indices, victim_indices)
    
    
    if victim_data is None:
        victim_data = data.take(victim_indices, axis = 0)
        victim_label = labels.take(victim_indices, axis = 0)
    else:
        victim_data = np.concatenate([victim_data, data.take(victim_indices, axis=0)], axis = 0)
        victim_label = np.concatenate([victim_label, labels.take(victim_indices, axis=0)], axis = 0)
    
    if shadow_data is None:
        shadow_data = data.take(shadow_indices, axis = 0)
        shadow_label = labels.take(shadow_indices, axis = 0)
    else:
        shadow_data = np.concatenate([shadow_data, data.take(shadow_indices, axis=0)], axis = 0)
        shadow_label = np.concatenate([shadow_label, labels.take(shadow_indices, axis=0)], axis = 0)
        

In [33]:
victim_data.shape

(15000, 3072)

In [36]:
victim_to_save = {
    "data" : victim_data, 
    "labels" : victim_label,
}

shadow_to_save = {
    "data" : shadow_data, 
    "labels" : shadow_label,
}

In [38]:
with open("./cifar-10-batches-py/victim_dataset", "wb") as infile:
    pickle.dump(victim_to_save, infile)

with open("./cifar-10-batches-py/shadow_dataset", "wb") as infile:
    pickle.dump(shadow_to_save, infile)