# Load original training dataset to be split in partitions.

Download dataset in the TFrecord format.

In [1]:
!gdown https://drive.google.com/uc?id=1wD3vKqKEFh6OfrfLNtOENF-lbe4auQDb

Downloading...
From: https://drive.google.com/uc?id=1wD3vKqKEFh6OfrfLNtOENF-lbe4auQDb
To: /content/train_tfrecords0.record
100% 2.94G/2.94G [00:21<00:00, 138MB/s]


Define a set of loading and decoding functions.

In [2]:
import tensorflow as tf
import numpy as np

def load_tf_records(filepath):
    filenames = tf.io.gfile.glob(filepath)
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=tf.data.experimental.AUTOTUNE)
    return dataset

def tf_records_file_features_description():
    image_feature_description = {
        
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([],tf.string),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        'label/P': tf.io.FixedLenFeature([], tf.float32),
        'label/K': tf.io.FixedLenFeature([], tf.float32),
        'label/Mg': tf.io.FixedLenFeature([], tf.float32),
        'label/Ph': tf.io.FixedLenFeature([], tf.float32),
    }
    return image_feature_description

def decode_dataset(example_proto):
    features=tf.io.parse_single_example(example_proto, tf_records_file_features_description())

    image=features['image']
    height=features['image/height']
    width=features['image/width']
    image=tf.io.decode_raw(image,tf.int16)
    image=tf.reshape(image,[height,width,150])
    filename=features['image/filename']

    P=features['label/P']
    K=features['label/K']
    Mg=features['label/Mg']
    Ph=features['label/Ph']

    height=features['image/height']
    width=features['image/width']

    label=[P,K,Mg,Ph]

    return image, label, height, width

Load training dataset.

In [3]:
# Training dataset filepath
dataset_tf_records_path = '/content/train_tfrecords0.record'

dataset = load_tf_records(dataset_tf_records_path).map(decode_dataset, num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle and split training dataset into 5 partitions.

In [4]:
num_images = 1732

dataset = dataset.shuffle(1732, seed=958).cache()
for i in dataset:  # iterate over dataset so that it is cached and the new resulting sets do not have overlapping elements
    pass

split_1 = dataset.take(346)
split_2 = dataset.skip(346).take(346)
split_3 = dataset.skip(692).take(346)
split_4 = dataset.skip(1038).take(347)
split_5 = dataset.skip(1385)

splits = [split_1, split_2, split_3, split_4, split_5]

# Save newly created partitions as TFRecord files.

Define a set of funtions required to save the new partitions as TFRecord files.

In [5]:
#Define some utilities

def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [6]:
#Define the encoding of the resulting tfrecords file

def create_tf_example(image, label, height, width):

    image=image.numpy().tobytes()

    P=label[0]
    K=label[1]
    Mg=label[2]
    Ph=label[3]
    
    #This is needed for Object detection API and shall be coherent with the label map

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': int64_feature(height),
        'image/width': int64_feature(width),
        'image': bytes_feature(image),

        'label/P': float_feature(P),
        'label/K': float_feature(K),
        'label/Mg': float_feature(Mg),
        'label/Ph': float_feature(Ph),
        
          
    }))
    return tf_example

In [7]:
def save_tfrecord_file_from_ds(dataset, output_path):

    #Define the desired number of TFRecords files, for TPU parallel data loading 100MB is the optimal file size
    number_of_tfrecords_files=1

    images_processed=0

    for i in range(number_of_tfrecords_files):
        writer = tf.io.TFRecordWriter(output_path)
    
    
        for image, label, height, width in dataset:
            images_processed+=1
            features=label
      
            tf_example = create_tf_example(image, label, height, width)
            writer.write(tf_example.SerializeToString())

        writer.close()
        print('Successfully created the TFRecord file: {}'.format(output_path))

Save TFRecord files to hosted runtime path.

In [8]:
# Output path
output_path='/content/split_{}.record'

i=0
for data in splits:
    i+=1
    save_tfrecord_file_from_ds(data, output_path.format(i))

Successfully created the TFRecord file: /content/split_1.record
Successfully created the TFRecord file: /content/split_2.record
Successfully created the TFRecord file: /content/split_3.record
Successfully created the TFRecord file: /content/split_4.record
Successfully created the TFRecord file: /content/split_5.record
