Skip to content

Commit

Permalink
train_ali: added cli for z-dim, num-epochs, and splits
Browse files Browse the repository at this point in the history
  • Loading branch information
dribnet committed Aug 14, 2016
1 parent 2fa2727 commit 67f2c7d
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions experiments/train_ali.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,24 @@
from utils.samplecheckpoint import SampleCheckpoint
from utils.fuel_helper import create_custom_streams

NUM_EPOCHS = 123
IMAGE_SIZE = (64, 64)
NUM_CHANNELS = 3
NLAT = 256
GAUSSIAN_INIT = IsotropicGaussian(std=0.01)
ZERO_INIT = Constant(0)
LEARNING_RATE = 1e-4
BETA1 = 0.5
LEAK = 0.02


def create_model_brick(model_stream):
def create_model_brick(model_stream, image_size, z_dim):
layers = [
conv_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(5, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(1, 1, 2 * NLAT)]
conv_brick(1, 1, 2 * z_dim)]
encoder_mapping = ConvolutionalSequence(
layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE,
layers=layers, num_channels=NUM_CHANNELS, image_size=(image_size, image_size),
use_bias=False, name='encoder_mapping')
encoder = GaussianConditional(encoder_mapping, name='encoder')

Expand All @@ -59,7 +56,7 @@ def create_model_brick(model_stream):
conv_transpose_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(1, 1, NUM_CHANNELS), Logistic()]
decoder_mapping = ConvolutionalSequence(
layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False,
layers=layers, num_channels=z_dim, image_size=(1, 1), use_bias=False,
name='decoder_mapping')
decoder = DeterministicConditional(decoder_mapping, name='decoder')

Expand All @@ -70,15 +67,15 @@ def create_model_brick(model_stream):
conv_brick(7, 2, 256), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(4, 1, 512), bn_brick(), LeakyRectifier(leak=LEAK)]
x_discriminator = ConvolutionalSequence(
layers=layers, num_channels=NUM_CHANNELS, image_size=IMAGE_SIZE,
layers=layers, num_channels=NUM_CHANNELS, image_size=(image_size, image_size),
use_bias=False, name='x_discriminator')
x_discriminator.push_allocation_config()

layers = [
conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK),
conv_brick(1, 1, 1024), LeakyRectifier(leak=LEAK)]
z_discriminator = ConvolutionalSequence(
layers=layers, num_channels=NLAT, image_size=(1, 1), use_bias=False,
layers=layers, num_channels=z_dim, image_size=(1, 1), use_bias=False,
name='z_discriminator')
z_discriminator.push_allocation_config()

Expand Down Expand Up @@ -115,10 +112,10 @@ def create_model_brick(model_stream):
return ali


def create_models(model_stream):
ali = create_model_brick(model_stream)
def create_models(model_stream, image_size, z_dim):
ali = create_model_brick(model_stream, image_size, z_dim)
x = tensor.tensor4('features')
z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1))
z = ali.theano_rng.normal(size=(x.shape[0], z_dim, 1, 1))

def _create_model(with_dropout):
cg = ComputationGraph(ali.compute_losses(x, z))
Expand All @@ -143,8 +140,9 @@ def _create_model(with_dropout):
return model, bn_model, bn_updates


def create_main_loop(save_path, subdir, dataset, color_convert,
batch_size, monitor_every, checkpoint_every, image_size):
def create_main_loop(save_path, subdir, dataset, splits, color_convert,
batch_size, monitor_every, checkpoint_every, num_epochs,
image_size, z_dim):

if dataset is None:
streams = create_celeba_data_streams(batch_size, batch_size)
Expand All @@ -154,16 +152,18 @@ def create_main_loop(save_path, subdir, dataset, color_convert,
training_batch_size=batch_size,
monitoring_batch_size=batch_size,
include_targets=False,
color_convert=color_convert)
color_convert=color_convert,
split_names=splits)
model_stream = create_custom_streams(filename=dataset,
training_batch_size=500,
monitoring_batch_size=500,
include_targets=False,
color_convert=color_convert)[0]
color_convert=color_convert,
split_names=splits)[0]

main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

model, bn_model, bn_updates = create_models(model_stream)
model, bn_model, bn_updates = create_models(model_stream, image_size, z_dim)
ali, = bn_model.top_bricks
discriminator_loss, generator_loss = bn_model.outputs

Expand All @@ -181,7 +181,7 @@ def create_main_loop(save_path, subdir, dataset, color_convert,
model.outputs)
extensions = [
Timing(),
FinishAfter(after_n_epochs=NUM_EPOCHS),
FinishAfter(after_n_epochs=num_epochs),
DataStreamMonitoring(
bn_monitored_variables, train_monitor_stream, prefix="train",
updates=bn_updates, before_first_epoch=True,
Expand All @@ -192,7 +192,7 @@ def create_main_loop(save_path, subdir, dataset, color_convert,
Checkpoint(save_path, every_n_epochs=checkpoint_every,
before_training=True, after_epoch=True, after_training=True,
use_cpickle=True),
SampleCheckpoint(interface=AliModel, z_dim=NLAT, image_size=IMAGE_SIZE, channels=NUM_CHANNELS, dataset="celeba_64", split="valid", save_subdir=subdir, before_training=True, after_epoch=True),
SampleCheckpoint(interface=AliModel, z_dim=z_dim, image_size=(image_size, image_size), channels=NUM_CHANNELS, dataset=dataset, split=splits[1], save_subdir=subdir, before_training=True, after_epoch=True),
ProgressBar(),
Printing(),
]
Expand All @@ -210,6 +210,8 @@ def create_main_loop(save_path, subdir, dataset, color_convert,
help="Subdirectory for output files (images)")
parser.add_argument('--dataset', dest='dataset', default=None,
help="Dataset for training.")
parser.add_argument('--splits', dest='splits', default="train,valid,test",
help="train/valid/test dataset split names")
parser.add_argument("--image-size", dest='image_size', type=int, default=64,
help="size of (offset) images")
parser.add_argument('--color-convert', dest='color_convert',
Expand All @@ -222,7 +224,13 @@ def create_main_loop(save_path, subdir, dataset, color_convert,
parser.add_argument("--checkpoint-every", type=int,
dest="checkpoint_every", default=1,
help="Frequency in epochs for checkpointing")
parser.add_argument("--num-epochs", type=int, dest="num_epochs",
default=123, help="Stop training after num-epochs.")
parser.add_argument("--z-dim", type=int, dest="z_dim",
default=256, help="Z-vector dimension")
args = parser.parse_args()
create_main_loop(args.model, args.subdir, args.dataset,
splits = args.splits.split(",")
create_main_loop(args.model, args.subdir, args.dataset, splits,
args.color_convert, args.batch_size, args.monitor_every,
args.checkpoint_every, args.image_size).run()
args.checkpoint_every, args.num_epochs, args.image_size,
args.z_dim).run()

0 comments on commit 67f2c7d

Please sign in to comment.