Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backup running statistics is incorrect #42

Open
denizetkar opened this issue Jan 30, 2022 · 2 comments
Open

Backup running statistics is incorrect #42

denizetkar opened this issue Jan 30, 2022 · 2 comments

Comments

@denizetkar
Copy link

When MetaBatchNormLayer is called forward with backup_running_statistics=True, the running statistics are meant to be copied into the backup variables by copying:

self.backup_running_mean.data = copy(self.running_mean.data)

However, this is not what happens and the underlying data ends up being tied with each other such that when the running statistics get updated, so does the backup. Here is a short code snippet that minimally reproduces the same behavior:

import torch as th
from copy import copy

t = th.tensor([1, 2, 3])
t2 = th.empty_like(t)
t2.data = copy(t.data)
t[0] = 5
assert t2[0] == 5

Is my understanding of backing up the running statistics wrong or is this a bug that needs fixing?

@AntreasAntoniou
Copy link
Owner

I'll need some time to look into this, but I do believe that the behaviour I currently have coded is intentional.

I want the running statistics to be updated within a given episode, but then scrapped at the end if this isn't a training iteration.

@DubiousCactus
Copy link

@denizetkar is right, that's not your intended behaviour and the backup is overwritten by the validation pass. You simply need to use deepcopy instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants