In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
from dotenv import load_dotenv
load_dotenv('.env')

import os
token = os.getenv("GITHUB_TOKEN")
username = 'alessandroardenghi'
repo = 'SemanticSegmentation'

Data Preprocessing

In [68]:
import random
import tensorflow as tf
random.seed(1)

datapoints = [f'data/{element}' for element in os.listdir('data') if element.endswith('.tif')]
print(f'Example Datapoints:\n{datapoints[:3]}')

n_elements = 5000
selected_datapoints = random.sample(datapoints, n_elements)
random.shuffle(selected_datapoints)

n_folds = 5
fold_size = n_elements // n_folds
folds = [selected_datapoints[i*fold_size: (i+1)*fold_size] for i in range(n_folds)]
# Sanity Check
for i, fold in enumerate(folds):
  print(f'n_datapoints in fold {i} = {len(fold)}')

Example Datapoints:
['data/0000000224-0000028448.tif', 'data/0000000224-0000028672.tif', 'data/0000000224-0000028896.tif']
n_datapoints in fold 0 = 1000
n_datapoints in fold 1 = 1000
n_datapoints in fold 2 = 1000
n_datapoints in fold 3 = 1000
n_datapoints in fold 4 = 1000


In [69]:
TFRecord_filenames = [f'TFRecord_dir/TFRecord_fold{i}' for i in range(5)]

In [70]:
def _bytes_feature(value):        # Function taken from Tensorflow Official Documentation
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_datapoint(datapoint_filename):
    with open(datapoint_filename, 'rb') as file:
        bytes_data = file.read()

    feature = {
        'image_raw': _bytes_feature(bytes_data),
        'file_name': _bytes_feature(os.path.basename(datapoint_filename).encode('utf-8')),
    }

    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()

In [71]:
for i, fold in enumerate(folds):
  tfr_name = TFRecord_filenames[i]
  for datapoint in fold:
      serialized_datapoint = serialize_datapoint(datapoint)
      tf.io.TFRecordWriter(tfr_name).write(serialized_datapoint)