New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
question about implementation of pgi loss #1
Comments
Hi, it's imposed per mini-batch in practice (the relaxed version in Eq. 6 is what's implemented). If you performed full-batch updates for gradient descent, you would also use the full-batch penalty, but typically people use mini-batches for stochastic optimisation. I'm not sure if I understand why random shuffling is particularly relevant for this penalty, which matches the class-conditioned, group-average softmax - it's asking for cats from two different domains to be classified with equal recognition rates, and this requirement can also apply to a sampling of the dataset in a mini-batch. A mini-batch is just a smaller sample of data, and you would apply a data-dependent penalty the same way you'd apply a data-dependent loss (like the cross-entropy loss) for a mini-batch. |
Thanks for your comment @Faruk-Ahmed ! Sorry I did not make it clear. By saying "it may or may not be true for each batch (due to random shuffling)", I mean if we happen to have one batch in which images in domain 1 are all cats (or mostly cats), while images in domain 2 are all dogs (or mostly dogs), would it still make sense if we force both domains to have equal probability distributions? I may have misunderstood something but am not sure how to actually debug the tf1 code... Ah I think the issue is in class-conditioning (so we will always compare distributions for cats vs cats, not {all data for domain 1} vs {all data for domain 2}), is this correct? |
Yes exactly, the softmaxes are always matched across domains for the same class. In Eqs. 5 and 6, the samplings are from P^c and Q^c, where c corresponds to an object category, like cats, and not from the marginal P and Q (Eqs. 2 and 3). Sorry this wasn't clear! |
Sounds great, thank you for the explanation! |
Hi @Faruk-Ahmed , May I ask another question regarding an implementation detail: in the Section 3 of your paper it says "encouraging matched predictive distributions across the groups with a fixed last layer pushes for over-emphasis on minority-group features in the representation", I am wondering how is this fixed last layer reflected in the code? In this line: https://github.com/Faruk-Ahmed/syst-gen/blob/29d87ae70e608d0159364ee031b83281266c2a65/mnist/main.py#L214, if I understand it correctly, it is using softmax for computing PGI loss, then what would be the last layer in this case? Thanks! |
It's done in L409-410; the last layer isn't optimised for the PGI penalty. |
@Faruk-Ahmed That makes sense, thanks! What is the benefit of optimizing PGI with the fixed last layer vs with flexible last layer? Is it because we want to do something similar to IRM where we want to differentiate "representations" (which do not include last layer) and "last-layer classifiers" (which should be optimal over all environments)? Or it may be due to some other implementation considerations? |
Initially, I just wanted to use the last layer as the "discriminator of interest" instead of struggling with GAN-training, but it was also the intuition that not freezing the last layer might allow two different parts of the predictor-layer to specialise to different sets of features, and I wanted to promote learning features that would be used in the "same way" by a predictor (with the models in the domain-invariance literature in mind, including IRM). It's not clear if the choice would always lead to this desired behaviour, but empirically, freezing the layer always worked better than not-freezing across all datasets in my early experiments, so I went with it. |
@Faruk-Ahmed Yeah this is really helpful, thank you so much for the explanation! |
Glad to help! |
Hi authors,
Thanks for this great work!
I read through your implementation but I am not sure if I understand all details correctly. In the PGI loss code here: https://github.com/Faruk-Ahmed/syst-gen/blob/29d87ae70e608d0159364ee031b83281266c2a65/mnist/main.py#L199-L219, do you compute this loss for batches or for the entire training set? If I understand it correctly, you are optimizing this for data batches, if so, I am wondering whether Eq. 4/5 in the paper holds? I can imagine that Eq. 4/5 should work for the entire dataset if we find the invariant representation across both groups P, Q. But it may or may not be true for each batch (due to random shuffling). If you can share your thoughts that would be great, thank you!
The text was updated successfully, but these errors were encountered: