Skip to content
Browse files

Removed infer pipeline, not needed

  • Loading branch information
HasnainRaz committed Oct 28, 2019
1 parent eeb9183 commit 9e269e39a6b843787a627786db3becaccab6f7ec
Showing with 0 additions and 72 deletions.
  1. +0 −72
@@ -121,75 +121,3 @@ def dataset(self, batch_size, threads=4):
dataset = dataset.shuffle(30).batch(batch_size, drop_remainder=True).prefetch(

return dataset

class DataLoaderInfer(object):
"""Data Loader for the SR GAN, that prepares a tf data object for training."""

def __init__(self, image_dir):
Initializes the dataloader.
image_dir: The path to the directory containing high resolution images.
hr_image_size: Integer, the crop size of the images to train on (High
resolution images will be cropped to this width and height).
The dataloader object.
self.image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]

def _parse_image(self, image_path):
Function that loads the images given the path.
image_path: Path to an image file.
image: A tf tensor of the loaded image.

image =
image = tf.image.decode_png(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)

return image

def _rescale(self, low_res):
Function that rescales the pixel values to the -1 to 1 range.
For use with the generator output tanh function.
low_res: The tf tensor of the low res image.
high_res: The tf tensor of the high res image.
low_res: The tf tensor of the low res image, rescaled.
high_res: the tf tensor of the high res image, rescaled.
low_res = low_res * 2.0 - 1.0

return low_res

def dataset(self, batch_size):
Returns a tf dataset object with specified mappings.
batch_size: Int, The number of elements in a batch returned by the dataset.
threads: Int, CPU threads to use for multi-threaded operation.
dataset: A tf dataset object.

# Generate tf dataset from high res image paths.
dataset =

# Read the images
dataset =,

# Rescale the values in the input
dataset =,

dataset = dataset.shuffle(10)
# Batch the input, drop remainder to get a defined batch size.
# Prefetch the data for optimal GPU utilization.
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(

return dataset

0 comments on commit 9e269e3

Please sign in to comment.
You can’t perform that action at this time.