In [1]:
# cifar-10-batches-py/preprocess.ipynb
import os
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm

In [2]:
# unpickle function
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [3]:
# all train batches
train_batches = [item for item in os.listdir('.') if item.startswith('data')]
train_batches

['data_batch_1',
 'data_batch_2',
 'data_batch_3',
 'data_batch_4',
 'data_batch_5']

In [4]:
# check dict content
test = unpickle(train_batches[0])
print(test.keys())
print(test[b'labels'][0], len(test[b'labels']))
print(test[b'data'][0], len(test[b'data']))
print(test[b'filenames'][0], len(test[b'filenames']))

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
6 10000
[ 59  43  50 ... 140  84  72] 10000
b'leptodactylus_pentadactylus_s_000004.png' 10000


In [5]:
# get label names
label_names = [i.decode(encoding='utf-8') for i in unpickle('batches.meta')[b'label_names']]
label_names

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [6]:
# make data dir
data_root = os.path.join('..', 'cifar10')  # where to store dataset
if os.path.exists(data_root):
    shutil.rmtree(data_root)
os.mkdir(data_root)
os.mkdir(os.path.join(data_root, 'train'))
os.mkdir(os.path.join(data_root, 'test'))
for dir in ['train', 'test']:
    for cls_name in label_names:
        os.mkdir(os.path.join(data_root, dir, cls_name))

In [7]:
# process train set
for file in train_batches:
    data = unpickle(file)
    data_len = len(data[b'labels'])
    for label, im_array, filename in \
            tqdm(zip(data[b'labels'], data[b'data'], data[b'filenames']), total=data_len):
        im_array = np.transpose(np.reshape(im_array, (3, 32, 32)), (1, 2, 0))
        im = Image.fromarray(im_array)
        im.save(os.path.join(data_root, 'train', label_names[label], filename.decode(encoding='utf-8')))

100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:13<00:00, 723.92it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:13<00:00, 760.46it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 772.45it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:13<00:00, 766.30it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:13<00:00, 735.66it/s]


In [8]:
# same thing to test set
test_batch = unpickle('test_batch')
test_batch.keys()

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

In [9]:
for label, im_array, filename in \
        tqdm(zip(test_batch[b'labels'], test_batch[b'data'], test_batch[b'filenames']), total=data_len):
    im_array = np.transpose(np.reshape(im_array, (3, 32, 32)), (1, 2, 0))
    im = Image.fromarray(im_array)
    im.save(os.path.join(data_root, 'test', label_names[label], filename.decode(encoding='utf-8')))

100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:11<00:00, 844.19it/s]
