Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
upload code
  • Loading branch information
yzou2 committed Feb 8, 2021
1 parent 8871b94 commit d90613f
Show file tree
Hide file tree
Showing 19 changed files with 5,365 additions and 1 deletion.
180 changes: 179 additions & 1 deletion README.md 100644 → 100755
@@ -1 +1,179 @@
# DG-Net-PP
[![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md)
![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg)
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/NVlabs/DG-Net.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/NVlabs/DG-Net/context:python)

## Joint Discriminative and Generative Learning for Person Re-identification
![](NxN.jpg)

[[Project]](http://zdzheng.xyz/DG-Net/) [[Paper]](https://arxiv.org/abs/1904.07223) [[YouTube]](https://www.youtube.com/watch?v=ubCrEAIpQs4) [[Bilibili]](https://www.bilibili.com/video/av51439240) [[Poster]](http://zdzheng.xyz/images/DGNet_poster.pdf)
[[Supp]](http://jankautz.com/publications/JointReID_CVPR19_supp.pdf)

Joint Discriminative and Generative Learning for Person Re-identification, CVPR 2019 (Oral)<br>
[Zhedong Zheng](http://zdzheng.xyz/), [Xiaodong Yang](https://xiaodongyang.org/), [Zhiding Yu](https://chrisding.github.io/), [Liang Zheng](http://liangzheng.com.cn/), [Yi Yang](https://www.uts.edu.au/staff/yi.yang), [Jan Kautz](http://jankautz.com/) <br>

## Table of contents
* [News](#news)
* [Features](#features)
* [Prerequisites](#prerequisites)
* [Getting Started](#getting-started)
* [Installation](#installation)
* [Dataset Preparation](#dataset-preparation)
* [Testing](#testing)
* [Training](#training)
* [DG-Market](#dg-market)
* [Tips](#tips)
* [Citation](#citation)
* [Related Work](#related-work)
* [License](#license)

## News
- 08/24/2019: We add the direct transfer learning results of DG-Net [here](https://github.com/NVlabs/DG-Net#person-re-id-evaluation).
- 08/01/2019: We add the support of multi-GPU training: `python train.py --config configs/latest.yaml --gpu_ids 0,1`.

## Features
We have supported:
- Multi-GPU training (fp32)
- [APEX](https://github.com/NVIDIA/apex) to save GPU memory (fp16/fp32)
- Multi-query evaluation
- Random erasing
- Visualize training curves
- Generate all figures in the paper

## Prerequisites

- Python 3.6
- GPU memory >= 15G (fp32)
- GPU memory >= 10G (fp16/fp32)
- NumPy
- PyTorch 1.0+
- [Optional] APEX (fp16/fp32)

## Getting Started
### Installation
- Install [PyTorch](http://pytorch.org/)
- Install torchvision from the source:
```
git clone https://github.com/pytorch/vision
cd vision
python setup.py install
```
- [Optional] You may skip it. Install APEX from the source:
```
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext
```
- Clone this repo:
```bash
git clone https://github.com/NVlabs/DG-Net.git
cd DG-Net/
```

Our code is tested on PyTorch 1.0.0+ and torchvision 0.2.1+ .

### Dataset Preparation
Download the dataset [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html) [[Google Drive]](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) [[Baidu Disk]](https://pan.baidu.com/s/1ntIi2Op)

Preparation: put the images with the same id in one folder. You may use
```bash
python prepare-market.py # for Market-1501
```
Note to modify the dataset path to your own path.

### Testing

#### Download the trained model
We provide our trained model. You may download it from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). You may download and move it to the `outputs`.
```
├── outputs/
│ ├── E0.5new_reid0.5_w30000
├── models
│ ├── best/
```
#### Person re-id evaluation
- Supervised learning

| | Market-1501 | DukeMTMC-reID | MSMT17 | CUHK03-NP |
|---|--------------|----------------|----------|-----------|
| Rank@1 | 94.8% | 86.6% | 77.2% | 65.6% |
| mAP | 86.0% | 74.8% | 52.3% | 61.1% |


- Direct transfer learning
To verify the generalizability of DG-Net, we train the model on dataset A and directly test the model on dataset B (with no adaptation).
We denote the direct transfer learning protocol as `A→B`.

| |Market→Duke|Duke→Market|Market→MSMT|MSMT→Market|Duke→MSMT|MSMT→Duke|
|---|----------------|----------------| -------------- |----------------| -------------- |----------------|
| Rank@1 | 42.62% | 56.12% | 17.11% | 61.76% | 20.59% | 61.89% |
| Rank@5 | 58.57% | 72.18% | 26.66% | 77.67% | 31.67% | 75.81% |
| Rank@10 | 64.63% | 78.12% | 31.62% | 83.25% | 37.04% | 80.34% |
| mAP | 24.25% | 26.83% | 5.41% | 33.62% | 6.35% | 40.69% |


#### Image generation evaluation

Please check the `README.md` in the `./visual_tools`.

You may use the `./visual_tools/test_folder.py` to generate lots of images and then do the evaluation. The only thing you need to modify is the data path in [SSIM](https://github.com/layumi/PerceptualSimilarity) and [FID](https://github.com/layumi/TTUR).

### Training

#### Train a teacher model
You may directly download our trained teacher model from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf).
If you want to have it trained by yourself, please check the [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch) repository to train a teacher model, then copy and put it in the `./models`.
```
├── models/
│ ├── best/ /* teacher model for Market-1501
│ ├── net_last.pth /* model file
│ ├── ...
```

#### Train DG-Net
1. Setup the yaml file. Check out `configs/latest.yaml`. Change the data_root field to the path of your prepared folder-based dataset, e.g. `../Market-1501/pytorch`.


2. Start training
```
python train.py --config configs/latest.yaml
```
Or train with low precision (fp16)
```
python train.py --config configs/latest-fp16.yaml
```
Intermediate image outputs and model binary files are saved in `outputs/latest`.

3. Check the loss log
```
tensorboard --logdir logs/latest
```

## DG-Market
![](https://github.com/layumi/DG-Net/blob/gh-pages/index_files/DGMarket-logo.png)

We provide our generated images and make a large-scale synthetic dataset called DG-Market. This dataset is generated by our DG-Net and consists of 128,307 images (613MB), about 10 times larger than the training set of original Market-1501 (even much more can be generated with DG-Net). It can be used as a source of unlabeled training dataset for semi-supervised learning. You may download the dataset from [Google Drive](https://drive.google.com/file/d/126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo/view?usp=sharing) (or [Baidu Disk](https://pan.baidu.com/s/1n4M6s-qvE08J8SOOWtWfgw) password: qxyh).

| | DG-Market | Market-1501 (training) |
|---|--------------|-------------|
| #identity| - | 751 |
| #images| 128,307 | 12,936 |

## Tips
Note the format of camera id and number of cameras. For some datasets (e.g., MSMT17), there are more than 10 cameras. You need to modify the preparation and evaluation code to read the double-digit camera id. For some vehicle re-id datasets (e.g., VeRi) having different naming rules, you also need to modify the preparation and evaluation code.

## Citation
Please cite this paper if it helps your research:
```bibtex
@inproceedings{zheng2019joint,
title={Joint discriminative and generative learning for person re-identification},
author={Zheng, Zhedong and Yang, Xiaodong and Yu, Zhiding and Zheng, Liang and Yang, Yi and Kautz, Jan},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2019}
}
```

## Related Work
Other GAN-based methods compared in the paper include [LSGAN](https://github.com/layumi/DCGAN-pytorch), [FDGAN](https://github.com/layumi/FD-GAN) and [PG2GAN](https://github.com/charliememory/Pose-Guided-Person-Image-Generation). We forked the code and made some changes for evaluatation, thank the authors for their great work. We would also like to thank to the great projects in [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch), [MUNIT](https://github.com/NVlabs/MUNIT) and [DRIT](https://github.com/HsinYingLee/DRIT).

## License
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**). The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com).
119 changes: 119 additions & 0 deletions configs/duke2market.yaml
@@ -0,0 +1,119 @@
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

apex: false # Set True to use float16.
B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the DG-Net paper.
ID_class_a: 702 # The number of ID classes of source domain. For example, 751 for Market, 702 for DukeMTMC
ID_class_b: 0 # The number of ID classes of target domain. For example, 751 for Market, 702 for DukeMTMC
ID_stride: 1 # Stride in Appearance encoder
ID_style: AB # For time being, we only support AB.
batch_size: 8 # BatchSize
xx_port: 0.9 # The portion of the single domain batch (src-src/target-taget, or AA/BB)
ab_port: 0.1 # The portion of the cross-domain batch (src-target, or AB); xx_port + ab_port == 1
aa: true # Whether to use the AA-type batch
ab: true # Whether to use the AB-type batch
bb: true # Whether to use the BB-type batch
aa_drop: true # Whether to drop the aa batch in self-training
beta1: 0 # Adam hyperparameter
beta2: 0.999 # Adam hyperparameter
crop_image_height: 256 # Input height
crop_image_width: 128 # Input width
data_root_a: /workspace/home/re-id/datasets/DukeMTMC-reID/pytorch # Source dataset root
data_root_b: /workspace/home/re-id/datasets/Market/pytorch/ # Target dataset root
src_model_dir: /workspace/home/re-id/DG-Net++/models/zzd_duke/checkpoints # Source model root
dis_update_iter: 1 # The frequency to update the discriminator
dis:
# for image discriminator
LAMBDA: 0.01 # the hyperparameter for the regularization term
activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
dim: 32 # number of filters in the bottommost layer
gan_type: lsgan # GAN loss [lsgan/nsgan]
n_layer: 2 # number of layers in D
n_res: 4 # number of layers in D
non_local: 0 # number of non_local layers
norm: none # normalization layer [none/bn/in/ln]
num_scales: 3 # number of scales
pad_type: reflect # padding type [zero/reflect]
# for domain id discriminator
id_ganType: lsgan # the type of the network of ID discriminator
id_activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
id_norm: bn # normalization layer [none/bn/in/ln]
id_nLayer: 4 # number of layers in domain id discriminator
id_nFilter: 1024 # number of layer filters in domain id discriminator
id_ds: 2 # down sampling rate in domain id discriminator
display_size: 16 # How much display images
erasing_p: 0.5 # Random erasing probability [0-1]
gamma: 0.1 # Learning Rate Decay (except appearance encoder)
gamma2: 0.1 # Learning Rate Decay (for appearance encoder)
gan_w: 1 # The weight of gan loss
gen:
activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
dec: basic # [basic/parallel/series]
dim: 16 # number of filters in the bottommost layer
dropout: 0 # use dropout in the generator
id_dim: 2048 # length of appearance code
mlp_dim: 512 # number of filters in MLP
mlp_norm: none # norm in mlp [none/bn/in/ln]
n_downsample: 2 # number of downsampling layers in content encoder
n_res: 4 # number of residual blocks in content encoder/decoder
non_local: 0 # number of non_local layer
pad_type: reflect # padding type [zero/reflect]
tanh: false # use tanh or not at the last layer
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
id_adv_w: 1.000 # The initial weight of domain ID adversarial loss
id_adv_w_max: 1.000 # The maximum weight of domain ID adversarial loss
adv_warm_max_round: 1 # The maximum round for domain ID adversarial training
adv_warm_scale: 0.0000 # How fast to warm up the domain ID adversarial training
id_w: 1.0 # The weight of ID loss
test_batchsize: 80 # The batch size in test time
input_dim_a: 1 # We use the gray-scale input, so the input dim is 1
input_dim_b: 1 # We use the gray-scale input, so the input dim is 1
log_iter: 1 # How often do you want to log the training stats
lr_id_d: 0.00001 # Initial domain ID discriminator learning rate
lr2: 0.00001 # Initial appearance encoder learning rate
lr2_ramp_factor: 60 # The factor to multiply the lr2 after switching to self-training
lr_d: 0.000001 # Initial discriminator learning rate
lr_g: 0.000001 # Initial generator (except appearance encoder) learning rate
lr_policy: multistep # Learning rate scheduler [multistep|constant|step]
max_cyc_w: 2 # The maximum weight for cycle loss
max_teacher_w: 2 # The maximum weight for prime loss (teacher KL loss)
max_w: 1 # The maximum weight for feature reconstruction losses
new_size: 128 # The resized size
norm_id: false # Do we normalize the appearance code
num_workers: 4 # Number of workers to load the data
id_tgt: false # Whether to use identification loss on target domain
tgt_pos: 0.0 # Whether to use the identification loss of the positive samples (the samples with the same pseudo-ID of a given sample) in target domain
pid_w: 1.0 # Positive ID loss
pool: max # Pooling layer for the appearance encoder
recon_s_w: 1 # The initial weight for structure code reconstruction
recon_f_w: 1 # The initial weight for appearance code reconstruction
recon_id_w: 0.5 # The initial weight for ID reconstruction
recon_x_cyc_w: 2 # The initial weight for cycle reconstruction
recon_x_w: 5 # The initial weight for self-reconstruction
recon_xp_w: 5 # The initial weight for self-identity reconstruction
recon_xp_tgt_w: 0 # The weight for self-positive-identity reconstruction in target domain
single: gray # Make input to gray-scale
snapshot_save_iter: 10000 # How often to save the checkpoint
sqrt: false # Whether use square loss.
step_size: 120000 # When to decay the learning rate
teacher: "best-duke" # Teacher model name. For Market, you can set "best"; For DukeMTMC, you can set `best-duke`; "" means no teacher.
teacher_w: 2.0 # The initial weight for prime loss (teacher KL loss)
teacher_style: 0 # Teacher style, # 0-Our smooth dynamic label; 1-One-hot dynamic pseudo label; 2-Conditional hard static label; 3-LSRO, static smooth label; 4-Dynamic Soft Two-label
teacher_tgt: false # Whether to use the teacher loss on target samples
train_bn: true # Whether we train the bn for the generated image.
use_decoder_again: true # Whether we train the decoder on the generatd image.
use_encoder_again: 0.5 # The probability we train the structure encoder on the generatd image.
vgg_w: 0 # We do not use vgg as one kind of inception loss.
warm_iter: 0 # When to start warm up the losses (fine-grained/feature reconstruction losses).
warm_scale: 0.0000 # How fast to warm up
warm_teacher_iter: 0 # When to start warm up the prime loss
weight_decay: 0.0005 # Weight decay
max_iter: 300000 # When you end the training
max_round: 40 # Maximum self-training rounds
epoch_round: 2 # Epochs per round in self-training
epoch_round_adv: 22 # Maximal training epochs of domain ID adversarial training
randseed: 3 # Random seed
time_constraint: false # Whether to use sequential time constraint in pseudo-label generation
clustering:
eps: 0.45
min_samples: 7

0 comments on commit d90613f

Please sign in to comment.