zi2zi(字到字, 意思是将一种文字转化为另一种文字) 使用 GAN 学习东亚文字的字体,是最近流行的 pix2pix 模型在汉字中的应用和扩展。
详细内容可以在这篇博客中找到。
这个神经网络结构基于 pix2pix 。同时添加了 category embedding 和两种 loss —— category loss 和 constant loss ,分别来自 AC-GAN 和 DTN。
The network structure is based off pix2pix with the addition of category embedding and two other losses, category loss and constant loss, from AC-GAN and DTN respectively.
After sufficient training, d_loss will drop to near zero, and the model's performance plateaued. Label Shuffling mitigate this problem by presenting new challenges to the model.
Specifically, within a given minibatch, for the same set of source characters, we generate two sets of target characters: one with correct embedding labels, the other with the shuffled labels. The shuffled set likely will not have the corresponding target images to compute L1_Loss, but can be used as a good source for all other losses, forcing the model to further generalize beyond the limited set of provided examples. Empirically, label shuffling improves the model's generalization on unseen data with better details, and decrease the required number of characters.
You can enable label shuffling by setting flip_labels=1 option in train.py script. It is recommended that you enable this after d_loss flatlines around zero, for further tuning.
Download tons of fonts as you please
- Python 2.7
- CUDA
- cudnn
- Tensorflow >= 1.0.1
- Pillow(PIL)
- numpy >= 1.12.1
- scipy >= 0.18.1
- imageio
To avoid IO bottleneck, preprocessing is necessary to pickle your data into binary and persist in memory during training.
First run the below command to get the font images:
python font2img.py --src_font=src.ttf
--dst_font=tgt.otf
--charset=CN
--sample_count=1000
--sample_dir=dir
--label=0
--filter=1
--shuffle=1
Four default charsets are offered: CN, CN_T(traditional), JP, KR. You can also point it to a one line file, it will generate the images of the characters in it. Note, filter option is highly recommended, it will pre sample some characters and filter all the images that have the same hash, usually indicating that character is missing. label indicating index in the category embeddings that this font associated with, default to 0.
After obtaining all images, run package.py to pickle the images and their corresponding labels into binary format:
python package.py --dir=image_directories
--save_dir=binary_save_directory
--split_ratio=[0,1]
After running this, you will find two objects train.obj and val.obj under the save_dir for training and validation, respectively.
experiment/
└── data
├── train.obj
└── val.obj
Create a experiment directory under the root of the project, and a data directory within it to place the two binaries. Assuming a directory layout enforce bettet data isolation, especially if you have multiple experiments running.
To start training run the following command
python train.py --experiment_dir=experiment
--experiment_id=0
--batch_size=16
--lr=0.001
--epoch=40
--sample_steps=50
--schedule=20
--L1_penalty=100
--Lconst_penalty=15
schedule here means in between how many epochs, the learning rate will decay by half. The train command will create sample,logs,checkpoint directory under experiment_dir if non-existed, where you can check and manage the progress of your training.
After training is done, run the below command to infer test data:
python infer.py --model_dir=checkpoint_dir/
--batch_size=16
--source_obj=binary_obj_path
--embedding_ids=label[s] of the font, separate by comma
--save_dir=save_dir/
Also you can do interpolation with this command:
python infer.py --model_dir= checkpoint_dir/
--batch_size=10
--source_obj=obj_path
--embedding_ids=label[s] of the font, separate by comma
--save_dir=frames/
--output_gif=gif_path
--interpolate=1
--steps=10
--uroboros=1
It will run through all the pairs of fonts specified in embedding_ids and interpolate the number of steps as specified.
Pretained model can be downloaded here which is trained with 27 fonts, only generator is saved to reduce the model size. You can use encoder in the this pretrained model to accelerate the training process.
Code derived and rehashed from:
- pix2pix-tensorflow by yenchenlin
- Domain Transfer Network by yunjey
- ac-gan by buriburisuri
- dc-gan by carpedm20
- origianl pix2pix torch code by phillipi
Apache 2.0
翻译者 :Wu Junkai
这篇README是我在未经许可的情况下翻译的,原文在这里。
翻译者还仅是个学生,文笔拙劣。部分翻译为节省工作量,使用机翻。若有翻译错误或翻译不到位的地方,请谅解。
图片使用jsDelivr进行加速显示。