## TUT for CREATING TFRECORDS

[medium-post](https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c)

- [done] create tfrecords for images
- [done] sharding 
- [ ] create tfrecords for text
- [ ] create tfrecords for audio 

In [33]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os

## The TFRecords Converters

In [34]:
def _bytes_feature(data):
    """returns a bytes_list from a string / byte"""
    if isinstance(data , type(tf.constant(0))): # if value is tensorflow type
        data = data.numpy() # get value of tensor
    bytes_list = tf.train.BytesList(value=[data]) # create bytes list
    feature = tf.train.Feature(bytes_list=bytes_list) # create feature from bytes list
    return feature


def _float_feature(data):
    """returns a float_list from a float / double"""
    float_list = tf.train.FloatList(value=[data]) # create float list
    feature = tf.train.Feature(float_list=float_list) # create feature from float list
    return feature

def _int64_feature(data):
    """returns a int64_list from a bool / enum / int / uint"""
    int64_list = tf.train.Int64List(value=[data]) # create int64 list
    feature = tf.train.Feature(int64_list=int64_list) # create feature from int64 list
    return feature

def serialize_array(array):
    array = tf.io.serialize_tensor(array) # serialize tensor
    return array

def deserialize_array(array):
    array = tf.io.parse_tensor(array, out_type=tf.float32) # parse tensor
    return array


## Generate images 

In [35]:

# ## Create a random image
# image_small_shape = (250, 250, 3)
# num_small_images = 100

# images_small =np.random.randint(
#     low=0, high = 255, size = (num_small_images, *image_small_shape), dtype=np.int16
# )

# print(images_small.shape)
# #plt.imshow(imgaes_small[0])

# ## Create a label 
# labels_small = np.random.randint(
#     low =0 , high = 5, size = (num_small_images, 1) 
# )
# print(labels_small.shape)
# labels_small[0:5]

## Parsing data

In [36]:
def parse_images(images, labels):
    examples = []
    for idx in range(len(images)):
        curr_image = images[idx]
        curr_label = labels[idx]

        data = {
            'height': _int64_feature(curr_image.shape[0]),
            'width': _int64_feature(curr_image.shape[1]),
            'depth': _int64_feature(curr_image.shape[2]),
            'label': _int64_feature(curr_label[0]),
            'image_raw': _bytes_feature(serialize_array(curr_image))
        }
        feat = tf.train.Features(feature=data)
        out = tf.train.Example(features=feat)
        examples.append(out.SerializeToString())

    return examples


def write_images_to_tfr(images, labels, filename: str = "images"):
    filename = filename + ".tfrecords"  # add extension
    writer = tf.io.TFRecordWriter(filename)  # create writer

    examples = parse_images(images, labels)
    count = len(examples)  # keep track of how many images we write

    for example in examples:
        writer.write(example)

    writer.close()
    print(f"Wrote {count} images to {filename}")
    return None



In [37]:
def parse_tfr_element(element):
    data = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string)
    }
    
    content =tf.io.parse_single_example(element, data) # parse single example
    
    heigh = content['height'] # get height
    width = content['width'] # get width
    label = content['label'] # get label
    depth = content['depth'] # get depth
    raw_img = content['image_raw'] # get image raw bytes
    
    
    feature = tf.io.parse_tensor(raw_img, out_type=tf.int16) # parse image raw bytes
    feature = tf.reshape(feature, (heigh, width, depth)) # reshape image
    return feature, label




In [38]:
def parse_single_image(image, label):
  
  #define the dictionary -- the structure -- of our single example
  data = {
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        'label' : _int64_feature(label[0])
    }
  #create an Example, wrapping the single features
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

In [39]:
# write_images_to_tfr(images_small, labels_small, filename="images")
# data_set = tf.data.TFRecordDataset("images.tfrecords") # create dataset
# data_set = data_set.map(parse_tfr_element) # parse dataset
# data_set.batch(5) # batch dataset
# for sample in data_set.take(5):
#     image = sample[0].numpy()
#     label = sample[1].numpy()
#     print(f"Image shape: {image.shape}")
#     print(f"Label: {label}")

## SHARDING

In [40]:
big_image_size = (500, 500, 3)
num_big_images = 500

large_images = np.random.randint(
    low = 0, high = 255, size = (num_big_images, *big_image_size), dtype=np.int16
)

large_labels = np.random.randint(
    low = 0, high = 5, size = (num_big_images, 1)
)



In [41]:
splits = (len(large_images) // 10) + 1
splits

51

In [42]:
def write_images_to_tfr_long(images, labels, filename:str = 'Large_images', max_files:int = 10, out_dir:str= "large_TFR"):
    if os.path.exists(out_dir) == False:
        os.mkdir(out_dir)
    splits = (len(images) // max_files) + 1 # get number of splits
    
    if len(images) % max_files == 0:
        splits -= 1
        
    print(f"Writing {splits} files")
    file_count = 0
    
    for idx in tqdm(range(splits)):
        current_shard_name = f"{out_dir}/{filename}_{idx}.tfrecords"
        
        writer = tf.io.TFRecordWriter(current_shard_name)
        current_shard_count =0
        
        while current_shard_count < max_files: # while we have not written max files
            # get current index
            index = idx * max_files + current_shard_count
            if index == len(images):
                break
            
            curr_image = images[index]
            curr_label = labels[index]
            
            
            out = parse_single_image(curr_image, curr_label)
            writer.write(out.SerializeToString())
            current_shard_count += 1
            file_count += 1
            
        writer.close()
    print(f"Wrote {file_count} images to {out_dir}/{filename}_*.tfrecords")
    return file_count

In [43]:
write_images_to_tfr_long(large_images, large_labels, max_files=30)


Writing 17 files


  0%|          | 0/17 [00:00<?, ?it/s]

Wrote 500 images to large_TFR/Large_images_*.tfrecords


500