In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import tensorflow_probability as tfp
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
from random import randrange
import matplotlib.pyplot as plt
import scipy.ndimage as ndimage
import numpy as np
from numpy import expand_dims

def zoom(img, zoom_factor):
  img = img.numpy()
  zoom_factor = zoom_factor.numpy()
  h, w = img.shape[:2]
  zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2)
  # Zooming out
  if zoom_factor < 1:
    zh = int(np.round(h * zoom_factor))
    zw = int(np.round(w * zoom_factor))
    top = (h - zh) // 2
    left = (w - zw) // 2
    # Zero-padding
    out = np.zeros_like(img)
    out[top:top+zh, left:left+zw] = ndimage.zoom(img, zoom_tuple)
  # Zooming in
  elif zoom_factor > 1:
    zh = int(np.round(h / zoom_factor))
    zw = int(np.round(w / zoom_factor))
    top = (h - zh) // 2
    left = (w - zw) // 2
    out = ndimage.zoom(img[top:top+zh, left:left+zw], zoom_tuple)
  # No Zoom
  else:
    out = img
  return out

def rotate(img, angle):
  img = img.numpy()
  angle = angle.numpy()
  rotated_img = ndimage.rotate(img, angle, reshape=False)
  return rotated_img

def rand_scale(s):
  scale = np.random.uniform(low = 1, high = s)
  if (np.random.randint(low = 0, high = 9223372036854775807)%2):
      return scale
  return  1/scale


def preprocessing_selection(choice):
  def train_preprocess_image_classification(image, label):
    low = 128
    high = 448
    image = tf.cast(image, tf.float32)
    deg = np.random.uniform(low = -7, high = 7)
    aspect = rand_scale(0.75)
    width = image.shape[1]
    height = image.shape[0]
    scale = int(np.random.randint(low = low, high = high))/224
    image = tf.image.resize(image, (width, int(width/aspect)), preserve_aspect_ratio=False)
    shape = [image.shape[0], image.shape[1], image.shape[2]]
    image = tf.py_function(zoom, [image, scale], tf.float32)
    try:
      image.set_shape(shape)
    except:
      pass
    image = tf.py_function(rotate, [image, deg], tf.float32)
    try:
      image.set_shape(shape)
    except:
      pass
    image = tf.image.resize_with_pad(image, target_width=224, target_height=224)
    if image.shape[-1] > 1:
      image = tf.image.adjust_hue(image, np.random.uniform(low = -0.1, high = 0.1))
      image = tf.image.adjust_saturation(image, rand_scale(.75))
    image = tf.image.adjust_brightness(image, rand_scale(.75))
    image = image / 255
    return image, label

  def train_preprocess_priming_classification(image, label):
    low = 448
    high = 512
    image = tf.cast(image, tf.float32)
    width = image.shape[1]
    height = image.shape[0]
    scale = int(np.random.randint(low = low, high = high))/448
    shape = [image.shape[0], image.shape[1], image.shape[2]]
    image = tf.py_function(zoom, [image, scale], tf.float32)
    image.set_shape(shape)
    image = tf.image.resize_with_pad(image, target_width=448, target_height=448)
    image = image / 255
    return image, label

  def train_preprocess_object_detection(image, label):
    image = tf.cast(image, tf.float32)
    width = image.shape[1]
    height = image.shape[0]
    resize_num = np.random.randint(low = 10, high = 19)*32
    image = tf.image.resize_with_pad(image, target_width=resize_num, target_height=resize_num)
    image = tf.image.adjust_hue(image, np.random.uniform(low = -0.1, high = 0.1))
    image = tf.image.adjust_saturation(image, rand_scale(1.5))
    image = tf.image.adjust_brightness(image, rand_scale(1.5))
    image = image / 255
    return image, label

  try:
    if choice.lower() == "detection":
      return train_preprocess_object_detection
    elif choice.lower() == "classification":
      return train_preprocess_image_classification
    elif choice.lower() == "priming":
      return train_preprocess_priming_classification
    else:
      NameError('Invalid Input')
  except:
    raise NameError('Invalid Input')
  

In [None]:
#param train, top%, bottom%, type Splits before 
def preprocessing(training_dataset, data_augmentation_split, type_of_preprocessing, size_of_dataset):
  non_preprocessed_split = int(((100 - data_augmentation_split)/100)*size_of_dataset)
  data_augmentation_split = int((data_augmentation_split/100)*size_of_dataset)

  data_augmentation_dataset = training_dataset.take(data_augmentation_split)
  remaining = training_dataset.skip(data_augmentation_split)  
  non_preprocessed_split = remaining.take(non_preprocessed_split)

  '''
  Try to split dataset without having to use tfds.load 
  but rather split the training_dataset object itself.

  ***UPDATE***
  Got size of Dataset via:
  Dataset, Info = tfds.load('mnist', split=['train', 'test'], with_info=True)
  Train = Dataset[0]
  Test = Dataset[1]
  Size = int(Info.splits['train'].num_examples)
  
  --------------------------------------------------------------------------------

  And Split via:
  non_preprocessed_split = int(((100 - data_augmentation_split)/100)*size_of_dataset)
  data_augmentation_split = int((data_augmentation_split/100)*size_of_dataset)
  data_augmentation_dataset = training_dataset.take(data_augmentation_split)
  remaining = training_dataset.skip(data_augmentation_split)  
  non_preprocessed_split = remaining.take(non_preprocessed_split)


  was able to split the datset

  ###############################################################
  REASONS WHY MAP FUNCTION IS NOT WORKING:
  1. make sure as_supervised is true
  2. rotate and zoom functions are not working properly, they only work with numpy arrays
  3. concatenation is not working, different types due to casting (probably need to normalize the non_preprocessed part too) (fixed)
  ################################################################
  I already commented out the lines that are not working
  update: concatenation works after adding the below normalize function
  update: apparently tf.keras.preprocessing.image.random_rotate does not work with tensors
  update: I can't figure out how to covert a tensor to a numpy array, .numpy() is not working
  update: tfa.image.rotate works
  '''
  

  def normalize(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255
    return image, label


  preprocessing_function = preprocessing_selection("classification")
  data_augmentation_dataset = data_augmentation_dataset.map(preprocessing_function, num_parallel_calls=4)
  non_preprocessed_split = non_preprocessed_split.map(normalize, num_parallel_calls = 4)
  data_augmentation_dataset.concatenate(non_preprocessed_split)
  return data_augmentation_dataset

Dataset, Info = tfds.load('mnist', split=['train', 'test'], with_info=True, as_supervised= True)
Train = Dataset[0]
Test = Dataset[1]
Size = int(Info.splits['train'].num_examples)
train_ds = preprocessing(Train, 5,"classification", Size)
train_ds_sample = train_ds.take(1)


for x,y in train_ds_sample:
  x = tf.reshape(x, [x.shape[0], x.shape[1]])
  plt.imshow(x)
  break
  image = img_to_array(x)

#train_ds = train_ds.batch(batch_size)
#train_ds = train_ds.prefetch(1)

In [None]:
#testing
image_string = tf.io.read_file('robertducky5.jpg')
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
print(image.shape)
label = 'A'
preprocessing_function = preprocessing_selection("classification")
img, lb = preprocessing_function(image, label)
plt.imshow(img)
plt.show()

In [None]:
'''
dataset = dataset.map(train_preprocess, num_parallel_calls=4)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
'''
'''
def train_preprocess(image, label):
    image = tf.image.random_flip_left_right(image)

    image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)

    #Make sure the image is still in [0, 1]
    image = tf.clip_by_value(image, 0.0, 1.0)

    return image, label
'''