- Python3.7
- PyTorch>=1.7.0+cu110
- TorchVision>=0.8.1+cu110
- Numpy==1.19.5
- CUDA 11.0+
- DANN结构擅于避免模型过学习
- feature_extractor与domain_classifier模块合并构成域分类器
- feature_extractor与label_predictor模块合并构成样本分类器
- 通过输入真实数据与抽象数据,输出基于域分类的dc_loss,用于domain_classifier的反向传递
- 将真实数据输入样本分类器,将lp_loss作用于feature_extractor,并将lp_loss-dc_loss作用于label_predictor
- 优化单个分体模型时,将计算合并模型的梯度,需使用detach()或zero_grad()转为常量
- 默认使用mnist作为真实样本,svhn作为抽象样本
- 首次运行将自行下载以上两种数据集
- 运行train.py即可开始训练