In [1]:
import requests
import numpy as np
import tensorflow as tf
from pathlib import Path
import pickle
from collections import Counter

np.random.seed(0)

In [2]:
def download_datasets(url, dataset_name):
    path_to_zip = tf.keras.utils.get_file(
        fname=f"{dataset_name}.tar.gz",
        origin=url,
        extract=True)

    path_to_zip = Path(path_to_zip)
    path = path_to_zip.parent / dataset_name
    return path_to_zip, path

In [3]:
# Copy from https://www.cs.toronto.edu/~kriz/cifar.html
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
def extract_batches(path, batchname):
    batch_dic = unpickle(str(path / batchname))
    batch_labels = batch_dic.get(b'labels')
    batch_images = batch_dic.get(b'data')
    batch_filenames = batch_dic.get(b'filenames')

    print(f'{batchname}:')
    print(batch_images.shape) # image as 3072 byte, 1024 each rgb channel

    assert((len(batch_labels) == 10000)), 'labels contains not all 10000 labels'
    assert((len(batch_images) == 10000)), 'images contains not all 10000 images'
    assert((batch_images.shape[1] == 3072)), 'images are not in 3072 bytes'
    assert((len(batch_filenames) == 10000)), 'filenames contains not all 10000 filenames'

    for number, count in Counter(batch_labels).items():
        print(f"Number {number}: {count} occurrences")
    return batch_labels, batch_images, batch_filenames

In [5]:
def get_class_labels_number(class_names, dic):
    labels = dic.get(b'label_names')
    labels = list(map(lambda x: x.decode('utf-8'), labels))
    labels_nr = [index for index, value in enumerate(labels) if value in class_names]
    return labels_nr

In [6]:
def filter_class(classes, labels, images, filenames):
    filtered_labels = [label for label in labels if label in classes]
    filtered_images = [images[labels == cls] for cls in classes]
    filtered_filenames = [filenames[labels == cls] for cls in classes]
    return np.array(filtered_labels), np.array(filtered_images), np.array(filtered_filenames)

In [7]:
def get_data(batchnames, batch_size, path):
    data_labels = []
    data_images = []
    data_filenames = []

    for name in batchnames:
        labels, images, filenames = extract_batches(path, name)
        data_labels.extend(labels)
        data_images.extend(images)
        data_filenames.extend(filenames)

    assert((len(data_labels) == batch_size)), f'labels contains not all {batch_size} labels'
    assert((len(data_images) == batch_size)), f'images contains not all {batch_size} images'
    assert((len(data_filenames) == batch_size)), f'filenames contains not all {batch_size} filenames'

    data_labels = np.array(data_labels)
    data_images = np.array(data_images)
    data_filenames = np.array(data_filenames)
    data_labels, data_images, data_filenames = filter_class(selected_cifar10_classes, data_labels, data_images, data_filenames)
    print(set(data_labels))
    print(data_labels.shape)
    print(data_images.shape)
    print(data_filenames.shape)
    return data_labels, data_images, data_filenames

In [8]:
cifar10_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
cifar100_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
zip, cifar10_path = download_datasets(cifar10_url, 'cifar-10-batches-py')

In [9]:
# class_labels: 1, 2, 3, 4, 5, 7, 9
dic_file = 'batches.meta'
needed_cifar10_classes = ['automobile', 'bird',  'cat', 'deer', 'dog', 'horse', 'truck']
cifar10_dic = unpickle(str(cifar10_path / dic_file))
selected_cifar10_classes = get_class_labels_number(['automobile', 'bird',  'cat', 'deer', 'dog', 'horse', 'truck'], cifar10_dic)
print(selected_cifar10_classes)

[1, 2, 3, 4, 5, 7, 9]


In [10]:
batchnames = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
cifar10_labels, cifar10_images, cifar10_filenames = get_data(batchnames, 50000, cifar10_path)
cifar10_test_labels, cifar10_test_images, cifar10_test_filenames = get_data(['test_batch'], 10000, cifar10_path)

data_batch_1:
(10000, 3072)
Number 6: 1030 occurrences
Number 9: 981 occurrences
Number 4: 999 occurrences
Number 1: 974 occurrences
Number 2: 1032 occurrences
Number 7: 1001 occurrences
Number 8: 1025 occurrences
Number 3: 1016 occurrences
Number 5: 937 occurrences
Number 0: 1005 occurrences
data_batch_2:
(10000, 3072)
Number 1: 1007 occurrences
Number 6: 1008 occurrences
Number 8: 987 occurrences
Number 3: 995 occurrences
Number 4: 1010 occurrences
Number 0: 984 occurrences
Number 5: 988 occurrences
Number 2: 1010 occurrences
Number 7: 1026 occurrences
Number 9: 985 occurrences
data_batch_3:
(10000, 3072)
Number 8: 961 occurrences
Number 5: 1029 occurrences
Number 0: 994 occurrences
Number 6: 978 occurrences
Number 9: 1029 occurrences
Number 2: 965 occurrences
Number 3: 997 occurrences
Number 7: 1015 occurrences
Number 4: 990 occurrences
Number 1: 1042 occurrences
data_batch_4:
(10000, 3072)
Number 0: 1003 occurrences
Number 6: 1004 occurrences
Number 2: 1041 occurrences
Number 7: 98

In [11]:
# TODO :
# First merge all datasets
# Load images to rgb
# get overview of images: number for each label, look of images