Skip to content

Commit

Permalink
GAN dataset modif
Browse files Browse the repository at this point in the history
  • Loading branch information
ColasGael committed Nov 26, 2018
1 parent 09d4c93 commit 89e27cf
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 592 deletions.
37 changes: 25 additions & 12 deletions Colorizing-with-GANs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

CIFAR10_DATASET = 'cifar10'
PLACES365_DATASET = 'places365'

MOMENTSINTIME_DATASET = 'Moments_in_Time_Mini'

class BaseDataset():
def __init__(self, name, path, training=True, augment=True):
Expand All @@ -34,6 +34,7 @@ def __iter__(self):
def __getitem__(self, index):
val = self.data[index]
try:
# OLD : img = imread(val) if isinstance(val, str) else val
img = np.load(val) if isinstance(val, str) else val

if self.augment and np.random.binomial(1, 0.5) == 1:
Expand All @@ -44,7 +45,7 @@ def __getitem__(self, index):

return img

def generator(self, batch_size, recusrive=False):
def generator(self, batch_size, recursive=False):
start = 0
total = len(self)

Expand All @@ -61,7 +62,7 @@ def generator(self, batch_size, recusrive=False):
start = end
yield np.array(items)

if recusrive:
if recursive:
start = 0

else:
Expand Down Expand Up @@ -116,15 +117,27 @@ def __init__(self, path, training=True, augment=True):

def load(self):
if self.training:
#data = np.array(
# glob.glob(self.path + '/data_256/**/*.jpg', recursive=True))
#data = np.array(
#glob.glob("C:\\Users\\rafae\\Desktop\\Classes\\CS 230 - Stanford\\Colorizing-with-GANs\\dataset\\places365\\data_256\*.jpg"))
#data = np.array(glob.glob("C:\\Users\\rafae\\Desktop\\Classes\\CS 230 - Stanford\\frames\\*"))
data = np.array(glob.glob("/home/ubuntu/Automatic-Video-Colorization/data/Moments_in_Time_Mini/training/frames2/*"))
data = np.array(
glob.glob(self.path + '/data_256/**/*.jpg', recursive=True))

else:
#data = np.array(glob.glob(self.path + '/val_256/*.jpg'))
#data = np.array(glob.glob("C:\\Users\\rafae\\Desktop\\Classes\\CS 230 - Stanford\\Colorizing-with-GANs\\dataset\\places365\\val_256\*.jpg"))
data = np.array(glob.glob("/home/ubuntu/Automatic-Video-Colorization/data/Moments_in_Time_Mini/training/frames/*"))
data = np.array(glob.glob(self.path + '/val_256/*.jpg'))

return data


class MomentsInTimeDataset(BaseDataset):
def __init__(self, path, training=True, augment=True):
super(MomentsInTimeDataset, self).__init__(MOMENTSINTIME_DATASET, path, training, augment)

def load(self):
if self.training:
data = np.array(
glob.glob(self.path + '/training/frames2/*', recursive=True))

#data = np.array(glob.glob("/home/ubuntu/Automatic-Video-Colorization/data/Moments_in_Time_Mini/training/frames2/*"))
else:
data = np.array(glob.glob(self.path + '/training/frames1/*', recursive=True))
#data = np.array(glob.glob("/home/ubuntu/Automatic-Video-Colorization/data/Moments_in_Time_Mini/training/frames/*"))

return data
130 changes: 0 additions & 130 deletions Colorizing-with-GANs/dataset.py~

This file was deleted.

5 changes: 4 additions & 1 deletion Colorizing-with-GANs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
from options import ModelOptions
from models import Cifar10Model, Places365Model
from dataset import CIFAR10_DATASET, PLACES365_DATASET
from dataset import CIFAR10_DATASET, PLACES365_DATASET, MOMENTSINTIME_DATASET


def main(options):
Expand All @@ -27,6 +27,9 @@ def main(options):

elif options.dataset == PLACES365_DATASET:
model = Places365Model(sess, options)

elif options.dataset == MOMENTSINTIME_DATASET:
model = MomentsInTimeModel(sess, options)

if not os.path.exists(options.checkpoints_path):
os.makedirs(options.checkpoints_path)
Expand Down
74 changes: 62 additions & 12 deletions Colorizing-with-GANs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ def train(self):

self.epoch = epoch + 1
self.iteration = 0


generator = self.dataset_train.generator(self.options.batch_size)
progbar = keras.utils.Progbar(total, stateful_metrics=['epoch', 'iteration', 'step'])

for input_rgb in generator:

feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_t1: input_rgb[:,:,:,3:6]}
# OLD : feed_dict = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_prev: input_rgb[:,:,:,3:6]}

self.iteration = self.iteration + 1
self.sess.run([self.dis_train], feed_dict=feed_dic)
Expand Down Expand Up @@ -104,8 +103,8 @@ def evaluate(self):
result = []

for input_rgb in test_generator:
#feed_dic = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_t1: input_rgb[:,:,:,3:6]}
# OLD : feed_dic = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_prev: input_rgb[:,:,:,3:6]}

