Skip to content

Commit

Permalink
updates to support batch_size != 64
Browse files Browse the repository at this point in the history
The code does not currently support sample_size != batch_size,
so dropped sample_size as a paramater to the model constructor.

Also to suppor this, save_images was updated to clip the number
of images saved at rows * cols.

Also - the validation inputs are now also saved at their native
size. This file is called inputs_small.png.
  • Loading branch information
dribnet committed Oct 8, 2016
1 parent 0852d4b commit 14ff4eb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions model.py
Expand Up @@ -17,7 +17,7 @@ def doresize(x, shape):

class DCGAN(object):
def __init__(self, sess, image_size=128, is_crop=True,
batch_size=64, sample_size = 64, image_shape=[128, 128, 3],
batch_size=64, image_shape=[128, 128, 3],
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
checkpoint_dir=None):
Expand All @@ -39,7 +39,7 @@ def __init__(self, sess, image_size=128, is_crop=True,
self.batch_size = batch_size
self.image_size = image_size
self.input_size = 32
self.sample_size = sample_size
self.sample_size = batch_size
self.image_shape = image_shape

self.y_dim = y_dim
Expand Down Expand Up @@ -105,6 +105,7 @@ def train(self, config):
sample_images = np.array(sample).astype(np.float32)
sample_input_images = np.array(sample_inputs).astype(np.float32)

save_images(sample_input_images, [8, 8], './samples/inputs_small.png')
save_images(sample_images, [8, 8], './samples/reference.png')

counter = 1
Expand Down
3 changes: 2 additions & 1 deletion utils.py 100644 → 100755
Expand Up @@ -18,7 +18,8 @@ def get_image(image_path, image_size, is_crop=True):
return transform(imread(image_path), image_size, is_crop)

def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
num_im = size[0] * size[1]
return imsave(inverse_transform(images[:num_im]), size, image_path)

def imread(path):
return scipy.misc.imread(path).astype(np.float)
Expand Down

0 comments on commit 14ff4eb

Please sign in to comment.