In [18]:
import os
import pickle
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from datasets import load_from_disk, load_dataset

In [19]:
# set seed
random.seed(42)

In [20]:
# download dataset or load from disk
dataset_path = "../Dataset/CIFAR10"

# check if the dataset is already downloaded
if not os.path.exists(dataset_path):
    dataset = load_dataset(
                'cifar10'
            )
    os.makedirs(dataset_path, exist_ok=True)
    dataset.save_to_disk(dataset_path)
else:
    dataset = load_from_disk(dataset_path)
    
train_dataset = dataset['train']
test_dataset = dataset['test']


In [21]:
# train set of CIFAR2
train_cifar2 = pd.DataFrame()
train_cifar2['label'] = train_dataset['label']
train_cifar2 = train_cifar2[(train_cifar2['label']==1) | (train_cifar2['label']==7)]
train_cifar2, _ = train_test_split(train_cifar2, train_size=0.5, random_state=42, stratify=train_cifar2['label'])
print(train_cifar2.head())
print(train_cifar2.shape)

       label
17274      3
17585      3
25421      3
7451       5
4801       3
(5000, 1)


In [22]:
# test set of CIFAR2
test_cifar2 = pd.DataFrame()
test_cifar2['label'] = test_dataset['label']
test_cifar2 = test_cifar2[(test_cifar2['label']==1) | (test_cifar2['label']==7)]
print(test_cifar2.head())
test_cifar2, _ = train_test_split(test_cifar2, train_size=0.5, random_state=42, stratify=test_cifar2['label'])
print(test_cifar2.shape)

    label
0       3
8       3
12      5
16      5
24      5
(1000, 1)


In [23]:
# save indices
cifar2_indices_path = "./data/"
os.makedirs(cifar2_indices_path, exist_ok=True)

# train set indices
train_index_cifar2 = os.path.join(cifar2_indices_path, "idx-train.pkl")
with open(train_index_cifar2, 'wb') as handle:
    pickle.dump(train_cifar2.index.to_list(), handle)

# test set indices
test_index_cifar2 = os.path.join(cifar2_indices_path, "idx-test.pkl")
with open(test_index_cifar2, 'wb') as handle:
    pickle.dump(test_cifar2.index.to_list(), handle)

In [24]:
# generate subsets for lds validation
for k in range(256):
    tmp, _ = train_test_split(train_cifar2, train_size=0.5, random_state=42+k, stratify=train_cifar2['label'])
    filename = os.path.join('./data/lds-val/sub-idx-{}.pkl'.format(k))
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as handle:
        pickle.dump(tmp.index.to_list(), handle)

In [25]:
# validate indices
with open("./data/idx-train.pkl", 'rb') as handle:
    train_indices = pickle.load(handle)
print(train_indices[0:10])

with open("./data/idx-test.pkl", 'rb') as handle:
    test_indices = pickle.load(handle)
print(test_indices[0:10])

with open('./data/lds-val/sub-idx-0.pkl', 'rb') as handle:
    sub_0 = pickle.load(handle)
print(sub_0[0:10])

with open('./data/lds-val/sub-idx-1.pkl', 'rb') as handle:
    sub_1 = pickle.load(handle)
print(sub_1[0:10])

[17274, 17585, 25421, 7451, 4801, 39117, 11089, 38052, 39975, 5386]
[7336, 5290, 1268, 6683, 7493, 7668, 7118, 8201, 3390, 7174]
[9967, 43790, 15845, 49871, 13092, 24485, 46444, 7720, 21366, 49174]
[9634, 10163, 24298, 24627, 4703, 16755, 46100, 21062, 36606, 1883]
