In [1]:
import os

if not os.path.exists('CIFAR10_data'):
    
    !wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    !mkdir CIFAR10_data
    !tar -xf cifar-10-python.tar.gz -C CIFAR10_data

from tqdm import tqdm

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

from models import CIFAR_CNN
from trainer import Trainer
from utils import unpickle

datadir = 'CIFAR10_data/cifar-10-batches-py/'
batches_train = sorted([datadir + batch for batch in os.listdir(datadir) if 'data_batch' in batch], key=lambda x: int(x[-1]))
batch_test = datadir + 'test_batch'

logdir = 'tf_logs/standard/'

In [2]:
for i in tqdm(range(5)):
    
    batch = unpickle(batches_train[i])

    if i == 0:
        data = batch[b'data'].astype(np.float32)
        cifar = np.transpose(np.reshape(data, [-1,3,32,32]), [0,2,3,1])
        labels = batch[b'labels']
    else:
        data = batch[b'data'].astype(np.float32)
        cifar = np.concatenate((cifar, np.transpose(np.reshape(data, [-1,3,32,32]), [0,2,3,1])), axis=0)
        labels = np.concatenate((labels, batch[b'labels']), axis=0)

test_batch = unpickle(batch_test)
cifar_test = np.transpose(np.reshape(test_batch[b'data'], [-1,3,32,32]), [0,2,3,1])
labels_test = np.array(test_batch[b'labels'])

data_train = (cifar / 127.5 - 1.0, labels)
data_test = (cifar_test / 127.5 - 1.0, labels_test)

100%|██████████| 5/5 [00:02<00:00,  2.17it/s]


In [3]:
tf.reset_default_graph()

DNN = CIFAR_CNN(logdir, 'CNN')

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

trainer = Trainer(sess, DNN, data_train)
trainer.train(n_epochs=20, p_epochs=5)

test_acc = DNN.evaluate(sess, data_test)
print('Test Accuracy : {:.5f}'.format(test_acc))

sess.close()

Epoch : 5   | Loss : 0.43010 | Train Accuracy : 0.84832
Epoch : 10  | Loss : 0.13459 | Train Accuracy : 0.95272
Epoch : 15  | Loss : 0.06263 | Train Accuracy : 0.97884
Epoch : 20  | Loss : 0.04240 | Train Accuracy : 0.98498
Test Accuracy : 0.73510
