Skip to content

Commit

Permalink
train_ali can now train on other fuel datasets
Browse files Browse the repository at this point in the history
Used fuel_helper to add dataset reading options
to train_ali, including:

  --dataset
  --color-convert
  --batch-size
  --monitor-every
  --checkpoint-every

Note that image-size was also added, but that is not
yet implemented and network still assumes image-size
is 64x64.
  • Loading branch information
dribnet committed Aug 14, 2016
1 parent a9252d4 commit 2fa2727
Showing 1 changed file with 51 additions and 16 deletions.
67 changes: 51 additions & 16 deletions experiments/train_ali.py
Expand Up @@ -25,9 +25,8 @@

from wrapper.interface import AliModel
from utils.samplecheckpoint import SampleCheckpoint
from utils.fuel_helper import create_custom_streams

BATCH_SIZE = 100
MONITORING_BATCH_SIZE = 500
NUM_EPOCHS = 123
IMAGE_SIZE = (64, 64)
NUM_CHANNELS = 3
Expand All @@ -39,7 +38,7 @@
LEAK = 0.02


def create_model_brick():
def create_model_brick(model_stream):
layers = [
conv_brick(2, 1, 64), bn_brick(), LeakyRectifier(leak=LEAK),
conv_brick(7, 2, 128), bn_brick(), LeakyRectifier(leak=LEAK),
Expand Down Expand Up @@ -109,16 +108,15 @@ def create_model_brick():
x_discriminator.layers[0].use_bias = True
x_discriminator.layers[0].tied_biases = True
ali.initialize()
raw_marginals, = next(
create_celeba_data_streams(500, 500)[0].get_epoch_iterator())
raw_marginals, = next(model_stream.get_epoch_iterator())
b_value = get_log_odds(raw_marginals)
decoder_mapping.layers[-2].b.set_value(b_value)

return ali


def create_models():
ali = create_model_brick()
def create_models(model_stream):
ali = create_model_brick(model_stream)
x = tensor.tensor4('features')
z = ali.theano_rng.normal(size=(x.shape[0], NLAT, 1, 1))

Expand All @@ -145,8 +143,27 @@ def _create_model(with_dropout):
return model, bn_model, bn_updates


def create_main_loop(save_path, subdir):
model, bn_model, bn_updates = create_models()
def create_main_loop(save_path, subdir, dataset, color_convert,
batch_size, monitor_every, checkpoint_every, image_size):

if dataset is None:
streams = create_celeba_data_streams(batch_size, batch_size)
model_stream = create_celeba_data_streams(500, 500)[0]
else:
streams = create_custom_streams(filename=dataset,
training_batch_size=batch_size,
monitoring_batch_size=batch_size,
include_targets=False,
color_convert=color_convert)
model_stream = create_custom_streams(filename=dataset,
training_batch_size=500,
monitoring_batch_size=500,
include_targets=False,
color_convert=color_convert)[0]

main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

model, bn_model, bn_updates = create_models(model_stream)
ali, = bn_model.top_bricks
discriminator_loss, generator_loss = bn_model.outputs

Expand All @@ -155,8 +172,7 @@ def create_main_loop(save_path, subdir):
step_rule, generator_loss,
ali.generator_parameters, step_rule)
algorithm.add_updates(bn_updates)
streams = create_celeba_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE)
main_loop_stream, train_monitor_stream, valid_monitor_stream = streams

bn_monitored_variables = (
[v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
bn_model.outputs)
Expand All @@ -168,11 +184,14 @@ def create_main_loop(save_path, subdir):
FinishAfter(after_n_epochs=NUM_EPOCHS),
DataStreamMonitoring(
bn_monitored_variables, train_monitor_stream, prefix="train",
updates=bn_updates),
updates=bn_updates, before_first_epoch=True,
every_n_epochs=monitor_every),
DataStreamMonitoring(
monitored_variables, valid_monitor_stream, prefix="valid"),
Checkpoint(save_path, after_epoch=True, after_training=True,
use_cpickle=True),
monitored_variables, valid_monitor_stream, prefix="valid",
before_first_epoch=False, every_n_epochs=monitor_every),
Checkpoint(save_path, every_n_epochs=checkpoint_every,
before_training=True, after_epoch=True, after_training=True,
use_cpickle=True),
SampleCheckpoint(interface=AliModel, z_dim=NLAT, image_size=IMAGE_SIZE, channels=NUM_CHANNELS, dataset="celeba_64", split="valid", save_subdir=subdir, before_training=True, after_epoch=True),
ProgressBar(),
Printing(),
Expand All @@ -189,5 +208,21 @@ def create_main_loop(save_path, subdir):
default="ali_celeba.zip", help="Model to save")
parser.add_argument("--subdir", dest='subdir', type=str, default="output",
help="Subdirectory for output files (images)")
parser.add_argument('--dataset', dest='dataset', default=None,
help="Dataset for training.")
parser.add_argument("--image-size", dest='image_size', type=int, default=64,
help="size of (offset) images")
parser.add_argument('--color-convert', dest='color_convert',
default=False, action='store_true',
help="Convert source dataset to color from grayscale.")
parser.add_argument("--batch-size", type=int, dest="batch_size",
default=100, help="Size of each mini-batch")
parser.add_argument("--monitor-every", type=int, dest="monitor_every",
default=4, help="Frequency in epochs for monitoring")
parser.add_argument("--checkpoint-every", type=int,
dest="checkpoint_every", default=1,
help="Frequency in epochs for checkpointing")
args = parser.parse_args()
create_main_loop(args.model, args.subdir).run()
create_main_loop(args.model, args.subdir, args.dataset,
args.color_convert, args.batch_size, args.monitor_every,
args.checkpoint_every, args.image_size).run()

0 comments on commit 2fa2727

Please sign in to comment.