In [4]:
import os
import glob
import numpy as np
from shutil import copyfile
symlink = True    # If this is false the files are copied instead
combine_train_valid = False    # If this is true, the train and valid sets are ALSO combined

# CIFAR-10 constituent samples' extraction
This notebook shows how to construct a dataset that has only CIFAR samples. This can be used for other tasks or for assessment of models trained on the imagenet constituents, to understand how well these models deal with distribution shift. 

#### ENSURE THAT CINIC-10 IS DOWNLOADED AND STORED IN ../data/cinic-10

In [2]:
cinic_directory = "../data/cinic-10"
cifar_directory = "../data/cinic-10-cifar"
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
sets = ['train', 'valid', 'test']
if not os.path.exists(cifar_directory):
    os.makedirs(cifar_directory)
if not os.path.exists(cifar_directory + '/train'):
    os.makedirs(cifar_directory + '/train')
if not os.path.exists(cifar_directory + '/test'):
    os.makedirs(cifar_directory + '/test')
    
for c in classes:
    if not os.path.exists('{}/train/{}'.format(cifar_directory, c)):
        os.makedirs('{}/train/{}'.format(cifar_directory, c))
    if not os.path.exists('{}/test/{}'.format(cifar_directory, c)):
        os.makedirs('{}/test/{}'.format(cifar_directory, c))
    if not combine_train_valid:
        if not os.path.exists('{}/valid/{}'.format(cifar_directory, c)):
            os.makedirs('{}/valid/{}'.format(cifar_directory, c))

In [3]:
for s in sets:
    for c in classes:
        source_directory = '{}/{}/{}'.format(cinic_directory, s, c)
        filenames = glob.glob('{}/*.png'.format(source_directory))
        for fn in filenames:
            dest_fn = fn.split('/')[-1]
            if (s == 'train' or s == 'valid') and combine_train_valid and 'cifar' in fn.split('/')[-1]:
                dest_fn = '{}/train/{}/{}'.format(cifar_directory, c, dest_fn)
                if symlink:
                    if not os.path.islink(dest_fn):
                        os.symlink(fn, dest_fn)
                else:
                    copyfile(fn, dest_fn)
                
            elif (s == 'train') and 'cifar' in fn.split('/')[-1]:
                dest_fn = '{}/train/{}/{}'.format(cifar_directory, c, dest_fn)
                if symlink:
                    if not os.path.islink(dest_fn):
                        os.symlink(fn, dest_fn)
                else:
                    copyfile(fn, dest_fn)
                    
            elif (s == 'valid') and 'cifar' in fn.split('/')[-1]:
                dest_fn = '{}/valid/{}/{}'.format(cifar_directory, c, dest_fn)
                if symlink:
                    if not os.path.islink(dest_fn):
                        os.symlink(fn, dest_fn)
                else:
                    copyfile(fn, dest_fn)
                    
            elif s == 'test' and 'cifar' in fn.split('/')[-1]:
                dest_fn = '{}/test/{}/{}'.format(cifar_directory, c, dest_fn)
                if symlink:
                    if not os.path.islink(dest_fn):
                        os.symlink(fn, dest_fn)
                else:
                    copyfile(fn, dest_fn)
                    