图像修复
- Python3.6
- Tensorflow1.4或者更高的版本,除了tensorflow2.0
- Opencv
- numpy
- scipy
- easydict
对于给定的数据集,训练阶段分为两部分。首先使用confidence-driven reconstruction
损失预训练整个网络,然后在前一阶段收敛后使用adversarial
和ID-MRF loss
进行训练。
- pretrain
python train.py --dataset [DATASET_NAME] --data_file [DATASET_TRAININGFILE] --gpu_ids [NUM] --pretrain_network 1 --batch_size 16
- finetune
python train.py --dataset [DATASET_NAME] --data_file [DATASET_TRAININGFILE] --gpu_ids [NUM] --pretrain_network 0 --load_model_dir [PRETRAINED_MODEL_PATH] --batch_size 8
参数说明:
- DATASET_TRAININGFILE:训练集文件,包含所有的训练图像的地址,可以使txt文件
- mask_type:掩膜类型
- 下载预训练模型:paris_streetview,CelebA-HQ_256,CelebA-HQ_512,Places2。注意一点是这是谷歌云盘,需要科学上网。
- 解压预训练模型到
./checkpoints
下,当然也可以选择其他文件夹 - 调用
test.py
文件,然后设定--dataset_path
和--load_model_dir
后即可:
python test.py --dataset paris_streetview --data_file ./imgs/paris-streetview_256x256/ --load_model_dir ./checkpoints/paris-streetview_256x256_rect --random_mask 0
- 修改测试过程中的图像resize
- 修改训练过程中的训练数据显示
- 修改网络的输入过程:输入ground_true和gound true with mask,而不是random mask