In [1]:
import tensorflow as tf

In [2]:
import glob

In [3]:
from itertools import groupby
from collections import defaultdict

In [4]:
image_filenames = glob.glob("./imagenet-dogs/n02*/*.jpg")
image_filenames[0:2]

['./imagenet-dogs/n02097658-silky_terrier/n02097658_26.jpg',
 './imagenet-dogs/n02097658-silky_terrier/n02097658_4869.jpg']

In [5]:
training_dataset = defaultdict(list)
testing_dataset = defaultdict(list)

In [6]:
image_filename_with_breed = list(map(lambda filename: (filename.split("/")[2],filename),image_filenames))

In [7]:
for dog_breed,breed_images in groupby(image_filename_with_breed,lambda x : x[0]):
    for i, breed_image in enumerate(breed_images):
        if i % 5 == 0:
            testing_dataset[dog_breed].append(breed_image[1])
        else:
            training_dataset[dog_breed].append(breed_image[1])
    breed_training_count = len(training_dataset[dog_breed])
    breed_testing_count = len(testing_dataset[dog_breed])
    
    assert round(breed_training_count / (breed_training_count + breed_testing_count),2) > 0.18, "Not enough testing images."

In [8]:
sess = tf.Session()

In [10]:
def write_records_file(dataset,record_location):
    writer = None
    
    current_index = 0
    for breed, images_filenames in dataset.items():
        for image_filename in images_filenames:
            if current_index % 100 == 0:
                if writer:
                    writer.close()
                record_filename = "{record_location}-{current_index}.tfrecords".format(record_location=record_location,current_index=current_index)
                writer = tf.python_io.TFRecordWriter(record_filename)
                current_index += 1
                image_file = tf.read_file(image_filename)
                try:
                    image=tf.image.decode_jpeg(image_file)
                except:
                    print(image_filename)
                    continue
                grayscale_image = tf.image.rgb_to_grayscale(image)
                resized_image = tf.image.resize_images(grayscale_image,250,151)
                image_bytes = sess.run(tf.cast(resized_image,tf.uint8)).tobytes()
                image_label = breed.encode("utf-8")
                example = tf.train.Example(features=tf.train.Features(feature={'label':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))}))
                writer.write(example.SerializeToString())
    writer.close()

In [None]:
write_records_file(testing_dataset,"./output")