#### Import library

In [None]:
import os
import numpy as np
import tensorflow as tf

#### CIFAR-10 Dataset Analize

In [None]:
CIFAR_DIR = '/home/commaai-03/Data/dataset/cifar-10-python'

def unpickle(file):
    import pickle
    with open(file, 'rb') as f:
        data = pickle.load(f, encoding='bytes')
    return data

In [None]:
filenames = [os.path.join(CIFAR_DIR, file) 
             for file in os.listdir(CIFAR_DIR)
             if '.html' not in file]
filenames.sort()
meta_file = filenames[0]
train_files = filenames[1:-1]
test_file = [filenames[-1]]

#### Test Data Handle

It's just a test sample which show you the details about cifar data format.

In [None]:
test_data = unpickle(test_file)
for k, v in test_data.items():
    print(k)

In [None]:
print(type(test_data[b'batch_label']))
print(type(test_data[b'labels']))
print(type(test_data[b'data']))
print(type(test_data[b'filenames']))

print(test_data[b'batch_label'])
print(test_data[b'labels'][:5])
print(test_data[b'data'][:5])
print(test_data[b'filenames'][:5])

In [None]:
img_arr = test_data[b'data'][0]
# 32 * 32 * 3 (R,G,B)
img_arr_reshaped = img_arr.reshape((3, 32, 32))
img = img_arr_reshaped.transpose(1, 2, 0)
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

%matplotlib inline
imshow(img)

#### Class CifarData

In [None]:
class CifarDate:
    
    def __init__(self, filenames, need_shuffle):
        all_data = []
        all_label = []
        for filename in filenames:
            data, labels = self.load_data(filename)
            for k, v in zip(data, labels):
                if v in [0, 1]:
                    all_data.append(k)
                    all_label.append(v)
            #all_data.append(data)
            #all_label.append(labels)
        self._data = np.vstack(all_data)
        self._label = np.hstack(all_label)
        print('[CIFAR-10]: Data shape-> %s' % str(self._data.shape))
        print('[CIFAR-10]: Label shape-> %s' % str(self._label.shape))
        
        self.num_examples = self._data.shape[0]
        self._need_shuffle = need_shuffle
        self._indicator = 0
        if self._need_shuffle:
            self._shuffle_data()
        

    def load_data(self, filename):
        import pickle
        with open(filename, mode='rb') as f:
            data = pickle.load(f, encoding='bytes')
        return data[b'data'], data[b'labels']
        
    def _shuffle_data(self):
        index = np.random.permutation(self.num_examples)
        self._data = self._data[index]
        self._label = self._label[index]
        
    def next_batch(self, batch_size):
        end_indicator = self._indicator + batch_size
        if end_indicator > self.num_examples:
            rest_num_examples = self.num_examples - self._indicator
            data_rest_part = self._data[self._indicator: self.num_examples]
            label_rest_part = self._label[self._indicator: self.num_examples]
            
            if self._need_shuffle:
                self._shuffle_data()
            # For new loop, self._indicator + batch_size = self.num_examples
            self._indicator = batch_size - rest_num_examples
            end_indicator = self._indicator
            data_new_part = self._data[:end_indicator]
            label_new_part = self._label[:end_indicator]
            batch_data = np.concatenate((data_rest_part, data_new_part), axis=0)
            batch_label = np.concatenate((label_rest_part, label_new_part), axis=0)
        else:
            batch_data = self._data[self._indicator:end_indicator]
            batch_label = self._label[self._indicator:end_indicator]
            self._indicator = end_indicator
        
        return batch_data, batch_label

In [None]:
test_data = CifarDate(test_file, False)
test_data.next_batch(32)

In [None]:
train_data = CifarDate(train_files, True)
train_data.next_batch(32)

#### Draw the graph

In [None]:
x = tf.placeholder(tf.float32, shape=(None, 3072))
y = tf.placeholder(tf.int64, shape=(None))

w = tf.get_variable('w', shape=[x.get_shape()[-1], 1], 
                    initializer=tf.random_normal_initializer(0, 1))
b = tf.get_variable('b', shape=[1],
                   initializer=tf.initializers.constant(0.0))

y_ = tf.matmul(x, w) + b
p_y_1 = tf.math.sigmoid(y_)

y_reshape = tf.reshape(y, [-1, 1])
y_reshape_float = tf.cast(y_reshape, tf.float32)

loss = tf.reduce_mean(tf.square(y_reshape_float - p_y_1))

predict = p_y_1 > 0.5
correct_prediction = tf.equal(tf.cast(predict, tf.int64), y_reshape)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

with tf.name_scope('train_op'):
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

In [None]:
init = tf.initializers.global_variables()
batch_size = 20
train_steps = 100000

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init)
    for i in range(train_steps):
        batch_data, batch_label = train_data.next_batch(batch_size)
        loss_val, acc_val, _ = sess.run(
                                    [loss, accuracy, train_op],
                                    feed_dict={
                                                x: batch_data,
                                                y: batch_label
                                    })
        if (i+1) % 500 == 0:
            print('[Train]: Step: %d, loss: %4.5f, acc: %4.5f'
                 % (i+1, loss_val, acc_val))