self.sess.run([self.dis_loss, self.gen_loss, self.accuracy], feed_dict=feed_dic)

Expand All @@ -129,12 +128,15 @@ def sample(self, show=True):
self.build()

input_rgb = next(self.sample_generator)
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_t1: input_rgb[:,:,:,3:6]}
#feed_dic = {self.input_rgb: input_rgb}

# OLD : feed_dic = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_prev: input_rgb[:,:,:,3:6]}

step, rate = self.sess.run([self.global_step, self.learning_rate])
fake_image, input_gray = self.sess.run([self.sampler, self.input_gray], feed_dict=feed_dic)
fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB)

# OLD : img = stitch_images(input_gray, input_rgb, fake_image.eval())
img = stitch_images(input_gray, input_rgb[:,:,:,3:6], fake_image.eval())

if not os.path.exists(self.samples_dir):
Expand All @@ -156,8 +158,9 @@ def turing_test(self):

while count < self.options.test_size:
input_rgb = next(gen)
#feed_dic = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_t1: input_rgb[:,:,:,3:6]}

# OLD : feed_dic = {self.input_rgb: input_rgb}
feed_dic = {self.input_rgb: input_rgb[:,:,:,0:3], self.input_rgb_prev: input_rgb[:,:,:,3:6]}
fake_image = self.sess.run(self.sampler, feed_dict=feed_dic)
fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB)

Expand All @@ -180,11 +183,11 @@ def build(self):
kernel = self.options.kernel_size

self.input_rgb = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb')
self.input_rgb_t1 = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb_t1')
self.input_rgb_prev = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb_prev')

self.input_gray = tf.image.rgb_to_grayscale(self.input_rgb)
self.input_color = preprocess(self.input_rgb, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space)
self.input_color_t1 = preprocess(self.input_rgb_t1, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space)
self.input_color_t1 = preprocess(self.input_rgb_prev, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space)

gen = gen_factory.create(tf.concat([self.input_gray, self.input_color_t1],3), kernel, seed)
dis_real = dis_factory.create(tf.concat([self.input_color, self.input_color_t1], 3), kernel, seed)
Expand All @@ -199,7 +202,7 @@ def build(self):
self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce)

self.gen_loss_gan = tf.reduce_mean(gen_ce)
#self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_color - gen)) * self.options.l1_weight
# OLD : self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_color - gen)) * self.options.l1_weight
self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_gray - tf.image.rgb_to_grayscale(gen))) * self.options.l1_weight

self.gen_loss = self.gen_loss_l1 #self.gen_loss_gan + self.gen_loss_l1
Expand Down Expand Up @@ -360,3 +363,50 @@ def create_dataset(self, training=True):
path=self.options.dataset_path,
training=training,
augment=self.options.augment)

class MomentsInTimeModel(BaseModel):
def __init__(self, sess, options):
super(MomentsInTimeModel, self).__init__(sess, options)

def create_generator(self):
kernels_gen_encoder = [
(64, 1, 0), # [batch, 256, 256, ch] => [batch, 256, 256, 64]
(64, 2, 0), # [batch, 256, 256, 64] => [batch, 128, 128, 64]
(128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128]
(256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256]
(512, 2, 0), # [batch, 32, 32, 256] => [batch, 16, 16, 512]
(512, 2, 0), # [batch, 16, 16, 512] => [batch, 8, 8, 512]
(512, 2, 0), # [batch, 8, 8, 512] => [batch, 4, 4, 512]
(512, 2, 0) # [batch, 4, 4, 512] => [batch, 2, 2, 512]
]

kernels_gen_decoder = [
(512, 2, 0.5), # [batch, 2, 2, 512] => [batch, 4, 4, 512]
(512, 2, 0.5), # [batch, 4, 4, 512] => [batch, 8, 8, 512]
(512, 2, 0.5), # [batch, 8, 8, 512] => [batch, 16, 16, 512]
(256, 2, 0), # [batch, 16, 16, 512] => [batch, 32, 32, 256]
(128, 2, 0), # [batch, 32, 32, 256] => [batch, 64, 64, 128]
(64, 2, 0), # [batch, 64, 64, 128] => [batch, 128, 128, 64]
(64, 2, 0) # [batch, 128, 128, 64] => [batch, 256, 256, 64]
]

return Generator('gen', kernels_gen_encoder, kernels_gen_decoder)

def create_discriminator(self):
kernels_dis = [
(64, 2, 0), # [batch, 256, 256, ch] => [batch, 128, 128, 64]
(128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128]
(256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256]
(512, 2, 0), # [batch, 32, 32, 256] => [batch, 16, 16, 512]
(512, 2, 0), # [batch, 16, 16, 512] => [batch, 8, 8, 512]
(512, 2, 0), # [batch, 8, 8, 512] => [batch, 4, 4, 512]
(512, 1, 0), # [batch, 4, 4, 512] => [batch, 4, 4, 512]
]

return Discriminator('dis', kernels_dis)

def create_dataset(self, training=True):
return MomentsInTimeDataset(
path=self.options.dataset_path,
training=training,
augment=self.options.augment)

0 comments on commit 89e27cf

Please sign in to comment.