Skip to content

Commit

Permalink
reproducing unetr arch from monai project, results to be tested on br…
Browse files Browse the repository at this point in the history
…ats!
  • Loading branch information
black0017 committed Jul 26, 2021
1 parent 05b3d79 commit dbefdb3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
10 changes: 5 additions & 5 deletions self_attention_cv/UnetTr/UnetTr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3,
self.z3_deconv = TranspConv3DBlock(base_filters * 2, base_filters)

# Yellow blocks in Fig.1
self.z9_conv = Conv3DBlock(base_filters * 8 * 2, base_filters * 8, double=True)
self.z6_conv = Conv3DBlock(base_filters * 4 * 2, base_filters * 4, double=True)
self.z3_conv = Conv3DBlock(base_filters * 2 * 2, base_filters * 2, double=True)
self.z9_conv = Conv3DBlock(base_filters * 8 * 2, base_filters * 8, double=True, norm=self.norm)
self.z6_conv = Conv3DBlock(base_filters * 4 * 2, base_filters * 4, double=True, norm=self.norm)
self.z3_conv = Conv3DBlock(base_filters * 2 * 2, base_filters * 2, double=True, norm=self.norm)
# out convolutions
self.out_conv = nn.Sequential(
# last yellow conv block
Conv3DBlock(base_filters * 2, base_filters, double=True),
Conv3DBlock(base_filters * 2, base_filters, double=True, norm=self.norm),
# grey block, final classification layer
Conv3DBlock(base_filters, output_dim, kernel_size=1, double=False))
nn.Conv3d(base_filters, output_dim, kernel_size=1, stride=1))

def forward(self, x):
transf_input = self.embed(x)
Expand Down
22 changes: 17 additions & 5 deletions self_attention_cv/UnetTr/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,35 @@ def forward(self, x):
class TranspConv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super().__init__()
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2,
padding=0, output_padding=0,bias=False)

def forward(self, x):
y = self.block(x)
return y


# blue box in Fig.1
class BlueBlock(nn.Module):
def __init__(self, in_planes, out_planes, layers=1):
def __init__(self, in_planes, out_planes, layers=1, conv_block=False):
"""
blue box in Fig.1
Args:
in_planes: in channels of transpose convolution
out_planes: out channels of transpose convolution
layers: number of blue blocks, transpose convs
conv_block: whether to include a conv block after each transpose conv. deafaults to False
"""
super().__init__()
self.blocks = nn.ModuleList([TranspConv3DBlock(in_planes, out_planes),
Conv3DBlock(out_planes, out_planes, double=False)])
])
if conv_block:
self.blocks.append(Conv3DBlock(out_planes, out_planes, double=False))

if int(layers)>=2:
for _ in range(int(layers) - 1):
self.blocks.append(TranspConv3DBlock(out_planes, out_planes))
self.blocks.append(Conv3DBlock(out_planes, out_planes, double=False))
if conv_block:
self.blocks.append(Conv3DBlock(out_planes, out_planes, double=False))

def forward(self, x):
for blk in self.blocks:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_unetTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,17 @@ def test_unettr():
model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1).to(device)
assert model(a).shape == (1, 1, 64, 64, 64)

num_heads = 12 # 12 normally
embed_dim = 768 # 768 normally
roi_size = (128,128,64)
model = UNETR(img_shape=tuple(roi_size), input_dim=4, output_dim=3,
embed_dim=embed_dim, patch_size=16, num_heads=num_heads,
ext_layers=[3, 6, 9, 12], norm='instance',
base_filters=16,
dim_linear_block=3072)
print(model)
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)
print(pytorch_total_params - 101910630)

test_unettr()

0 comments on commit dbefdb3

Please sign in to comment.