In [1]:
import os
import pickle
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from datasets import load_from_disk, load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set seed
random.seed(42)

In [3]:
# download dataset or load from disk
download_dataset = False

if download_dataset:
    dataset = load_dataset(
                'cifar10'
            )
    os.makedirs("../Dataset", exist_ok=True)
    dataset.save_to_disk("../Dataset/CIFAR10")
else:
    dataset = load_from_disk("../Dataset/CIFAR10")
    
train_dataset = dataset['train']
test_dataset = dataset['test']

In [4]:
# CIFAR2
df_train_cifar2 = pd.DataFrame()
df_train_cifar2['label'] = train_dataset['label']
df_train_cifar2 = df_train_cifar2[(df_train_cifar2['label']==3) | (df_train_cifar2['label']==5)]
print(df_train_cifar2.head())

    label
10      5
13      3
15      3
16      5
17      5


In [5]:
df_test_cifar2 = pd.DataFrame()
df_test_cifar2['label'] = test_dataset['label']
df_test_cifar2 = df_test_cifar2[(df_test_cifar2['label']==3) | (df_test_cifar2['label']==5)]
print(df_test_cifar2.head())

    label
0       3
8       3
12      5
16      5
24      5


In [6]:
cifar2_indices_dir = "./data/cifar2"
os.makedirs(cifar2_indices_dir, exist_ok=True)

train_index_cifar2 = os.path.join(cifar2_indices_dir, "idx-train.pkl")
with open(train_index_cifar2, 'wb') as handle:
    pickle.dump(df_train_cifar2.index.to_list(), handle)

test_index_cifar2 = os.path.join(cifar2_indices_dir, "idx-test.pkl")
with open(test_index_cifar2, 'wb') as handle:
    pickle.dump(df_test_cifar2.index.to_list(), handle)

In [7]:
for k in range(256):
    tmp, _ = train_test_split(df_train_cifar2, train_size=0.5, random_state=42+k, 
                            stratify=df_train_cifar2['label']
                        )
    filename = os.path.join('./data/cifar2/lds-val/sub-idx-{}.pkl'.format(k))
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as handle:
        pickle.dump(tmp.index.to_list(), handle)

In [8]:
with open("./data/cifar2/idx-train.pkl", 'rb') as handle:
    train_indices = pickle.load(handle)
print(train_indices[0:10])

[10, 13, 15, 16, 17, 22, 28, 42, 46, 54]


In [9]:
with open('./data/cifar2/lds-val/sub-idx-0.pkl', 'rb') as handle:
    sub_0 = pickle.load(handle)
print(sub_0[0:10])


[17274, 17585, 25421, 7451, 4801, 39117, 11089, 38052, 39975, 5386]


In [10]:
with open('./data/cifar2/lds-val/sub-idx-1.pkl', 'rb') as handle:
    sub_1 = pickle.load(handle)
print(sub_1[0:10])

[24759, 36596, 21724, 9151, 16896, 27615, 25774, 36980, 12662, 24064]


In [11]:
#  CIFAR10

df_train_cifar10 = pd.DataFrame()
df_train_cifar10['label'] = train_dataset['label']
print(df_train_cifar10.head())

   label
0      0
1      6
2      0
3      2
4      7


In [12]:
df_test_cifar10 = pd.DataFrame()
df_test_cifar10['label'] = test_dataset['label']
print(df_test_cifar10.head())

   label
0      3
1      8
2      8
3      0
4      6


In [13]:
cifar10_indices_dir = "./data/cifar10"
os.makedirs(cifar10_indices_dir, exist_ok=True)

train_index_cifar10 = os.path.join(cifar10_indices_dir, "idx-train.pkl")
with open(train_index_cifar10, 'wb') as handle:
    pickle.dump(df_train_cifar10.index.to_list(), handle)

test_index_cifar10 = "./data/cifar10/idx-test.pkl"
with open(test_index_cifar10, 'wb') as handle:
    pickle.dump(df_test_cifar10.index.to_list(), handle)

In [14]:
for k in range(256):
    tmp, _ = train_test_split(df_train_cifar10, train_size=0.5, random_state=42+k, 
                            stratify=df_train_cifar10['label']
                        )
    filename = os.path.join('./data/cifar10/lds-val/sub-idx-{}.pkl'.format(k))
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as handle:
        pickle.dump(tmp.index.to_list(), handle)

In [15]:
with open("./data/cifar10/idx-train.pkl", 'rb') as handle:
    train_indices = pickle.load(handle)
print(train_indices[0:10])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [16]:
with open('./data/cifar10/lds-val/sub-idx-0.pkl', 'rb') as handle:
    sub_0 = pickle.load(handle)
print(sub_0[0:10])

[10936, 11178, 26489, 22034, 14307, 6860, 34584, 49472, 32957, 39985]


In [17]:
with open('./data/cifar2/lds-val/sub-idx-1.pkl', 'rb') as handle:
    sub_1 = pickle.load(handle)
print(sub_1[0:10])

[24759, 36596, 21724, 9151, 16896, 27615, 25774, 36980, 12662, 24064]
