In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import shutil
from tqdm import tqdm

def unpickle(file):
    with open(file, 'rb') as fo:
        res = pickle.load(fo, encoding='bytes')
    return res

## cifar100

In [11]:
meta = unpickle('data/cifar-100-python/meta')

fine_label_names = [t.decode('utf8') for t in meta[b'fine_label_names']]

train = unpickle('data/cifar-100-python/train')

filenames = [t.decode('utf8') for t in train[b'filenames']]
fine_labels = train[b'fine_labels']
data = train[b'data']

images = list()
for d in data:
    image = np.zeros((32,32,3), dtype=np.uint8)
    image[...,0] = np.reshape(d[:1024], (32,32)) # Red channel
    image[...,1] = np.reshape(d[1024:2048], (32,32)) # Green channel
    image[...,2] = np.reshape(d[2048:], (32,32)) # Blue channel
    images.append(image)

for index,image in tqdm(enumerate(images)):
    filename = filenames[index]
    label = fine_labels[index]
    label = fine_label_names[label]
    os.makedirs(f'data/cifar-100/{label}', exist_ok=True)
    plt.imsave(f'data/cifar-100/{label}/{filename}', image)

test = unpickle('data/cifar-100-python/test')
filenames = [t.decode('utf8') for t in test[b'filenames']]
fine_labels = test[b'fine_labels']
data = test[b'data']

images = list()
for d in data:
    image = np.zeros((32,32,3), dtype=np.uint8)
    image[...,0] = np.reshape(d[:1024], (32,32)) # Red channel
    image[...,1] = np.reshape(d[1024:2048], (32,32)) # Green channel
    image[...,2] = np.reshape(d[2048:], (32,32)) # Blue channel
    images.append(image)

for index,image in tqdm(enumerate(images)):
    filename = filenames[index]
    label = fine_labels[index]
    label = fine_label_names[label]
    os.makedirs(f'data/cifar-100/{label}', exist_ok=True)
    plt.imsave(f'data/cifar-100/{label}/{filename}', image)


50000it [00:21, 2331.04it/s]
10000it [00:04, 2237.91it/s]


## coil100

In [14]:
path = 'data/coil-100'

for img in tqdm(os.listdir(path)):
    if img.endswith('png'):
        label, fname = img.split('__')
        orig = os.path.join(path, img)
        cls_fold = os.path.join(path, label)
        new = os.path.join(cls_fold, fname)
        os.makedirs(cls_fold, exist_ok=True)
        os.rename(orig, new)

100%|██████████| 7180/7180 [00:00<00:00, 45139.18it/s]


## vgg-cats


In [17]:
path = 'data/vgg-cats/'

for img in tqdm(os.listdir(path)):
    if img.endswith('jpg'):
        *label, fname = img.split('_')
        label = '_'.join(label)
        orig = os.path.join(path, img)
        cls_fold = os.path.join(path, label)
        new = os.path.join(cls_fold, fname)
        os.makedirs(cls_fold, exist_ok=True)
        os.rename(orig, new)

100%|██████████| 2400/2400 [00:00<00:00, 41736.44it/s]


## cinic10

prune 90% of the data (huge dataset)

In [15]:
path = 'data/CINIC-10'

for folder in os.listdir(path):
    images = os.listdir(os.path.join(path, folder))
    np.random.shuffle(images)
    prune_num = int(len(images) * 0.9)
    [os.remove(os.path.join(path, folder, img)) for img in images[:prune_num]]

# Data splits: 

* cinic-10: Split in 2
* cifar-100: Split in 5
* coil-100: split in 5
* stanford-dogs: split in 6
* vgg-cats
* flowers


Total: 20 datasets, 125k samples

In [20]:
data_splits = [('data/cifar-100', 5),
               ('data/coil-100', 5),
               ('data/stanford-dogs', 6),
               ('data/CINIC-10', 2)]


for data_folder, num_splits in data_splits:
    
    all_classes = os.listdir(data_folder)
    num_classes = len(all_classes)
    np.random.shuffle(all_classes)
    
    for idx in range(num_splits):
        new_folder = f"{data_folder}_{idx + 1}"
        os.makedirs(new_folder, exist_ok=True)
        
        classes = all_classes[idx * (num_classes // num_splits):(idx + 1) * (num_classes // num_splits)]
        
        old_paths = [os.path.join(data_folder, cls) for cls in classes]
        new_paths = [os.path.join(new_folder, cls) for cls in classes]

        [shutil.move(src, dst) for src, dst in zip(old_paths, new_paths)]