Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransUNet - Why is the patch_dim set to 1? #10

Closed
dsitnik opened this issue Sep 7, 2021 · 7 comments
Closed

TransUNet - Why is the patch_dim set to 1? #10

dsitnik opened this issue Sep 7, 2021 · 7 comments

Comments

@dsitnik
Copy link

dsitnik commented Sep 7, 2021

Hi,

Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

@black0017
Copy link
Contributor

black0017 commented Sep 8, 2021

Hello @dsitnik , I used it based on the figure
image

since the output tokens of vit is n_patch= (h/16)*(w/16) and the vit's input (after the convs) is something like batch, channels, h/16, w/16 i can only imagine that the patch_dim is set to 1.

There are indeed many details missing on the paper but i guess this one is ok. What do you think?

@dsitnik
Copy link
Author

dsitnik commented Sep 9, 2021

I think patch dimensionality should be a hyper-parameter. In the paper (table 3.), they investigated the influence of patch size. If you have an input image of size 3x256x256, after convolution layers, the size would be 1024x16x16. You should then choose a patch size <16 (e.g. 2,4,8).

Changes I made:

  • I included the patch_dim parameter in TransUnet class and I forward it to ViT.
  • In TransUnet's forward function, I changed the rearrange to:
y = rearrange(y, 'b (x y) (patch_x patch_y c) -> b c (patch_x x) (patch_y y)', 
                     x=self.img_dim_vit//self.patch_size, y=self.img_dim_vit//self.patch_size, 
                     patch_x=self.patch_size ,patch_y=self.patch_size) 
  • In ViT's forward function, I changed the rearrange function to:
# from [batch, channels, h, w] to [batch, tokens , N], N=p*p*c , tokens = h/p *w/p
img_patches = rearrange(img, 
                        'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                        x=self.img_dim//self.p, y=self.img_dim//self.p, patch_x=self.p ,patch_y=self.p)

-Also, I defined self.project_patches_back = nn.Linear(dim,self.token_dim) in ViT class. I call this function in forward just before the return line:

 y = self.transformer(patch_embeddings, mask)
 y = self.project_patches_back(y)`

Hope these changes are correct. If you have any opinion about this, I would appreciate it.

@black0017
Copy link
Contributor

Yes that makes sense! Some remarks:

  • patch size should be divisible of the feature map's spatial dim
  • if possible the vit class should remain unchanged and encapsulate all the logic in the transunet class. do you think this is possible? If so, can you please try to restructure it like this?
  • if you have some results, i would like to know

I am reopening the issue! Thanks for the contribution!!!

@black0017 black0017 reopened this Sep 9, 2021
@black0017
Copy link
Contributor

black0017 commented Sep 9, 2021

It's not clear why to me why this change is necessary in the vit class. would this change make the vit class run as before for the other architectures?

# before
img_patches = rearrange(img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                               patch_x=self.p, patch_y=self.p)
# after
img_patches = rearrange(img, 
                       'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                       x=self.img_dim//self.p, y=self.img_dim//self.p, patch_x=self.p ,patch_y=self.p)

are you sure that this change is necessary?

in einops axis decomposition , specifing one axis len should be enough

@dsitnik
Copy link
Author

dsitnik commented Sep 9, 2021

I included x, y, patch_x, and patch_y just to be more clear what is going on from the code. If self.p is fixed to 1 and this is the only change made, the vit class should work with other architectures like before. However, if you include y = self.project_patches_back(y) there might be some dimensionality issues. Maybe the projection back should be included in transunet class after calling ViT?

@black0017
Copy link
Contributor

black0017 commented Sep 9, 2021

Maybe the y = self.project_patches_back(y) should be included in transunet class after calling ViT?

Yes, exactly!

I included x, y, patch_x, and patch_y just to be more clear what is going on from the code. If self.p is fixed to 1 and this is the only change made, the VIT class should work with other architectures like before.

Is this just for readability then (to make this change) or it is necessary to work? I dont mind changing this but I have to check that all the other architectures that are based on vit work fine.

Thanks again.

@black0017
Copy link
Contributor

black0017 commented Sep 11, 2021

Hello @dsitnik I added the proposed changes on the transunet architecture. let me know if you find any problem.

cheers,
N.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants