<a href="https://colab.research.google.com/github/Alizzie/CS50/blob/main/ST_CA1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [74]:
import requests
import numpy as np
import tensorflow as tf
from pathlib import Path
import pickle
from collections import Counter

In [75]:
np.random.seed(0)

In [76]:
def download_datasets(url, dataset_name):
    path_to_zip = tf.keras.utils.get_file(
        fname=f"{dataset_name}.tar.gz",
        origin=url,
        extract=True)

    path_to_zip = Path(path_to_zip)
    path = path_to_zip.parent / dataset_name
    return path_to_zip, path

In [77]:
# Copy from https://www.cs.toronto.edu/~kriz/cifar.html
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [78]:
def extract_batches(path, batchname):
    batch_dic = unpickle(str(path / batchname))
    batch_labels = batch_dic.get(b'labels')
    batch_images = batch_dic.get(b'data')
    batch_filenames = batch_dic.get(b'filenames')

    print(f'{batchname}:')
    print(batch_images.shape) # image as 3072 byte, 1024 each rgb channel

    assert((len(batch_labels) == 10000)), 'labels contains not all 10000 labels'
    assert((len(batch_images) == 10000)), 'images contains not all 10000 images'
    assert((batch_images.shape[1] == 3072)), 'images are not in 3072 bytes'
    assert((len(batch_filenames) == 10000)), 'filenames contains not all 10000 filenames'

    for number, count in Counter(batch_labels).items():
        print(f"Number {number}: {count} occurrences")
    return batch_labels, batch_images, batch_filenames

In [79]:
def get_class_labels_number(class_names, dic):
    labels = dic.get(b'label_names')
    labels = list(map(lambda x: x.decode('utf-8'), labels))
    labels_nr = [index for index, value in enumerate(labels) if value in class_names]
    return labels_nr

In [80]:
def filter_class(classes, labels, images, filenames):
    filtered_labels = [label for label in labels if label in classes]
    filtered_images = [images[labels == cls] for cls in classes]
    filtered_filenames = [filenames[labels == cls] for cls in classes]
    return np.array(filtered_labels), np.array(filtered_images), np.array(filtered_filenames)

In [81]:
cifar10_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
zip, cifar10_path = download_datasets(cifar10_url, 'cifar-10-batches-py')

In [82]:
# class_labels: 1, 2, 3, 4, 5, 7, 9
cifar10_dic_file = 'batches.meta'
cifar10_dic = unpickle(str(cifar10_path / label_dic_file))
selected_cifar10_classes = get_class_labels_number(['automobile', 'bird',  'cat', 'deer', 'dog', 'horse', 'truck'], label_dic)
print(selected_cifar10_classes)

[1, 2, 3, 4, 5, 7, 9]


In [83]:
## Test batch overview for one batch firsts
#   filename = 'data_batch_1'
#   batch_1_dic = unpickle(str(cifar10_path / filename))
# Overview of dic structure
#print(type(batch_1_dic), len(batch_1_dic))
#for key, value in batch_1_dic.items():
#    print(f"Key: {key}, Type: {type(key)}")
#    print(f"Value: {value}, Type: {type(value)}")

In [84]:
cifar10_labels = []
cifar10_images = []
cifar10_filenames = []
batchnames = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']

for name in batchnames:
    labels, images, filenames = extract_batches(cifar10_path, name)
    cifar10_labels.extend(labels)
    cifar10_images.extend(images)
    cifar10_filenames.extend(filenames)

assert((len(cifar10_labels) == 50000)), 'labels contains not all 50000 labels'
assert((len(cifar10_images) == 50000)), 'images contains not all 50000 images'
assert((len(cifar10_filenames) == 50000)), 'filenames contains not all 50000 filenames'

cifar10_labels = np.array(cifar10_labels)
cifar10_images = np.array(cifar10_images)
cifar10_filenames = np.array(cifar10_filenames)

data_batch_1:
(10000, 3072)
Number 6: 1030 occurrences
Number 9: 981 occurrences
Number 4: 999 occurrences
Number 1: 974 occurrences
Number 2: 1032 occurrences
Number 7: 1001 occurrences
Number 8: 1025 occurrences
Number 3: 1016 occurrences
Number 5: 937 occurrences
Number 0: 1005 occurrences
data_batch_2:
(10000, 3072)
Number 1: 1007 occurrences
Number 6: 1008 occurrences
Number 8: 987 occurrences
Number 3: 995 occurrences
Number 4: 1010 occurrences
Number 0: 984 occurrences
Number 5: 988 occurrences
Number 2: 1010 occurrences
Number 7: 1026 occurrences
Number 9: 985 occurrences
data_batch_3:
(10000, 3072)
Number 8: 961 occurrences
Number 5: 1029 occurrences
Number 0: 994 occurrences
Number 6: 978 occurrences
Number 9: 1029 occurrences
Number 2: 965 occurrences
Number 3: 997 occurrences
Number 7: 1015 occurrences
Number 4: 990 occurrences
Number 1: 1042 occurrences
data_batch_4:
(10000, 3072)
Number 0: 1003 occurrences
Number 6: 1004 occurrences
Number 2: 1041 occurrences
Number 7: 98

In [85]:
cifar10_labels, cifar10_images, cifar10_filenames = filter_class(selected_cifar10_classes, cifar10_labels, cifar10_images, cifar10_filenames)
print(set(cifar10_labels))
print(cifar10_labels.shape)
print(cifar10_images.shape)
print(cifar10_filenames.shape)

{1, 2, 3, 4, 5, 7, 9}
(35000,)
(7, 5000, 3072)
(7, 5000)


In [86]:
# TODO :
# First merge all datasets
# Load images to rgb
# get overview of images: number for each label, look of images