Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projected GANs for image-to-image translation? #71

Closed
hukkelas opened this issue Mar 23, 2022 · 13 comments
Closed

Projected GANs for image-to-image translation? #71

hukkelas opened this issue Mar 23, 2022 · 13 comments

Comments

@hukkelas
Copy link

Hi,

Are you familiar with any work that has applied projected GANs for image-to-image translation? I spent a couple of days trying to get projected GANs to work for image inpainting of human bodies. However, I continuously struggled with the discriminator learning to discriminate between real/generated examples very early in training (often less than 100k images).

I experimented with several methods to prevent this behavior:

  • With/without separable discriminator
  • with/without Heavy data augmentation for the discriminator
  • Blurring the discriminator input images for the first 200K images.
  • Changing the model size of the generator.

Note that the discriminator never observed the conditional information, I only inputted the generated/real RGB image.
Also, the discriminator follows the implementation in this repo.

Would appreciate if you have any tips or related work that might be relevant for this use case.

@xl-sr
Copy link
Contributor

xl-sr commented Mar 24, 2022

Hi :)

Can you post some loss curves / logits etc? I assume the discriminator quickly overpowers the generator?
If you haven't tried, you should give the patch discriminators a try (it's an argument for the discriminator). This should lead to more stable training.

@hukkelas
Copy link
Author

Yes, the discriminator immediately overpowers the generator in training.
The below image shows the real/fake logits of the discriminator during training (x-axis in is in the scale of millions of images).
In the figure, you can see:

  • Purple: StyleGAN-like discriminator (trained with r1-regularization & epsilon penalty).
  • Orange: Seperable discriminator
  • Blue: Seperable patch discriminator
  • Green: seperable patch + blurring the first 200k images in training
  • Red: Seperable patch + blur200k + diffaugment (cutout, color).
  • Brown: Same as red, just blur for 600k instead of 200k.

drawing

Projected discriminators performs surprisingly well in terms of FID, but the observed image quality is significantly poorer than the stylegan baseline. This seems to be better reflected in LPIPS.
metrics

I have not tested with a larger generator architecture yet. Currently the model is similar to a U-net with roughly 8M parameters, which is quite small compared to what you were testing with in the paper (IIRC).

Thanks for the rapid answers on Github by the way :)

@xl-sr
Copy link
Contributor

xl-sr commented Mar 25, 2022

Thanks for the rapid answers on Github by the way :)

You're welcome :)

This does not seem like collapse though, the losses are not directly comparable (PG uses hinge, SG uses non-sat. loss). So training seems stable. If I am not mistaken you cut off the x-axis for the LPIPS plot, so it appears that PG is initially better than your baseline? It is also still improving, just more slowly.

Are you using a dataset with many faces? This is currently a weakness of PG that I mentioned in other issues.
So, you can try plugging in a different feature network, e.g., CLIP or Resnet50-mocov2. It might improve results and also makes sure you're not spuriously improving FID, see this very cool recent study on this issue: https://arxiv.org/abs/2203.06026

@hukkelas
Copy link
Author

This does not seem like collapse though, the losses are not directly comparable (PG uses hinge, SG uses non-sat. loss). So training seems stable

The first plot shows the raw logits of the discriminator (w/o sigmoid for stylegan-D), not the loss. Thus, it should at least be comparable to the stylegan logits in terms of discriminating gap between real/fake examples (but perhaps the scale of the logits will be different?). You're right that it's not collapsing, but from my previous experience with ns-loss/wasserstein loss I would assume that the gap would be slightly smaller. But perhaps the poor sample quality might be the feature network, not the training dynamics.

If I am not mistaken you cut off the x-axis for the LPIPS plot, so it appears that PG is initially better than your baseline?

Yeah, correct!

Are you using a dataset with many faces?

Yeah, I'm currently using a dataset of only human bodies (example illustration). I'll do some further ablations with CLIP/mocov2 resnet!

see this very cool recent study on this issue: https://arxiv.org/abs/2203.06026

Thanks for the reference, it's seems very interesting and quite relevant to my dataset :)

@xl-sr
Copy link
Contributor

xl-sr commented Mar 25, 2022

But perhaps the poor sample quality might be the feature network, not the training dynamics.

In my experience, the training dynamics are quite different from standard GAN training, so previous experiences might be misleading :)

I'll do some further ablations with CLIP/mocov2 resnet!

Awesome, keep me updated on the results! You could also take this further and pretrain a model on your specific dataset with eg. MoCoV2. This should then definitely give you useful features for human body generation.

@hukkelas
Copy link
Author

Thanks for the answers, will post an update later!

@hukkelas
Copy link
Author

A quick question regarding normalization of input images: do you use the standard imagenet normalization for the feature network in D, or do you keep the stylegan normalization to the range [-1, 1]?

@xl-sr
Copy link
Contributor

xl-sr commented Mar 25, 2022

tf_efficientnet_lite0 uses the SG normalization. if you use a different network, you need to adjust the normalization.
Here is a snippet that I use for these transformations, the imports are lists of the feature network names. You can plug this right before you feed data into the network.

from feature_networks.constants import NORMALIZED_0_5, NORMALIZED_IN, NORMALIZED_CLIP,

def norm_with_stats(x, stats):
    x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.5 / stats['mean'][0]) + (0.5 - stats['std'][0]) / stats['mean'][0]
    x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.5 / stats['mean'][1]) + (0.5 - stats['std'][1]) / stats['mean'][1]
    x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.5 / stats['mean'][2]) + (0.5 - stats['std'][2]) / stats['mean'][2]
    x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    return x

