New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you #11
Comments
I recently changed that. I will check it. what parameters are you passing to the model? |
I didn't change the parameters in test_TransUnet.py
|
hello i still cannot reproduce it. can you please discard your local changes and |
After many times of debugging, I found that the problematic code is the following three lines. I don't really understand what variable |
Ignoring batch size, for an input image of shape then it will be processed by vit where the input spatial dim will be the spatial dim will be patchified so we need patch size to be divisible by after patchfication the input and output tokens of vit will be while the sequence len aka token_dim will be however the spatial dimensions need to be restored somehow that's why i add the the class TransUnet(nn.Module):
def __init__(self, *, img_dim, in_channels, classes,
vit_blocks=12,
vit_heads=12,
vit_dim_linear_mhsa_block=3072,
patch_size=8,
vit_transformer_dim=768,
vit_transformer=None,
vit_channels=None,
):
super().__init__()
self.inplanes = 128
self.patch_size = patch_size
self.vit_transformer_dim = vit_transformer_dim
vit_channels = self.inplanes * 8 if vit_channels is None else vit_channels
# must be 128 channels and half spat dims.
in_conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(self.inplanes)
self.init_conv = nn.Sequential(in_conv1, bn1, nn.ReLU(inplace=True))
self.conv1 = Bottleneck(self.inplanes, self.inplanes * 2, stride=2)
self.conv2 = Bottleneck(self.inplanes * 2, self.inplanes * 4, stride=2)
self.conv3 = Bottleneck(self.inplanes * 4, vit_channels, stride=2)
self.img_dim_vit = img_dim // 16
assert (self.img_dim_vit % patch_size == 0), "Vit patch_dim not divisible"
self.vit = ViT(img_dim=self.img_dim_vit,
in_channels=vit_channels, # input features' channels (encoder)
patch_dim=patch_size,
# transformer inside dimension that input features will be projected
# out will be [batch, dim_out_vit_tokens, dim ]
dim=vit_transformer_dim,
blocks=vit_blocks,
heads=vit_heads,
dim_linear_block=vit_dim_linear_mhsa_block,
classification=False) if vit_transformer is None else vit_transformer
# to project patches back - undoes vit's patchification
token_dim = vit_channels * (patch_size ** 2)
self.project_patches_back = nn.Linear(vit_transformer_dim, token_dim)
# upsampling path
self.vit_conv = SignleConv(in_ch=vit_channels, out_ch=512)
self.dec1 = Up(vit_channels, 256)
self.dec2 = Up(512, 128)
self.dec3 = Up(256, 64)
self.dec4 = Up(64, 16)
self.conv1x1 = nn.Conv2d(in_channels=16, out_channels=classes, kernel_size=1)
def forward(self, x):
# ResNet 50-like encoder
x2 = self.init_conv(x)
x4 = self.conv1(x2)
x8 = self.conv2(x4)
x16 = self.conv3(x8) # out shape of 1024, img_dim_vit, img_dim_vit
y = self.vit(x16) # out shape of number_of_patches, vit_transformer_dim
# from [number_of_patches, vit_transformer_dim] -> [number_of_patches, token_dim]
y = self.project_patches_back(y)
# from [batch, number_of_patches, token_dim] -> [batch, channels, img_dim_vit, img_dim_vit]
y = rearrange(y, 'b (x y) (patch_x patch_y c) -> b c (patch_x x) (patch_y y)',
x=self.img_dim_vit // self.patch_size, y=self.img_dim_vit // self.patch_size,
patch_x=self.patch_size, patch_y=self.patch_size)
y = self.vit_conv(y)
y = self.dec1(y, x8)
y = self.dec2(y, x4)
y = self.dec3(y, x2)
y = self.dec4(y)
return self.conv1x1(y) What do you think? Do you still have problems making that work? Let me know. |
Thank you very much for the latest code you provided. It helps me understand this paper exactly. Now the code can run normally. |
Awesome. It also helped me improve the implementation. I close the issue now! Cheers, N. |
Thank you very much for the code. But when I run test_TransUnet.py ,
It starts reporting errors. Why is that?I
`Traceback (most recent call last):
File "self-attention-cv/tests/test_TransUnet.py", line 14, in
test_TransUnet()
File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet
y = model(a)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward
y = self.project_patches_back(y)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
Process finished with exit code 1
`
Could you please help me solve it? Thank you.
The text was updated successfully, but these errors were encountered: