-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix ViT model output + rewrite attention layer + adapt torchvision script #230
Conversation
I'm not sure that I want the porting scripts to be generalised to all models. We used it for CNNs because it's convenient but even there the script is not exactly always directly usable (for example, SqueezeNets require you to remove the |
Yeah that's the same reason that I only linked to it from the model card for the HF upload that used it (for reproducibility). I was afraid to suggest to users that it is a robust way to port weights from torchvision. Maybe the |
I agree we need not be worried about having a single one-size-fits-all script. If one script per model family/group of model families helps simplify the porting code, that sounds good to me. |
The problem with ViT is in the attention module, probably the weights have to be copied in some particular fashion, I will have to investigate further. I really want to get ViT in because it is the most popular vision backbone these days. |
@@ -1,5 +1,5 @@ | |||
""" | |||
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, | |||
MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the name more informative
pool === :class ? x -> x[:, 1, :] : seconddimmean), | ||
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) | ||
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this final tanh
had no reason to exist
After rewriting the attention layer on top of NNlib and removing the final tanh from ViT I can reproduce pytorch's outputs although there is still a slight discrepancy: Flux:
acoustic guitar: 0.90519154
stage: 0.0040107034
harmonica: 0.0028614246
microphone: 0.002621256
electric guitar: 0.0025401094
PyTorch:
acoustic guitar: 0.90745604
stage: 0.0038461224
harmonica: 0.002782756
microphone: 0.0025289422
electric guitar: 0.0023941135 This could be due to differences in the implementation of layer norm see FluxML/Flux.jl#2220 |
I'm happy with this. If i can get an approval I'll merge and move on |
Looks like changing the implementation of LayerNorm has little effect:
so I don't know why we observe these discrepancies. I added LayerNormv2 to the Layers module but didn't use it anywhere since I'm not sure it will be really needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor changes before merging but otherwise looks good
@@ -100,9 +102,10 @@ end | |||
@functor ViT | |||
|
|||
function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), | |||
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) | |||
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000, | |||
qkv_bias=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless it is typical to adjust this toggle, I think it should not get exposed going from vit
to ViT
. The logic with the codebase has been to make the uppercase exports as simple as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add it since the default for torchvision is true
, here is false
. The torchvision model is given by
ViT(:base, imsize=(224,224), qkv_bias=true)
I think we should change the defaults here to match that before the tag of the breaking release, but this can be done in another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so change the default to true
and remove the keyword? I assume you almost always want it as true
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'll do it in the next PR
Current status of this PR is that all weights are copied but the outputs on the test image differ (and flux's one don't make sense)
close #231