Skip to content

Commit

Permalink
add hd-celeba option
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnHo committed Jun 5, 2018
1 parent d77414c commit ff15dbe
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 26 deletions.
21 changes: 21 additions & 0 deletions README.md
Expand Up @@ -42,6 +42,10 @@
- https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FAnno&parentPath=%2F or
- https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs

- [HD-Celeba](https://github.com/LynnHo/HD-CelebA-Cropper) (optional)
- the images of ***img_align_celeba.zip*** are low resolution and uncropped, higher resolution and cropped images are available [here](https://github.com/LynnHo/HD-CelebA-Cropper)
- the high quality data should be placed in ***./data/img_crop_celeba/\*.jpg***

- Examples of training
- see [examples.md](./examples.md) for more examples

Expand Down Expand Up @@ -74,6 +78,23 @@
--experiment_name 384_shortcut1_inject1_none
```

- for 384x384 HD images (need [HD-Celeba](https://github.com/LynnHo/HD-CelebA-Cropper))

```console
CUDA_VISIBLE_DEVICES=0 \
python train.py \
--img_size 384 \
--enc_dim 48 \
--dec_dim 48 \
--dis_dim 48 \
--dis_fc_dim 512 \
--shortcut_layers 1 \
--inject_layers 1 \
--n_sample 24 \
--experiment_name 384_shortcut1_inject1_none \
--use_cropped_img
```

- tensorboard for loss visualization

```console
Expand Down
24 changes: 16 additions & 8 deletions data.py
Expand Up @@ -154,14 +154,22 @@ class Celeba(Dataset):
'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}

def __init__(self, data_dir, atts, img_resize, batch_size, prefetch_batch=2, drop_remainder=True,
num_threads=16, shuffle=True, buffer_size=4096, repeat=-1, sess=None, part='train'):
num_threads=16, shuffle=True, buffer_size=4096, repeat=-1, sess=None, part='train', crop=True):
super(Celeba, self).__init__()

list_file = os.path.join(data_dir, 'list_attr_celeba.txt')
img_dir = os.path.join(data_dir, 'img_align_celeba')
if crop:
img_dir_jpg = os.path.join(data_dir, 'img_align_celeba')
img_dir_png = os.path.join(data_dir, 'img_align_celeba_png')
else:
img_dir_jpg = os.path.join(data_dir, 'img_crop_celeba')
img_dir_png = os.path.join(data_dir, 'img_crop_celeba_png')

names = np.loadtxt(list_file, skiprows=2, usecols=[0], dtype=np.str)
img_paths = [os.path.join(img_dir, name) for name in names]
if os.path.exists(img_dir_png):
img_paths = [os.path.join(img_dir_png, name.replace('jpg', 'png')) for name in names]
elif os.path.exists(img_dir_jpg):
img_paths = [os.path.join(img_dir_jpg, name) for name in names]

att_id = [Celeba.att_dict[att] + 1 for att in atts]
labels = np.loadtxt(list_file, skiprows=2, usecols=att_id, dtype=np.int64)
Expand All @@ -177,12 +185,12 @@ def __init__(self, data_dir, atts, img_resize, batch_size, prefetch_batch=2, dro
img_size = 170

def _map_func(img, label):
img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, img_size, img_size)
img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1
if crop:
img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, img_size, img_size)
# img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1
# or
# img = tf.image.resize_images(img, [img_resize, img_resize], tf.image.ResizeMethod.BICUBIC)
# img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
# img = img * 2 - 1
img = tf.image.resize_images(img, [img_resize, img_resize], tf.image.ResizeMethod.BICUBIC)
img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
label = (label + 1) // 2
return img, label

Expand Down
3 changes: 2 additions & 1 deletion test.py
Expand Up @@ -46,6 +46,7 @@
thres_int = args['thres_int']
test_int = args_.test_int
# others
use_cropped_img = args['use_cropped_img']
experiment_name = args_.experiment_name


Expand All @@ -55,7 +56,7 @@

# data
sess = tl.session()
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess)
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess, crop=not use_cropped_img)

# models
Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
Expand Down
3 changes: 2 additions & 1 deletion test_multi.py
Expand Up @@ -48,6 +48,7 @@
thres_int = args['thres_int']
test_ints = args_.test_ints
# others
use_cropped_img = args['use_cropped_img']
experiment_name = args_.experiment_name

assert test_atts is not None, 'test_atts should be chosen in %s' % (str(atts))
Expand All @@ -63,7 +64,7 @@

# data
sess = tl.session()
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess)
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess, crop=not use_cropped_img)

# models
Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
Expand Down
3 changes: 2 additions & 1 deletion test_slide.py
Expand Up @@ -52,6 +52,7 @@
test_int_max = args_.test_int_max
n_slide = args_.n_slide
# others
use_cropped_img = args['use_cropped_img']
experiment_name = args_.experiment_name

assert test_att is not None, 'test_att should be chosen in %s' % (str(atts))
Expand All @@ -63,7 +64,7 @@

# data
sess = tl.session()
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess)
te_data = data.Celeba('./data', atts, img_size, 1, part='test', sess=sess, crop=not use_cropped_img)

# models
Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
Expand Down
32 changes: 17 additions & 15 deletions train.py
Expand Up @@ -28,26 +28,27 @@
att_default = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young']
parser.add_argument('--atts', dest='atts', default=att_default, choices=data.Celeba.att_dict.keys(), nargs='+', help='attributes to learn')
parser.add_argument('--img_size', dest='img_size', type=int, default=128, help='image size')
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1, help='shortcut_layers')
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=0, help='inject_layers')
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64, help='enc_dim')
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64, help='dec_dim')
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64, help='dis_dim')
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024, help='dis_fc_dim')
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5, help='enc_layers')
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5, help='dec_layers')
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5, help='dis_layers')
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=0)
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64)
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64)
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64)
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024)
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5)
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5)
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5)
# training
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epochs')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32, help='batch size')
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--n_d', dest='n_d', type=int, default=5, help='# of d updates per g update')
parser.add_argument('--b_distribution', dest='b_distribution', default='none', choices=['none', 'uniform', 'truncated_normal'], help='b_distribution')
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5, help='thres_int')
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0, help='test_int')
parser.add_argument('--b_distribution', dest='b_distribution', default='none', choices=['none', 'uniform', 'truncated_normal'])
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--n_sample', dest='n_sample', type=int, default=64, help='# of sample images')
# others
parser.add_argument('--experiment_name', dest='experiment_name', default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"), help='experiment_name')
parser.add_argument('--use_cropped_img', dest='use_cropped_img', action='store_true')
parser.add_argument('--experiment_name', dest='experiment_name', default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))

args = parser.parse_args()
# model
Expand All @@ -73,6 +74,7 @@
test_int = args.test_int
n_sample = args.n_sample
# others
use_cropped_img = args.use_cropped_img
experiment_name = args.experiment_name

pylib.mkdir('./output/%s' % experiment_name)
Expand All @@ -86,8 +88,8 @@

# data
sess = tl.session()
tr_data = data.Celeba('./data', atts, img_size, batch_size, part='train', sess=sess)
val_data = data.Celeba('./data', atts, img_size, n_sample, part='val', shuffle=False, sess=sess)
tr_data = data.Celeba('./data', atts, img_size, batch_size, part='train', sess=sess, crop=not use_cropped_img)
val_data = data.Celeba('./data', atts, img_size, n_sample, part='val', shuffle=False, sess=sess, crop=not use_cropped_img)

# models
Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
Expand Down

0 comments on commit ff15dbe

Please sign in to comment.