Skip to content

BatchNorm alters its sliding mean/standard deviation parameters even in testmode if Zygote is called #1429

@thalassocracy

Description

@thalassocracy

After training a classification model, I am attempting to perform Google Deepdream-like techniques with it, i.e. altering a new input x to maximize outputs from selected layers of the model.

I am using Zygote in the following manner to achieve the latter part of this:

ps = Flux.params(x)

loss, back = @sync Zygote.pullback(() -> mod_obj(m[1:layer](x)), ps)

gs = back(one(loss))

x = x .+ gs[x]

I noticed, however, that performing this process seemed to lead to instability in the classifier, despite having set the model to testmode. The classifier would begin to produce wildly inaccurate results after running this code a couple thousand iterations, but could be fixed by training the network for a very short duration once again. I began to suspect that the presence of BatchNorm in my model was partially responsible for this.

I have verified that the BatchNorm layers in my model have their active parameter set to off before and after the autodiff. Despite that, however, μ and σ² change every iteration. This seems to be the cause of my classifier becoming inaccurate after running the above code.

On the other hand, when I don't call autodiff via Zygote or Flux, this process does not trigger if I simply call m(x).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions