In [None]:
import tensorflow as tf
import os 
import numpy as np

# Parameters

In [None]:
# Dataset Constants
DATASET_SPLIT = ["training", "validation", "test"]
DATASET_PATH = "dataset"
TRAIN_DIR = "train"
VAL_DIR = "validation"
TEST_DIR = "test"

DATA_DIR = "data"
LABEL_DIR = "label"

IMG_EXT = "png"

# Helper Functions

In [None]:
def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[serialize_array(value)])
    )

def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature_list(value):
    """Returns a list of float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(image, path, example):
    feature = {
        "image": image_feature(image),
        "path": bytes_feature(path),
        "area": float_feature(example["area"]),
        "bbox": float_feature_list(example["bbox"]),
        "category_id": int64_feature(example["category_id"]),
        "id": int64_feature(example["id"]),
        "image_id": int64_feature(example["image_id"]),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "path": tf.io.FixedLenFeature([], tf.string),
        "area": tf.io.FixedLenFeature([], tf.float32),
        "bbox": tf.io.VarLenFeature(tf.float32),
        "category_id": tf.io.FixedLenFeature([], tf.int64),
        "id": tf.io.FixedLenFeature([], tf.int64),
        "image_id": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    example["bbox"] = tf.sparse.to_dense(example["bbox"])
    return example

# non keras
def serialize_array(array):
  array = tf.io.serialize_tensor(array)
  return array

In [None]:
def parse_single_image(image, label):
  
  #define the dictionary -- the structure -- of our single example
  data = {
        'image/height' : int64_feature(image.shape[0]),
        'image/width' : int64_feature(image.shape[1]),
        'image/depth' : int64_feature(image.shape[2]),
        'image/raw_image' : image_feature(image),
        'label/raw' : image_feature(label)
    }
  #create an Example, wrapping the single features
  out = tf.train.Example(features=tf.train.Features(feature=data))
  return out

In [None]:
def write_images_to_tfr_long(images, labels, filename:str="large_images", max_files:int=10, out_dir:str="/content/"):

    #determine the number of shards (single TFRecord files) we need:
    splits = (len(images)//max_files) + 1 #determine how many tfr shards are needed
    if len(images)%max_files == 0:
        splits-=1
    print(f"\nUsing {splits} shard(s) for {len(images)} files, with up to {max_files} samples per shard")

    file_count = 0
    for i in tqdm.tqdm(range(splits)):
        current_shard_name = f"{out_dir}{i+1}_{splits}{filename}.tfrecords"
        writer = tf.io.TFRecordWriter(current_shard_name)

        current_shard_count = 0
        while current_shard_count < max_files: #as long as our shard is not full
            #get the index of the file that we want to parse now
            index = i*max_files+current_shard_count
            if index == len(images): #when we have consumed the whole data, preempt generation
                break
            current_image = images[index]
            current_label = labels[index]

            #create the required Example representation
            out = parse_single_image(image=current_image, label=current_label)
            
            writer.write(out.SerializeToString())
            current_shard_count+=1
            file_count += 1
        writer.close()
    print(f"\nWrote {file_count} elements to TFRecord")
    return file_count

# Prepare Masks

In [None]:
def prepare_masks(image, mask, class_values):
    
    # extract certain classes from mask (e.g. cars)
    masks = [(mask == v) for v in class_values]
    mask = np.stack(masks, axis=-1).astype('float')
    
    # add background if mask is not binary
    if mask.shape[-1] != 1:
        background = 1 - mask.sum(axis=-1, keepdims=True)
        mask = np.concatenate((mask, background), axis=-1)

In [None]:
train_dir = os.path.join(DATASET_PATH, TRAIN_DIR)
val_dir = os.path.join(DATASET_PATH, VAL_DIR)
test_dir = os.path.join(DATASET_PATH, TEST_DIR)
train_img_dir = os.path.join(train_dir, DATA_DIR)
train_label_dir = os.path.join(train_dir, LABEL_DIR)
val_img_dir = os.path.join(val_dir, DATA_DIR)
val_label_dir = os.path.join(val_dir, LABEL_DIR)
test_img_dir = os.path.join(test_dir, DATA_DIR)
test_val_dir = os.path.join(test_dir, LABEL_DIR)

In [None]:
train_img_filenames = tf.io.gfile.glob(f"{train_img_dir}/*.{IMG_EXT}")