Retraining (fine-tuning) Current CNN Architectures
==================================================

The purpose of the script provided in this section is to download the CIFAR-10 data, and sort it out in the proper folder structure for running it through the Tensorflow fine-tuning tutorial.  The script should create the following folder structure.

```
-train_dir
  |--airplane
  |--automobile
  |--bird
  |--cat
  |--deer
  |--dog
  |--frog
  |--horse
  |--ship
  |--truck
-validation_dir
  |--airplane
  |--automobile
  |--bird
  |--cat
  |--deer
  |--dog
  |--frog
  |--horse
  |--ship
  |--truck
```

After this is done, we proceed with the [Tensorflow fine-tuning tutorial](https://github.com/tensorflow/models/tree/master/inception).

In [2]:
# Download/Saving CIFAR-10 images in Inception format
# ---------------------------------------
#
# In this script, we download the CIFAR-10 images and
# transform/save them in the Inception Retrianing Format
#
# The end purpose of the files is for retrianing the
import os
import tarfile
import _pickle as cPickle
import numpy as np
import urllib.request
import scipy.misc
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
data_dir = 'temp'
if not os.path.isdir(data_dir):
    os.makedirs(data_dir)
    
# Download tar file
target_file = os.path.join(data_dir,"cifar-10-python.tar.gz")
if not os.path.isfile(target_file):
    print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')
    print('This may take a few minutes, please wait.')
    fileanme,headers = urllib.request.urlretrieve(cifar_link,target_file)
    
# Extract into memory
tar = tarfile.open(target_file)
tar.extractall(path=data_dir)
tar.close()
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Create train image folders
train_folder = 'train_dir'
if not os.path.isdir(os.path.join(data_dir,train_folder)):
    for i in range(len(objects)):
        folder = os.path.join(data_dir,train_folder,objects[i])
        os.makedirs(folder)
# Create test image folders
test_folder = 'validation_dir'
if not os.path.isdir(os.path.join(data_dir,test_folder)):
    for i in range(len(objects)):
        folder = os.path.join(data_dir,test_folder,objects[i])
        os.makedirs(folder)
        
# Extract images accordingly
data_location = os.path.join(data_dir,'cifar-10-batches-py')
train_names = ['data_batch_' + str(x) for x in range(1,6)]
test_names = ['test_batch']
print(train_names,test_names)

CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)
This may take a few minutes, please wait.
['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5'] ['test_batch']


In [4]:
def load_batch_from_file(file):
    file_conn = open(file,"rb")
    image_dictionary = cPickle.load(file_conn,encoding='latin1') #反序列化对象。将文件中的数据解析为一个Python对象
    file_conn.close()
    return (image_dictionary)

def save_image_from_dict(image_dict,folder='data_dir'):
    # image_dict.keys() = 'labels', 'filenames', 'data', 'batch_label'
    for ix,label in enumerate(image_dict['labels']):
#         print(ix,label)
        folder_path = os.path.join(data_dir,folder,objects[label])
        filename = image_dict['filenames'][ix]
        # Transform image data
        image_array= image_dict['data'][ix]
        image_array.resize([3,32,32])
        # Save image
        output_location = os.path.join(folder_path,filename)
        scipy.misc.imsave(output_location,image_array.transpose())

# Sort train images
for file in train_names:
    print('Saving images from file: {}'.format(file))
    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
    image_dict = load_batch_from_file(file_location)
    save_image_from_dict(image_dict,folder=train_folder)



Saving images from file: data_batch_1
Saving images from file: data_batch_2
Saving images from file: data_batch_3
Saving images from file: data_batch_4
Saving images from file: data_batch_5


In [5]:
# Sort test images
for file in test_names:
    print("saving images from file:{}".format(file))
    file_location = os.path.join(data_dir,'cifar-10-batches-py',file)
    image_dict = load_batch_from_file(file_location)
    save_image_from_dict(image_dict,folder=test_folder)

saving images from file:test_batch


In [7]:
# Create labels file
cifar_labels_file = os.path.join(data_dir, 'cifar10_labels.txt')
print('Writing labels file, {}'.format(cifar_labels_file))
with open(cifar_labels_file,'w') as label_file:
    for item in objects:
        label_file.write("{}\n".format(item))

Writing labels file, temp\cifar10_labels.txt