def get_backbone_normstats(backbone):
    if backbone in NORMALIZED_0_5:
        return None

    elif backbone in NORMALIZED_IN:
        return {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225],
        }

    elif backbone in NORMALIZED_CLIP:
        return {
            'mean': [0.48145466, 0.4578275, 0.40821073],
            'std': [0.26862954, 0.26130258, 0.27577711],
        }

    else:
        raise NotImplementedError

@hukkelas
Copy link
Author

hukkelas commented Mar 25, 2022

Thanks! Do you keep the [-1, 1] normalization for the generator?

@xl-sr
Copy link
Contributor

xl-sr commented Mar 25, 2022

yes :)

@xl-sr
Copy link
Contributor

xl-sr commented Apr 5, 2022

gonna close this now, keep me updated on your results :)

@xl-sr xl-sr closed this as completed Apr 5, 2022
@hukkelas
Copy link
Author

hukkelas commented Apr 7, 2022

Here is an update after a couple of days of experiments with projected GANs for image inpainting of human figures.

From my experience, projected GANs converge quickly, however, it is prone to mode collapse early in training and some feature networks are more unstable to train than others. All models that I’ve trained have suffered from mode collapse, where the model generates deterministic completions for the same conditional input, or the model diversity is limited to simple semantic changes (e.g. only changing the color of the clothes, not the general appearance). To diminish this issue, I’ve experimented with blurring discriminator images, and turning on/off seperable/patch discriminator. Generally, blurring the first N iterations (tested with 200k-1M images) seems to improve diversity somewhat, but it is still far from the diversity of the baseline.

This figure shows various experiments with a ViT discriminator, with separable/patch/blurring turned on/off. I've used the ViT model from "Masked Autoencoders are Scalable Vision Learners". The figure includes the logits of the discriminator, FID-Clip (from the paper you linked), LPIPS Diversity, and LPIPS. You can observe that blurring improves diversity, however, too much (1M images) seemed to collapse training. The image quality of the model is quite similar to the baseline, however, the diversity is significantly worse. Also, I noticed some surprising results in these runs, e.g. ViT with only blur trains fine, while adding patch/separable options can collapse training. This might be randomness though.

Note that some plots do not have the full graph of FID-CLIP, as I implemented this in the middle of training runs.

This figure shows different feature networks that I've tested. The models are:

  • EfficientNet as implemented in your code.
  • RN50 Clip: This model was very unstable to train for most settings, but produced quite good results when the model did not collapse.
  • RN50 DensePose CSE. A Rn50 trained for dense pose estimation of humans (see here). This performs quite well in terms of diversity and quality.
  • RN50 Swav and RN50 MOCO: These models are much worse in term of diversity and quality compared to the baseline. I tried to train my own self-supervised model (MOCOv2 with RN50) on my dataset with no luck.

From the experiments, I find that the rn50 clip/rn50 densepose and MAE ViT provide quite good features for human figure synthesis.

In summary, I find the results promising and I will continue my experiments. The current issue is training instability and reduced diversity. I believe this is an issue of instability early in training, where simply blurring the first iterations diminishes the issue with no oberservable cost to image quality. I'm happy to hear if you have any suggestions or ideas to combat these issues :)

@DRJYYDS
Copy link

DRJYYDS commented Jun 30, 2022

Here is an update after a couple of days of experiments with projected GANs for image inpainting of human figures.

From my experience, projected GANs converge quickly, however, it is prone to mode collapse early in training and some feature networks are more unstable to train than others. All models that I’ve trained have suffered from mode collapse, where the model generates deterministic completions for the same conditional input, or the model diversity is limited to simple semantic changes (e.g. only changing the color of the clothes, not the general appearance). To diminish this issue, I’ve experimented with blurring discriminator images, and turning on/off seperable/patch discriminator. Generally, blurring the first N iterations (tested with 200k-1M images) seems to improve diversity somewhat, but it is still far from the diversity of the baseline.

This figure shows various experiments with a ViT discriminator, with separable/patch/blurring turned on/off. I've used the ViT model from "Masked Autoencoders are Scalable Vision Learners". The figure includes the logits of the discriminator, FID-Clip (from the paper you linked), LPIPS Diversity, and LPIPS. You can observe that blurring improves diversity, however, too much (1M images) seemed to collapse training. The image quality of the model is quite similar to the baseline, however, the diversity is significantly worse. Also, I noticed some surprising results in these runs, e.g. ViT with only blur trains fine, while adding patch/separable options can collapse training. This might be randomness though.

Note that some plots do not have the full graph of FID-CLIP, as I implemented this in the middle of training runs.

This figure shows different feature networks that I've tested. The models are:

  • EfficientNet as implemented in your code.
  • RN50 Clip: This model was very unstable to train for most settings, but produced quite good results when the model did not collapse.
  • RN50 DensePose CSE. A Rn50 trained for dense pose estimation of humans (see here). This performs quite well in terms of diversity and quality.
  • RN50 Swav and RN50 MOCO: These models are much worse in term of diversity and quality compared to the baseline. I tried to train my own self-supervised model (MOCOv2 with RN50) on my dataset with no luck.

From the experiments, I find that the rn50 clip/rn50 densepose and MAE ViT provide quite good features for human figure synthesis.

In summary, I find the results promising and I will continue my experiments. The current issue is training instability and reduced diversity. I believe this is an issue of instability early in training, where simply blurring the first iterations diminishes the issue with no oberservable cost to image quality. I'm happy to hear if you have any suggestions or ideas to combat these issues :)

You can probably try to generate image from smaller resolution first, and go higher resolution, this can help stable the training. Also you may try to add a noise to images, to manually make the "real data distribution" and "fake data distribution" has some overlap, this trick also found be useful in gan and diffusion model.

Hope this helpful for you. I found what you do is interesting :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants