# Loading the CIFAR-10 dataset

First, download the data [from this link](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz).

Then, you may use the code below.

In [None]:
import tensorflow as tf
import os, struct
import numpy as np
import matplotlib.pyplot as plt
import pickle
%matplotlib notebook

**Load the data**

In [None]:
def unpickle(file_path):
    with open(file_path, mode='rb') as fo:
        datadict = pickle.load(fo, encoding='bytes')
        return datadict

def load_data_cifar(path='.', test_percentage=0.8):
    # Load training data
    X_train = None
    Y_train = None
    for i in range(5):
        filepath = path+'/cifar-10-batches-py/data_batch_'+str(i+1)
        dict_i = unpickle(filepath)
        
        if X_train is not None:
            X_train = np.vstack( (X_train, dict_i[b'data'].astype(np.float32)) )
            Y_train = np.hstack( (Y_train, np.array(dict_i[b'labels']).astype(np.int32)) )
        else:
            X_train = dict_i[b'data'].astype(np.float32)
            Y_train = np.array(dict_i[b'labels']).astype(np.int32)
    Y_train = Y_train[:,None]
    
    # Load initial testing data
    filepath = path+'/cifar-10-batches-py/test_batch'
    dict_i = unpickle(filepath)
    X_test_ini = dict_i[b'data'].astype(np.float32)
    Y_test_ini = np.array(dict_i[b'labels']).astype(np.int32)
    
    # Split the test set into test/validation
    Ntest_ini = X_test_ini.shape[0]
    perm = np.random.permutation(Ntest_ini)
    n_test_cases = int(test_percentage*Ntest_ini)
    # test dataset
    X_test = X_test_ini[perm[:n_test_cases],:]
    Y_test = Y_test_ini[perm[:n_test_cases],None]
    # validation dataset
    X_val = X_test_ini[perm[n_test_cases:],:]
    Y_val = Y_test_ini[perm[n_test_cases:],None]
    
    # Load the label names
    filepath = path+'/cifar-10-batches-py/batches.meta'
    dict_i = unpickle(filepath)
    label_names = dict_i[b'label_names']
    
    return X_train, Y_train, X_test, Y_test, X_val, Y_val, label_names

In [None]:
X_train, Y_train, X_test, Y_test, X_val, Y_val, label_names = load_data_cifar(path='.')
print("Shape of X_train: ",X_train.shape)
print("Shape of Y_train: ",Y_train.shape)
print("Shape of X_test: ",X_test.shape)
print("Shape of Y_test: ",Y_test.shape)
print("Shape of X_val: ",X_val.shape)
print("Shape of Y_val: ",X_val.shape)

**Auxiliary plotting function**

In [None]:
def plot_image_cifar(x_row):
    aux = x_row.reshape([3, 32, 32])/256.
    
    plt.figure(figsize = (1,1))
    plt.imshow(np.swapaxes(np.swapaxes(aux,0,2),0,1))
    plt.xlabel('')
    plt.axis('off')
    plt.show()

In [None]:
# Here, we print image "0"
plot_image_cifar(X_train[0,:])