In [3]:
%matplotlib inline

import math
import io 
import skimage.io as skiio
import tensorflow as tf
import os
import matplotlib.image as mpimg
import PIL
import shutil
import numpy as np
import sys
from PIL import Image

In [4]:
filename='/home/kunho/dog.8893.jpg'

In [5]:
os.path.basename(filename).split('.')[0]

'dog'

In [6]:
class GenerateTFRecord:
    def __init__(self, labels):
        self.labels = labels

    def convert_image_folder(self, img_folder, tfrecord_folder, split_name, tfrecord_filename, _NUM_SHARDS):
        
        assert split_name in ['train', 'validation']
        
        # Get all file names of images present in folder
        img_paths = os.listdir(img_folder)
        img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]
        
        
        num_per_shard = int(math.ceil(len(img_paths) / float(_NUM_SHARDS)))

        with tf.Graph().as_default():
            #image_reader = ImageReader()
            with tf.Session('') as sess:
                for shard_id in range(_NUM_SHARDS):
                    output_filename = self._get_dataset_filename(
                        tfrecord_folder, split_name, shard_id, tfrecord_filename, _NUM_SHARDS)
                    print('output_filename is %s' % (output_filename))
                    with tf.python_io.TFRecordWriter(output_filename) as writer:
                        start_ndx = shard_id * num_per_shard
                        end_ndx = min((shard_id+1) * num_per_shard, len(img_paths))
                        for i in range(start_ndx, end_ndx):
                            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                                i+1, len(img_paths), shard_id))
                            sys.stdout.flush()
                            example = self._convert_image(img_paths[i])
                            writer.write(example.SerializeToString())

    def _convert_image(self, img_path):
        label = self._get_label_with_filename(img_path)
        img_shape = mpimg.imread(img_path).shape
        filename = os.path.basename(img_path)

        # Read image data in terms of bytes
        #with tf.gfile.FastGFile(img_path, 'rb') as fid:
        #    image_data = fid.read()
        
        img = Image.open(img_path, mode='r')
        imgByteArr = io.BytesIO()
        img.save(imgByteArr, format=img.format)
        image_data = imgByteArr.getvalue()
            
        #image_string = tf.read_file(final_path[0])
        #image = tf.image.decode_jpeg(image_string)
        #image = tf.cast(image,tf.int8)
        #image = tf.cast(image,tf.float32)
        
        example = tf.train.Example(features = tf.train.Features(feature = {
            'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[2]])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example
        
    def _get_label_with_filename(self, filename):
        basename = os.path.basename(filename).split('.')[0]
        basename = basename.split('_')[0]
        return self.labels[basename]
    
    def _get_dataset_filename(self,dataset_dir, split_name, shard_id, tfrecord_filename_input, _NUM_SHARDS):
        output_filename = '%s_%s_%05d-of-%05d.tfrecord' % (
        tfrecord_filename_input, split_name, shard_id, _NUM_SHARDS)
        return os.path.join(dataset_dir, output_filename)

In [None]:
img_folder=os.path.join(os.path.expanduser('~'),'anaconda3','examples','cats_and_dog','train')
tf_record_folder=os.path.join(os.path.expanduser('~'),'anaconda3','examples','cats_and_dog','tf_record')
if not os.path.exists('tf_record'):
    os.makedirs('tf_record')
    print("%s was created since it didn't exist!" % (tf_record_folder))
    
labels = {'cat': 0, 'dog': 10}
t = GenerateTFRecord(labels)
t.convert_image_folder( img_folder, tf_record_folder, 'train', 'image', 10)

output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00000-of-00010.tfrecord
>> Converting image 2500/25000 shard 0output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00001-of-00010.tfrecord
>> Converting image 5000/25000 shard 1output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00002-of-00010.tfrecord
>> Converting image 7500/25000 shard 2output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00003-of-00010.tfrecord
>> Converting image 10000/25000 shard 3output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00004-of-00010.tfrecord
>> Converting image 12500/25000 shard 4output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00005-of-00010.tfrecord
>> Converting image 15000/25000 shard 5output_filename is /home/kunho/anaconda3/examples/cats_and_dog/tf_record/image_train_00006-of-00010.tfrecord
>> Con

In [7]:
class TFRecordExtractor:
    def __init__(self, tfrecord_file):
        self.tfrecord_file = os.path.abspath(tfrecord_file)

    def _extract_fn(self, tfrecord):
        # Extract features using the keys set during creation
        features = {
            'filename': tf.FixedLenFeature([], tf.string),
            'rows': tf.FixedLenFeature([], tf.int64),
            'cols': tf.FixedLenFeature([], tf.int64),
            'channels': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }

        # Extract the data record
        sample = tf.parse_single_example(tfrecord, features)

        image = tf.image.decode_image(sample['image'])        
        img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
        label = sample['label']
        filename = sample['filename']
        return [image, label, filename, img_shape]        

    def extract_image(self):
        # Create folder to store extracted images
        folder_path = './ExtractedImages'
        shutil.rmtree(folder_path, ignore_errors = True)
        os.mkdir(folder_path)

        # Pipeline of dataset and iterator 
        dataset = tf.data.TFRecordDataset([self.tfrecord_file])
        dataset = dataset.map(self._extract_fn).batch(64)
        #dataset = dataset.map(self._extract_fn),shuffle(True).batch(32)
        #tfrecord_dataset = tfrecord_dataset.map(lambda   x:_parse_(x)).shuffle(True).batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        next_image_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            try:
                # Keep extracting data till TFRecord is exhausted
                while True:
                    image_data = sess.run(next_image_data)

                    # Check if image shape is same after decoding
                    if not np.array_equal(image_data[0].shape, image_data[3]):
                        print('Image {} not decoded properly'.format(image_data[2]))
                        continue
                        
                    save_path = os.path.abspath(os.path.join(folder_path, image_data[2].decode('utf-8')))
                    #image_data[0] = tf.cast(image_data[0],tf.int8)
                    mpimg.imsave(save_path, image_data[0])
                    print('Save path = ', save_path, ', Label = ', image_data[1])
            except:
                pass


In [8]:
tf_record_folder=os.path.join(os.path.expanduser('~'),'anaconda3','examples','cats_and_dog','tf_record')

onlyfiles = [f for f in os.listdir(tf_record_folder) if os.path.isfile(os.path.join(tf_record_folder, f))]
print(onlyfiles)
for file in onlyfiles:
    t = TFRecordExtractor(os.path.join(tf_record_folder,file))
    t.extract_image()


['image_train_00000-of-00010.tfrecord', 'image_train_00001-of-00010.tfrecord', 'image_train_00002-of-00010.tfrecord', 'image_train_00003-of-00010.tfrecord', 'image_train_00004-of-00010.tfrecord', 'image_train_00005-of-00010.tfrecord', 'image_train_00006-of-00010.tfrecord', 'image_train_00007-of-00010.tfrecord', 'image_train_00008-of-00010.tfrecord', 'image_train_00009-of-00010.tfrecord']
Instructions for updating:
Colocations handled automatically by placer.


In [None]:
    t = TFRecordExtractor('./images_byte.tfrecord')
    t.extract_image()

In [None]:
cat_img = skiio.imread(os.path.join(os.path.expanduser('~'),'anaconda3','examples','cats_and_dog','train','cat.1000.jpg'))
skiio.imshow(cat_img)

In [None]:
cat_img = skiio.imread(os.path.join(os.path.expanduser('~'),'anaconda3','examples','cats_and_dog','ExtractedImages','cat.1000.jpg'))
skiio.imshow(cat_img)