Skip to content
Permalink
Browse files

Removed infer pipeline, not needed

  • Loading branch information
HasnainRaz committed Oct 28, 2019
1 parent eeb9183 commit 9e269e39a6b843787a627786db3becaccab6f7ec
Showing with 0 additions and 72 deletions.
  1. +0 −72 dataloader.py
@@ -121,75 +121,3 @@ def dataset(self, batch_size, threads=4):
dataset = dataset.shuffle(30).batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

return dataset


class DataLoaderInfer(object):
"""Data Loader for the SR GAN, that prepares a tf data object for training."""

def __init__(self, image_dir):
"""
Initializes the dataloader.
Args:
image_dir: The path to the directory containing high resolution images.
hr_image_size: Integer, the crop size of the images to train on (High
resolution images will be cropped to this width and height).
Returns:
The dataloader object.
"""
self.image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]

def _parse_image(self, image_path):
"""
Function that loads the images given the path.
Args:
image_path: Path to an image file.
Returns:
image: A tf tensor of the loaded image.
"""

image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)

return image

def _rescale(self, low_res):
"""
Function that rescales the pixel values to the -1 to 1 range.
For use with the generator output tanh function.
Args:
low_res: The tf tensor of the low res image.
high_res: The tf tensor of the high res image.
Returns:
low_res: The tf tensor of the low res image, rescaled.
high_res: the tf tensor of the high res image, rescaled.
"""
low_res = low_res * 2.0 - 1.0

return low_res

def dataset(self, batch_size):
"""
Returns a tf dataset object with specified mappings.
Args:
batch_size: Int, The number of elements in a batch returned by the dataset.
threads: Int, CPU threads to use for multi-threaded operation.
Returns:
dataset: A tf dataset object.
"""

# Generate tf dataset from high res image paths.
dataset = tf.data.Dataset.from_tensor_slices(self.image_paths)

# Read the images
dataset = dataset.map(self._parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

# Rescale the values in the input
dataset = dataset.map(self._rescale, num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset = dataset.shuffle(10)
# Batch the input, drop remainder to get a defined batch size.
# Prefetch the data for optimal GPU utilization.
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

return dataset

0 comments on commit 9e269e3

Please sign in to comment.
You can’t perform that action at this time.