Skip to content

seqgan: sequence generative adversarial nets with policy gradient

Notifications You must be signed in to change notification settings

JJASMINE22/SeqGan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SeqGAN: Sequence Generative Adversarial Nets的pytorch实现


目录

  1. 所需环境 Environment
  2. 模型结构 Structure
  3. 注意事项 Attention
  4. 文件下载 Download
  5. 训练步骤 How2train
  6. 参考资料 Reference

所需环境

Python3.7 PyTorch>=1.7.0+cu110
Numpy==1.19.5 CUDA 11.0+ nltk==3.7 tqdm==4.62.3

模型结构

image

注意事项

  1. 新增基于生成器、对抗器权重的l2正则化误差,降低过拟合概率
  2. 新增基于HingeLoss的对抗器误差
  3. 更改基于参考资料[1]的蒙特卡洛补偿计算方法,使其与数据特征能正确匹配
  4. 数据路径、训练目标等参数均位于config.py,默认使用image_coco数据集

文件下载

链接:https://pan.baidu.com/s/1SNc7uJ3PMxX6gxLrELfjEQ 提取码:LEAK 下载解压后放置于config.py中设置的路径即可。

训练步骤

  1. 默认使用image_coco进行训练,并使用nltk划分单词。
  2. 运行train.py即可开始训练。

Reference

  1. https://github.com/williamSYSU/TextGAN-PyTorch

About

seqgan: sequence generative adversarial nets with policy gradient

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages