Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you #11

Closed
yezhengjie opened this issue Sep 14, 2021 · 7 comments

Comments

@yezhengjie
Copy link

Thank you very much for the code. But when I run test_TransUnet.py ,
It starts reporting errors. Why is that?I
`Traceback (most recent call last):
File "self-attention-cv/tests/test_TransUnet.py", line 14, in
test_TransUnet()
File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet
y = model(a)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward
y = self.project_patches_back(y)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

Process finished with exit code 1
`
Could you please help me solve it? Thank you.

@black0017
Copy link
Contributor

I recently changed that. I will check it. what parameters are you passing to the model?

@yezhengjie
Copy link
Author

I didn't change the parameters in test_TransUnet.py

a = torch.rand(2, 3, 128, 128).to(device) model = TransUnet(in_channels=3, img_dim=128, vit_blocks=1, vit_dim_linear_mhsa_block=512, classes=5, patch_size=4).to(device) y = model(a)
And I Try print(dim_out_vit_tokens,token_dim)
output:4 16384
I don't quite understand whether these two parameters are correct. I hope you can have a view again

@black0017
Copy link
Contributor

black0017 commented Sep 14, 2021

hello i still cannot reproduce it. can you please discard your local changes and $ git pull so as to have the latest version ?

@yezhengjie
Copy link
Author

After many times of debugging, I found that the problematic code is the following three lines.
Could you confirm that there are no errors in the following codes
dim_out_vit_tokens = (self.img_dim_vit // patch_size) ** 2 token_dim = vit_channels * (patch_size ** 2) self.project_patches_back = nn.Linear(dim_out_vit_tokens, token_dim)

I don't really understand what variable dim_out_vit_tokens means
But when patch_size=4, I try self.project_patches_back = nn.Linear(1024, token_dim),Not an error;
Is there a problem with the definition of variable dim_out_vit_tokens?
What does it mean, what do these three lines of code do?
Thank you for your help.

@black0017
Copy link
Contributor

black0017 commented Sep 14, 2021

Ignoring batch size, for an input image of shape [3,img_dim,img_dim] the output of the first three conv layers will be [1024, img_dim_vit, img_dim_vit].

then it will be processed by vit where the input spatial dim will be img_dim_vit = img_dim // 16

the spatial dim will be patchified so we need patch size to be divisible by self.img_dim_vit

after patchfication the input and output tokens of vit will be (self.img_dim_vit // patch_size) ** 2,

while the sequence len aka token_dim will be token_dim = vit_channels * (patch_size ** 2)

however the spatial dimensions need to be restored somehow that's why i add the self.project_patches_back layer

the vit_transformer_dim is a free parameter hidden size D in the paper : https://arxiv.org/pdf/2102.04306.pdf
here is the new version that i just pushed in the main branch:

class TransUnet(nn.Module):
    def __init__(self, *, img_dim, in_channels, classes,
                 vit_blocks=12,
                 vit_heads=12,
                 vit_dim_linear_mhsa_block=3072,
                 patch_size=8,
                 vit_transformer_dim=768,
                 vit_transformer=None,
                 vit_channels=None,
                 ):
        super().__init__()
        self.inplanes = 128
        self.patch_size = patch_size
        self.vit_transformer_dim = vit_transformer_dim
        vit_channels = self.inplanes * 8 if vit_channels is None else vit_channels

        # must be 128 channels and half spat dims.
        in_conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
                             bias=False)
        bn1 = nn.BatchNorm2d(self.inplanes)
        self.init_conv = nn.Sequential(in_conv1, bn1, nn.ReLU(inplace=True))
        self.conv1 = Bottleneck(self.inplanes, self.inplanes * 2, stride=2)
        self.conv2 = Bottleneck(self.inplanes * 2, self.inplanes * 4, stride=2)
        self.conv3 = Bottleneck(self.inplanes * 4, vit_channels, stride=2)

        self.img_dim_vit = img_dim // 16

        assert (self.img_dim_vit % patch_size == 0), "Vit patch_dim not divisible"

        self.vit = ViT(img_dim=self.img_dim_vit,
                       in_channels=vit_channels,  # input features' channels (encoder)
                       patch_dim=patch_size,
                       # transformer inside dimension that input features will be projected
                       # out will be [batch, dim_out_vit_tokens, dim ]
                       dim=vit_transformer_dim,
                       blocks=vit_blocks,
                       heads=vit_heads,
                       dim_linear_block=vit_dim_linear_mhsa_block,
                       classification=False) if vit_transformer is None else vit_transformer

        # to project patches back - undoes vit's patchification
        token_dim = vit_channels * (patch_size ** 2)
        self.project_patches_back = nn.Linear(vit_transformer_dim, token_dim)
        # upsampling path
        self.vit_conv = SignleConv(in_ch=vit_channels, out_ch=512)
        self.dec1 = Up(vit_channels, 256)
        self.dec2 = Up(512, 128)
        self.dec3 = Up(256, 64)
        self.dec4 = Up(64, 16)
        self.conv1x1 = nn.Conv2d(in_channels=16, out_channels=classes, kernel_size=1)

    def forward(self, x):
        # ResNet 50-like encoder
        x2 = self.init_conv(x)
        x4 = self.conv1(x2)
        x8 = self.conv2(x4)
        x16 = self.conv3(x8)  # out shape of 1024, img_dim_vit, img_dim_vit
        y = self.vit(x16)  # out shape of number_of_patches, vit_transformer_dim

        # from [number_of_patches, vit_transformer_dim] -> [number_of_patches, token_dim]
        y = self.project_patches_back(y)

        # from [batch, number_of_patches, token_dim] -> [batch, channels, img_dim_vit, img_dim_vit]
        y = rearrange(y, 'b (x y) (patch_x patch_y c) -> b c (patch_x x) (patch_y y)',
                      x=self.img_dim_vit // self.patch_size, y=self.img_dim_vit // self.patch_size,
                      patch_x=self.patch_size, patch_y=self.patch_size)

        y = self.vit_conv(y)
        y = self.dec1(y, x8)
        y = self.dec2(y, x4)
        y = self.dec3(y, x2)
        y = self.dec4(y)
        return self.conv1x1(y)

What do you think? Do you still have problems making that work?

Let me know.

@yezhengjie
Copy link
Author

Thank you very much for the latest code you provided. It helps me understand this paper exactly. Now the code can run normally.

@black0017
Copy link
Contributor

Awesome. It also helped me improve the implementation. I close the issue now! Cheers, N.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants