Code for Semantically Robust Unpaired Image Translation for Data with Unmatched Semantics Statistics (SRUNIT), ICCV 2021
This is a PyTorch (re-)implementation of the method presented in the following paper:
Semantically Robust Unpaired Image Translation for Data with Unmatched Semantics Statistics
Zhiwei Jia, Bodi Yuan, Kangkang Wang, Hong Wu, David Clifford, Zhiqiang Yuan, Hao Su
UC San Diego and Google X
ICCV 2021
The code is largely adapted from the PyTorch implementation of CUT from the Github repo link. Specifically, SRUNIT deals with the common and detrimental sementic flipping issue in unpaired image translatios by equipping CUT with a novel multi-scale semantic robustness regularization strategy. The key idea is to encourage a consistent translation such that contents of the same semantics are not transformed into contents of several different semantics. The data flow is illustrated below
We provide two versions of the implementation for performing backward pass for the semantic robustness (SR) loss self.compute_reg_loss()
.
Namely, v1
where the SR loss is used to update G (the generator) but not F (the domain-invariant feature extractor in CUT) and v2
where it is used to update both.
We find that v2
outperforms v1
consistently and use the latter as the default value for --reg_type
.
# Update G and F.
self.set_requires_grad(self.netD, False)
self.set_requires_grad(self.netF, True)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.optimizer_F.zero_grad()
if self.opt.reg_type == 'v1':
self.loss_G.backward(retain_graph=True)
self.set_requires_grad(self.netF, False)
self.loss_reg = self.compute_reg_loss()
loss = self.reg * self.loss_reg
elif self.opt.reg_type == 'v2':
self.loss_reg = self.compute_reg_loss()
loss = self.loss_G + self.reg * self.loss_reg
loss.backward()
self.optimizer_G.step()
self.optimizer_F.step()
To improve the computational efficiency of SRUNIT so that it achieves a training speed comparable to CUT, at each gradient update, we sample one scale at a time for injecting the noises to the features, as shown below.
# By default reg_layers = '0,1,2,3,4' (the multi-scale loss)
choices = [int(l) for l in self.opt.reg_layers.split(',')]
self.choice = np.random.choice(choices, 1)[0]
self.feats_perturbed, self.noise_magnitude = self.netG(
layers=[self.nce_layers[self.choice]],
feats=[self.feats_real_A[self.choice]],
noises=[self.opt.reg_noise]
)
To further reduce the computation burden, at each time we only sample feature vectors at 256
position using the generated sample_ids
from PatchSampleF
, similar to the constrastive loss in CUT. The pseudo-code is shown below.
def compute_reg_loss(self):
# ... some steps omitted here ...
# f_q, f_k are features from the original src images and the that with feature perturbation.
noise_mag = self.noise_magnitude[0] # Only one scale at a time.
noise_mag = (noise_mag.flatten(1, 3)[:, sample_ids]).flatten(0, 1)
loss = euc_dis(f_q, f_k) / noise_mag # euc_dis is a distance function
return loss.mean()
While most of the hyper-parameters are inherited from CycleGAN
and CUT
, we have 3 hyper-parameters that can be domain-dependent and might need some level of tuning.
We list them below with their default valuev shown in parentheses.
--reg_noise (0.001): the max magnitude of injected noises in computing SR loss
--reg (0.001): the coefficient for SR loss
--reg_layers ('0,1,2,3,4'): what layers/scales to use in multi-scale SR loss
--inact_epochs (100): the number of initial epochs where SR loss is inactivated
Some notes:
- When tuning
reg
and set it too large, the training can become instable and crash near the end. - Usually using all scales (i.e., set
reg_layers='0,1,2,3,4'
) is a good strategy. - We usually set
inact_epochs
to be 1/4 of total training epochs (which is 400 in theLabel-to-Image
example).
With the task Label-to-Image
from Cityscapes as an example, the training script is
python train.py --dataroot=$DATA_FOLDER --preprocess=crop --n_epochs=200 --n_epochs_decay=200 \
--reg_layers=0,1,2,3,4 --reg_noise=0.001 --reg=0.001 --init_epochs=100 --name=$MODEL_NAME
The path $DATA_FOLDER
should be structured such that $DATA_FOLDER/trainA
and $DATA_FOLDER/trainB
contain images from the source and the target domain, respectively.
The model is trained with 200 + 200 = 400
epochs and we set init_epochs
as 1/4 of it.
Notice that this implementation (as adapted from CUT
) currently does not support multi-gpu training and the default batch size is 1.
The evaluation script is
python test.py --dataroot=$DATA_FOLDER --name=$MODEL_NAME --epoch=latest --preprocess=none \
--phase=val --output_path=$OUTPUT_PATH
Where similarly there should be valA
& valB
or testA
& testB
(*B
can be empty) under the directory $DATA_FOLDER
and the phase --phase
controls where to load the data for inference.
We here show the reproduced results on the Label-to-Image
task from the Cityscapes dataset. As mentioned in the paper, we sub-sample the images to create a statistical discrepancy between the source and the target domain, which is a natural setup in most real-world unpaired image translation tasks.
Specifically, we use K-means to generate two clusters of images based on the their different semantic distribution (illustrated below).
We list the original filenames for the source and target images used in our setup in examples/src_domain_paths.txt
and examples/tar_domain_paths.txt
, repectively.
We train the previous state-of-the-art method CUT using the same code but setting reg=0.0
(i.e., deactivating the SR loss).
The numerical results below demonstrate the clear advantage of SRUNIT over CUT.
PixelAcc (%) | ClassAcc (%) | Mean IoU | |
---|---|---|---|
CUT | 74.39 | 30.61 | 23.86 |
SRUNIT | 78.39 | 36.70 | 28.76 |
If you would like to cite our paper, please refer to
@inproceedings{jia2021semantically,
title={Semantically Robust Unpaired Image Translation for Data with Unmatched Semantics Statistics},
author={Jia, Zhiwei and Yuan, Bodi and Wang, Kangkang and Wu, Hong and Clifford, David and Yuan, Zhiqiang and Su, Hao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={14273--14283},
year={2021}
}