# Transfer a reset/efficientnet to resunet/efficientunet

In [1]:
import torch
import timm
from torch import nn
from ptimz.model_zoo import ResNet, UNet

In [2]:
in_chans = 2 # Assume we have T1-w T2-w as input
out_chans = 8 # We want to segment 8 types of objects

## We need to split backbone(resnet/efficientnet) into groups of encoders
### efficientnet

In [3]:
# the classifier(fc layer) is not required
backbone = timm.create_model('efficientnet_b0', in_chans=in_chans)

**Let's construct a 2d image input**

In [4]:
input_2d = torch.rand(size=(1, in_chans, 224, 224))

**Efficientnet consists of stem -> blocks x N -> featrue map**  
**Let's go through efficientnet convolution blocks. Record output channels of each block output** 

In [5]:
efficientnet_stem = nn.Sequential(backbone.conv_stem, backbone.bn1, backbone.act1)
with torch.no_grad():
    stemout = efficientnet_stem(input_2d)
print(f"input {input_2d.shape}")
print(f"stem {stemout.shape} channels {stemout.shape[1]}") # 32 channels, down sample to 1/2

# go through efficientnet blocks
blkout = stemout
with torch.no_grad():
    for block_id, blk in enumerate(backbone.blocks):
        blkout = blk(blkout)
        shape = blkout.shape
        print(f"block {block_id} {shape} channles {shape[1]} ratio {shape[-1]/224}")

input torch.Size([1, 2, 224, 224])
stem torch.Size([1, 32, 112, 112]) channels 32
block 0 torch.Size([1, 16, 112, 112]) channles 16 ratio 0.5
block 1 torch.Size([1, 24, 56, 56]) channles 24 ratio 0.25
block 2 torch.Size([1, 40, 28, 28]) channles 40 ratio 0.125
block 3 torch.Size([1, 80, 14, 14]) channles 80 ratio 0.0625
block 4 torch.Size([1, 112, 14, 14]) channles 112 ratio 0.0625
block 5 torch.Size([1, 192, 7, 7]) channles 192 ratio 0.03125
block 6 torch.Size([1, 320, 7, 7]) channles 320 ratio 0.03125


## Let's construct efficientunet

In [6]:
backbone = timm.create_model('efficientnet_b0', in_chans=in_chans)

# Load backbone state_dict here
# checkpoint = torch.load("backbone.pth.tar", map_location='cpu')
# backbone.load_state_dict(checkpoint['state_dict'])


efficientnet_stem = nn.Sequential(backbone.conv_stem, backbone.bn1, backbone.act1)
unet_encoders = [nn.Sequential(backbone.conv_stem, backbone.bn1, backbone.act1, backbone.blocks[0]),
                backbone.blocks[1],
                backbone.blocks[2],
                nn.Sequential(*backbone.blocks[3:5]),
                nn.Sequential(*backbone.blocks[5:])]
# ptimz UNet config
extra = dict(num_channels=[16, 24, 40, 112, 320], # output channels of each encoder
             nclasses=out_chans) # classes to segment
conv_cfg = dict(type='Conv2d')
norm_cfg = dict(type='BN2d', requires_grad=True)
efficientunet = UNet(extra, unet_encoders, conv_cfg=conv_cfg, norm_cfg=norm_cfg)

### EfficientUnet prediction

In [7]:
with torch.no_grad():
    output_2d = efficientunet(input_2d)
print(output_2d.shape)

torch.Size([1, 8, 224, 224])


### Resunet with Resnet50 as backbone

In [8]:
input_3d = torch.rand(size=(1, in_chans, 128, 128, 128))

In [9]:
conv_cfg = dict(type='Conv3d')
norm_cfg = dict(type='BN3d', requires_grad=True)
backbone = ResNet(50, in_channels=in_chans, deep_stem=True,
                 first_stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg)

# load backbone pretrained weights here
# backbone.load_state_dict(...)

unet_encoders = [backbone.stem, nn.Sequential(backbone.maxpool, getattr(backbone, backbone.res_layers[0]))] + [
                 getattr(backbone, x) for x in backbone.res_layers[1:]]

extra = dict(num_channels=[64, 256, 512, 1024, 2048],
             nclasses=out_chans)
resunet = UNet(extra, unet_encoders, conv_cfg=conv_cfg, norm_cfg=norm_cfg)

### Resunet prediction

In [10]:
with torch.no_grad():
    output_3d = resunet(input_3d)
print(output_3d.shape)

torch.Size([1, 8, 128, 128, 128])
