Skip to content

Official PyTorch implementation of "TridentAdapt: Learning Domain-invariance via Source-Target Confrontation and Self-induced Cross-domain Augmentation"

Notifications You must be signed in to change notification settings

HMRC-AEL/TridentAdapt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TridentAdapt: Learning Domain-invariance via Source-Target Confrontation and Self-induced Cross-domain Augmentation (BMVC 2021, Official Pytorch implementation, Paper Page)

Abstract

Due to the difficulty of obtaining ground-truth labels, learning from virtual-world datasets is of great interest for real-world applications like semantic segmentation. From domain adaptation perspective, the key challenge is to learn domain-agnostic representation of the inputs in order to benefit from virtual data. In this paper, we propose a novel trident-like architecture that enforces a shared feature encoder to satisfy confrontational source and target constraints simultaneously, thus learning a domain-invariant feature space. Moreover, we also introduce a novel training pipeline enabling self-induced cross-domain data augmentation during the forward pass. This contributes to a further reduction of the domain gap. Combined with a self-training process, we obtain state-of-the-art results on benchmark datasets (e.g. GTA5 or Synthia to Cityscapes adaptation).

Demo Results

Datasets

The data folder /data follows the original structure of each dataset (e.g. GTA5-->Cityscapes):

|---data
    |--- Cityscapes
    |   |--- gtFine
    |   |--- leftImg8bit
    |--- GTA5
        |--- images
        |--- labels

Usage

Run Training Script: e.g. GTA5->Cityscapes

Stage1 (go to ./TridentAdapt_GTA5 folder)

python train_gta2city_stg1.py --gta5_data_path /data/GTA5 --city_data_path /data/Cityscapes

After stage1, generate pseudo-labels and put in /data/Cityscapes for stage2 self-training

python pseudolabel_generator.py ./weights --city_data_path ./data/Cityscapes

Stage2

python train_gta2city_stg2.py --gta5_data_path /data/GTA5 --city_data_path /data/Cityscapes

Run Evaluation Script:

python evaluate_val.py ./weights --city_data_path /data/Cityscapes

Pretrained Weights

We provide our pretrained models and generated pseudo labels:

GTA5 Pseudo Labels,

GTA5 final model (peaked at 53.5 mIoU).

Synthia Pseudo Labels,

Synthia final model (54.4 mIoU/13 Classes).

Citation

If you like this work and would like to use our code or models for research, please feel free to cite as follows.

@article{shen2021tridentadapt,
  title={TridentAdapt: Learning Domain-invariance via Source-Target Confrontation and Self-induced Cross-domain Augmentation},
  author={Shen, Fengyi and Gurram, Akhil and Tuna, Ahmet Faruk and Urfalioglu, Onay and Knoll, Alois},
  journal={arXiv preprint arXiv:2111.15300},
  year={2021}
}

Acknowledgement

Our implementation is inspired by AdaptSeg and DISE.

About

Official PyTorch implementation of "TridentAdapt: Learning Domain-invariance via Source-Target Confrontation and Self-induced Cross-domain Augmentation"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages