In [2]:
import os
import pickle
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from datasets import load_from_disk, load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# set seed
random.seed(42)

In [4]:
# download dataset or load from disk
dataset_path = "../Dataset/CIFAR10"

# check if the dataset is already downloaded
if not os.path.exists(dataset_path):
    dataset = load_dataset(
                'cifar10'
            )
    os.makedirs(dataset_path, exist_ok=True)
    dataset.save_to_disk(dataset_path)
else:
    dataset = load_from_disk(dataset_path)
    
train_dataset = dataset['train']
test_dataset = dataset['test']


In [5]:
# train set of CIFAR10
train_cifar10 = pd.DataFrame()
train_cifar10['label'] = train_dataset['label']
train_cifar10, _ = train_test_split(train_cifar10, train_size=0.5, random_state=42, stratify=train_cifar10['label'])
print(train_cifar10.head())
print(train_cifar10.shape)

       label
10936      8
11178      1
26489      6
22034      6
14307      6
(25000, 1)


In [6]:
# test set of CIFAR10
test_cifar10 = pd.DataFrame()
test_cifar10['label'] = test_dataset['label']
print(test_cifar10.head())
test_cifar10, _ = train_test_split(test_cifar10, train_size=0.5, random_state=42, stratify=test_cifar10['label'])
print(test_cifar10.shape)

   label
0      3
1      8
2      8
3      0
4      6
(5000, 1)


In [7]:
# save indices
cifar10_indices_path = "./data/"
os.makedirs(cifar10_indices_path, exist_ok=True)

# train set indices
train_index_cifar10 = os.path.join(cifar10_indices_path, "idx-train.pkl")
with open(train_index_cifar10, 'wb') as handle:
    pickle.dump(test_cifar10.index.to_list(), handle)

# test set indices
test_index_cifar10 = os.path.join(cifar10_indices_path, "idx-test.pkl")
with open(test_index_cifar10, 'wb') as handle:
    pickle.dump(test_cifar10.index.to_list(), handle)

In [8]:
# generate subsets for lds validation
for k in range(256):
    tmp, _ = train_test_split(train_cifar10, train_size=0.5, random_state=42+k, stratify=train_cifar10['label'])
    filename = os.path.join('./data/lds-val/sub-idx-{}.pkl'.format(k))
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as handle:
        pickle.dump(tmp.index.to_list(), handle)

In [9]:
# validate indices
with open("./data/idx-train.pkl", 'rb') as handle:
    train_indices = pickle.load(handle)
print(train_indices[0:10])

with open("./data/idx-test.pkl", 'rb') as handle:
    test_indices = pickle.load(handle)
print(test_indices[0:10])

with open('./data/lds-val/sub-idx-0.pkl', 'rb') as handle:
    sub_0 = pickle.load(handle)
print(sub_0[0:10])

with open('./data/lds-val/sub-idx-1.pkl', 'rb') as handle:
    sub_1 = pickle.load(handle)
print(sub_1[0:10])

[7813, 2997, 972, 8717, 2015, 590, 5854, 9648, 3373, 9786]
[7813, 2997, 972, 8717, 2015, 590, 5854, 9648, 3373, 9786]
[10253, 20048, 27765, 30423, 36629, 24578, 32158, 10965, 47543, 27000]
[23937, 8446, 24379, 23473, 8565, 17363, 31481, 46971, 46123, 1308]
