Real-time Domain Adaptation in Semantic Segmentation (Course Project)
This repository provides a starter-code setup for the Real-time Domain Adaptation in Semantic Segmentation project of the Advance Machine Learning Course. (Presentation of the project)
- datasets: Contains the dataset classes for the Cityscapes and GTA datasets. The train and validation images for both datasets will also be inserted here.
- model: Contains the STDCNet model and the Discriminator model.
- runs: Contains the tensorboard logs for all the project steps.
- saved_models: Contains the saved models for all the project steps.
- STDCNET_weights: Contains the pre-trained weights for the STDCNet model.
- Download the pre-trained weight at this link at put it in the STDCNET_weights folder.
- Download the Cityscapes dataset and the GTA dataset at this link and put it in the datasets' folder.
- 2.A: Train the STDCNet model on the Cityscapes dataset and evaluate it on the Cityscapes dataset.
- 2.B: Train the STDCNet model on the GTA dataset and evaluate it on the GTA dataset.
- 2.C.1: Evaluate the best model from step 2.B on the Cityscapes dataset.
- 2.C.2: Train the STDCNet model on the GTA augmented dataset and evaluate it on the Cityscapes dataset.
- 3: Train the STDCNet model with unsupervised adversarial training domain adaptation with labeled synthetic data (source GTA dataset) and unlabelled real data (target Cityscapes datasets).
- 4.A: Train the STDCNet model with unsupervised adversarial training domain adaptation with labeled synthetic data (source GTA dataset) and unlabelled real data (target Cityscapes datasets) using a depthwise discriminator.
- 4.B: Train the STDCNet model with unsupervised adversarial training domain adaptation with labeled synthetic data (source GTA dataset) and unlabelled real data (target Cityscapes datasets) using a diagonalwise discriminator.
- 2.A:
--train_dataset Cityscapes --val_dataset Cityscapes --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --num_epochs 50 --learning_rate 0.01 --crop_height 512 --crop_width 1024 --tensorboard_path runs/2_A --save_model_path saved_models/2_A --optimizer sgd --loss crossentropy
- 2.B:
--train_dataset GTA --val_dataset GTA --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --num_epochs 50 --crop_height 512 --learning_rate 0.01 --crop_width 1024 --tensorboard_path runs/2_B --save_model_path saved_models/2_B --optimizer sgd --loss crossentropy
- 2.C.1:
--mode val --val_dataset Cityscapes --crop_height 512 --crop_width 1024 --save_model_path saved_models/2_B/best.pth
- 2.C.2:
--train_dataset GTA_aug --val_dataset Cityscapes --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --learning_rate 0.01 --num_epochs 50 --crop_height 512 --crop_width 1024 --tensorboard_path runs/2_C_2 --save_model_path saved_models/2_C_2 --optimizer sgd --loss crossentropy
- 3:
--mode train_adversarial --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --learning_rate 0.01 --discriminator_learning_rate 0.001 --num_epochs 50 --crop_height 512 --crop_width 1024 --tensorboard_path runs/3 --save_model_path saved_models/3
- 4.A:
--mode train_adversarial --depthwise_discriminator depthwise --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --learning_rate 0.01 --discriminator_learning_rate 0.001 --num_epochs 50 --crop_height 512 --crop_width 1024 --tensorboard_path runs/4_A --save_model_path saved_models/4_A
- 4.B:
--mode train_adversarial --depthwise_discriminator diagonalwise --pretrain_path STDCNET_weights/STDCNet813M_73.91.tar --batch_size 8 --learning_rate 0.01 --discriminator_learning_rate 0.001 --num_epochs 50 --crop_height 512 --crop_width 1024 --tensorboard_path runs/4_B --save_model_path saved_models/4_B
Train Datasets | Validation Datasets | Accuracy (%) | mIoU (%) | Train Time (avg per-epochs) |
---|---|---|---|---|
Cityscapes | Cityscapes | 81 | 57.8 | 2:33 minutes |
GTA | GTA | 80.8 | 62.0 | 3:28 minutes |
GTA | Cityscapes | 60.1 | 24.6 | None |
GTA augmented | Cityscapes | 70.2 | 30.7 | 5:22 minutes |
Single Layer DA Source=GTA, Target=Cityscapes |
Cityscapes | 74.3 | 33.8 | 4:33 minutes |
Single Layer DA Source=GTA, Target=Cityscapes Depthwise discriminator function |
Cityscapes | 73.1 | 32.7 | 4:32 minutes |
Single Layer DA Source=GTA, Target=Cityscapes Diagonalwise discriminator function |
Cityscapes | 74.0 | 33.5 | 4:25 minutes |