diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 53b4ee6..0abc121 --- a/README.md +++ b/README.md @@ -1 +1,179 @@ -# DG-Net-PP +[![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md) +![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) +[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/NVlabs/DG-Net.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/NVlabs/DG-Net/context:python) + +## Joint Discriminative and Generative Learning for Person Re-identification +![](NxN.jpg) + +[[Project]](http://zdzheng.xyz/DG-Net/) [[Paper]](https://arxiv.org/abs/1904.07223) [[YouTube]](https://www.youtube.com/watch?v=ubCrEAIpQs4) [[Bilibili]](https://www.bilibili.com/video/av51439240) [[Poster]](http://zdzheng.xyz/images/DGNet_poster.pdf) +[[Supp]](http://jankautz.com/publications/JointReID_CVPR19_supp.pdf) + +Joint Discriminative and Generative Learning for Person Re-identification, CVPR 2019 (Oral)
+[Zhedong Zheng](http://zdzheng.xyz/), [Xiaodong Yang](https://xiaodongyang.org/), [Zhiding Yu](https://chrisding.github.io/), [Liang Zheng](http://liangzheng.com.cn/), [Yi Yang](https://www.uts.edu.au/staff/yi.yang), [Jan Kautz](http://jankautz.com/)
+ +## Table of contents +* [News](#news) +* [Features](#features) +* [Prerequisites](#prerequisites) +* [Getting Started](#getting-started) + * [Installation](#installation) + * [Dataset Preparation](#dataset-preparation) + * [Testing](#testing) + * [Training](#training) +* [DG-Market](#dg-market) +* [Tips](#tips) +* [Citation](#citation) +* [Related Work](#related-work) +* [License](#license) + +## News +- 08/24/2019: We add the direct transfer learning results of DG-Net [here](https://github.com/NVlabs/DG-Net#person-re-id-evaluation). +- 08/01/2019: We add the support of multi-GPU training: `python train.py --config configs/latest.yaml --gpu_ids 0,1`. + +## Features +We have supported: +- Multi-GPU training (fp32) +- [APEX](https://github.com/NVIDIA/apex) to save GPU memory (fp16/fp32) +- Multi-query evaluation +- Random erasing +- Visualize training curves +- Generate all figures in the paper + +## Prerequisites + +- Python 3.6 +- GPU memory >= 15G (fp32) +- GPU memory >= 10G (fp16/fp32) +- NumPy +- PyTorch 1.0+ +- [Optional] APEX (fp16/fp32) + +## Getting Started +### Installation +- Install [PyTorch](http://pytorch.org/) +- Install torchvision from the source: +``` +git clone https://github.com/pytorch/vision +cd vision +python setup.py install +``` +- [Optional] You may skip it. Install APEX from the source: +``` +git clone https://github.com/NVIDIA/apex.git +cd apex +python setup.py install --cuda_ext --cpp_ext +``` +- Clone this repo: +```bash +git clone https://github.com/NVlabs/DG-Net.git +cd DG-Net/ +``` + +Our code is tested on PyTorch 1.0.0+ and torchvision 0.2.1+ . + +### Dataset Preparation +Download the dataset [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html) [[Google Drive]](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) [[Baidu Disk]](https://pan.baidu.com/s/1ntIi2Op) + +Preparation: put the images with the same id in one folder. You may use +```bash +python prepare-market.py # for Market-1501 +``` +Note to modify the dataset path to your own path. + +### Testing + +#### Download the trained model +We provide our trained model. You may download it from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). You may download and move it to the `outputs`. +``` +├── outputs/ +│ ├── E0.5new_reid0.5_w30000 +├── models +│ ├── best/ +``` +#### Person re-id evaluation +- Supervised learning + +| | Market-1501 | DukeMTMC-reID | MSMT17 | CUHK03-NP | +|---|--------------|----------------|----------|-----------| +| Rank@1 | 94.8% | 86.6% | 77.2% | 65.6% | +| mAP | 86.0% | 74.8% | 52.3% | 61.1% | + + +- Direct transfer learning +To verify the generalizability of DG-Net, we train the model on dataset A and directly test the model on dataset B (with no adaptation). +We denote the direct transfer learning protocol as `A→B`. + +| |Market→Duke|Duke→Market|Market→MSMT|MSMT→Market|Duke→MSMT|MSMT→Duke| +|---|----------------|----------------| -------------- |----------------| -------------- |----------------| +| Rank@1 | 42.62% | 56.12% | 17.11% | 61.76% | 20.59% | 61.89% | +| Rank@5 | 58.57% | 72.18% | 26.66% | 77.67% | 31.67% | 75.81% | +| Rank@10 | 64.63% | 78.12% | 31.62% | 83.25% | 37.04% | 80.34% | +| mAP | 24.25% | 26.83% | 5.41% | 33.62% | 6.35% | 40.69% | + + +#### Image generation evaluation + +Please check the `README.md` in the `./visual_tools`. + +You may use the `./visual_tools/test_folder.py` to generate lots of images and then do the evaluation. The only thing you need to modify is the data path in [SSIM](https://github.com/layumi/PerceptualSimilarity) and [FID](https://github.com/layumi/TTUR). + +### Training + +#### Train a teacher model +You may directly download our trained teacher model from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). +If you want to have it trained by yourself, please check the [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch) repository to train a teacher model, then copy and put it in the `./models`. +``` +├── models/ +│ ├── best/ /* teacher model for Market-1501 +│ ├── net_last.pth /* model file +│ ├── ... +``` + +#### Train DG-Net +1. Setup the yaml file. Check out `configs/latest.yaml`. Change the data_root field to the path of your prepared folder-based dataset, e.g. `../Market-1501/pytorch`. + + +2. Start training +``` +python train.py --config configs/latest.yaml +``` +Or train with low precision (fp16) +``` +python train.py --config configs/latest-fp16.yaml +``` +Intermediate image outputs and model binary files are saved in `outputs/latest`. + +3. Check the loss log +``` + tensorboard --logdir logs/latest +``` + +## DG-Market +![](https://github.com/layumi/DG-Net/blob/gh-pages/index_files/DGMarket-logo.png) + +We provide our generated images and make a large-scale synthetic dataset called DG-Market. This dataset is generated by our DG-Net and consists of 128,307 images (613MB), about 10 times larger than the training set of original Market-1501 (even much more can be generated with DG-Net). It can be used as a source of unlabeled training dataset for semi-supervised learning. You may download the dataset from [Google Drive](https://drive.google.com/file/d/126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo/view?usp=sharing) (or [Baidu Disk](https://pan.baidu.com/s/1n4M6s-qvE08J8SOOWtWfgw) password: qxyh). + +| | DG-Market | Market-1501 (training) | +|---|--------------|-------------| +| #identity| - | 751 | +| #images| 128,307 | 12,936 | + +## Tips +Note the format of camera id and number of cameras. For some datasets (e.g., MSMT17), there are more than 10 cameras. You need to modify the preparation and evaluation code to read the double-digit camera id. For some vehicle re-id datasets (e.g., VeRi) having different naming rules, you also need to modify the preparation and evaluation code. + +## Citation +Please cite this paper if it helps your research: +```bibtex +@inproceedings{zheng2019joint, + title={Joint discriminative and generative learning for person re-identification}, + author={Zheng, Zhedong and Yang, Xiaodong and Yu, Zhiding and Zheng, Liang and Yang, Yi and Kautz, Jan}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2019} +} +``` + +## Related Work +Other GAN-based methods compared in the paper include [LSGAN](https://github.com/layumi/DCGAN-pytorch), [FDGAN](https://github.com/layumi/FD-GAN) and [PG2GAN](https://github.com/charliememory/Pose-Guided-Person-Image-Generation). We forked the code and made some changes for evaluatation, thank the authors for their great work. We would also like to thank to the great projects in [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch), [MUNIT](https://github.com/NVlabs/MUNIT) and [DRIT](https://github.com/HsinYingLee/DRIT). + +## License +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**). The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com). diff --git a/configs/duke2market.yaml b/configs/duke2market.yaml new file mode 100644 index 0000000..887b409 --- /dev/null +++ b/configs/duke2market.yaml @@ -0,0 +1,119 @@ +# Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). + +apex: false # Set True to use float16. +B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the DG-Net paper. +ID_class_a: 702 # The number of ID classes of source domain. For example, 751 for Market, 702 for DukeMTMC +ID_class_b: 0 # The number of ID classes of target domain. For example, 751 for Market, 702 for DukeMTMC +ID_stride: 1 # Stride in Appearance encoder +ID_style: AB # For time being, we only support AB. +batch_size: 8 # BatchSize +xx_port: 0.9 # The portion of the single domain batch (src-src/target-taget, or AA/BB) +ab_port: 0.1 # The portion of the cross-domain batch (src-target, or AB); xx_port + ab_port == 1 +aa: true # Whether to use the AA-type batch +ab: true # Whether to use the AB-type batch +bb: true # Whether to use the BB-type batch +aa_drop: true # Whether to drop the aa batch in self-training +beta1: 0 # Adam hyperparameter +beta2: 0.999 # Adam hyperparameter +crop_image_height: 256 # Input height +crop_image_width: 128 # Input width +data_root_a: /workspace/home/re-id/datasets/DukeMTMC-reID/pytorch # Source dataset root +data_root_b: /workspace/home/re-id/datasets/Market/pytorch/ # Target dataset root +src_model_dir: /workspace/home/re-id/DG-Net++/models/zzd_duke/checkpoints # Source model root +dis_update_iter: 1 # The frequency to update the discriminator +dis: + # for image discriminator + LAMBDA: 0.01 # the hyperparameter for the regularization term + activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + dim: 32 # number of filters in the bottommost layer + gan_type: lsgan # GAN loss [lsgan/nsgan] + n_layer: 2 # number of layers in D + n_res: 4 # number of layers in D + non_local: 0 # number of non_local layers + norm: none # normalization layer [none/bn/in/ln] + num_scales: 3 # number of scales + pad_type: reflect # padding type [zero/reflect] + # for domain id discriminator + id_ganType: lsgan # the type of the network of ID discriminator + id_activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + id_norm: bn # normalization layer [none/bn/in/ln] + id_nLayer: 4 # number of layers in domain id discriminator + id_nFilter: 1024 # number of layer filters in domain id discriminator + id_ds: 2 # down sampling rate in domain id discriminator +display_size: 16 # How much display images +erasing_p: 0.5 # Random erasing probability [0-1] +gamma: 0.1 # Learning Rate Decay (except appearance encoder) +gamma2: 0.1 # Learning Rate Decay (for appearance encoder) +gan_w: 1 # The weight of gan loss +gen: + activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + dec: basic # [basic/parallel/series] + dim: 16 # number of filters in the bottommost layer + dropout: 0 # use dropout in the generator + id_dim: 2048 # length of appearance code + mlp_dim: 512 # number of filters in MLP + mlp_norm: none # norm in mlp [none/bn/in/ln] + n_downsample: 2 # number of downsampling layers in content encoder + n_res: 4 # number of residual blocks in content encoder/decoder + non_local: 0 # number of non_local layer + pad_type: reflect # padding type [zero/reflect] + tanh: false # use tanh or not at the last layer + init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] +id_adv_w: 1.000 # The initial weight of domain ID adversarial loss +id_adv_w_max: 1.000 # The maximum weight of domain ID adversarial loss +adv_warm_max_round: 1 # The maximum round for domain ID adversarial training +adv_warm_scale: 0.0000 # How fast to warm up the domain ID adversarial training +id_w: 1.0 # The weight of ID loss +test_batchsize: 80 # The batch size in test time +input_dim_a: 1 # We use the gray-scale input, so the input dim is 1 +input_dim_b: 1 # We use the gray-scale input, so the input dim is 1 +log_iter: 1 # How often do you want to log the training stats +lr_id_d: 0.00001 # Initial domain ID discriminator learning rate +lr2: 0.00001 # Initial appearance encoder learning rate +lr2_ramp_factor: 60 # The factor to multiply the lr2 after switching to self-training +lr_d: 0.000001 # Initial discriminator learning rate +lr_g: 0.000001 # Initial generator (except appearance encoder) learning rate +lr_policy: multistep # Learning rate scheduler [multistep|constant|step] +max_cyc_w: 2 # The maximum weight for cycle loss +max_teacher_w: 2 # The maximum weight for prime loss (teacher KL loss) +max_w: 1 # The maximum weight for feature reconstruction losses +new_size: 128 # The resized size +norm_id: false # Do we normalize the appearance code +num_workers: 4 # Number of workers to load the data +id_tgt: false # Whether to use identification loss on target domain +tgt_pos: 0.0 # Whether to use the identification loss of the positive samples (the samples with the same pseudo-ID of a given sample) in target domain +pid_w: 1.0 # Positive ID loss +pool: max # Pooling layer for the appearance encoder +recon_s_w: 1 # The initial weight for structure code reconstruction +recon_f_w: 1 # The initial weight for appearance code reconstruction +recon_id_w: 0.5 # The initial weight for ID reconstruction +recon_x_cyc_w: 2 # The initial weight for cycle reconstruction +recon_x_w: 5 # The initial weight for self-reconstruction +recon_xp_w: 5 # The initial weight for self-identity reconstruction +recon_xp_tgt_w: 0 # The weight for self-positive-identity reconstruction in target domain +single: gray # Make input to gray-scale +snapshot_save_iter: 10000 # How often to save the checkpoint +sqrt: false # Whether use square loss. +step_size: 120000 # When to decay the learning rate +teacher: "best-duke" # Teacher model name. For Market, you can set "best"; For DukeMTMC, you can set `best-duke`; "" means no teacher. +teacher_w: 2.0 # The initial weight for prime loss (teacher KL loss) +teacher_style: 0 # Teacher style, # 0-Our smooth dynamic label; 1-One-hot dynamic pseudo label; 2-Conditional hard static label; 3-LSRO, static smooth label; 4-Dynamic Soft Two-label +teacher_tgt: false # Whether to use the teacher loss on target samples +train_bn: true # Whether we train the bn for the generated image. +use_decoder_again: true # Whether we train the decoder on the generatd image. +use_encoder_again: 0.5 # The probability we train the structure encoder on the generatd image. +vgg_w: 0 # We do not use vgg as one kind of inception loss. +warm_iter: 0 # When to start warm up the losses (fine-grained/feature reconstruction losses). +warm_scale: 0.0000 # How fast to warm up +warm_teacher_iter: 0 # When to start warm up the prime loss +weight_decay: 0.0005 # Weight decay +max_iter: 300000 # When you end the training +max_round: 40 # Maximum self-training rounds +epoch_round: 2 # Epochs per round in self-training +epoch_round_adv: 22 # Maximal training epochs of domain ID adversarial training +randseed: 3 # Random seed +time_constraint: false # Whether to use sequential time constraint in pseudo-label generation +clustering: + eps: 0.45 + min_samples: 7 diff --git a/configs/market2duke.yaml b/configs/market2duke.yaml new file mode 100644 index 0000000..efe251d --- /dev/null +++ b/configs/market2duke.yaml @@ -0,0 +1,119 @@ +# Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). + +apex: false # Set True to use float16. +B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the DG-Net paper. +ID_class_a: 751 # The number of ID classes of source domain. For example, 751 for Market, 702 for DukeMTMC +ID_class_b: 0 # The number of ID classes of target domain. For example, 751 for Market, 702 for DukeMTMC +ID_stride: 1 # Stride in Appearance encoder +ID_style: AB # For time being, we only support AB. +batch_size: 8 # BatchSize +xx_port: 0.9 # The portion of the single domain batch (src-src/target-taget, or AA/BB) +ab_port: 0.1 # The portion of the cross-domain batch (src-target, or AB); xx_port + ab_port == 1 +aa: true # Whether to use the AA-type batch +ab: true # Whether to use the AB-type batch +bb: true # Whether to use the BB-type batch +aa_drop: true # Whether to drop the aa batch in self-training +beta1: 0 # Adam hyperparameter +beta2: 0.999 # Adam hyperparameter +crop_image_height: 256 # Input height +crop_image_width: 128 # Input width +data_root_a: /workspace/home/re-id/datasets/Market/pytorch # Source dataset root +data_root_b: /workspace/home/re-id/datasets/DukeMTMC-reID/pytorch/ # Target dataset root +src_model_dir: /workspace/home/re-id/DG-Net++/models/zzd_market/checkpoints # Source model root +dis_update_iter: 1 # The frequency to update the discriminator +dis: + # for image discriminator + LAMBDA: 0.01 # the hyperparameter for the regularization term + activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + dim: 32 # number of filters in the bottommost layer + gan_type: lsgan # GAN loss [lsgan/nsgan] + n_layer: 2 # number of layers in D + n_res: 4 # number of layers in D + non_local: 0 # number of non_local layers + norm: none # normalization layer [none/bn/in/ln] + num_scales: 3 # number of scales + pad_type: reflect # padding type [zero/reflect] + # for domain id discriminator + id_ganType: lsgan # the type of the network of ID discriminator + id_activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + id_norm: bn # normalization layer [none/bn/in/ln] + id_nLayer: 4 # number of layers in domain id discriminator + id_nFilter: 1024 # number of layer filters in domain id discriminator + id_ds: 2 # down sampling rate in domain id discriminator +display_size: 16 # How much display images +erasing_p: 0.5 # Random erasing probability [0-1] +gamma: 0.1 # Learning Rate Decay (except appearance encoder) +gamma2: 0.1 # Learning Rate Decay (for appearance encoder) +gan_w: 1 # The weight of gan loss +gen: + activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] + dec: basic # [basic/parallel/series] + dim: 16 # number of filters in the bottommost layer + dropout: 0 # use dropout in the generator + id_dim: 2048 # length of appearance code + mlp_dim: 512 # number of filters in MLP + mlp_norm: none # norm in mlp [none/bn/in/ln] + n_downsample: 2 # number of downsampling layers in content encoder + n_res: 4 # number of residual blocks in content encoder/decoder + non_local: 0 # number of non_local layer + pad_type: reflect # padding type [zero/reflect] + tanh: false # use tanh or not at the last layer + init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] +id_adv_w: 1.000 # The initial weight of domain ID adversarial loss +id_adv_w_max: 1.000 # The maximum weight of domain ID adversarial loss +adv_warm_max_round: 1 # The maximum round for domain ID adversarial training +adv_warm_scale: 0.0000 # How fast to warm up the domain ID adversarial training +id_w: 1.0 # The weight of ID loss +test_batchsize: 80 # The batch size in test time +input_dim_a: 1 # We use the gray-scale input, so the input dim is 1 +input_dim_b: 1 # We use the gray-scale input, so the input dim is 1 +log_iter: 1 # How often do you want to log the training stats +lr_id_d: 0.00001 # Initial domain ID discriminator learning rate +lr2: 0.00001 # Initial appearance encoder learning rate +lr2_ramp_factor: 60 # The factor to multiply the lr2 after switching to self-training +lr_d: 0.000001 # Initial discriminator learning rate +lr_g: 0.000001 # Initial generator (except appearance encoder) learning rate +lr_policy: multistep # Learning rate scheduler [multistep|constant|step] +max_cyc_w: 2 # The maximum weight for cycle loss +max_teacher_w: 2 # The maximum weight for prime loss (teacher KL loss) +max_w: 1 # The maximum weight for feature reconstruction losses +new_size: 128 # The resized size +norm_id: false # Do we normalize the appearance code +num_workers: 4 # Number of workers to load the data +id_tgt: false # Whether to use identification loss on target domain +tgt_pos: 0.0 # Whether to use the identification loss of the positive samples (the samples with the same pseudo-ID of a given sample) in target domain +pid_w: 1.0 # Positive ID loss +pool: max # Pooling layer for the appearance encoder +recon_s_w: 1 # The initial weight for structure code reconstruction +recon_f_w: 1 # The initial weight for appearance code reconstruction +recon_id_w: 0.5 # The initial weight for ID reconstruction +recon_x_cyc_w: 2 # The initial weight for cycle reconstruction +recon_x_w: 5 # The initial weight for self-reconstruction +recon_xp_w: 5 # The initial weight for self-identity reconstruction +recon_xp_tgt_w: 0 # The weight for self-positive-identity reconstruction in target domain +single: gray # Make input to gray-scale +snapshot_save_iter: 10000 # How often to save the checkpoint +sqrt: false # Whether use square loss. +step_size: 120000 # When to decay the learning rate +teacher: "best" # Teacher model name. For Market, you can set "best"; For DukeMTMC, you can set `best-duke`; "" means no teacher. +teacher_w: 2.0 # The initial weight for prime loss (teacher KL loss) +teacher_style: 0 # Teacher style, # 0-Our smooth dynamic label; 1-One-hot dynamic pseudo label; 2-Conditional hard static label; 3-LSRO, static smooth label; 4-Dynamic Soft Two-label +teacher_tgt: false # Whether to use the teacher loss on target samples +train_bn: true # Whether we train the bn for the generated image. +use_decoder_again: true # Whether we train the decoder on the generatd image. +use_encoder_again: 0.5 # The probability we train the structure encoder on the generatd image. +vgg_w: 0 # We do not use vgg as one kind of inception loss. +warm_iter: 0 # When to start warm up the losses (fine-grained/feature reconstruction losses). +warm_scale: 0.0000 # How fast to warm up +warm_teacher_iter: 0 # When to start warm up the prime loss +weight_decay: 0.0005 # Weight decay +max_iter: 300000 # When you end the training +max_round: 40 # Maximum self-training rounds +epoch_round: 2 # Epochs per round in self-training +epoch_round_adv: 22 # Maximal training epochs of domain ID adversarial training +randseed: 3 # Random seed +time_constraint: false # Whether to use sequential time constraint in pseudo-label generation +clustering: + eps: 0.45 + min_samples: 7 diff --git a/data.py b/data.py new file mode 100755 index 0000000..3e85795 --- /dev/null +++ b/data.py @@ -0,0 +1,129 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +import torch.utils.data as data +import os.path + +def default_loader(path): + return Image.open(path).convert('RGB') + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + + +class ImageFilelist(data.Dataset): + def __init__(self, root, flist, transform=None, + flist_reader=default_flist_reader, loader=default_loader): + self.root = root + self.imlist = flist_reader(flist) + self.transform = transform + self.loader = loader + + def __getitem__(self, index): + impath = self.imlist[index] + img = self.loader(os.path.join(self.root, impath)) + if self.transform is not None: + img = self.transform(img) + + return img + + def __len__(self): + return len(self.imlist) + + +class ImageLabelFilelist(data.Dataset): + def __init__(self, root, flist, transform=None, + flist_reader=default_flist_reader, loader=default_loader): + self.root = root + self.imlist = flist_reader(os.path.join(self.root, flist)) + self.transform = transform + self.loader = loader + self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist]))) + self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} + self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist] + + def __getitem__(self, index): + impath, label = self.imgs[index] + img = self.loader(os.path.join(self.root, impath)) + if self.transform is not None: + img = self.transform(img) + return img, label + + def __len__(self): + return len(self.imgs) + +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = sorted(make_dataset(root)) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/networks.py b/networks.py new file mode 100644 index 0000000..d2274bc --- /dev/null +++ b/networks.py @@ -0,0 +1,1017 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +from torch import nn +from torch.autograd import Variable +import torch +import torch.nn.functional as F +from torchvision import models +from utils import weights_init + +try: + from itertools import izip as zip +except ImportError: # will be 3.x series + pass + + +################################################################################## +# Discriminator +################################################################################## + +class IdDis(nn.Module): + # Domain ID Discriminator architecture + def __init__(self, input_dim, params, fp16): + super(IdDis, self).__init__() + self.n_layer = params['id_nLayer'] + self.dim = params['id_nFilter'] + self.gan_type = params['id_ganType'] + self.activ = params['id_activ'] + self.norm = params['id_norm'] + self.ds = params['id_ds'] + self.input_dim = input_dim + self.fp16 = fp16 + + self.fcnet = self.one_fcnet() + self.fcnet.apply(weights_init('gaussian')) + + def one_fcnet(self): + dim = self.dim + fcnet_x = [] + fcnet_x += [FcBlock(self.input_dim, dim, norm=self.norm, activation=self.activ)] + for i in range(self.n_layer - 2): + dim2 = max(dim // self.ds, 32) + fcnet_x += [FcBlock(dim, dim2, norm=self.norm, activation=self.activ)] + dim = dim2 + fcnet_x += [nn.Linear(dim, 1)] + fcnet_x = nn.Sequential(*fcnet_x) + return fcnet_x + + def forward(self, x): + outputs = self.fcnet(x) + outputs = torch.squeeze(outputs) + return outputs + + def calc_dis_loss_ab(self, input_s, input_t): + outs0 = self.forward(input_s) + outs1 = self.forward(input_t) + loss = 0 + + reg = 0.0 + for it, (out0, out1) in enumerate(zip(outs0, outs1)): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) # 0 indicates source and 1 indicates target + elif self.gan_type == 'nsgan': + all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) + all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + + F.binary_cross_entropy(F.sigmoid(out1), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + + loss = loss+reg + return loss, reg, 0.0 + + def calc_dis_loss_aa(self, input_s0, input_s1): + outs0 = self.forward(input_s0) + outs1 = self.forward(input_s1) + loss = 0 + + reg = 0.0 + for it, (out0, out1) in enumerate(zip(outs0, outs1)): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 0)**2) + elif self.gan_type == 'nsgan': + all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) + all1 = Variable(torch.zeros_like(out1.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + + F.binary_cross_entropy(F.sigmoid(out1), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + + loss = loss+reg + return loss, reg, 0.0 + + def calc_dis_loss_bb(self, input_t0, input_t1): + outs0 = self.forward(input_t0) + outs1 = self.forward(input_t1) + loss = 0 + + reg = 0.0 + for it, (out0, out1) in enumerate(zip(outs0, outs1)): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 1)**2) + torch.mean((out1 - 1)**2) + elif self.gan_type == 'nsgan': + all0 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) + all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + + F.binary_cross_entropy(F.sigmoid(out1), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + + loss = loss+reg + return loss, reg, 0.0 + + def calc_gen_loss(self, input_t): + outs0 = self.forward(input_t) + loss = 0 + Drift = 0.001 + + for it, (out0) in enumerate(outs0): + if self.gan_type == 'lsgan': + loss += torch.mean((out0)**2) * 2 # LSGAN + elif self.gan_type == 'nsgan': + all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + return loss + +class MsImageDis(nn.Module): + # Multi-scale discriminator architecture + def __init__(self, input_dim, params, fp16): + super(MsImageDis, self).__init__() + self.n_layer = params['n_layer'] + self.gan_type = params['gan_type'] + self.dim = params['dim'] + self.norm = params['norm'] + self.activ = params['activ'] + self.num_scales = params['num_scales'] + self.pad_type = params['pad_type'] + self.LAMBDA = params['LAMBDA'] + self.non_local = params['non_local'] + self.n_res = params['n_res'] + self.input_dim = input_dim + self.fp16 = fp16 + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + if not self.gan_type == 'wgan': + self.cnns = nn.ModuleList() + for _ in range(self.num_scales): + Dis = self._make_net() + Dis.apply(weights_init('gaussian')) + self.cnns.append(Dis) + else: + self.cnn = self.one_cnn() + + def _make_net(self): + dim = self.dim + cnn_x = [] + cnn_x += [Conv2dBlock(self.input_dim, dim, 1, 1, 0, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + cnn_x += [Conv2dBlock(dim, dim, 3, 1, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + cnn_x += [Conv2dBlock(dim, dim, 3, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + for i in range(self.n_layer - 1): + dim2 = min(dim*2, 512) + cnn_x += [Conv2dBlock(dim, dim, 3, 1, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + #cnn_x += [ResBlock(dim, norm=self.norm, activation=self.activ, pad_type=self.pad_type, res_type='basic')] + cnn_x += [Conv2dBlock(dim, dim2, 3, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + dim = dim2 + if self.non_local>1: + cnn_x += [NonlocalBlock(dim)] + for i in range(self.n_res): + cnn_x += [ResBlock(dim, norm=self.norm, activation=self.activ, pad_type=self.pad_type, res_type='basic')] + if self.non_local>0: + cnn_x += [NonlocalBlock(dim)] + cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)] + cnn_x = nn.Sequential(*cnn_x) + return cnn_x + + def one_cnn(self): + dim = self.dim + cnn_x = [] + cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)] + for i in range(5): + dim2 = min(dim*2, 512) + cnn_x += [Conv2dBlock(dim, dim2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] + dim = dim2 + cnn_x += [nn.Conv2d(dim, 1, (4,2), 1, 0)] + cnn_x = nn.Sequential(*cnn_x) + return cnn_x + + def forward(self, x): + if not self.gan_type == 'wgan': + outputs = [] + for model in self.cnns: + outputs.append(model(x)) + x = self.downsample(x) + else: + outputs = self.cnn(x) + outputs = torch.squeeze(outputs) + return outputs + + def calc_dis_loss(self, input_fake, input_real): + # calculate the loss to train D + input_real.requires_grad_() + outs0 = self.forward(input_fake) + outs1 = self.forward(input_real) + loss = 0 + reg = 0 + Drift = 0.001 + LAMBDA = self.LAMBDA + + if self.gan_type == 'wgan': + loss += torch.mean(outs0) - torch.mean(outs1) + # progressive gan + loss += Drift*( torch.sum(outs0**2) + torch.sum(outs1**2)) + alpha = torch.FloatTensor(input_fake.shape).uniform_(0., 1.) + alpha = alpha.cuda() + differences = input_fake - input_real + interpolates = Variable(input_real + (alpha*differences), requires_grad=True) + dis_interpolates = self.forward(interpolates) + gradient_penalty = self.compute_grad2(dis_interpolates, interpolates).mean() + loss += LAMBDA*gradient_penalty + return loss + + for it, (out0, out1) in enumerate(zip(outs0, outs1)): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) + # regularization + reg += LAMBDA* self.compute_grad2(out1, input_real).mean() + elif self.gan_type == 'nsgan': + all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) + all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + + F.binary_cross_entropy(F.sigmoid(out1), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + + loss = loss+reg + return loss, reg + + def calc_gen_loss(self, input_fake): + # calculate the loss to train G + outs0 = self.forward(input_fake) + loss = 0 + Drift = 0.001 + if self.gan_type == 'wgan': + loss += -torch.mean(outs0) + # progressive gan + loss += Drift*torch.sum(outs0**2) + return loss + + for it, (out0) in enumerate(outs0): + if self.gan_type == 'lsgan': + loss += torch.mean((out0 - 1)**2) * 2 # LSGAN + elif self.gan_type == 'nsgan': + all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) + loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) + else: + assert 0, "Unsupported GAN type: {}".format(self.gan_type) + return loss + + def compute_grad2(self, d_out, x_in): + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = grad_dout2.view(batch_size, -1).sum(1) + return reg + +################################################################################## +# Generator +################################################################################## + +class AdaINGen(nn.Module): + # AdaIN auto-encoder architecture + def __init__(self, input_dim, params, fp16): + super(AdaINGen, self).__init__() + dim = params['dim'] + n_downsample = params['n_downsample'] + n_res = params['n_res'] + activ = params['activ'] + pad_type = params['pad_type'] + mlp_dim = params['mlp_dim'] + mlp_norm = params['mlp_norm'] + id_dim = params['id_dim'] + which_dec = params['dec'] + dropout = params['dropout'] + tanh = params['tanh'] + non_local = params['non_local'] + + # content encoder + self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type, dropout=dropout, tanh=tanh, res_type='basic') + + self.output_dim = self.enc_content.output_dim + if which_dec =='basic': + self.dec = Decoder(n_downsample, n_res, self.output_dim, 3, dropout=dropout, res_norm='adain', activ=activ, pad_type=pad_type, res_type='basic', non_local = non_local, fp16 = fp16) + elif which_dec =='slim': + self.dec = Decoder(n_downsample, n_res, self.output_dim, 3, dropout=dropout, res_norm='adain', activ=activ, pad_type=pad_type, res_type='slim', non_local = non_local, fp16 = fp16) + elif which_dec =='series': + self.dec = Decoder(n_downsample, n_res, self.output_dim, 3, dropout=dropout, res_norm='adain', activ=activ, pad_type=pad_type, res_type='series', non_local = non_local, fp16 = fp16) + elif which_dec =='parallel': + self.dec = Decoder(n_downsample, n_res, self.output_dim, 3, dropout=dropout, res_norm='adain', activ=activ, pad_type=pad_type, res_type='parallel', non_local = non_local, fp16 = fp16) + else: + ('unkonw decoder type') + + # MLP to generate AdaIN parameters + self.mlp_w1 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_w2 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_w3 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_w4 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + + self.mlp_b1 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_b2 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_b3 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + self.mlp_b4 = MLP(id_dim, 2*self.output_dim, mlp_dim, 3, norm=mlp_norm, activ=activ) + + self.apply(weights_init(params['init'])) + + def encode(self, images): + # encode an image to its content and style codes + content = self.enc_content(images) + return content + + def decode(self, content, ID): + # decode style codes to an image + ID1 = ID[:,:2048] + ID2 = ID[:,2048:4096] + ID3 = ID[:,4096:6144] + ID4 = ID[:,6144:] + adain_params_w = torch.cat( (self.mlp_w1(ID1), self.mlp_w2(ID2), self.mlp_w3(ID3), self.mlp_w4(ID4)), 1) + adain_params_b = torch.cat( (self.mlp_b1(ID1), self.mlp_b2(ID2), self.mlp_b3(ID3), self.mlp_b4(ID4)), 1) + self.assign_adain_params(adain_params_w, adain_params_b, self.dec) + images = self.dec(content) + return images + + def assign_adain_params(self, adain_params_w, adain_params_b, model): + # assign the adain_params to the AdaIN layers in model + dim = self.output_dim + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + mean = adain_params_b[:,:dim].contiguous() + std = adain_params_w[:,:dim].contiguous() + m.bias = mean.view(-1) + m.weight = std.view(-1) + if adain_params_w.size(1)>dim : #Pop the parameters + adain_params_b = adain_params_b[:,dim:] + adain_params_w = adain_params_w[:,dim:] + + def get_num_adain_params(self, model): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + num_adain_params += m.num_features + return num_adain_params + + +class VAEGen(nn.Module): + # VAE architecture + def __init__(self, input_dim, params): + super(VAEGen, self).__init__() + dim = params['dim'] + n_downsample = params['n_downsample'] + n_res = params['n_res'] + activ = params['activ'] + pad_type = params['pad_type'] + + # content encoder + self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) + self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type) + + def forward(self, images): + # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. + hiddens = self.encode(images) + if self.training == True: + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + images_recon = self.decode(hiddens + noise) + else: + images_recon = self.decode(hiddens) + return images_recon, hiddens + + def encode(self, images): + hiddens = self.enc(images) + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + return hiddens, noise + + def decode(self, hiddens): + images = self.dec(hiddens) + return images + + +################################################################################## +# Encoder and Decoders +################################################################################## + +class StyleEncoder(nn.Module): + def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): + super(StyleEncoder, self).__init__() + self.model = [] + # Here I change the stride to 2. + self.model += [Conv2dBlock(input_dim, dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)] + for i in range(2): + self.model += [Conv2dBlock(dim, 2 * dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + for i in range(n_downsample - 2): + self.model += [Conv2dBlock(dim, dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling + self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] + self.model = nn.Sequential(*self.model) + self.output_dim = dim + + def forward(self, x): + return self.model(x) + +class ContentEncoder(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type, dropout, tanh, res_type='basic'): + super(ContentEncoder, self).__init__() + self.model = [] + # Here I change the stride to 2. + self.model += [Conv2dBlock(input_dim, dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.model += [Conv2dBlock(dim, 2*dim, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *=2 # 32dim + # downsampling blocks + for i in range(n_downsample-1): + self.model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.model += [Conv2dBlock(dim, 2 * dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + # residual blocks + self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type, res_type=res_type)] + # 64 -> 128 + self.model += [ASPP(dim, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + if tanh: + self.model +=[nn.Tanh()] + self.model = nn.Sequential(*self.model) + self.output_dim = dim + + def forward(self, x): + return self.model(x) + +class ContentEncoder_ImageNet(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_ImageNet, self).__init__() + self.model = models.resnet50(pretrained=True) + # remove the final downsample + self.model.layer4[0].downsample[0].stride = (1,1) + self.model.layer4[0].conv2.stride = (1,1) + # (256,128) ----> (16,8) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + return x + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, dropout=0, res_norm='adain', activ='relu', pad_type='zero', res_type='basic', non_local=False, fp16 = False): + super(Decoder, self).__init__() + self.input_dim = dim + self.model = [] + self.model += [nn.Dropout(p = dropout)] + self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type, res_type=res_type)] + # non-local + if non_local>0: + self.model += [NonlocalBlock(dim)] + print('use non-local!') + for i in range(n_upsample): + self.model += [nn.Upsample(scale_factor=2), + Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type, fp16 = fp16)] + dim //= 2 + # use reflection padding in the last conv layer + self.model += [Conv2dBlock(dim, dim, 3, 1, 1, norm='none', activation=activ, pad_type=pad_type)] + self.model += [Conv2dBlock(dim, dim, 3, 1, 1, norm='none', activation=activ, pad_type=pad_type)] + self.model += [Conv2dBlock(dim, output_dim, 1, 1, 0, norm='none', activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + output = self.model(x) + return output + +################################################################################## +# Sequential Models +################################################################################## +class ResBlocks(nn.Module): + def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero', res_type='basic'): + super(ResBlocks, self).__init__() + self.model = [] + self.res_type = res_type + for i in range(num_blocks): + self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, res_type=res_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x) + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='in', activ='relu'): + + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + +# enlarge the ID 2time +class Deconv(nn.Module): + def __init__(self, input_dim, output_dim): + super(Deconv, self).__init__() + model = [] + model += [nn.ConvTranspose2d( input_dim, output_dim, kernel_size=(2,2), stride=2)] + model += [nn.InstanceNorm2d(output_dim)] + model += [nn.ReLU(inplace=True)] + model += [nn.Conv2d( output_dim, output_dim, kernel_size=(1,1), stride=1)] + self.model = nn.Sequential(*model) + def forward(self, x): + return self.model(x) + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm, activation='relu', pad_type='zero', res_type='basic'): + super(ResBlock, self).__init__() + + model = [] + if res_type=='basic' or res_type=='nonlocal': + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + elif res_type=='slim': + dim_half = dim//2 + model += [Conv2dBlock(dim ,dim_half, 1, 1, 0, norm='in', activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim_half, dim_half, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim_half, dim_half, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim_half, dim, 1, 1, 0, norm='in', activation='none', pad_type=pad_type)] + elif res_type=='series': + model += [Series2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Series2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + elif res_type=='parallel': + model += [Parallel2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Parallel2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + else: + ('unkown block type') + self.res_type = res_type + self.model = nn.Sequential(*model) + if res_type=='nonlocal': + self.nonloc = NonlocalBlock(dim) + + def forward(self, x): + if self.res_type == 'nonlocal': + x = self.nonloc(x) + residual = x + out = self.model(x) + out += residual + return out + +class NonlocalBlock(nn.Module): + def __init__(self, in_dim, norm='in'): + super(NonlocalBlock, self).__init__() + self.chanel_in = in_dim + + self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) + self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + def forward(self,x): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + m_batchsize,C,width ,height = x.size() + proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N + + out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = out.view(m_batchsize,C,width,height) + + out = self.gamma*out + x + return out + +class ASPP(nn.Module): + # ASPP (a) + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ASPP, self).__init__() + dim_part = dim//2 + self.conv1 = Conv2dBlock(dim,dim_part, 1, 1, 0, norm=norm, activation='none', pad_type=pad_type) + + self.conv6 = [] + self.conv6 += [Conv2dBlock(dim,dim_part, 1, 1, 0, norm=norm, activation=activation, pad_type=pad_type)] + self.conv6 += [Conv2dBlock(dim_part,dim_part, 3, 1, 3, norm=norm, activation='none', pad_type=pad_type, dilation=3)] + self.conv6 = nn.Sequential(*self.conv6) + + self.conv12 = [] + self.conv12 += [Conv2dBlock(dim,dim_part, 1, 1, 0, norm=norm, activation=activation, pad_type=pad_type)] + self.conv12 += [Conv2dBlock(dim_part,dim_part, 3, 1, 6, norm=norm, activation='none', pad_type=pad_type, dilation=6)] + self.conv12 = nn.Sequential(*self.conv12) + + self.conv18 = [] + self.conv18 += [Conv2dBlock(dim,dim_part, 1, 1, 0, norm=norm, activation=activation, pad_type=pad_type)] + self.conv18 += [Conv2dBlock(dim_part,dim_part, 3, 1, 9, norm=norm, activation='none', pad_type=pad_type, dilation=9)] + self.conv18 = nn.Sequential(*self.conv18) + + self.fuse = Conv2dBlock(4*dim_part,2*dim, 1, 1, 0, norm=norm, activation='none', pad_type=pad_type) + + def forward(self, x): + conv1 = self.conv1(x) + conv6 = self.conv6(x) + conv12 = self.conv12(x) + conv18 = self.conv18(x) + out = torch.cat((conv1,conv6,conv12, conv18), dim=1) + out = self.fuse(out) + return out + + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero', dilation=1, fp16 = False): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim, fp16 = fp16) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + # initialize convolution + if norm == 'sn': + self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, dilation=dilation, bias=self.use_bias)) + else: + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, dilation=dilation, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +class FcBlock(nn.Module): + def __init__(self, input_dim ,output_dim, norm='none', activation='relu',fp16 = False): + super(FcBlock, self).__init__() + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim, fp16 = fp16) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm1d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + # initialize fc layer + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, x): + x = self.linear(x) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +class Series2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Series2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm2d(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + # initialize convolution + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + self.instance_norm = nn.InstanceNorm2d(norm_dim) + + def forward(self, x): + x = self.conv(self.pad(x)) + x = self.norm(x) + x + x = self.instance_norm(x) + if self.activation: + x = self.activation(x) + return x + +class Parallel2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Parallel2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm2d(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + # initialize convolution + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + self.instance_norm = nn.InstanceNorm2d(norm_dim) + + def forward(self, x): + x = self.conv(self.pad(x)) + self.norm(x) + x = self.instance_norm(x) + if self.activation: + x = self.activation(x) + return x + +class LinearBlock(nn.Module): + def __init__(self, input_dim, output_dim, norm='none', activation='relu'): + super(LinearBlock, self).__init__() + use_bias = True + # initialize fully connected layer + self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + def forward(self, x): + out = self.fc(x) + if self.norm: + #reshape input + out = out.unsqueeze(1) + out = self.norm(out) + out = out.view(out.size(0),out.size(2)) + if self.activation: + out = self.activation(out) + return out + +################################################################################## +# VGG network definition +################################################################################## +class Vgg16(nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X): + h = F.relu(self.conv1_1(X), inplace=True) + h = F.relu(self.conv1_2(h), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h), inplace=True) + h = F.relu(self.conv2_2(h), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h), inplace=True) + h = F.relu(self.conv3_2(h), inplace=True) + h = F.relu(self.conv3_3(h), inplace=True) + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h), inplace=True) + h = F.relu(self.conv4_2(h), inplace=True) + h = F.relu(self.conv4_3(h), inplace=True) + + h = F.relu(self.conv5_1(h), inplace=True) + h = F.relu(self.conv5_2(h), inplace=True) + h = F.relu(self.conv5_3(h), inplace=True) + relu5_3 = h + + return relu5_3 + +################################################################################## +# Normalization layers +################################################################################## +class AdaptiveInstanceNorm2d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm2d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + # weight and bias are dynamically assigned + self.weight = None + self.bias = None + # just dummy buffers, not used + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x): + assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" + b, c = x.size(0), x.size(1) + running_mean = self.running_mean.repeat(b).type_as(x) + running_var = self.running_var.repeat(b).type_as(x) + # Apply instance norm + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + out = F.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + + return out.view(b, c, *x.size()[2:]) + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' + + +class LayerNorm(nn.Module): + def __init__(self, num_features, eps=1e-5, affine=True, fp16=False): + super(LayerNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.fp16 = fp16 + if self.affine: + self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) + self.beta = nn.Parameter(torch.zeros(num_features)) + def forward(self, x): + shape = [-1] + [1] * (x.dim() - 1) + if x.type() == 'torch.cuda.HalfTensor': # For Safety + mean = x.view(-1).float().mean().view(*shape) + std = x.view(-1).float().std().view(*shape) + mean = mean.half() + std = std.half() + else: + mean = x.view(x.size(0), -1).mean(1).view(*shape) + std = x.view(x.size(0), -1).std(1).view(*shape) + + x = (x - mean) / (std + self.eps) + if self.affine: + shape = [1, -1] + [1] * (x.dim() - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + diff --git a/prepare-duke.py b/prepare-duke.py new file mode 100755 index 0000000..be46672 --- /dev/null +++ b/prepare-duke.py @@ -0,0 +1,93 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import os +from shutil import copyfile + +# You only need to change this line to your dataset download path +download_path = 'datasets/DukeMTMC-reID' + +if not os.path.isdir(download_path): + print('please change the download_path') + +save_path = download_path + '/pytorch' +if not os.path.isdir(save_path): + os.mkdir(save_path) +#----------------------------------------- +#query +query_path = download_path + '/query' +query_save_path = download_path + '/pytorch/query' +if not os.path.isdir(query_save_path): + os.mkdir(query_save_path) + +for root, dirs, files in os.walk(query_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = query_path + '/' + name + dst_path = query_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + +#----------------------------------------- +#gallery +gallery_path = download_path + '/bounding_box_test' +gallery_save_path = download_path + '/pytorch/gallery' +if not os.path.isdir(gallery_save_path): + os.mkdir(gallery_save_path) + +for root, dirs, files in os.walk(gallery_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = gallery_path + '/' + name + dst_path = gallery_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + +#--------------------------------------- +#train_all +train_path = download_path + '/bounding_box_train' +train_save_path = download_path + '/pytorch/train_all' +if not os.path.isdir(train_save_path): + os.mkdir(train_save_path) + +for root, dirs, files in os.walk(train_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = train_path + '/' + name + dst_path = train_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + + +#--------------------------------------- +#train_val +train_path = download_path + '/bounding_box_train' +train_save_path = download_path + '/pytorch/train' +val_save_path = download_path + '/pytorch/val' +if not os.path.isdir(train_save_path): + os.mkdir(train_save_path) + os.mkdir(val_save_path) + +for root, dirs, files in os.walk(train_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = train_path + '/' + name + dst_path = train_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + dst_path = val_save_path + '/' + ID[0] #first image is used as val image + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) diff --git a/prepare-market.py b/prepare-market.py new file mode 100755 index 0000000..7a81761 --- /dev/null +++ b/prepare-market.py @@ -0,0 +1,113 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import os +from shutil import copyfile + +# You only need to change this line to your dataset download path +download_path = 'datasets/Market' + +if not os.path.isdir(download_path): + print('please change the download_path') + +save_path = download_path + '/pytorch' +if not os.path.isdir(save_path): + os.mkdir(save_path) +#----------------------------------------- +#query +query_path = download_path + '/query' +query_save_path = download_path + '/pytorch/query' +if not os.path.isdir(query_save_path): + os.mkdir(query_save_path) + +for root, dirs, files in os.walk(query_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = query_path + '/' + name + dst_path = query_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + +#----------------------------------------- +#multi-query +query_path = download_path + '/gt_bbox' +# for dukemtmc-reid, we do not need multi-query +if os.path.isdir(query_path): + query_save_path = download_path + '/pytorch/multi-query' + if not os.path.isdir(query_save_path): + os.mkdir(query_save_path) + + for root, dirs, files in os.walk(query_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = query_path + '/' + name + dst_path = query_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + +#----------------------------------------- +#gallery +gallery_path = download_path + '/bounding_box_test' +gallery_save_path = download_path + '/pytorch/gallery' +if not os.path.isdir(gallery_save_path): + os.mkdir(gallery_save_path) + +for root, dirs, files in os.walk(gallery_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = gallery_path + '/' + name + dst_path = gallery_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + +#--------------------------------------- +#train_all +train_path = download_path + '/bounding_box_train' +train_save_path = download_path + '/pytorch/train_all' +if not os.path.isdir(train_save_path): + os.mkdir(train_save_path) + +for root, dirs, files in os.walk(train_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = train_path + '/' + name + dst_path = train_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) + + +#--------------------------------------- +#train_val +train_path = download_path + '/bounding_box_train' +train_save_path = download_path + '/pytorch/train' +val_save_path = download_path + '/pytorch/val' +if not os.path.isdir(train_save_path): + os.mkdir(train_save_path) + os.mkdir(val_save_path) + +for root, dirs, files in os.walk(train_path, topdown=True): + for name in files: + if not name[-3:]=='jpg': + continue + ID = name.split('_') + src_path = train_path + '/' + name + dst_path = train_save_path + '/' + ID[0] + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + dst_path = val_save_path + '/' + ID[0] #first image is used as val image + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + name) diff --git a/random_erasing.py b/random_erasing.py new file mode 100755 index 0000000..0bf8030 --- /dev/null +++ b/random_erasing.py @@ -0,0 +1,61 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from __future__ import absolute_import + +from torchvision.transforms import * + +from PIL import Image +import random +import math +import numpy as np +import torch + +class RandomErasing(object): + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + Args: + probability: The probability that the Random Erasing operation will be performed. + sl: Minimum proportion of erased area against input image. + sh: Maximum proportion of erased area against input image. + r1: Minimum aspect ratio of erased area. + mean: Erasing value. + """ + + def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): + self.probability = probability + self.mean = mean + self.sl = sl + self.sh = sh + self.r1 = r1 + random.seed(7) + + def __call__(self, img): + + if random.uniform(0, 1) > self.probability: + return img + + for attempt in range(100): + area = img.size()[1] * img.size()[2] + + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1/self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img.size()[2] and h < img.size()[1]: + x1 = random.randint(0, img.size()[1] - h) + y1 = random.randint(0, img.size()[2] - w) + if img.size()[0] == 3: + img[0, x1:x1+h, y1:y1+w] = self.mean[0] + img[1, x1:x1+h, y1:y1+w] = self.mean[1] + img[2, x1:x1+h, y1:y1+w] = self.mean[2] + else: + img[0, x1:x1+h, y1:y1+w] = self.mean[0] + return img.detach() + + return img.detach() diff --git a/reIDfolder.py b/reIDfolder.py new file mode 100644 index 0000000..62eafe7 --- /dev/null +++ b/reIDfolder.py @@ -0,0 +1,108 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from torchvision import datasets +import os +import numpy as np +import random + +class ReIDFolder(datasets.ImageFolder): + + def __init__(self, root, transform): + super(ReIDFolder, self).__init__(root, transform) + targets = np.asarray([s[1] for s in self.samples]) + self.targets = targets + self.img_num = len(self.samples) + print(self.img_num) + + def _get_cam_id(self, path): + camera_id = [] + filename = os.path.basename(path) + camera_id = filename.split('c')[1][0] + return int(camera_id)-1 + + def _get_pos_sample(self, target, index, path): + pos_index = np.argwhere(self.targets == target) + pos_index = pos_index.flatten() + pos_index = np.setdiff1d(pos_index, index) + if len(pos_index)==0: # in the query set, only one sample + return path + else: + rand = random.randint(0,len(pos_index)-1) + return self.samples[pos_index[rand]][0] + + def _get_neg_sample(self, target): + neg_index = np.argwhere(self.targets != target) + neg_index = neg_index.flatten() + rand = random.randint(0,len(neg_index)-1) + return self.samples[neg_index[rand]] + + def __getitem__(self, index): + path, target = self.samples[index] + sample = self.loader(path) + + pos_path = self._get_pos_sample(target, index, path) + pos = self.loader(pos_path) + + if self.transform is not None: + sample = self.transform(sample) + pos = self.transform(pos) + + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, pos + +class ReIDFolder_mix(datasets.ImageFolder): + + def __init__(self, root, transform, idx_list): + super(ReIDFolder_mix, self).__init__(root, transform) + self.idx_list = idx_list + targets = np.asarray([s[1] for s in self.samples]) + self.targets = targets + self.img_num = len(self.samples) + print(self.img_num) + + def _get_cam_id(self, path): + camera_id = [] + filename = os.path.basename(path) + camera_id = filename.split('c')[1][0] + return int(camera_id)-1 + + def _get_pos_sample(self, target, index, path): + pos_index = np.argwhere(self.targets == target) + pos_index = pos_index.flatten() + pos_index = np.setdiff1d(pos_index, index) + if len(pos_index)==0: # in the query set, only one sample + return path + else: + rand = random.randint(0,len(pos_index)-1) + return self.samples[pos_index[rand]][0] + + def _get_neg_sample(self, target): + neg_index = np.argwhere(self.targets != target) + neg_index = neg_index.flatten() + rand = random.randint(0,len(neg_index)-1) + return self.samples[neg_index[rand]] + + def __getitem__(self, index): + idx = self.idx_list[index] + path, target = self.samples[idx] + sample = self.loader(path) + + pos_path = self._get_pos_sample(target, idx, path) + pos = self.loader(pos_path) + + if self.transform is not None: + sample = self.transform(sample) + pos = self.transform(pos) + + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, pos + + def __len__(self): + return len(self.idx_list) \ No newline at end of file diff --git a/reIDmodel.py b/reIDmodel.py new file mode 100644 index 0000000..30d98f2 --- /dev/null +++ b/reIDmodel.py @@ -0,0 +1,332 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import torch +import torch.nn as nn +from torch.nn import init +from torchvision import models + +# PretrainedModel = 'pretrainedmodels/resnet50.pth' # (PretrainedModel = Path to the ImageNet pretrained model) +###################################################################### +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') + init.constant_(m.bias.data, 0.0) + elif classname.find('InstanceNorm1d') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + +def weights_init_classifier(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + init.normal_(m.weight.data, std=0.001) + init.constant_(m.bias.data, 0.0) + +def fix_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + +# Defines the new fc layer and classification layer +# |--Linear--|--bn--|--relu--|--Linear--| +class ClassBlock(nn.Module): + def __init__(self, input_dim, class_num, droprate=0.5, relu=False, num_bottleneck=512): + super(ClassBlock, self).__init__() + add_block = [] + add_block += [nn.Linear(input_dim, num_bottleneck)] + #num_bottleneck = input_dim # We remove the input_dim + add_block += [nn.BatchNorm1d(num_bottleneck, affine=True)] + if relu: + add_block += [nn.LeakyReLU(0.1)] + if droprate>0: + add_block += [nn.Dropout(p=droprate)] + add_block = nn.Sequential(*add_block) + add_block.apply(weights_init_kaiming) + + classifier = [] + classifier += [nn.Linear(num_bottleneck, class_num)] + classifier = nn.Sequential(*classifier) + classifier.apply(weights_init_classifier) + + self.add_block = add_block + self.classifier = classifier + def forward(self, x): + x = self.add_block(x) + x = self.classifier(x) + return x + +# Define the ResNet50-based Model +class ft_net(nn.Module): + + def __init__(self, class_num, norm=False, pool='avg', stride=2): + super(ft_net, self).__init__() + if norm: + self.norm = True + else: + self.norm = False + model_ft = models.resnet50(pretrained=True) + #model_ft = models.resnet50() + #model_ft.load_state_dict(torch.load(PretrainedModel)) + # avg pooling to global pooling + self.part = 4 + if pool=='max': + model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1)) + else: + model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) + # remove the final downsample + if stride == 1: + model_ft.layer4[0].downsample[0].stride = (1,1) + model_ft.layer4[0].conv2.stride = (1,1) + + self.model = model_ft + self.classifier = ClassBlock(2048, class_num) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) # -> 512 32*16 + x = self.model.layer3(x) + x = self.model.layer4(x) + f = self.model.partpool(x) # 8 * 2048 4*1 + x = self.model.avgpool(x) # 8 * 2048 1*1 + + x = x.view(x.size(0),x.size(1)) + f = f.view(f.size(0),f.size(1)*self.part) + if self.norm: + fnorm = torch.norm(f, p=2, dim=1, keepdim=True) + 1e-8 + f = f.div(fnorm.expand_as(f)) + x = self.classifier(x) + return f, x + +# Define the AB Model +class ft_netAB(nn.Module): + + def __init__(self, class_num, norm=False, stride=2, droprate=0.5, pool='avg'): + super(ft_netAB, self).__init__() + model_ft = models.resnet50(pretrained=True) + # model_ft = models.resnet50() + # model_ft.load_state_dict(torch.load(PretrainedModel)) + self.part = 4 + if pool=='max': + model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1)) + else: + model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) + + self.model = model_ft + + if stride == 1: + self.model.layer4[0].downsample[0].stride = (1,1) + self.model.layer4[0].conv2.stride = (1,1) + + self.classifier1 = ClassBlock(2048, class_num, 0.5) + self.classifier2 = ClassBlock(2048, class_num, 0.75) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + f = self.model.partpool(x) + f = f.view(f.size(0),f.size(1)*self.part) + f = f.detach() # no gradient + x = self.model.avgpool(x) + x = x.view(x.size(0), x.size(1)) + x1 = self.classifier1(x) + x2 = self.classifier2(x) + x=[] + x.append(x1) + x.append(x2) + return f, x + +class ft_netABe(nn.Module): + + def __init__(self, class_num, norm=False, stride=2, droprate=0.5, pool='avg'): + super(ft_netABe, self).__init__() + model_ft = models.resnet50(pretrained=True) + # model_ft = models.resnet50() + # model_ft.load_state_dict(torch.load(PretrainedModel)) + self.part = 4 + if pool=='max': + model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1)) + else: + model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1)) + model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) + #model_ft.final_bn = nn.BatchNorm1d(2048) + #model_ft.final_bn.apply(weights_init_kaiming) + self.model = model_ft + + if stride == 1: + self.model.layer4[0].downsample[0].stride = (1,1) + self.model.layer4[0].conv2.stride = (1,1) + + self.classifier1 = ClassBlock(2048, class_num, 0.5) + self.classifier2 = ClassBlock(2048, class_num, 0.75) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + f = self.model.partpool(x) + f = f.view(f.size(0),f.size(1)*self.part) + f = f.detach() # no gradient + x = self.model.avgpool(x) + x = x.view(x.size(0), x.size(1)) + #x = self.model.final_bn(x) + x1 = self.classifier1(x) + x2 = self.classifier2(x) + xo=[] + xo.append(x1) + xo.append(x2) + return f, xo, x + +# Define the DenseNet121-based Model +class ft_net_dense(nn.Module): + + def __init__(self, class_num ): + super().__init__() + model_ft = models.densenet121(pretrained=True) + model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1)) + model_ft.fc = nn.Sequential() + self.model = model_ft + # For DenseNet, the feature dim is 1024 + self.classifier = ClassBlock(1024, class_num) + + def forward(self, x): + x = self.model.features(x) + x = torch.squeeze(x) + x = self.classifier(x) + return x + +# Define the ResNet50-based Model (Middle-Concat) +# In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017). +class ft_net_middle(nn.Module): + + def __init__(self, class_num ): + super(ft_net_middle, self).__init__() + model_ft = models.resnet50(pretrained=True) + # model_ft = models.resnet50() + # model_ft.load_state_dict(torch.load(PretrainedModel)) + # avg pooling to global pooling + model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.model = model_ft + self.classifier = ClassBlock(2048+1024, class_num) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + # x0 n*1024*1*1 + x0 = self.model.avgpool(x) + x = self.model.layer4(x) + # x1 n*2048*1*1 + x1 = self.model.avgpool(x) + x = torch.cat((x0,x1),1) + x = torch.squeeze(x) + x = self.classifier(x) + return x + +# Part Model proposed in Yifan Sun etal. (2018) +class PCB(nn.Module): + def __init__(self, class_num ): + super(PCB, self).__init__() + + self.part = 4 # We cut the pool5 to 4 parts + model_ft = models.resnet50(pretrained=True) + # model_ft = models.resnet50() + # model_ft.load_state_dict(torch.load(PretrainedModel)) + self.model = model_ft + self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) + self.dropout = nn.Dropout(p=0.5) + # remove the final downsample + self.model.layer4[0].downsample[0].stride = (1,1) + self.model.layer4[0].conv2.stride = (1,1) + self.softmax = nn.Softmax(dim=1) + # define 4 classifiers + for i in range(self.part): + name = 'classifier'+str(i) + setattr(self, name, ClassBlock(2048, class_num, True, False, 256)) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + x = self.avgpool(x) + f = x + f = f.view(f.size(0),f.size(1)*self.part) + x = self.dropout(x) + part = {} + predict = {} + # get part feature batchsize*2048*4 + for i in range(self.part): + part[i] = x[:,:,i].contiguous() + part[i] = part[i].view(x.size(0), x.size(1)) + name = 'classifier'+str(i) + c = getattr(self,name) + predict[i] = c(part[i]) + + y=[] + for i in range(self.part): + y.append(predict[i]) + + return f, y + +class PCB_test(nn.Module): + def __init__(self,model): + super(PCB_test,self).__init__() + self.part = 6 + self.model = model.model + self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) + # remove the final downsample + self.model.layer3[0].downsample[0].stride = (1,1) + self.model.layer3[0].conv2.stride = (1,1) + + self.model.layer4[0].downsample[0].stride = (1,1) + self.model.layer4[0].conv2.stride = (1,1) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + x = self.avgpool(x) + y = x.view(x.size(0),x.size(1),x.size(2)) + return y + + diff --git a/re_ranking_one.py b/re_ranking_one.py new file mode 100755 index 0000000..7c35d32 --- /dev/null +++ b/re_ranking_one.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python2/python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Jun 26 14:46:56 2017 +@author: luohao +Modified by Houjing Huang, 2017-12-22. +- This version accepts distance matrix instead of raw features. +- The difference of `/` division between python 2 and 3 is handled. +- numpy.float16 is replaced by numpy.float32 for numerical precision. + +Modified by Zhedong Zheng, 2018-1-12. +- replace sort with topK, which save about 30s. +""" + +""" +CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. +url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf +Matlab version: https://github.com/zhunzhong07/person-re-ranking +""" + +""" +API +q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] +q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] +g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] +k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) +Returns: + final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] +""" + + +import numpy as np + +def k_reciprocal_neigh( initial_rank, i, k1): + forward_k_neigh_index = initial_rank[i,:k1+1] + backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] + fi = np.where(backward_k_neigh_index==i)[0] + return forward_k_neigh_index[fi] + +def re_ranking_one(original_dist, k1=20, k2=6, lambda_value=0.3): + # The following naming, e.g. gallery_num, is different from outer scope. + # Don't care about it. + original_dist = 2. - 2 * original_dist # change the cosine similarity metric to euclidean similarity metric + original_dist = np.power(original_dist, 2).astype(np.float32) + original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) + V = np.zeros_like(original_dist).astype(np.float32) + #initial_rank = np.argsort(original_dist).astype(np.int32) + # top K1+1 + initial_rank = np.argpartition( original_dist, range(1,k1+1) ) + + all_num = original_dist.shape[0] + query_num = all_num + + for i in range(all_num): + # k-reciprocal neighbors + k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1) + k_reciprocal_expansion_index = k_reciprocal_index + for j in range(len(k_reciprocal_index)): + candidate = k_reciprocal_index[j] + candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2))) + if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): + k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) + + k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) + weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) + V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) + + original_dist = original_dist[:query_num,] + if k2 != 1: + V_qe = np.zeros_like(V,dtype=np.float32) + for i in range(all_num): + V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) + V = V_qe + del V_qe + del initial_rank + invIndex = [] + for i in range(all_num): + invIndex.append(np.where(V[:,i] != 0)[0]) + + jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) + + for i in range(query_num): + temp_min = np.zeros(shape=[1,all_num],dtype=np.float32) + indNonZero = np.where(V[i,:] != 0)[0] + indImages = [] + indImages = [invIndex[ind] for ind in indNonZero] + for j in range(len(indNonZero)): + temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) + jaccard_dist[i] = 1-temp_min/(2.-temp_min) + + final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value + del original_dist + del V + del jaccard_dist + #final_dist = final_dist[:query_num,query_num:] + return final_dist diff --git a/reid_eval/README.md b/reid_eval/README.md new file mode 100755 index 0000000..cd20b2c --- /dev/null +++ b/reid_eval/README.md @@ -0,0 +1,18 @@ +## Evaluation +The results are slightly different from the paper. + +- For market2duke +```bash +python test_2label_duke.py --name market2duke3 --which_epoch 231805 +``` +The result is `Rank@1:0.7931 Rank@5:0.8793 Rank@10:0.8990 mAP:0.6436`. + +`--name` model name + +`--which_epoch` select the i-th model + +- For duke2market +```bash +python test_2label_market.py --name duke2market7 --which_epoch 172353 + +The result is `Rank@1:0.8260 Rank@5:0.9136 Rank@10:0.9388 mAP:0.6400` diff --git a/reid_eval/evaluate_gpu.py b/reid_eval/evaluate_gpu.py new file mode 100755 index 0000000..7cc56a7 --- /dev/null +++ b/reid_eval/evaluate_gpu.py @@ -0,0 +1,137 @@ + +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import scipy.io +import torch +import numpy as np +import time +import os +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +####################################################################### +# Evaluate + +def evaluate(qf,ql,qc,gf,gl,gc): + query = qf.view(-1,1) + # print(query.shape) + score = torch.mm(gf,query) + score = score.squeeze(1).cpu() + score = score.numpy() + # predict index + index = np.argsort(score) #from small to large + index = index[::-1] + # index = index[0:2000] + # good index + query_index = np.argwhere(gl==ql) + #same camera + camera_index = np.argwhere(gc==qc) + + good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) + junk_index1 = np.argwhere(gl==-1) + junk_index2 = np.intersect1d(query_index, camera_index) + junk_index = np.append(junk_index2, junk_index1) #.flatten()) + + CMC_tmp = compute_mAP(index, qc, good_index, junk_index) + return CMC_tmp + + +def compute_mAP(index, qc, good_index, junk_index): + ap = 0 + cmc = torch.IntTensor(len(index)).zero_() + if good_index.size==0: # if empty + cmc[0] = -1 + return ap,cmc + + # remove junk_index + ranked_camera = gallery_cam[index] + mask = np.in1d(index, junk_index, invert=True) + mask2 = np.in1d(index, np.append(good_index,junk_index), invert=True) + index = index[mask] + ranked_camera = ranked_camera[mask] + + # find good_index index + ngood = len(good_index) + mask = np.in1d(index, good_index) + rows_good = np.argwhere(mask==True) + rows_good = rows_good.flatten() + + cmc[rows_good[0]:] = 1 + for i in range(ngood): + d_recall = 1.0/ngood + precision = (i+1)*1.0/(rows_good[i]+1) + if rows_good[i]!=0: + old_precision = i*1.0/rows_good[i] + else: + old_precision=1.0 + ap = ap + d_recall*(old_precision + precision)/2 + + return ap, cmc + +###################################################################### +result = scipy.io.loadmat('pytorch_result.mat') +query_feature = torch.FloatTensor(result['query_f']) +query_cam = result['query_cam'][0] +query_label = result['query_label'][0] +gallery_feature = torch.FloatTensor(result['gallery_f']) +gallery_cam = result['gallery_cam'][0] +gallery_label = result['gallery_label'][0] + +multi = os.path.isfile('multi_query.mat') + +if multi: + m_result = scipy.io.loadmat('multi_query.mat') + mquery_feature = torch.FloatTensor(m_result['mquery_f']) + mquery_cam = m_result['mquery_cam'][0] + mquery_label = m_result['mquery_label'][0] + mquery_feature = mquery_feature.cuda() + +query_feature = query_feature.cuda() +gallery_feature = gallery_feature.cuda() + +print(query_feature.shape) +alpha = [0, 0.5, -1] +#print(query_label) +for j in range(len(alpha)): + CMC = torch.IntTensor(len(gallery_label)).zero_() + ap = 0.0 + for i in range(len(query_label)): + qf = query_feature[i].clone() + if alpha[j] == -1: + qf[0:512] *= 0 + else: + qf[512:1024] *= alpha[j] # Why? + ap_tmp, CMC_tmp = evaluate(qf,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) + if CMC_tmp[0]==-1: + continue + CMC = CMC + CMC_tmp + ap += ap_tmp + #print(i, CMC_tmp[0]) + + CMC = CMC.float() + CMC = CMC/len(query_label) #average CMC + print('Alpha:%.2f Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(alpha[j], CMC[0],CMC[4],CMC[9],ap/len(query_label))) + +# multiple-query +CMC = torch.IntTensor(len(gallery_label)).zero_() +ap = 0.0 +if multi: + malpha = 0.5 ###### + for i in range(len(query_label)): + mquery_index1 = np.argwhere(mquery_label==query_label[i]) + mquery_index2 = np.argwhere(mquery_cam==query_cam[i]) + mquery_index = np.intersect1d(mquery_index1, mquery_index2) + mq = torch.mean(mquery_feature[mquery_index,:], dim=0) + mq[512:1024] *= malpha + ap_tmp, CMC_tmp = evaluate(mq,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) + if CMC_tmp[0]==-1: + continue + CMC = CMC + CMC_tmp + ap += ap_tmp + #print(i, CMC_tmp[0]) + CMC = CMC.float() + CMC = CMC/len(query_label) #average CMC + print('multi Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) diff --git a/reid_eval/test_2label_duke.py b/reid_eval/test_2label_duke.py new file mode 100755 index 0000000..f25e40f --- /dev/null +++ b/reid_eval/test_2label_duke.py @@ -0,0 +1,238 @@ + +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from __future__ import print_function, division + +import sys +sys.path.append('..') +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from torch.autograd import Variable +import numpy as np +import torchvision +from torchvision import datasets, models, transforms +import time +import os +import scipy.io +import yaml +from reIDmodel import ft_net, ft_netAB, ft_net_dense, PCB, PCB_test + +###################################################################### +# Options +# -------- +parser = argparse.ArgumentParser(description='Training') +parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') +parser.add_argument('--which_epoch',default=100000, type=int, help='80000') +parser.add_argument('--test_dir',default='datasets/DukeMTMC-reID/pytorch',type=str, help='./test_data') +parser.add_argument('--name', default='test', type=str, help='save model path') +parser.add_argument('--batchsize', default=80, type=int, help='batchsize') +parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) +parser.add_argument('--PCB', action='store_true', help='use PCB' ) +parser.add_argument('--multi', action='store_true', help='use multiple query' ) + +opt = parser.parse_args() + +str_ids = opt.gpu_ids.split(',') +#which_epoch = opt.which_epoch +name = opt.name +test_dir = opt.test_dir + +gpu_ids = [] +for str_id in str_ids: + id = int(str_id) + if id >=0: + gpu_ids.append(id) + +# set gpu ids +if len(gpu_ids)>0: + torch.cuda.set_device(gpu_ids[0]) + +###################################################################### +# Load Data +# --------- +# +# We will use torchvision and torch.utils.data packages for loading the +# data. +# +data_transforms = transforms.Compose([ + transforms.Resize((256,128), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +############### Ten Crop + #transforms.TenCrop(224), + #transforms.Lambda(lambda crops: torch.stack( + # [transforms.ToTensor()(crop) + # for crop in crops] + # )), + #transforms.Lambda(lambda crops: torch.stack( + # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) + # for crop in crops] + # )) +]) + +if opt.PCB: + data_transforms = transforms.Compose([ + transforms.Resize((384,192), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + +data_dir = test_dir +image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']} +dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, + shuffle=False, num_workers=16) for x in ['gallery','query']} + +class_names = image_datasets['query'].classes +use_gpu = torch.cuda.is_available() + +###################################################################### +# Load model +#--------------------------- +def load_network(network, save_path): + state_dict = torch.load(save_path) + network.load_state_dict(state_dict['a'], strict=False) + return network + +def get_model_stats(): + checkpoint_name = 'id_%08d.pt'%opt.which_epoch + main_folder = os.path.join('../outputs', name) + folders = os.listdir(main_folder) + for folder in folders: + if not folder.isdigit(): + continue + checkpoint_folder = os.path.join(main_folder, folder, 'checkpoints') + checkpoints = os.listdir(checkpoint_folder) + if checkpoint_name in checkpoints: + checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) + checkpoint = torch.load(checkpoint_path) + return checkpoint_path, checkpoint['a']['classifier2.classifier.0.weight'].size()[0] + if use_gpu: + del checkpoint + torch.cuda.empty_cache() + + +###################################################################### +# Extract feature +# ---------------------- +# +# Extract feature from a trained model. +# +def fliplr(img): + '''flip horizontal''' + inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W + img_flip = img.index_select(3,inv_idx) + return img_flip + +def norm(f, dim = 1): + f = f.squeeze() + fnorm = torch.norm(f, p=2, dim=dim, keepdim=True) + f = f.div(fnorm.expand_as(f)) + return f + +def extract_feature(model,dataloaders): + features = torch.FloatTensor() + count = 0 + for data in dataloaders: + img, label = data + n, c, h, w = img.size() + count += n + #print(count) + if opt.use_dense: + ff = torch.FloatTensor(n,1024).zero_() + else: + ff = torch.FloatTensor(n,1024).zero_() + if opt.PCB: + ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts + for i in range(2): + if(i==1): + img = fliplr(img) + input_img = Variable(img.cuda()) + f, x = model(input_img) + x[0] = norm(x[0]) + x[1] = norm(x[1]) + f = torch.cat((x[0],x[1]), dim=1) #use 512-dim feature + f = f.data.cpu() + ff = ff+f + + ff[:, 0:512] = norm(ff[:, 0:512], dim=1) + ff[:, 512:1024] = norm(ff[:, 512:1024], dim =1) + # norm feature + if opt.PCB: + # feature size (n,2048,6) + # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. + # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). + fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) + ff = ff.div(fnorm.expand_as(ff)) + ff = ff.view(ff.size(0), -1) + + features = torch.cat((features,ff), 0) + return features + +def get_id(img_path): + camera_id = [] + labels = [] + for path, v in img_path: + #filename = path.split('/')[-1] + filename = os.path.basename(path) + label = filename[0:4] + camera = filename.split('c')[1] + if label[0:2]=='-1': + labels.append(-1) + else: + labels.append(int(label)) + camera_id.append(int(camera[0])) + return camera_id, labels + +gallery_path = image_datasets['gallery'].imgs +query_path = image_datasets['query'].imgs + +gallery_cam,gallery_label = get_id(gallery_path) +query_cam,query_label = get_id(query_path) + +###################################################################### +# Load Collected data Trained model +print('-------test-----------') + +###load config### +config_path = os.path.join('../outputs',name,'config.yaml') +with open(config_path, 'r') as stream: + config = yaml.load(stream) + +model_path, output_dim = get_model_stats() + +model_structure = ft_netAB(output_dim, norm=config['norm_id'], stride=config['ID_stride'], pool=config['pool']) + +model = load_network(model_structure, model_path) + +# Remove the final fc layer and classifier layer +model.model.fc = nn.Sequential() +model.classifier1.classifier = nn.Sequential() +model.classifier2.classifier = nn.Sequential() + +# Change to test mode +model = model.eval() +if use_gpu: + model = model.cuda() + +# Extract feature +with torch.no_grad(): + gallery_feature = extract_feature(model,dataloaders['gallery']) + query_feature = extract_feature(model,dataloaders['query']) + if opt.multi: + mquery_feature = extract_feature(model,dataloaders['multi-query']) + +# Save to Matlab for check +result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} +scipy.io.savemat('pytorch_result.mat',result) +if opt.multi: + result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} + scipy.io.savemat('multi_query.mat',result) + +os.system('python evaluate_gpu.py') diff --git a/reid_eval/test_2label_market.py b/reid_eval/test_2label_market.py new file mode 100755 index 0000000..8c687f9 --- /dev/null +++ b/reid_eval/test_2label_market.py @@ -0,0 +1,247 @@ + +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from __future__ import print_function, division + +import sys +sys.path.append('..') +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from torch.autograd import Variable +import numpy as np +import torchvision +from torchvision import datasets, models, transforms +import time +import os +import scipy.io +import yaml +from reIDmodel import ft_net, ft_netAB, ft_net_dense, PCB, PCB_test + +###################################################################### +# Options +# -------- +parser = argparse.ArgumentParser(description='Training') +parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') +parser.add_argument('--which_epoch',default=90000, type=int, help='80000') +parser.add_argument('--test_dir',default='datasets/Market/pytorch',type=str, help='./test_data') +parser.add_argument('--name', default='test', type=str, help='save model path') +parser.add_argument('--batchsize', default=80, type=int, help='batchsize') +parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) +parser.add_argument('--PCB', action='store_true', help='use PCB' ) +parser.add_argument('--multi', action='store_true', help='use multiple query' ) + +opt = parser.parse_args() + +str_ids = opt.gpu_ids.split(',') +#which_epoch = opt.which_epoch +name = opt.name +test_dir = opt.test_dir + +gpu_ids = [] +for str_id in str_ids: + id = int(str_id) + if id >=0: + gpu_ids.append(id) + +# set gpu ids +if len(gpu_ids)>0: + torch.cuda.set_device(gpu_ids[0]) + +###################################################################### +# Load Data +# --------- +# +# We will use torchvision and torch.utils.data packages for loading the +# data. +# +data_transforms = transforms.Compose([ + transforms.Resize((256,128), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +############### Ten Crop + #transforms.TenCrop(224), + #transforms.Lambda(lambda crops: torch.stack( + # [transforms.ToTensor()(crop) + # for crop in crops] + # )), + #transforms.Lambda(lambda crops: torch.stack( + # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) + # for crop in crops] + # )) +]) + +if opt.PCB: + data_transforms = transforms.Compose([ + transforms.Resize((384,192), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + +data_dir = test_dir +image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} +dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, + shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} + +class_names = image_datasets['query'].classes +use_gpu = torch.cuda.is_available() + +###################################################################### +# Load model +#--------------------------- +def load_network(network, save_path): + state_dict = torch.load(save_path) + network.load_state_dict(state_dict['a'], strict=False) + return network + +def get_model_stats(): + checkpoint_name = 'id_%08d.pt'%opt.which_epoch + main_folder = os.path.join('../outputs', name) + folders = os.listdir(main_folder) + for folder in folders: + if not folder.isdigit(): + continue + checkpoint_folder = os.path.join(main_folder, folder, 'checkpoints') + checkpoints = os.listdir(checkpoint_folder) + if checkpoint_name in checkpoints: + checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) + checkpoint = torch.load(checkpoint_path) + return checkpoint_path, checkpoint['a']['classifier2.classifier.0.weight'].size()[0] + if use_gpu: + del checkpoint + torch.cuda.empty_cache() + + print('No checkpoint found.') + + +###################################################################### +# Extract feature +# ---------------------- +# +# Extract feature from a trained model. +# +def fliplr(img): + '''flip horizontal''' + inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W + img_flip = img.index_select(3,inv_idx) + return img_flip + +def norm(f): + f = f.squeeze() + fnorm = torch.norm(f, p=2, dim=1, keepdim=True) + f = f.div(fnorm.expand_as(f)) + return f + +def extract_feature(model,dataloaders): + features = torch.FloatTensor() + count = 0 + for data in dataloaders: + img, label = data + n, c, h, w = img.size() + count += n + #print(count) + if opt.use_dense: + ff = torch.FloatTensor(n,1024).zero_() + else: + ff = torch.FloatTensor(n,1024).zero_() + if opt.PCB: + ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts + for i in range(2): + if(i==1): + img = fliplr(img) + input_img = Variable(img.cuda()) + f, x = model(input_img) + x[0] = norm(x[0]) + x[1] = norm(x[1]) + f = torch.cat((x[0],x[1]), dim=1) #use 512-dim feature + f = f.data.cpu() + ff = ff+f + + ff[:, 0:512] = norm(ff[:, 0:512]) + ff[:, 512:1024] = norm(ff[:, 512:1024]) + + # norm feature + if opt.PCB: + # feature size (n,2048,6) + # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. + # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). + fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) + ff = ff.div(fnorm.expand_as(ff)) + ff = ff.view(ff.size(0), -1) + + features = torch.cat((features,ff), 0) + return features + +def get_id(img_path): + camera_id = [] + labels = [] + for path, v in img_path: + #filename = path.split('/')[-1] + filename = os.path.basename(path) + label = filename[0:4] + camera = filename.split('c')[1] + if label[0:2]=='-1': + labels.append(-1) + else: + labels.append(int(label)) + camera_id.append(int(camera[0])) + return camera_id, labels + +gallery_path = image_datasets['gallery'].imgs +query_path = image_datasets['query'].imgs +mquery_path = image_datasets['multi-query'].imgs + +gallery_cam,gallery_label = get_id(gallery_path) +query_cam,query_label = get_id(query_path) +mquery_cam,mquery_label = get_id(mquery_path) + +###################################################################### +# Load Collected data Trained model +print('-------test-----------') + +###load config### +config_path = os.path.join('../outputs',name,'config.yaml') +with open(config_path, 'r') as stream: + config = yaml.load(stream) + +model_path, output_dim = get_model_stats() + +model_structure = ft_netAB(output_dim, norm=config['norm_id'], stride=config['ID_stride'], pool=config['pool']) + +model = load_network(model_structure, model_path) + +# Remove the final fc layer and classifier layer +model.model.fc = nn.Sequential() +model.classifier1.classifier = nn.Sequential() +model.classifier2.classifier = nn.Sequential() + +# Change to test mode +model = model.eval() +if use_gpu: + model = model.cuda() + +# Extract feature +since = time.time() +with torch.no_grad(): + gallery_feature = extract_feature(model,dataloaders['gallery']) + query_feature = extract_feature(model,dataloaders['query']) + time_elapsed = time.time() - since + print('Extract features complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + if opt.multi: + mquery_feature = extract_feature(model,dataloaders['multi-query']) + +# Save to Matlab for check +result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} +scipy.io.savemat('pytorch_result.mat',result) +if opt.multi: + result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} + scipy.io.savemat('multi_query.mat',result) + +os.system('python evaluate_gpu.py') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ab70475 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +opencv-python +h5py +pillow<7 +cython +tensorboardX +tensorflow==1.13.1 +scikit-image +ipdb diff --git a/train.py b/train.py new file mode 100644 index 0000000..b228a7f --- /dev/null +++ b/train.py @@ -0,0 +1,273 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +from utils import get_mix_data_loaders, get_data_loader_folder, prepare_sub_folder_pseudo, write_html, write_loss, get_config, write_2images, Timer +import argparse +from trainer import DGNetpp_Trainer +import torch.backends.cudnn as cudnn +import torch +import random as rn +import numpy.random as random +try: + from itertools import izip as zip +except ImportError: # will be 3.x series + pass +import os +import sys +import tensorboardX +import shutil + +# set random seed +def set_seed(seed=0): + rn.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + # cudnn.enabled = False + cudnn.deterministic = True + cudnn.benchmark = False + +parser = argparse.ArgumentParser() +parser.add_argument('--config', type=str, default='configs/latest.yaml', help='Path to the config file.') +parser.add_argument('--output_path', type=str, default='.', help="outputs path") +parser.add_argument('--name', type=str, default='latest_ablation', help="outputs path") +parser.add_argument("--resume", action="store_true") +parser.add_argument('--trainer', type=str, default='DGNet++', help="DGNet++") +parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') +opts = parser.parse_args() + +str_ids = opts.gpu_ids.split(',') +gpu_ids = [] +for str_id in str_ids: + gpu_ids.append(int(str_id)) +num_gpu = len(gpu_ids) +if num_gpu > 1: + raise Exception('Currently only single GPU training is supported!') + +# Load experiment setting +config = get_config(opts.config) +set_seed(config['randseed']) +max_iter = config['max_iter'] +display_size = config['display_size'] +config['vgg_model_path'] = opts.output_path + +# preparing sampling images +train_loader_a_sample = get_data_loader_folder(os.path.join(config['data_root_a'], 'train_all'), config['batch_size'], False, + config['new_size'], config['crop_image_height'], config['crop_image_width'], config['num_workers'], False) +train_loader_b_sample = get_data_loader_folder(os.path.join(config['data_root_b'], 'train_all'), config['batch_size'], False, + config['new_size'], config['crop_image_height'], config['crop_image_width'], config['num_workers'], False) + +train_aba_rand = random.permutation(train_loader_a_sample.dataset.img_num)[0:display_size] +train_abb_rand = random.permutation(train_loader_b_sample.dataset.img_num)[0:display_size] +train_aab_rand = random.permutation(train_loader_a_sample.dataset.img_num)[0:display_size] +train_bbb_rand = random.permutation(train_loader_b_sample.dataset.img_num)[0:display_size] + +train_display_images_aba = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aba_rand]).cuda() +train_display_images_abb = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_abb_rand]).cuda() +train_display_images_aaa = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aba_rand]).cuda() +train_display_images_aab = torch.stack([train_loader_a_sample.dataset[i][0] for i in train_aab_rand]).cuda() +train_display_images_bba = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_abb_rand]).cuda() +train_display_images_bbb = torch.stack([train_loader_b_sample.dataset[i][0] for i in train_bbb_rand]).cuda() + +# Setup logger and output folders +model_name = os.path.splitext(os.path.basename(opts.config))[0] +train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) +output_directory = os.path.join(opts.output_path + "/outputs", model_name) +if not os.path.exists(output_directory): + os.makedirs(output_directory) +else: + shutil.rmtree(output_directory) + os.makedirs(output_directory) +shutil.copyfile(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder +shutil.copyfile('trainer.py', os.path.join(output_directory, 'trainer.py')) # copy file to output folder +shutil.copyfile('reIDmodel.py', os.path.join(output_directory, 'reIDmodel.py')) # copy file to output folder +shutil.copyfile('networks.py', os.path.join(output_directory, 'networks.py')) # copy file to output folder + +checkpoint_directory_prev = config['src_model_dir'] + +countaa, countab, countba, countbb = 1, 1, 1, 1 +count_dis_update = config['dis_update_iter'] +nepoch = 0 +iterations = 0 +epoch_round = config['epoch_round_adv'] +lr_decayed = False +mAP_list, rank1_list, rank5_list, rank10_list = [], [], [], [] +for round_idx in range(config['max_round']): + ### setup folders + round_output_directory = os.path.join(output_directory, str(round_idx)) + checkpoint_directory, image_directory, pseudo_directory = prepare_sub_folder_pseudo(round_output_directory) + config['data_root'] = pseudo_directory + + # In the initial round, we disenable self-training and warmup the network with adversarial training + # At the round of adv_warm_max_round, we switch to self-training + if round_idx == config['adv_warm_max_round']: + config['lr2'] *= config['lr2_ramp_factor'] + config['id_adv_w'] = 0.0 + config['id_adv_w_max'] = 0.0 + config['id_tgt'] = True + config['teacher'] = '' # we do not use teacher in the self-training + if config['aa_drop']: + config['aa'] = False + + ### Evaluate source model ### + if round_idx == 0: + ### Model initialization with source model for test ### + if opts.trainer == 'DGNet++': + trainer = DGNetpp_Trainer(config) + trainer.cuda() + _ = trainer.resume_DAt1(checkpoint_directory_prev) if round_idx > 0 else trainer.resume_DAt0(checkpoint_directory_prev) + + trainer.test(config) + write_loss(iterations, trainer, train_writer) + rank1 = trainer.rank_1 + rank5 = trainer.rank_5 + rank10 = trainer.rank_10 + mAP0 = trainer.mAP_zero + mAP05 = trainer.mAP_half + mAPn1 = trainer.mAP_neg_one + + mAP_list.append(mAP05) + rank1_list.append(rank1.numpy()) + rank5_list.append(rank5.numpy()) + rank10_list.append(rank10.numpy()) + + ### Pseudo-label generation ### + trainer.pseudo_label_generate(config) + + ### Model initialization w.r.t. current pseudo labels for train ### + if round_idx == 0: + config['ID_class_b'] = 0 # In the initial round, we disenable self-training + if opts.trainer == 'DGNet++': + trainer = DGNetpp_Trainer(config) + trainer.cuda() + _ = trainer.resume_DAt1(checkpoint_directory_prev) if round_idx > 0 else trainer.resume_DAt0(checkpoint_directory_prev) + + trainer.rank_1 = rank1 + trainer.rank_5 = rank5 + trainer.rank_10 = rank10 + trainer.mAP_zero = mAP0 + trainer.mAP_half = mAP05 + trainer.mAP_neg_one = mAPn1 + ### DGNet++ Training ### + # data initialize + train_loader_a, train_loader_b, _, _ = get_mix_data_loaders(config) + print('Note that dataloader may hang with too much nworkers.') + mixData_size = 2 * min(config['sample_a'], config['sample_b']) + config['epoch_iteration'] = mixData_size // config['batch_size'] + print('Every epoch need %d iterations' % config['epoch_iteration']) + + # training + subiterations = 0 + epoch_ridx = 0 + while epoch_ridx < epoch_round: + for it, ((images_a, labels_a, pos_a), (images_b, labels_b, pos_b)) in enumerate(zip(train_loader_a, train_loader_b)): + trainer.update_learning_rate() + + print('labels_a: ' + str(labels_a)) + print('labels_b: ' + str(labels_b)) + images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() + pos_a, pos_b = pos_a.cuda().detach(), pos_b.cuda().detach() + labels_a, labels_b = labels_a.cuda().detach(), labels_b.cuda().detach() + + with Timer("Elapsed time in update: %f"): + # Main training code + if labels_a[0] < config['ID_class_a'] and labels_b[0] < config['ID_class_a'] and config['aa']: # aa + print('aa') + if countaa == count_dis_update: + trainer.dis_update_aa(images_a, images_b, config) + countaa = 0 + trainer.gen_update_aa(images_a, labels_a, pos_a, images_b, labels_b, pos_b, config, subiterations) + countaa += 1 + elif labels_a[0] < config['ID_class_a'] and labels_b[0] >= config['ID_class_a'] and config['ab']: # ab + print('ab') + if countab == count_dis_update: + trainer.dis_update_ab(images_a, images_b, config) + countab = 0 + trainer.gen_update_ab(images_a, labels_a, pos_a, images_b, labels_b, pos_b, config, subiterations) + countab += 1 + elif labels_a[0] >= config['ID_class_a'] and labels_b[0] < config['ID_class_a'] and config['ab']: # ba + print('ba') + if countba == count_dis_update: + trainer.dis_update_ab(images_b, images_a, config) + countba = 0 + trainer.gen_update_ab(images_b, labels_b, pos_b, images_a, labels_a, pos_a, config, subiterations) + countba += 1 + elif labels_a[0] >= config['ID_class_a'] and labels_b[0] >= config['ID_class_a'] and config['bb']: # bb + print('bb') + if countbb == count_dis_update: + trainer.dis_update_bb(images_a, images_b, config) + countbb = 0 + trainer.gen_update_bb(images_a, labels_a, pos_a, images_b, labels_b, pos_b, config, subiterations) + countbb += 1 + + torch.cuda.synchronize() + # Dump training stats in log file + if (iterations + 1) % config['log_iter'] == 0: + print("\033[1m Round: %02d Epoch: %02d Iteration: %08d/%08d \033[0m \n" % (round_idx, nepoch, subiterations + 1, config['epoch_iteration'] * epoch_round), end=" ") + write_loss(iterations, trainer, train_writer) + + iterations += 1 + subiterations += 1 + if iterations >= max_iter: + # Save network weights + trainer.save(checkpoint_directory, iterations) + print('Max mAP: ' + str(max(mAP_list)*100) + '%') + print('Max rank 1 accuracy: ' + str(max(rank1_list)*100) + '%') + print('Max rank 5 accuracy: ' + str(max(rank5_list)*100) + '%') + print('Max rank 10 accuracy: ' + str(max(rank10_list)*100) + '%') + sys.exit('Finish training') + nepoch += 1 + + # test in target domain in every epoch + trainer.test(config) + write_loss(iterations, trainer, train_writer) + rank1 = trainer.rank_1 + rank5 = trainer.rank_5 + rank10 = trainer.rank_10 + mAP0 = trainer.mAP_zero + mAP05 = trainer.mAP_half + mAPn1 = trainer.mAP_neg_one + + mAP_list.append(mAP05) + rank1_list.append(rank1.numpy()) + rank5_list.append(rank5.numpy()) + rank10_list.append(rank10.numpy()) + # save generated images in every round + with torch.no_grad(): + image_outputs = trainer.sample_ab(train_display_images_aba, train_display_images_abb) + write_2images(image_outputs, display_size, image_directory, 'train_ab_%08d' % (iterations + 1)) + del image_outputs + + with torch.no_grad(): + image_outputs = trainer.sample_aa(train_display_images_aaa, train_display_images_aab) + write_2images(image_outputs, display_size, image_directory, 'train_aa_%08d' % (iterations + 1)) + del image_outputs + + with torch.no_grad(): + image_outputs = trainer.sample_bb(train_display_images_bba, train_display_images_bbb) + write_2images(image_outputs, display_size, image_directory, 'train_bb_%08d' % (iterations + 1)) + del image_outputs + + # regenerate data loaders in every epoch + train_loader_a, train_loader_b, _, _ = get_mix_data_loaders(config) + + # adjust the total epochs per round + epoch_ridx += 1 + if epoch_ridx == epoch_round and round_idx == 0: + epoch_round = config['epoch_round'] + break + + # Save network weights + trainer.save(checkpoint_directory, iterations) + + # update model_prev_folder + checkpoint_directory_prev = checkpoint_directory + +print('Max mAP: {:.2%}'.format(max(mAP_list))) +print('Max rank 1 accuracy: {:.2%}'.format(max(rank1_list))) +print('Max rank 5 accuracy: {:.2%}'.format(max(rank5_list))) +print('Max rank 10 accuracy: {:.2%}'.format(max(rank10_list))) +print('Finish training') diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..db29381 --- /dev/null +++ b/trainer.py @@ -0,0 +1,1617 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +from networks import AdaINGen, MsImageDis, IdDis +from reIDmodel import ft_net, ft_netABe +from utils import get_model_list, vgg_preprocess, load_vgg16, get_scheduler +from torch.autograd import Variable +import torch +import torch.nn as nn +import torchvision +import copy +import os +import cv2 +import numpy as np +from random_erasing import RandomErasing +from shutil import copyfile, copytree +import random +import yaml +from re_ranking_one import re_ranking_one +from sklearn.cluster import DBSCAN + + + +def to_gray(half=False): #simple + def forward(x): + x = torch.mean(x, dim=1, keepdim=True) + if half: + x = x.half() + return x + return forward + +def to_edge(x): + x = x.data.cpu() + out = torch.FloatTensor(x.size(0), x.size(2), x.size(3)) + for i in range(x.size(0)): + xx = recover(x[i,:,:,:]) # 3 channel, 256x128x3 + xx = cv2.cvtColor(xx, cv2.COLOR_RGB2GRAY) # 256x128x1 + xx = cv2.Canny(xx, 10, 200) #256x128 + xx = xx/255.0 - 0.5 # {-0.5,0.5} + xx += np.random.randn(xx.shape[0],xx.shape[1])*0.1 #add random noise + xx = torch.from_numpy(xx.astype(np.float32)) + out[i,:,:] = xx + out = out.unsqueeze(1) + return out.cuda() + +def scale2(x): + if x.size(2) > 128: # do not need to scale the input + return x + x = torch.nn.functional.upsample(x, scale_factor=2, mode='nearest') #bicubic is not available for the time being. + return x + +def recover(inp): + inp = inp.numpy().transpose((1, 2, 0)) + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + inp = std * inp + mean + inp = inp * 255.0 + inp = np.clip(inp, 0, 255) + inp = inp.astype(np.uint8) + return inp + +def train_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.train() + +def fliplr(img): + '''flip horizontal''' + inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W + img_flip = img.index_select(3,inv_idx) + return img_flip + +def update_teacher(model_s, model_t, alpha=0.999): + for param_s, param_t in zip(model_s.parameters(), model_t.parameters()): + param_t.data.mul_(alpha).add_(1 - alpha, param_s.data) + +def predict_label(teacher_models, inputs, num_class, alabel, slabel, teacher_style=0): +# teacher_style: +# 0: Our smooth dynamic label +# 1: Pseudo label, hard dynamic label +# 2: Conditional label, hard static label +# 3: LSRO, static smooth label +# 4: Dynamic Soft Two-label +# alabel is appearance label + if teacher_style == 0: + count = 0 + sm = nn.Softmax(dim=1) + for teacher_model in teacher_models: + _, outputs_t1 = teacher_model(inputs) + outputs_t1 = sm(outputs_t1.detach()) + _, outputs_t2 = teacher_model(fliplr(inputs)) + outputs_t2 = sm(outputs_t2.detach()) + if count==0: + outputs_t = outputs_t1 + outputs_t2 + else: + outputs_t = outputs_t * opt.alpha # old model decay + outputs_t += outputs_t1 + outputs_t2 + count +=2 + elif teacher_style == 1: # dynamic one-hot label + count = 0 + sm = nn.Softmax(dim=1) + for teacher_model in teacher_models: + _, outputs_t1 = teacher_model(inputs) + outputs_t1 = sm(outputs_t1.detach()) # change softmax to max + _, outputs_t2 = teacher_model(fliplr(inputs)) + outputs_t2 = sm(outputs_t2.detach()) + if count==0: + outputs_t = outputs_t1 + outputs_t2 + else: + outputs_t = outputs_t * opt.alpha # old model decay + outputs_t += outputs_t1 + outputs_t2 + count +=2 + _, dlabel = torch.max(outputs_t.data, 1) + outputs_t = torch.zeros(inputs.size(0), num_class).cuda() + for i in range(inputs.size(0)): + outputs_t[i, dlabel[i]] = 1 + elif teacher_style == 2: # appearance label + outputs_t = torch.zeros(inputs.size(0), num_class).cuda() + for i in range(inputs.size(0)): + outputs_t[i, alabel[i]] = 1 + elif teacher_style == 3: # LSRO + outputs_t = torch.ones(inputs.size(0), num_class).cuda() + elif teacher_style == 4: #Two-label + count = 0 + sm = nn.Softmax(dim=1) + for teacher_model in teacher_models: + _, outputs_t1 = teacher_model(inputs) + outputs_t1 = sm(outputs_t1.detach()) + _, outputs_t2 = teacher_model(fliplr(inputs)) + outputs_t2 = sm(outputs_t2.detach()) + if count==0: + outputs_t = outputs_t1 + outputs_t2 + else: + outputs_t = outputs_t * opt.alpha # old model decay + outputs_t += outputs_t1 + outputs_t2 + count +=2 + mask = torch.zeros(outputs_t.shape) + mask = mask.cuda() + for i in range(inputs.size(0)): + mask[i, alabel[i]] = 1 + mask[i, slabel[i]] = 1 + outputs_t = outputs_t*mask + else: + print('not valid style. teacher-style is in [0-3].') + + s = torch.sum(outputs_t, dim=1, keepdim=True) + s = s.expand_as(outputs_t) + outputs_t = outputs_t/s + return outputs_t + +###################################################################### +# Load model +#--------------------------- +def load_network(network, name): + save_path = os.path.join('./models',name,'net_last.pth') + network.load_state_dict(torch.load(save_path)) + return network + +def load_config(name): + config_path = os.path.join('./models',name,'opts.yaml') + with open(config_path, 'r') as stream: + config = yaml.load(stream) + return config + +def norm(f, dim = 1): + f = f.squeeze() + fnorm = torch.norm(f, p=2, dim=dim, keepdim=True) + f = f.div(fnorm.expand_as(f)) + return f + +def get_id(img_path, time_constraint = False): + camera_id = [] + time_id = [] + labels = [] + for path, v in img_path: + # filename = path.split('/')[-1] + filename = os.path.basename(path) + label = filename[0:4] + camera = filename.split('c')[1] + if time_constraint: + metadata = filename.split('_') + num_metadata = len(metadata) + if num_metadata == 3: + time = filename.split('f')[1] + elif num_metadata == 4: + time = metadata[2] + # print(camera) + if label[0:2] == '-1': + labels.append(-1) + else: + labels.append(int(label)) + camera_id.append(int(camera[0])) + if time_constraint: + if num_metadata == 3: + time_id.append(int(time[0:7])) + elif num_metadata == 4: + time_id.append(int(time[0:6])) + return camera_id, labels, time_id + +def evaluate(qf, ql, qc, gf, gl, gc): + query = qf.view(-1, 1) + # print(query.shape) + score = torch.mm(gf, query) + score = score.squeeze(1).cpu() + score = score.numpy() + # predict index + index = np.argsort(score) # from small to large + index = index[::-1] + # index = index[0:2000] + # good index + query_index = np.argwhere(gl == ql) + # same camera + camera_index = np.argwhere(gc == qc) + + good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) + junk_index1 = np.argwhere(gl == -1) + junk_index2 = np.intersect1d(query_index, camera_index) + junk_index = np.append(junk_index2, junk_index1) # .flatten()) + + CMC_tmp = compute_mAP(index, gc, good_index, junk_index) + return CMC_tmp + +def compute_mAP(index, gc, good_index, junk_index): + ap = 0 + cmc = torch.IntTensor(len(index)).zero_() + if good_index.size == 0: # if empty + cmc[0] = -1 + return ap, cmc + + # remove junk_index + ranked_camera = gc[index] + mask = np.in1d(index, junk_index, invert=True) + mask2 = np.in1d(index, np.append(good_index, junk_index), invert=True) + index = index[mask] + ranked_camera = ranked_camera[mask] + + # find good_index index + ngood = len(good_index) + mask = np.in1d(index, good_index) + rows_good = np.argwhere(mask == True) + rows_good = rows_good.flatten() + + cmc[rows_good[0]:] = 1 + for i in range(ngood): + d_recall = 1.0 / ngood + precision = (i + 1) * 1.0 / (rows_good[i] + 1) + if rows_good[i] != 0: + old_precision = i * 1.0 / rows_good[i] + else: + old_precision = 1.0 + ap = ap + d_recall * (old_precision + precision) / 2 + + return ap, cmc + +class DGNetpp_Trainer(nn.Module): + def __init__(self, hyperparameters): + super(DGNetpp_Trainer, self).__init__() + lr_g = hyperparameters['lr_g'] + lr_d = hyperparameters['lr_d'] + lr_id_d = hyperparameters['lr_id_d'] + ID_class_a = hyperparameters['ID_class_a'] + + # Initiate the networks + # We do not need to manually set fp16 in the network. So here I set fp16=False. + self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a + self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'], fp16=False) # auto-encoder for domain b + + if not 'ID_stride' in hyperparameters.keys(): + hyperparameters['ID_stride'] = 2 + + self.id_a = ft_netABe(ID_class_a + hyperparameters['ID_class_b'], stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) + + self.id_b = self.id_a + self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a + self.dis_b = self.dis_a + + self.id_dis = IdDis(hyperparameters['gen']['id_dim'], hyperparameters['dis'], fp16=False) # ID discriminator + + # load teachers + if hyperparameters['teacher'] != "": + teacher_name = hyperparameters['teacher'] + print(teacher_name) + teacher_names = teacher_name.split(',') + teacher_model = nn.ModuleList() + teacher_count = 0 + for teacher_name in teacher_names: + config_tmp = load_config(teacher_name) + if 'stride' in config_tmp: + stride = config_tmp['stride'] + else: + stride = 2 + model_tmp = ft_net(ID_class_a, stride = stride) + teacher_model_tmp = load_network(model_tmp, teacher_name) + teacher_model_tmp.model.fc = nn.Sequential() # remove the original fc layer in ImageNet + teacher_model_tmp = teacher_model_tmp.cuda() + teacher_model.append(teacher_model_tmp.cuda().eval()) + teacher_count += 1 + self.teacher_model = teacher_model + if hyperparameters['train_bn']: + self.teacher_model = self.teacher_model.apply(train_bn) + + self.instancenorm = nn.InstanceNorm2d(512, affine=False) + display_size = int(hyperparameters['display_size']) + + # RGB to one channel + if hyperparameters['single'] == 'edge': + self.single = to_edge + else: + self.single = to_gray(False) + + # Random Erasing when training + if not 'erasing_p' in hyperparameters.keys(): + hyperparameters['erasing_p'] = 0 + self.single_re = RandomErasing(probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0]) + + if not 'T_w' in hyperparameters.keys(): + hyperparameters['T_w'] = 1 + # Setup the optimizers + beta1 = hyperparameters['beta1'] + beta2 = hyperparameters['beta2'] + dis_a_params = list(self.dis_a.parameters()) + gen_a_params = list(self.gen_a.parameters()) + gen_b_params = list(self.gen_b.parameters()) + id_dis_params = list(self.id_dis.parameters()) + + self.dis_a_opt = torch.optim.Adam([p for p in dis_a_params if p.requires_grad], + lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) + self.id_dis_opt = torch.optim.Adam([p for p in id_dis_params if p.requires_grad], + lr=lr_id_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) + self.gen_a_opt = torch.optim.Adam([p for p in gen_a_params if p.requires_grad], + lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) + self.gen_b_opt = torch.optim.Adam([p for p in gen_b_params if p.requires_grad], + lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) + # id params + ignored_params = (list(map(id, self.id_a.classifier1.parameters())) + + list(map(id, self.id_a.classifier2.parameters()))) + base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) + lr2 = hyperparameters['lr2'] + self.id_opt = torch.optim.SGD([ + {'params': base_params, 'lr': lr2}, + {'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10}, + {'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10} + ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) + + self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters) + self.id_dis_scheduler = get_scheduler(self.id_dis_opt, hyperparameters) + self.id_dis_scheduler.gamma = hyperparameters['gamma2'] + self.gen_a_scheduler = get_scheduler(self.gen_a_opt, hyperparameters) + self.gen_b_scheduler = get_scheduler(self.gen_b_opt, hyperparameters) + self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) + self.id_scheduler.gamma = hyperparameters['gamma2'] + + # ID Loss + self.id_criterion = nn.CrossEntropyLoss() + self.criterion_teacher = nn.KLDivLoss(size_average=False) + # Load VGG model if needed + if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: + self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + + def to_re(self, x): + out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) + out = out.cuda() + for i in range(x.size(0)): + out[i, :, :, :] = self.single_re(x[i, :, :, :]) + return out + + def recon_criterion(self, input, target): + diff = input - target.detach() + return torch.mean(torch.abs(diff[:])) + + def recon_criterion_sqrt(self, input, target): + diff = input - target + return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) + + def recon_criterion2(self, input, target): + diff = input - target + return torch.mean(diff[:] ** 2) + + def recon_cos(self, input, target): + cos = torch.nn.CosineSimilarity() + cos_dis = 1 - cos(input, target) + return torch.mean(cos_dis[:]) + + def forward(self, x_a, x_b): + self.eval() + s_a = self.gen_a.encode(self.single(x_a)) + s_b = self.gen_b.encode(self.single(x_b)) + f_a, _, _ = self.id_a(scale2(x_a)) + f_b, _, _ = self.id_b(scale2(x_b)) + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + self.train() + return x_ab, x_ba + + def gen_update_ab(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): + # ppa, ppb is the same person + self.gen_a_opt.zero_grad() + self.gen_b_opt.zero_grad() + self.id_opt.zero_grad() + self.id_dis_opt.zero_grad() + # encode + s_a = self.gen_a.encode(self.single(x_a)) + s_b = self.gen_b.encode(self.single(x_b)) + f_a, p_a, fe_a = self.id_a(scale2(x_a)) + f_b, p_b, fe_b = self.id_b(scale2(x_b)) + # autodecode + x_a_recon = self.gen_a.decode(s_a, f_a) + x_b_recon = self.gen_b.decode(s_b, f_b) + + # encode the same ID different photo + fp_a, pp_a, fe_pa = self.id_a(scale2(xp_a)) + fp_b, pp_b, fe_pb = self.id_b(scale2(xp_b)) + + # decode the same person + x_a_recon_p = self.gen_a.decode(s_a, fp_a) + x_b_recon_p = self.gen_b.decode(s_b, fp_b) + + # has gradient + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + # no gradient + x_ba_copy = Variable(x_ba.data, requires_grad=False) + x_ab_copy = Variable(x_ab.data, requires_grad=False) + + rand_num = random.uniform(0, 1) + ################################# + # encode structure + if hyperparameters['use_encoder_again'] >= rand_num: + # encode again (encoder is tuned, input is fixed) + s_a_recon = self.gen_a.enc_content(self.single(x_ab_copy)) + s_b_recon = self.gen_b.enc_content(self.single(x_ba_copy)) + else: + # copy the encoder + self.enc_content_a_copy = copy.deepcopy(self.gen_a.enc_content) + self.enc_content_a_copy = self.enc_content_a_copy.eval() + self.enc_content_b_copy = copy.deepcopy(self.gen_b.enc_content) + self.enc_content_b_copy = self.enc_content_b_copy.eval() + # encode again (encoder is fixed, input is tuned) + s_a_recon = self.enc_content_a_copy(self.single(x_ab)) + s_b_recon = self.enc_content_b_copy(self.single(x_ba)) + + ################################# + # encode appearance + self.id_a_copy = copy.deepcopy(self.id_a) + self.id_a_copy = self.id_a_copy.eval() + if hyperparameters['train_bn']: + self.id_a_copy = self.id_a_copy.apply(train_bn) + self.id_b_copy = self.id_a_copy + # encode again (encoder is fixed, input is tuned) + f_a_recon, p_a_recon, _ = self.id_a_copy(scale2(x_ba)) + f_b_recon, p_b_recon, _ = self.id_b_copy(scale2(x_ab)) + + # teacher Loss + # Tune the ID model + log_sm = nn.LogSoftmax(dim=1) + if hyperparameters['teacher_w'] > 0 and hyperparameters['teacher'] != "": + if hyperparameters['ID_style'] == 'normal': + _, p_a_student, _ = self.id_a(scale2(x_ba_copy)) + p_a_student = log_sm(p_a_student) + p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy)) + self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) + + _, p_b_student, _ = self.id_b(scale2(x_ab_copy)) + p_b_student = log_sm(p_b_student) + p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy)) + self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) + elif hyperparameters['ID_style'] == 'AB': + # normal teacher-student loss + # BA -> LabelA(smooth) + LabelB(batchB) + _, p_ba_student, _ = self.id_a(scale2(x_ba_copy)) # f_a, s_b + p_a_student = log_sm(p_ba_student[0]) + with torch.no_grad(): + p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), + num_class=hyperparameters['ID_class_a'], alabel=l_a, slabel=l_b, + teacher_style=hyperparameters['teacher_style']) + p_a_teacher = torch.cat( + (p_a_teacher, torch.zeros((p_a_teacher.size(0), hyperparameters['ID_class_b'])).cuda()), + 1).detach() + self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) + + _, p_ab_student, _ = self.id_b(scale2(x_ab_copy)) # f_b, s_a + # branch b loss + # here we give different label + self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + loss_B = self.id_criterion(p_ab_student[1], l_a) + self.loss_teacher = self.loss_teacher + hyperparameters['B_w'] * loss_B + else: + self.loss_teacher = 0.0 + + # decode again (if needed) + if hyperparameters['use_decoder_again']: + x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + else: + self.mlp_w_a_copy = copy.deepcopy(self.gen_a.mlp_w) + self.mlp_b_a_copy = copy.deepcopy(self.gen_a.mlp_b) + self.dec_a_copy = copy.deepcopy(self.gen_a.dec) # Error + ID = f_a_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w_a = self.mlp_w_a_copy(ID_Style) + adain_params_b_a = self.mlp_b_a_copy(ID_Style) + self.gen_a.assign_adain_params(adain_params_w_a, adain_params_b_a, self.dec_a_copy) + x_aba = self.dec_a_copy(s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + self.mlp_w_b_copy = copy.deepcopy(self.gen_b.mlp_w) + self.mlp_b_b_copy = copy.deepcopy(self.gen_b.mlp_b) + self.dec_b_copy = copy.deepcopy(self.gen_b.dec) # Error + ID = f_b_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w_b = self.mlp_w_b_copy(ID_Style) + adain_params_b_b = self.mlp_b_b_copy(ID_Style) + self.gen_a.assign_adain_params(adain_params_w_b, adain_params_b_b, self.dec_b_copy) + x_bab = self.dec_b_copy(s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + # auto-encoder image reconstruction + self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) + self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) + self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) + self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) + + # feature reconstruction + self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 + self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 + + # Random Erasing only effect the ID and PID loss. + if hyperparameters['erasing_p'] > 0: + x_a_re = self.to_re(scale2(x_a.clone())) + x_b_re = self.to_re(scale2(x_b.clone())) + xp_a_re = self.to_re(scale2(xp_a.clone())) + xp_b_re = self.to_re(scale2(xp_b.clone())) + _, p_a, _ = self.id_a(x_a_re) + _, p_b, _ = self.id_b(x_b_re) + # encode the same ID different photo + _, pp_a, _ = self.id_a(xp_a_re) + _, pp_b, _ = self.id_b(xp_b_re) + + # ID loss AND Tune the Generated image + weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] + if hyperparameters['id_tgt']: + self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + + weight_B * (self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b)) + self.loss_pid = self.id_criterion(pp_a[0], l_a) + hyperparameters['tgt_pos'] * self.id_criterion(pp_b[0],l_b) # + weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) + self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) + else: + self.loss_id = self.id_criterion(p_a[0], l_a) + weight_B * self.id_criterion(p_a[1], l_a) + self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + + # print(f_a_recon, f_a) + self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + # GAN loss + self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) + self.loss_gen_adv_b = self.dis_a.calc_gen_loss(x_ab) + # domain-invariant perceptual loss + self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 + self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 + + # ID domain adversarial loss + self.loss_gen_id_adv = self.id_dis.calc_gen_loss(fe_b) if hyperparameters['id_adv_w'] > 0 else 0 + + if iteration > hyperparameters['warm_iter']: + hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) + hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) + hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) + + if iteration > hyperparameters['warm_teacher_iter']: + hyperparameters['teacher_w'] += hyperparameters['warm_scale'] + hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) + + hyperparameters['id_adv_w'] += hyperparameters['adv_warm_scale'] + hyperparameters['id_adv_w'] = min(hyperparameters['id_adv_w'], hyperparameters['id_adv_w_max']) + + # total loss + self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ + hyperparameters['gan_w'] * self.loss_gen_adv_b + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ + hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ + hyperparameters['recon_xp_w'] * hyperparameters['recon_xp_tgt_w'] * self.loss_gen_recon_xp_b + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ + hyperparameters['id_w'] * self.loss_id + \ + hyperparameters['pid_w'] * self.loss_pid + \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ + hyperparameters['teacher_w'] * self.loss_teacher + \ + hyperparameters['id_adv_w'] * self.loss_gen_id_adv + self.loss_gen_total.backward() + self.gen_a_opt.step() + self.gen_b_opt.step() + self.id_opt.step() + print( + "L_total: %.4f, L_gan: %.4f, L_adv: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f" % ( + self.loss_gen_total, \ + hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ + hyperparameters['id_adv_w'] * (self.loss_gen_id_adv), \ + hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ + hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + hyperparameters['recon_xp_tgt_w'] * self.loss_gen_recon_xp_b), \ + hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ + hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ + hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ + hyperparameters['id_w'] * self.loss_id, \ + hyperparameters['pid_w'] * self.loss_pid, \ + hyperparameters['teacher_w'] * self.loss_teacher)) + + def gen_update_aa(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): + # ppa, ppb is the same person + self.gen_a_opt.zero_grad() + self.id_opt.zero_grad() + self.id_dis_opt.zero_grad() + # encode + s_a = self.gen_a.encode(self.single(x_a)) + s_b = self.gen_a.encode(self.single(x_b)) + f_a, p_a, _ = self.id_a(scale2(x_a)) + f_b, p_b, _ = self.id_a(scale2(x_b)) + # autodecode + x_a_recon = self.gen_a.decode(s_a, f_a) + x_b_recon = self.gen_a.decode(s_b, f_b) + + # encode the same ID different photo + fp_a, pp_a, _ = self.id_a(scale2(xp_a)) + fp_b, pp_b, _ = self.id_a(scale2(xp_b)) + + # decode the same person + x_a_recon_p = self.gen_a.decode(s_a, fp_a) + x_b_recon_p = self.gen_a.decode(s_b, fp_b) + + # has gradient + x_ba = self.gen_a.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + # no gradient + x_ba_copy = Variable(x_ba.data, requires_grad=False) + x_ab_copy = Variable(x_ab.data, requires_grad=False) + + rand_num = random.uniform(0, 1) + ################################# + # encode structure + if hyperparameters['use_encoder_again'] >= rand_num: + # encode again (encoder is tuned, input is fixed) + s_a_recon = self.gen_a.enc_content(self.single(x_ab_copy)) + s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) + else: + # copy the encoder + self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) + self.enc_content_copy = self.enc_content_copy.eval() + # encode again (encoder is fixed, input is tuned) + s_a_recon = self.enc_content_copy(self.single(x_ab)) + s_b_recon = self.enc_content_copy(self.single(x_ba)) + + ################################# + # encode appearance + self.id_a_copy = copy.deepcopy(self.id_a) + self.id_a_copy = self.id_a_copy.eval() + if hyperparameters['train_bn']: + self.id_a_copy = self.id_a_copy.apply(train_bn) + self.id_b_copy = self.id_a_copy + # encode again (encoder is fixed, input is tuned) + f_a_recon, p_a_recon, _ = self.id_a_copy(scale2(x_ba)) + f_b_recon, p_b_recon, _ = self.id_b_copy(scale2(x_ab)) + + # teacher Loss + # Tune the ID model + log_sm = nn.LogSoftmax(dim=1) + if hyperparameters['teacher_w'] > 0 and hyperparameters['teacher'] != "": + if hyperparameters['ID_style'] == 'normal': + _, p_a_student, _ = self.id_a(scale2(x_ba_copy)) + p_a_student = log_sm(p_a_student) + p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy)) + self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) + + _, p_b_student, _ = self.id_a(scale2(x_ab_copy)) + p_b_student = log_sm(p_b_student) + p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy)) + self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) + elif hyperparameters['ID_style'] == 'AB': + # normal teacher-student loss + # BA -> LabelA(smooth) + LabelB(batchB) + _, p_ba_student, _ = self.id_a(scale2(x_ba_copy)) # f_a, s_b + p_a_student = log_sm(p_ba_student[0]) + with torch.no_grad(): + p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), + num_class=hyperparameters['ID_class_a'], alabel=l_a, slabel=l_b, + teacher_style=hyperparameters['teacher_style']) + p_a_teacher = torch.cat( + (p_a_teacher, torch.zeros((p_a_teacher.size(0), hyperparameters['ID_class_b'])).cuda()), + 1).detach() + self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) + + _, p_ab_student, _ = self.id_a(scale2(x_ab_copy)) # f_b, s_a + p_b_student = log_sm(p_ab_student[0]) + with torch.no_grad(): + p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), + num_class=hyperparameters['ID_class_a'], alabel=l_b, slabel=l_a, + teacher_style=hyperparameters['teacher_style']) + p_b_teacher = torch.cat( + (p_b_teacher, torch.zeros((p_b_teacher.size(0), hyperparameters['ID_class_b'])).cuda()), + 1).detach() + self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) + + # branch b loss + # here we give different label + loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a) + self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B + else: + self.loss_teacher = 0.0 + + # decode again (if needed) + if hyperparameters['use_decoder_again']: + x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_a.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + else: + self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w) + self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b) + self.dec_copy = copy.deepcopy(self.gen_a.dec) # Error + ID = f_a_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w = self.mlp_w_copy(ID_Style) + adain_params_b = self.mlp_b_copy(ID_Style) + self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) + x_aba = self.dec_copy(s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + ID = f_b_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w = self.mlp_w_copy(ID_Style) + adain_params_b = self.mlp_b_copy(ID_Style) + self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) + x_bab = self.dec_copy(s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + # auto-encoder image reconstruction + self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) + self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) + self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) + self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) + + # feature reconstruction + self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 + self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 + + # Random Erasing only effect the ID and PID loss. + if hyperparameters['erasing_p'] > 0: + x_a_re = self.to_re(scale2(x_a.clone())) + x_b_re = self.to_re(scale2(x_b.clone())) + xp_a_re = self.to_re(scale2(xp_a.clone())) + xp_b_re = self.to_re(scale2(xp_b.clone())) + _, p_a, _ = self.id_a(x_a_re) + _, p_b, _ = self.id_a(x_b_re) + # encode the same ID different photo + _, pp_a, _ = self.id_a(xp_a_re) + _, pp_b, _ = self.id_a(xp_b_re) + + # ID loss AND Tune the Generated image + weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] + self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + + weight_B * (self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b)) + self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], + l_b) # + weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) + self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) + + # print(f_a_recon, f_a) + self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + # GAN loss + self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) + self.loss_gen_adv_b = self.dis_a.calc_gen_loss(x_ab) + # domain-invariant perceptual loss + self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 + self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 + + # ID domain adversarial loss + self.loss_gen_id_adv = 0.0 + + if iteration > hyperparameters['warm_iter']: + hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) + hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) + hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) + + if iteration > hyperparameters['warm_teacher_iter']: + hyperparameters['teacher_w'] += hyperparameters['warm_scale'] + hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) + # total loss + self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ + hyperparameters['gan_w'] * self.loss_gen_adv_b + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ + hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ + hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ + hyperparameters['id_w'] * self.loss_id + \ + hyperparameters['pid_w'] * self.loss_pid + \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ + hyperparameters['teacher_w'] * self.loss_teacher + + self.loss_gen_total.backward() + self.gen_a_opt.step() + self.id_opt.step() + print( + "L_total: %.4f, L_gan: %.4f, L_adv: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f" % ( + self.loss_gen_total, \ + hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ + hyperparameters['id_adv_w'] * (self.loss_gen_id_adv), \ + hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ + hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ + hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ + hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ + hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ + hyperparameters['id_w'] * self.loss_id, \ + hyperparameters['pid_w'] * self.loss_pid, \ + hyperparameters['teacher_w'] * self.loss_teacher)) + + def gen_update_bb(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): + # ppa, ppb is the same person + self.gen_b_opt.zero_grad() + self.id_opt.zero_grad() + self.id_dis_opt.zero_grad() + # encode + s_a = self.gen_b.encode(self.single(x_a)) + s_b = self.gen_b.encode(self.single(x_b)) + f_a, p_a, fe_a = self.id_b(scale2(x_a)) + f_b, p_b, fe_b = self.id_b(scale2(x_b)) + # autodecode + x_a_recon = self.gen_b.decode(s_a, f_a) + x_b_recon = self.gen_b.decode(s_b, f_b) + + # encode the same ID different photo + fp_a, pp_a, fe_pa = self.id_b(scale2(xp_a)) + fp_b, pp_b, fe_pb = self.id_b(scale2(xp_b)) + + # decode the same person + x_a_recon_p = self.gen_b.decode(s_a, fp_a) + x_b_recon_p = self.gen_b.decode(s_b, fp_b) + + # has gradient + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_b.decode(s_a, f_b) + # no gradient + x_ba_copy = Variable(x_ba.data, requires_grad=False) + x_ab_copy = Variable(x_ab.data, requires_grad=False) + + rand_num = random.uniform(0, 1) + ################################# + # encode structure + if hyperparameters['use_encoder_again'] >= rand_num: + # encode again (encoder is tuned, input is fixed) + s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) + s_b_recon = self.gen_b.enc_content(self.single(x_ba_copy)) + else: + # copy the encoder + self.enc_content_copy = copy.deepcopy(self.gen_b.enc_content) + self.enc_content_copy = self.enc_content_copy.eval() + # encode again (encoder is fixed, input is tuned) + s_a_recon = self.enc_content_copy(self.single(x_ab)) + s_b_recon = self.enc_content_copy(self.single(x_ba)) + + ################################# + # encode appearance + self.id_a_copy = copy.deepcopy(self.id_b) + self.id_a_copy = self.id_a_copy.eval() + if hyperparameters['train_bn']: + self.id_a_copy = self.id_a_copy.apply(train_bn) + self.id_b_copy = self.id_a_copy + # encode again (encoder is fixed, input is tuned) + f_a_recon, p_a_recon, _ = self.id_a_copy(scale2(x_ba)) + f_b_recon, p_b_recon, _ = self.id_b_copy(scale2(x_ab)) + + # teacher Loss + self.loss_teacher = 0.0 + + # decode again (if needed) + if hyperparameters['use_decoder_again']: + x_aba = self.gen_b.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + else: + self.mlp_w_copy = copy.deepcopy(self.gen_b.mlp_w) + self.mlp_b_copy = copy.deepcopy(self.gen_b.mlp_b) + self.dec_copy = copy.deepcopy(self.gen_b.dec) # Error + ID = f_a_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w = self.mlp_w_copy(ID_Style) + adain_params_b = self.mlp_b_copy(ID_Style) + self.gen_b.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) + x_aba = self.dec_copy(s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + ID = f_b_recon + ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) + adain_params_w = self.mlp_w_copy(ID_Style) + adain_params_b = self.mlp_b_copy(ID_Style) + self.gen_b.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) + x_bab = self.dec_copy(s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None + + # auto-encoder image reconstruction + self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) + self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) + self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) + self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) + + # feature reconstruction + self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 + self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 + self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 + + # Random Erasing only effect the ID and PID loss. + if hyperparameters['erasing_p'] > 0: + x_a_re = self.to_re(scale2(x_a.clone())) + x_b_re = self.to_re(scale2(x_b.clone())) + xp_a_re = self.to_re(scale2(xp_a.clone())) + xp_b_re = self.to_re(scale2(xp_b.clone())) + _, p_a, _ = self.id_b(x_a_re) + _, p_b, _ = self.id_b(x_b_re) + # encode the same ID different photo + _, pp_a, _ = self.id_b(xp_a_re) + _, pp_b, _ = self.id_b(xp_b_re) + + # ID loss AND Tune the Generated image + if hyperparameters['id_tgt']: + weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] + self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + + weight_B * (self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b)) + self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], + l_b) # + weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) + self.loss_pid *= self.loss_pid*hyperparameters['tgt_pos'] + self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) + else: + self.loss_id = 0.0 + self.loss_pid = 0.0 + self.loss_gen_recon_id = 0.0 + + # print(f_a_recon, f_a) + self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 + # GAN loss + self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) + self.loss_gen_adv_b = self.dis_a.calc_gen_loss(x_ab) + # domain-invariant perceptual loss + self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 + self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 + + # ID domain adversarial loss + self.loss_gen_id_adv = ( self.id_dis.calc_gen_loss(fe_b) + self.id_dis.calc_gen_loss(fe_a) ) / 2 if hyperparameters['id_adv_w'] > 0 else 0 + + if iteration > hyperparameters['warm_iter']: + hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) + hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) + hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] + hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) + + if iteration > hyperparameters['warm_teacher_iter']: + hyperparameters['teacher_w'] += hyperparameters['warm_scale'] + hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) + + hyperparameters['id_adv_w'] += hyperparameters['adv_warm_scale'] + hyperparameters['id_adv_w'] = min(hyperparameters['id_adv_w'], hyperparameters['id_adv_w_max']) + # total loss + self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ + hyperparameters['gan_w'] * self.loss_gen_adv_b + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ + hyperparameters['recon_xp_w'] * hyperparameters['recon_xp_tgt_w'] * self.loss_gen_recon_xp_a + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ + hyperparameters['recon_xp_w'] * hyperparameters['recon_xp_tgt_w'] * self.loss_gen_recon_xp_b + \ + hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ + hyperparameters['id_w'] * self.loss_id + \ + hyperparameters['pid_w'] * self.loss_pid + \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ + hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ + hyperparameters['teacher_w'] * self.loss_teacher + \ + hyperparameters['id_adv_w'] * self.loss_gen_id_adv + self.loss_gen_total.backward() + self.gen_b_opt.step() + self.id_opt.step() + print( + "L_total: %.4f, L_gan: %.4f, L_adv: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f" % ( + self.loss_gen_total, \ + hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ + hyperparameters['id_adv_w'] * (self.loss_gen_id_adv), \ + hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ + hyperparameters['recon_xp_w'] * hyperparameters['recon_xp_tgt_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ + hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ + hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ + hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ + hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ + hyperparameters['id_w'] * self.loss_id, \ + hyperparameters['pid_w'] * self.loss_pid, \ + hyperparameters['teacher_w'] * self.loss_teacher)) + + def compute_vgg_loss(self, vgg, img, target): + img_vgg = vgg_preprocess(img) + target_vgg = vgg_preprocess(target) + img_fea = vgg(img_vgg) + target_fea = vgg(target_vgg) + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) + + def sample_ab(self, x_a, x_b): + self.eval() + x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_aba, x_bab = [], [], [], [], [], [], [], [] + for i in range(x_a.size(0)): + s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) + s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) + f_a, _, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) + f_b, _, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) + x_a_recon.append(self.gen_a.decode(s_a, f_a)) + x_b_recon.append(self.gen_b.decode(s_b, f_b)) + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + x_ba1.append(x_ba) + x_ba2.append(self.gen_b.decode(s_b, f_a)) + x_ab1.append(x_ab) + x_ab2.append(self.gen_a.decode(s_a, f_b)) + # cycle + s_b_recon = self.gen_b.enc_content(self.single(x_ba)) + s_a_recon = self.gen_a.enc_content(self.single(x_ab)) + f_a_recon, _, _ = self.id_a(scale2(x_ba)) + f_b_recon, _, _ = self.id_b(scale2(x_ab)) + x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) + x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) + + x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) + x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) + x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) + x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) + self.train() + + return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2 + + def sample_aa(self, x_a, x_b): + self.eval() + x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_aba, x_bab = [], [], [], [], [], [], [], [] + for i in range(x_a.size(0)): + s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) + s_b = self.gen_a.encode(self.single(x_b[i].unsqueeze(0))) + f_a, _, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) + f_b, _, _ = self.id_a(scale2(x_b[i].unsqueeze(0))) + x_a_recon.append(self.gen_a.decode(s_a, f_a)) + x_b_recon.append(self.gen_a.decode(s_b, f_b)) + x_ba = self.gen_a.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + x_ba1.append(x_ba) + x_ba2.append(self.gen_a.decode(s_b, f_a)) + x_ab1.append(x_ab) + x_ab2.append(self.gen_a.decode(s_a, f_b)) + # cycle + s_b_recon = self.gen_a.enc_content(self.single(x_ba)) + s_a_recon = self.gen_a.enc_content(self.single(x_ab)) + f_a_recon, _, _ = self.id_a(scale2(x_ba)) + f_b_recon, _, _ = self.id_a(scale2(x_ab)) + x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) + x_bab.append(self.gen_a.decode(s_b_recon, f_b_recon)) + + x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) + x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) + x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) + x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) + self.train() + + return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2 + + def sample_bb(self, x_a, x_b): + self.eval() + x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_aba, x_bab = [], [], [], [], [], [], [], [] + for i in range(x_a.size(0)): + s_a = self.gen_b.encode(self.single(x_a[i].unsqueeze(0))) + s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) + f_a, _, _ = self.id_b(scale2(x_a[i].unsqueeze(0))) + f_b, _, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) + x_a_recon.append(self.gen_b.decode(s_a, f_a)) + x_b_recon.append(self.gen_b.decode(s_b, f_b)) + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_b.decode(s_a, f_b) + x_ba1.append(x_ba) + x_ba2.append(self.gen_b.decode(s_b, f_a)) + x_ab1.append(x_ab) + x_ab2.append(self.gen_b.decode(s_a, f_b)) + # cycle + s_b_recon = self.gen_b.enc_content(self.single(x_ba)) + s_a_recon = self.gen_b.enc_content(self.single(x_ab)) + f_a_recon, _, _ = self.id_b(scale2(x_ba)) + f_b_recon, _, _ = self.id_b(scale2(x_ab)) + x_aba.append(self.gen_b.decode(s_a_recon, f_a_recon)) + x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) + + x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) + x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) + x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) + x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) + self.train() + + return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2 + + def dis_update_ab(self, x_a, x_b, hyperparameters): + self.dis_a_opt.zero_grad() + self.id_dis_opt.zero_grad() + # self.dis_b_opt.zero_grad() + # encode + # x_a_single = self.single(x_a) + s_a = self.gen_a.encode(self.single(x_a)) + s_b = self.gen_b.encode(self.single(x_b)) + f_a, _, fe_a = self.id_a(scale2(x_a)) + f_b, _, fe_b = self.id_b(scale2(x_b)) + # decode (cross domain) + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + # print(x_ab) + # D loss + self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) + self.loss_dis_b, reg_b = self.dis_a.calc_dis_loss(x_ab.detach(), x_b) + self.loss_id_dis_ab, _, _ = self.id_dis.calc_dis_loss_ab(fe_a.detach(), fe_b.detach()) + self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b + self.loss_id_dis_total = hyperparameters['id_adv_w'] * self.loss_id_dis_ab + print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b), "ID_adv: %.4f" % self.loss_id_dis_total) + self.loss_dis_total.backward() + self.loss_id_dis_total.backward() + # check gradient norm + self.loss_total_norm = 0.0 + for p in self.id_dis.parameters(): + param_norm = p.grad.data.norm(2) + self.loss_total_norm += param_norm.item() ** 2 + self.loss_total_norm = self.loss_total_norm ** (1. / 2) + # + self.dis_a_opt.step() + self.id_dis_opt.step() + # self.dis_b_opt.step() + + def dis_update_aa(self, x_a, x_b, hyperparameters): + self.dis_a_opt.zero_grad() + self.id_dis_opt.zero_grad() + # encode + # x_a_single = self.single(x_a) + s_a = self.gen_a.encode(self.single(x_a)) + s_b = self.gen_a.encode(self.single(x_b)) + f_a, _, fe_a = self.id_a(scale2(x_a)) + f_b, _, fe_b = self.id_a(scale2(x_b)) + # decode (cross domain) + x_ba = self.gen_a.decode(s_b, f_a) + x_ab = self.gen_a.decode(s_a, f_b) + # print(x_ab) + # D loss + self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) + self.loss_dis_b, reg_b = self.dis_a.calc_dis_loss(x_ab.detach(), x_b) + self.loss_id_dis_aa, _, _ = self.id_dis.calc_dis_loss_aa(fe_a.detach(), fe_b.detach()) + self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b + self.loss_id_dis_total = hyperparameters['id_adv_w'] * self.loss_id_dis_aa + print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b), "ID_adv: %.4f" % self.loss_id_dis_total) + self.loss_dis_total.backward() + self.loss_id_dis_total.backward() + # check gradient norm + self.loss_total_norm = 0.0 + for p in self.id_dis.parameters(): + param_norm = p.grad.data.norm(2) + self.loss_total_norm += param_norm.item() ** 2 + self.loss_total_norm = self.loss_total_norm ** (1. / 2) + # + self.dis_a_opt.step() + self.id_dis_opt.step() + + def dis_update_bb(self, x_a, x_b, hyperparameters): + self.dis_a_opt.zero_grad() + self.id_dis_opt.zero_grad() + # encode + # x_a_single = self.single(x_a) + s_a = self.gen_b.encode(self.single(x_a)) + s_b = self.gen_b.encode(self.single(x_b)) + f_a, _, fe_a = self.id_b(scale2(x_a)) + f_b, _, fe_b = self.id_b(scale2(x_b)) + # decode (cross domain) + x_ba = self.gen_b.decode(s_b, f_a) + x_ab = self.gen_b.decode(s_a, f_b) + # print(x_ab) + # D loss + self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) + self.loss_dis_b, reg_b = self.dis_a.calc_dis_loss(x_ab.detach(), x_b) + self.loss_id_dis_bb, _, _ = self.id_dis.calc_dis_loss_bb(fe_a.detach(), fe_b.detach()) + self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b + self.loss_id_dis_total = hyperparameters['id_adv_w'] * self.loss_id_dis_bb + print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b), "ID_adv: %.4f" % self.loss_id_dis_total) + self.loss_dis_total.backward() + self.loss_id_dis_total.backward() + # check gradient norm + self.loss_total_norm = 0.0 + for p in self.id_dis.parameters(): + param_norm = p.grad.data.norm(2) + self.loss_total_norm += param_norm.item() ** 2 + self.loss_total_norm = self.loss_total_norm ** (1. / 2) + # + self.dis_a_opt.step() + self.id_dis_opt.step() + + def update_learning_rate(self): + if self.dis_a_scheduler is not None: + self.dis_a_scheduler.step() + # if self.dis_b_scheduler is not None: + # self.dis_b_scheduler.step() + if self.gen_a_scheduler is not None: + self.gen_a_scheduler.step() + if self.gen_b_scheduler is not None: + self.gen_b_scheduler.step() + if self.id_scheduler is not None: + self.id_scheduler.step() + if self.id_dis_scheduler is not None: + self.id_dis_scheduler.step() + + def scale_learning_rate(self, lr_decayed, lr_recover, hyperparameters): + if not lr_decayed: + if lr_recover: + for g in self.dis_a_opt.param_groups: + g['lr'] *= hyperparameters['gamma'] + for g in self.gen_a_opt.param_groups: + g['lr'] *= hyperparameters['gamma'] + for g in self.gen_b_opt.param_groups: + g['lr'] *= hyperparameters['gamma'] + for g in self.id_opt.param_groups: + g['lr'] *= hyperparameters['gamma2'] + for g in self.id_dis_opt.param_groups: + g['lr'] *= hyperparameters['gamma2'] + elif not lr_recover: + for g in self.id_opt.param_groups: + g['lr'] = g['lr'] * hyperparameters['lr2_ramp_factor'] + elif lr_decayed: + for g in self.dis_a_opt.param_groups: + g['lr'] = g['lr'] / hyperparameters['gamma'] * hyperparameters['lr2_ramp_factor'] + for g in self.gen_a_opt.param_groups: + g['lr'] = g['lr'] / hyperparameters['gamma'] * hyperparameters['lr2_ramp_factor'] + for g in self.gen_b_opt.param_groups: + g['lr'] = g['lr'] / hyperparameters['gamma'] * hyperparameters['lr2_ramp_factor'] + for g in self.id_opt.param_groups: + g['lr'] = g['lr'] / hyperparameters['gamma2'] * hyperparameters['lr2_ramp_factor'] + for g in self.id_dis_opt.param_groups: + g['lr'] = g['lr'] / hyperparameters['gamma2'] * hyperparameters['lr2_ramp_factor'] + + def resume(self, checkpoint_dir, hyperparameters): + # Load generators + last_model_name = get_model_list(checkpoint_dir, "gen_a") + state_dict = torch.load(last_model_name) + self.gen_a.load_state_dict(state_dict['a']) + last_model_name = get_model_list(checkpoint_dir, "gen_b") + state_dict = torch.load(last_model_name) + self.gen_b.load_state_dict(state_dict['b']) + iterations = int(last_model_name[-11:-3]) + # Load discriminators + last_model_name = get_model_list(checkpoint_dir, "dis_a") + state_dict = torch.load(last_model_name) + self.dis_a.load_state_dict(state_dict['a']) + # last_model_name = get_model_list(checkpoint_dir, "dis_b") + # state_dict = torch.load(last_model_name) + self.dis_b = self.dis_a + # Load ID dis + last_model_name = get_model_list(checkpoint_dir, "id") + state_dict = torch.load(last_model_name) + self.id_a.load_state_dict(state_dict['a']) + self.id_b = self.id_a + # Load optimizers + try: + state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) + self.dis_a_opt.load_state_dict(state_dict['dis_a']) + # self.dis_b_opt.load_state_dict(state_dict['dis_b']) + self.gen_a_opt.load_state_dict(state_dict['gen_a']) + self.gen_b_opt.load_state_dict(state_dict['gen_b']) + self.id_opt.load_state_dict(state_dict['id']) + except: + pass + # Reinitilize schedulers + self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, iterations) + # self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, iterations) + self.gen_a_scheduler = get_scheduler(self.gen_a_opt, hyperparameters, iterations) + self.gen_b_scheduler = get_scheduler(self.gen_b_opt, hyperparameters, iterations) + self.id_scheduler = get_scheduler(self.id_opt, hyperparameters, iterations) + print('Resume from iteration %d' % iterations) + return iterations + + def resume_DAt0(self, checkpoint_dir): + # Load generators + last_model_name = get_model_list(checkpoint_dir, "gen") + # last_model_name = get_model_list('/home/yazou/Projects/DGNet-master/outputs/latest/checkpoints', "gen") + state_dict = torch.load(last_model_name) + self.gen_a.load_state_dict(state_dict['a'],strict=False) + # last_model_name = get_model_list(checkpoint_dir, "gen_b") + # state_dict = torch.load(last_model_name) + self.gen_b.load_state_dict(state_dict['a'],strict=False) + iterations = 0 + # Load discriminators + last_model_name = get_model_list(checkpoint_dir, "dis") + state_dict = torch.load(last_model_name) + self.dis_a.load_state_dict(state_dict['a'],strict=False) + # last_model_name = get_model_list(checkpoint_dir, "dis_b") + # state_dict = torch.load(last_model_name) + self.dis_b = self.dis_a + # Load ID dis + last_model_name = get_model_list(checkpoint_dir, "id") + state_dict = torch.load(last_model_name) + classifier1 = self.id_a.classifier1.classifier + classifier2 = self.id_a.classifier2.classifier + self.id_a.classifier1.classifier = nn.Sequential() + self.id_a.classifier2.classifier = nn.Sequential() + self.id_a.load_state_dict(state_dict['a'], strict=False) + self.id_a.classifier1.classifier = classifier1 + self.id_a.classifier2.classifier = classifier2 + self.id_b = self.id_a + print('Resume from iteration %d' % iterations) + #self.save(checkpoint_dir, 0) + return iterations + + def resume_DAt1(self, checkpoint_dir): + # Load generators + last_model_name = get_model_list(checkpoint_dir, "gen_a") + # last_model_name = get_model_list('/home/yazou/Projects/DGNet-master/outputs/latest/checkpoints', "gen") + state_dict = torch.load(last_model_name) + self.gen_a.load_state_dict(state_dict['a']) + # last_model_name = get_model_list(checkpoint_dir, "gen_b") + # state_dict = torch.load(last_model_name) + last_model_name = get_model_list(checkpoint_dir, "gen_b") + # last_model_name = get_model_list('/home/yazou/Projects/DGNet-master/outputs/latest/checkpoints', "gen") + state_dict = torch.load(last_model_name) + self.gen_b.load_state_dict(state_dict['b']) + iterations = 0 + # Load discriminators + last_model_name = get_model_list(checkpoint_dir, "dis_a") + state_dict = torch.load(last_model_name) + self.dis_a.load_state_dict(state_dict['a']) + # last_model_name = get_model_list(checkpoint_dir, "dis_b") + # state_dict = torch.load(last_model_name) + self.dis_b = self.dis_a + # Load ID dis + last_model_name = get_model_list(checkpoint_dir, "id") + state_dict = torch.load(last_model_name) + classifier1 = self.id_a.classifier1.classifier + classifier2 = self.id_a.classifier2.classifier + self.id_a.classifier1.classifier = nn.Sequential() + self.id_a.classifier2.classifier = nn.Sequential() + self.id_a.load_state_dict(state_dict['a'], strict=False) + self.id_a.classifier1.classifier = classifier1 + self.id_a.classifier2.classifier = classifier2 + self.id_b = self.id_a + print('Resume from iteration %d' % iterations) + #self.save(checkpoint_dir, 0) + return iterations + + def save(self, snapshot_dir, iterations): + # Save generators, discriminators, and optimizers + gen_a_name = os.path.join(snapshot_dir, 'gen_a_%08d.pt' % (iterations + 1)) + gen_b_name = os.path.join(snapshot_dir, 'gen_b_%08d.pt' % (iterations + 1)) + dis_a_name = os.path.join(snapshot_dir, 'dis_a_%08d.pt' % (iterations + 1)) + dis_b_name = os.path.join(snapshot_dir, 'dis_b_%08d.pt' % (iterations + 1)) + id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) + opt_name = os.path.join(snapshot_dir, 'optimizer.pt') + torch.save({'a': self.gen_a.state_dict()}, gen_a_name) + torch.save({'b': self.gen_b.state_dict()}, gen_b_name) + torch.save({'a': self.dis_a.state_dict()}, dis_a_name) + torch.save({'b': self.dis_b.state_dict()}, dis_b_name) + torch.save({'a': self.id_a.state_dict()}, id_name) + torch.save({'gen_a': self.gen_a_opt.state_dict(), 'gen_b': self.gen_b_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis_a': self.dis_a_opt.state_dict(), 'dis_b': self.dis_a_opt.state_dict()}, + opt_name) + + def test(self, opt): + self.eval() + test_dir = opt['data_root_b'] + data_transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 128), interpolation=3), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ############### Ten Crop + # transforms.TenCrop(224), + # transforms.Lambda(lambda crops: torch.stack( + # [transforms.ToTensor()(crop) + # for crop in crops] + # )), + # transforms.Lambda(lambda crops: torch.stack( + # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) + # for crop in crops] + # )) + ]) + data_dir = test_dir + image_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms) for x in + ['gallery', 'query']} + dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt['test_batchsize'], + shuffle=False, num_workers=16) for x in ['gallery', 'query']} + + gallery_path = image_datasets['gallery'].imgs + query_path = image_datasets['query'].imgs + + gallery_cam, gallery_label, _ = get_id(gallery_path, time_constraint = opt['time_constraint']) + query_cam, query_label, _ = get_id(query_path, time_constraint = opt['time_constraint']) + + ###################################################################### + # Load Collected data Trained model + print('-------test-----------') + + # Extract feature + with torch.no_grad(): + gallery_feature = self.extract_feature(dataloaders['gallery'], opt) + query_feature = self.extract_feature(dataloaders['query'], opt) + + gallery_label = np.array(gallery_label) + gallery_cam = np.array(gallery_cam) + query_label = np.array(query_label) + query_cam = np.array(query_cam) + alpha = [0, 0.5, -1] + mAP_alpha = [0]*3 + # print(query_label) + for j in range(len(alpha)): + CMC = torch.IntTensor(len(gallery_label)).zero_() + ap = 0.0 + for i in range(len(query_label)): + qf = query_feature[i].clone() + if alpha[j] == -1: + qf[0:512] *= 0 + else: + qf[512:1024] *= alpha[j] + + ap_tmp, CMC_tmp = evaluate(qf, query_label[i], query_cam[i], gallery_feature, gallery_label, + gallery_cam) + if CMC_tmp[0] == -1: + continue + CMC = CMC + CMC_tmp + ap += ap_tmp + # print(i, CMC_tmp[0]) + + CMC = CMC.float() + CMC = CMC / len(query_label) # average CMC + print('Alpha:%.2f Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f' % ( + alpha[j], CMC[0], CMC[4], CMC[9], ap / len(query_label))) + mAP_alpha[j] = ap / len(query_label) + self.rank_1 = CMC[0] + self.rank_5 = CMC[4] + self.rank_10 = CMC[9] + self.mAP_zero = mAP_alpha[0] + self.mAP_half = mAP_alpha[1] + self.mAP_neg_one = mAP_alpha[2] + + del gallery_feature, query_feature, query_label + self.train() + + return + + def pseudo_label_generate(self, opt): + ### Feature extraction ### + self.eval() + test_dir = opt['data_root_b'] + data_transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 128), interpolation=3), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + data_dir = test_dir + image_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms) for x in ['train_all']} + dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt['test_batchsize'], + shuffle=False, num_workers=16) for x in ['train_all']} + train_path = image_datasets['train_all'].imgs + # Extract feature + with torch.no_grad(): + train_feature = self.extract_feature(dataloaders['train_all'], opt) + self.train() + + ### clustering ### + labels = self.clustering(train_feature, train_path, opt) + + ### copy and save images ### + n_samples = train_feature.shape[0] + opt['ID_class_b'] = int(max(labels)) + 1 + self.copy_save(labels, train_path, n_samples, opt) + return + + def clustering(self, train_feature, train_path, opt): + + ###################################################################### + alpha = 0.5 + + n_samples = train_feature.shape[0] + train_feature_clone = train_feature.clone() + train_feature_clone[:, 512:1024] *= alpha # since we count 0.5 for the fine-grained feature. 0.7*0.7=0.49 + train_dist = torch.mm(train_feature_clone, torch.transpose(train_feature, 0, 1)) / (1 + alpha) + print(train_dist) + + if opt['time_constraint']: + print('--------------------------Use Time Constraint---------------------------') + train_camera_id, train_time_id, train_labels = get_id(train_path, time_constraint = opt['time_constraint']) + train_time_id = np.asarray(train_time_id) + train_camera_id = np.asarray(train_camera_id) + + # Long Time + for i in range(n_samples): + t_time = train_time_id[i] + index = np.argwhere(np.absolute(train_time_id - t_time) > 40000).flatten() + train_dist[i, index] = -1 + print(len(index)) + + # Same Camera Long Time + for i in range(n_samples): + t_time = train_time_id[i] + t_cam = train_camera_id[i] + index = np.argwhere(np.absolute(train_time_id - t_time) > 5000).flatten() + c_index = np.argwhere(train_camera_id == t_cam).flatten() + index = np.intersect1d(index, c_index) + train_dist[i, index] = -1 + print(len(index)) + + print('--------------------------Start Re-ranking---------------------------') + train_dist = re_ranking_one(train_dist.cpu().numpy()) + print('--------------------------Clustering---------------------------') + # cluster + min_samples = opt['clustering']['min_samples'] + eps = opt['clustering']['eps'] + + cluster = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed', n_jobs=8) + ### non-negative clustering + train_dist = np.maximum(train_dist, 0) + ### + cluster = cluster.fit(train_dist) + print('Cluster Class Number: %d' % len(np.unique(cluster.labels_))) + # center = cluster.core_sample_indices_ + labels = cluster.labels_ + + return labels + + def copy_save(self, labels, train_path, n_samples, opt): + ### copy pseudo-labels in target ### + save_path = opt['data_root'] + '/train_all' + sample_b_valid = 0 + for i in range(n_samples): + if labels[i] != -1: + src_path = train_path[i][0] + dst_id = labels[i] + dst_path = save_path + '/' + 'B_' + str(int(dst_id)) + if not os.path.isdir(dst_path): + os.mkdir(dst_path) + copyfile(src_path, dst_path + '/' + os.path.basename(src_path)) + sample_b_valid += 1 + + opt['sample_b'] = sample_b_valid + + ### copy ground truth in source ### + # train_all + src_all_path = opt['data_root_a'] + # for dukemtmc-reid, we do not need multi-query + src_train_all_path = os.path.join(src_all_path, 'train_all') + subfolder_list = os.listdir(src_train_all_path) + file_list = [] + for path, subdirs, files in os.walk(src_train_all_path): + for name in files: + file_list.append(os.path.join(path, name)) + opt['ID_class_a'] = len(subfolder_list) + opt['sample_a'] = len(file_list) + for name in subfolder_list: + copytree(src_train_all_path + '/' + name, save_path + '/A_' + name) + + return + + + def extract_feature(self, dataloaders, opt): + model = copy.deepcopy(self.id_a) + if opt['train_bn']: + model = model.apply(train_bn) + # Remove the final fc layer and classifier layer + model.model.fc = nn.Sequential() + model.classifier1.classifier = nn.Sequential() + model.classifier2.classifier = nn.Sequential() + model.eval() + features = torch.FloatTensor() + count = 0 + for data in dataloaders: + img, label = data + img, label = img.cuda().detach(), label.cuda().detach() + n, c, h, w = img.size() + count += n + #print(count) + ff = torch.FloatTensor(n,1024).zero_() + for i in range(2): + if(i==1): + img = fliplr(img) + input_img = Variable(img) + f, x, _ = model(input_img) + x[0] = norm(x[0]) + x[1] = norm(x[1]) + f = torch.cat((x[0],x[1]), dim=1) #use 512-dim feature + f = f.data.cpu() + ff = ff+f + + ff[:, 0:512] = norm(ff[:, 0:512], dim=1) + ff[:, 512:1024] = norm(ff[:, 512:1024], dim =1) + features = torch.cat((features,ff), 0) + del model + return features + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5f56b6c --- /dev/null +++ b/utils.py @@ -0,0 +1,461 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from torch.utils.data import DataLoader, Subset +from torch.autograd import Variable +from torch.optim import lr_scheduler +from torchvision import transforms +from data import ImageFilelist +from torchvision.datasets import ImageFolder +from reIDfolder import ReIDFolder, ReIDFolder_mix +import torch +import os +import math +import torchvision.utils as vutils +import yaml +import numpy as np +from random import shuffle +import torch.nn.init as init +import time +from operator import itemgetter +from shutil import rmtree +# Methods +# get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB) +# get_data_loader_list : list-based data loader +# get_data_loader_folder : folder-based data loader +# get_config : load yaml file +# eformat : +# write_2images : save output image +# prepare_sub_folder : create checkpoints and images folders for saving outputs +# prepare_sub_folder_pseudo : create checkpoints, images and pseudo labels folders for saving outputs +# get_mix_data_loaders +# write_one_row_html : write one row of the html file for output images +# write_html : create the html file. +# write_loss +# slerp +# get_slerp_interp +# get_model_list +# load_vgg16 +# vgg_preprocess +# get_scheduler +# weights_init + +def get_mix_data_loaders(conf): + batch_size = conf['batch_size'] + num_workers = conf['num_workers'] + if 'new_size' in conf: + new_size_a= conf['new_size'] + new_size_b = conf['new_size'] + else: + new_size_a = conf['new_size_a'] + new_size_b = conf['new_size_b'] + height = conf['crop_image_height'] + width = conf['crop_image_width'] + + # generate the list of the mixed data folder + train_path = conf['data_root'] + mixData = ImageFolder(train_path) + ab_list = mixData.imgs + ab_idx = [i for i in range(len(ab_list))] + size_a = conf['sample_a'] + size_b = conf['sample_b'] + # full lists of two datasets + a_full_idx = ab_idx[0:size_a] + b_full_idx = ab_idx[size_a:] + # generate two sample lists of two datasets with equal size + if size_a > size_b: + sel_idx = list(np.random.choice(size_a, size_b, replace=False)) + a_idx = list(itemgetter(*sel_idx)(a_full_idx)) + b_idx = b_full_idx.copy() + elif size_b > size_a: + sel_idx = list(np.random.choice(size_b, size_a, replace=False)) + b_idx = list(itemgetter(*sel_idx)(b_full_idx)) + a_idx = a_full_idx.copy() + else: + a_idx = a_full_idx.copy() + b_idx = b_full_idx.copy() + + a_idx_a = a_idx.copy() + a_idx_b = a_idx.copy() + b_idx_a = b_idx.copy() + b_idx_b = b_idx.copy() + + # generate two lists for train_loader_a and train_loader_b + ab_port = conf['ab_port'] + bs = conf['batch_size'] + size_domain = min(size_a, size_b) + ab_num = math.floor(ab_port * size_domain) // bs * bs + xx_num = (size_domain - ab_num) // bs * bs + idx_la = [] # list for loader_a + idx_lb = [] # list for loader_b + + sel_idx_ab_a = list(np.random.choice(size_domain, ab_num, replace=False)) + sel_idx_ab_b = list(np.random.choice(size_domain, ab_num, replace=False)) + sel_idx_ba_a = list(np.random.choice(size_domain, ab_num, replace=False)) + sel_idx_ba_b = list(np.random.choice(size_domain, ab_num, replace=False)) + + aa_idx_a = [a_idx_a[i] for i in range(size_domain) if i not in sel_idx_ab_a] # batch aa for train_loader_a + aa_idx_b = [a_idx_b[i] for i in range(size_domain) if i not in sel_idx_ba_b] # batch aa for train_loader_b + bb_idx_a = [b_idx_a[i] for i in range(size_domain) if i not in sel_idx_ba_a] # batch bb for train_loader_a + bb_idx_b = [b_idx_b[i] for i in range(size_domain) if i not in sel_idx_ab_b] # batch bb for train_loader_b + shuffle(aa_idx_a) + shuffle(aa_idx_b) + shuffle(bb_idx_a) + shuffle(bb_idx_b) + aa_idx_a = aa_idx_a[:xx_num] + aa_idx_b = aa_idx_b[:xx_num] + bb_idx_a = bb_idx_a[:xx_num] + bb_idx_b = bb_idx_b[:xx_num] + ab_idx_a, ab_idx_b, ba_idx_a, ba_idx_b = [], [], [], [] + if sel_idx_ab_a != []: + ab_idx_a = list(itemgetter(*sel_idx_ab_a)(a_idx_a)) # batch ab for train_loader_a + if sel_idx_ab_b != []: + ab_idx_b = list(itemgetter(*sel_idx_ab_b)(b_idx_b)) # batch ab for train_loader_b + if sel_idx_ba_a != []: + ba_idx_a = list(itemgetter(*sel_idx_ba_a)(b_idx_a)) # batch ab for train_loader_a + if sel_idx_ba_b != []: + ba_idx_b = list(itemgetter(*sel_idx_ba_b)(a_idx_b)) # batch ab for train_loader_b + + aa_thresh = conf['xx_port'] / 2 + bb_thresh = aa_thresh * 2 + ab_thresh = bb_thresh + conf['ab_port'] / 2 + while aa_idx_b or bb_idx_a or ab_idx_a or ba_idx_a: + dice = np.random.uniform(0, 1) + if dice <= aa_thresh: + if not aa_idx_a: + continue + for _ in range(batch_size): + idx_la.append(aa_idx_a.pop()) + idx_lb.append(aa_idx_b.pop()) + elif dice > aa_thresh and dice <= bb_thresh: + if not bb_idx_a: + continue + for _ in range(batch_size): + idx_la.append(bb_idx_a.pop()) + idx_lb.append(bb_idx_b.pop()) + elif dice > bb_thresh and dice <= ab_thresh: + if not ab_idx_a: + continue + for _ in range(batch_size): + idx_la.append(ab_idx_a.pop()) + idx_lb.append(ab_idx_b.pop()) + else: + if not ba_idx_a: + continue + for _ in range(batch_size): + idx_la.append(ba_idx_a.pop()) + idx_lb.append(ba_idx_b.pop()) + + train_loader_a = get_data_loader_folder_mix(os.path.join(conf['data_root'], 'train_all'), idx_la, batch_size, True, + new_size_a, height, width, num_workers, True) + test_loader_a = get_data_loader_folder(os.path.join(conf['data_root_a'], 'query'), batch_size, False, + new_size_a, height, width, num_workers, False) + train_loader_b = get_data_loader_folder_mix(os.path.join(conf['data_root'], 'train_all'), idx_lb, batch_size, True, + new_size_b, height, width, num_workers, True) + test_loader_b = get_data_loader_folder(os.path.join(conf['data_root_b'], 'query'), batch_size, False, + new_size_b, height, width, num_workers, False) + + return train_loader_a, train_loader_b, test_loader_a, test_loader_b + +def get_all_data_loaders(conf): + batch_size = conf['batch_size'] + num_workers = conf['num_workers'] + if 'new_size' in conf: + new_size_a= conf['new_size'] + new_size_b = conf['new_size'] + else: + new_size_a = conf['new_size_a'] + new_size_b = conf['new_size_b'] + height = conf['crop_image_height'] + width = conf['crop_image_width'] + + if 'data_root' in conf: + train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True, + new_size_a, height, width, num_workers, True) + test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False, + new_size_a, height, width, num_workers, False) + train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True, + new_size_b, height, width, num_workers, True) + test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False, + new_size_b, height, width, num_workers, False) + else: + train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True, + new_size_a, height, width, num_workers, True) + test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False, + new_size_a, height, width, num_workers, False) + train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True, + new_size_b, height, width, num_workers, True) + test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False, + new_size_b, height, width, num_workers, False) + return train_loader_a, train_loader_b, test_loader_a, test_loader_b + + +def get_data_loader_list(root, file_list, batch_size, train, new_size=None, + height=256, width=128, num_workers=4, crop=True): + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))] + transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list + transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list + transform_list = [transforms.Resize((height, width), interpolation=3)] + transform_list if new_size is not None else transform_list + transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list + transform = transforms.Compose(transform_list) + dataset = ImageFilelist(root, file_list, transform=transform) + loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) + return loader + +def get_data_loader_folder(input_folder, batch_size, train, new_size=None, + height=256, width=128, num_workers=4, crop=True): + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))] + transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list + transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list + transform_list = [transforms.Resize((height,width), interpolation=3)] + transform_list if new_size is not None else transform_list + transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list + transform = transforms.Compose(transform_list) + dataset = ReIDFolder(input_folder, transform=transform) + loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) + return loader + +def get_data_loader_folder_mix(input_folder, idx_list, batch_size, train, new_size=None, + height=256, width=128, num_workers=4, crop=True): + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))] + transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list + transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list + transform_list = [transforms.Resize((height,width), interpolation=3)] + transform_list if new_size is not None else transform_list + transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list + transform = transforms.Compose(transform_list) + dataset = ReIDFolder_mix(input_folder, transform=transform, idx_list=idx_list) + loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=num_workers) + return loader + +def get_config(config): + with open(config, 'r') as stream: + return yaml.load(stream) + + +def eformat(f, prec): + s = "%.*e"%(prec, f) + mantissa, exp = s.split('e') + # add 1 to digits as 1 is taken by sign +/- + return "%se%d"%(mantissa, int(exp)) + + +def __write_images(image_outputs, display_image_num, file_name): + image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels + image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) + image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True, scale_each=True) + vutils.save_image(image_grid, file_name, nrow=1) + + +def write_2images(image_outputs, display_image_num, image_directory, postfix): + n = len(image_outputs) + __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) + __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) + +def vis_2images(image_outputs, display_image_num, image_directory, postfix): + __write_images(image_outputs, display_image_num, '%s/batch_%s.jpg' % (image_directory, postfix)) + +def prepare_sub_folder(output_directory): + image_directory = os.path.join(output_directory, 'images') + if not os.path.exists(image_directory): + print("Creating directory: {}".format(image_directory)) + os.makedirs(image_directory) + checkpoint_directory = os.path.join(output_directory, 'checkpoints') + if not os.path.exists(checkpoint_directory): + print("Creating directory: {}".format(checkpoint_directory)) + os.makedirs(checkpoint_directory) + return checkpoint_directory, image_directory + +def prepare_sub_folder_pseudo(output_directory): + image_directory = os.path.join(output_directory, 'images') + if not os.path.exists(image_directory): + print("Creating directory: {}".format(image_directory)) + os.makedirs(image_directory) + checkpoint_directory = os.path.join(output_directory, 'checkpoints') + if not os.path.exists(checkpoint_directory): + print("Creating directory: {}".format(checkpoint_directory)) + os.makedirs(checkpoint_directory) + pseudo_directory = os.path.join(output_directory, 'pseudo_train') + if not os.path.exists(pseudo_directory): + print("Creating directory: {}".format(pseudo_directory)) + os.makedirs(pseudo_directory) + os.makedirs(pseudo_directory + '/train_all') + else: + rmtree(pseudo_directory) + os.makedirs(pseudo_directory) + os.makedirs(pseudo_directory + '/train_all') + + return checkpoint_directory, image_directory, pseudo_directory + + +def write_one_row_html(html_file, iterations, img_filename, all_size): + html_file.write("

iteration [%d] (%s)

" % (iterations,img_filename.split('/')[-1])) + html_file.write(""" +

+ +
+

+ """ % (img_filename, img_filename, all_size)) + return + + +def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): + html_file = open(filename, "w") + html_file.write(''' + + + + Experiment name = %s + + + + ''' % os.path.basename(filename)) + html_file.write("

current

") + write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) + write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) + for j in range(iterations, image_save_iterations-1, -1): + if j % image_save_iterations == 0: + write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) + write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) + write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) + write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) + html_file.write("") + html_file.close() + + +def write_loss(iterations, trainer, train_writer): + members = [attr for attr in dir(trainer) \ + if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'rank_' in attr or 'mAP_' in attr or 'grad' in attr or 'nwd' in attr)] + for m in members: + train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) + + +def slerp(val, low, high): + """ + original: Animating Rotation with Quaternion Curves, Ken Shoemake + https://arxiv.org/abs/1609.04468 + Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White + """ + omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) + so = np.sin(omega) + return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high + + +def get_slerp_interp(nb_latents, nb_interp, z_dim): + """ + modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot + https://github.com/ptrblck/prog_gans_pytorch_inference + """ + + latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) + for _ in range(nb_latents): + low = np.random.randn(z_dim) + high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7 + interp_vals = np.linspace(0, 1, num=nb_interp) + latent_interp = np.array([slerp(v, low, high) for v in interp_vals], + dtype=np.float32) + latent_interps = np.vstack((latent_interps, latent_interp)) + + return latent_interps[:, :, np.newaxis, np.newaxis] + + +# Get model list for resume +def get_model_list(dirname, key): + if os.path.exists(dirname) is False: + return None + gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if + os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] + if gen_models is None: + return None + gen_models.sort() + last_model_name = gen_models[-1] + return last_model_name + + +def load_vgg16(model_dir): + """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): + if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): + os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) + vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) + vgg = Vgg16() + for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + dst.data[:] = src + torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) + vgg = Vgg16() + vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) + return vgg + + +def vgg_preprocess(batch): + tensortype = type(batch.data) + (r, g, b) = torch.chunk(batch, 3, dim = 1) + batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR + batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + batch = batch.sub(Variable(mean)) # subtract mean + return batch + + +def get_scheduler(optimizer, hyperparameters, iterations=-1): + if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': + scheduler = None # constant scheduler + elif hyperparameters['lr_policy'] == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], + gamma=hyperparameters['gamma'], last_epoch=iterations) + elif hyperparameters['lr_policy'] == 'multistep': + #50000 -- 75000 -- + step = hyperparameters['step_size'] + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[step, step+step//2, step+step//2+step//4], + gamma=hyperparameters['gamma'], last_epoch=iterations) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) + return scheduler + + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): + # print m.__class__.__name__ + if init_type == 'gaussian': + init.normal_(m.weight.data, 0.0, 0.02) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + return init_fun + + +class Timer: + def __init__(self, msg): + self.msg = msg + self.start_time = None + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, exc_type, exc_value, exc_tb): + print(self.msg % (time.time() - self.start_time)) + +