Python3.7
PyTorch>=1.7.0+cu110
Numpy==1.19.5
CUDA 11.0+
nltk==3.7
tqdm==4.62.3
- 调换生成器中manage_loss与work_loss的反向传递顺序,否则torch将于当前版本报错
- 新增基于生成器work模块、对抗器权重的l2正则化误差,降低过拟合概率
- 新增基于HingeLoss的对抗器误差
- 更改基于参考资料[2]的蒙特卡洛补偿计算方法,使其与数据特征能正确匹配
- 数据路径、训练目标等参数均位于config.py,默认使用image_coco数据集
- PyTorch的嵌入层需将整型变量转换为Long型(64位)
链接:https://pan.baidu.com/s/1SNc7uJ3PMxX6gxLrELfjEQ 提取码:LEAK 下载解压后放置于config.py中设置的路径即可。
- 默认使用image_coco进行训练,并使用nltk划分单词。
- 运行train.py即可开始训练。
https://github.com/CR-Gjx/LeakGAN https://github.com/williamSYSU/TextGAN-PyTorch