In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict


from functools import partial
from dataclasses import dataclass
from collections import OrderedDict

In [2]:
inp = torch.randn([2,32,16,16,16])

In [58]:
class Conv3dAuto(nn.Conv3d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2, self.kernel_size[2] // 2) # dynamic add padding based on the kernel_size   


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels, self.out_channels =  in_channels, out_channels
        self.blocks = nn.Identity()
        self.shortcut = nn.Identity()   
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels


class ResNetResidualBlock(ResidualBlock):
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, *args, **kwargs):
        super().__init__(in_channels, out_channels)

        self.expansion, self.downsampling = expansion, downsampling
        self.conv = partial(Conv3dAuto, kernel_size=3, bias=False)
        self.shortcut = nn.Sequential(OrderedDict(
        {
            'conv' : nn.Conv3d(self.in_channels, self.expanded_channels, kernel_size=1,
                      stride=self.downsampling, bias=False),
            'bn' : nn.BatchNorm3d(self.expanded_channels)
            
        })) if self.should_apply_shortcut else None
        
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels


def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(OrderedDict({'conv': conv(in_channels, out_channels, *args, **kwargs), 
                          'bn': nn.BatchNorm3d(out_channels) }))


class ResNetBasicBlock(ResNetResidualBlock):
    expansion = 1
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation(),
            conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),
        )


class ResNetBottleNeckBlock(ResNetResidualBlock):
    expansion = 4
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        self.blocks = nn.Sequential(
           conv_bn(self.in_channels, self.out_channels, self.conv, kernel_size=1),
             activation(),
             conv_bn(self.out_channels, self.out_channels, self.conv, kernel_size=3, stride=self.downsampling),
             activation(),
             conv_bn(self.out_channels, self.expanded_channels, self.conv, kernel_size=1),
        )


class ResNetLayer(nn.Module):
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        # 'We perform downsampling directly by convolutional layers that have a stride of 2.'
        downsampling = 2 if in_channels != out_channels else 1
        
        self.blocks = nn.Sequential(
            block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling),
            *[block(out_channels * block.expansion, 
                    out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)]
        )

    def forward(self, x):
        x = self.blocks(x)
        return x


class ResNetEncoder(nn.Module):
    """
    ResNet encoder composed by increasing different layers with increasing features.
    """
    def __init__(self, in_channels=3, blocks_sizes=[64, 128, 256, 512], deepths=[2,2,2,2], 
                 activation=nn.ReLU, block=ResNetBasicBlock, *args,**kwargs):
        super().__init__()
        
        self.blocks_sizes = blocks_sizes
        
        self.gate = nn.Sequential(
            nn.Conv3d(in_channels, self.blocks_sizes[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm3d(self.blocks_sizes[0]),
            activation(),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        )
        
        self.in_out_block_sizes = list(zip(blocks_sizes, blocks_sizes[1:]))
        self.blocks = nn.ModuleList([ 
            ResNetLayer(blocks_sizes[0], blocks_sizes[0], n=deepths[0], activation=activation, 
                        block=block,  *args, **kwargs),
            *[ResNetLayer(in_channels * block.expansion, 
                          out_channels, n=n, activation=activation, 
                          block=block, *args, **kwargs) 
              for (in_channels, out_channels), n in zip(self.in_out_block_sizes, deepths[1:])]       
        ])
        
        
    def forward(self, x):
        x = self.gate(x)
        for block in self.blocks:
            x = block(x)
        return x



class ResnetDecoder_classic(nn.Module):
    """
    This class represents the tail of ResNet. It performs a global pooling and maps the output to the
    correct class by using a fully connected layer.
    """
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.decoder = nn.Linear(in_features, n_classes)

    def forward(self, x):
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x


class ResnetDecoder(nn.Module):
    """
    This class represents the tail of ResNet. It performs a global pooling and maps the output to the
    correct class by using a fully connected layer.
    """
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.decoder = nn.Sequential(
            nn.Conv3d(in_features,in_features, kernel_size = (1,1,1), stride = (1,1,1)),
            nn.ReLU(),
            nn.Conv3d(in_features,in_features, kernel_size = (1,1,1), stride = (1,1,1)),
            nn.ReLU(),
            
            nn.Conv3d(in_features,n_classes, kernel_size = (1,1,1), stride = (1,1,1)),
            nn.Sigmoid(),
            nn.Flatten()
        )

    def forward(self, x):
        x = self.avg(x)
        #x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x

In [79]:
class ResNet(nn.Module):
    
    def __init__(self, in_channels, n_classes, *args, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.encoder = ResNetEncoder(in_channels, *args, **kwargs)
        self.decoder = ResnetDecoder(self.encoder.blocks[-1].blocks[-1].expanded_channels, self.n_classes)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [80]:
def resnet(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBasicBlock, deepths=[2, 2, 1, 1])

def resnet18(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBasicBlock, deepths=[2, 2, 2, 2])

def resnet34(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBasicBlock, deepths=[3, 4, 6, 3])

def resnet50(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBottleNeckBlock, deepths=[3, 4, 6, 3])

def resnet101(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBottleNeckBlock, deepths=[3, 4, 23, 3])

def resnet152(in_channels, n_classes):
    return ResNet(in_channels, n_classes, block=ResNetBottleNeckBlock, deepths=[3, 8, 36, 3])

In [88]:
from torchsummary import summary

model = resnet(1, 3)
summary(model.cpu(), (1, 16, 16, 16), device='cpu')

#model(torch.randn([2,1,16,16,16]))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1          [-1, 64, 8, 8, 8]          21,952
       BatchNorm3d-2          [-1, 64, 8, 8, 8]             128
              ReLU-3          [-1, 64, 8, 8, 8]               0
         MaxPool3d-4          [-1, 64, 4, 4, 4]               0
        Conv3dAuto-5          [-1, 64, 4, 4, 4]         110,592
       BatchNorm3d-6          [-1, 64, 4, 4, 4]             128
              ReLU-7          [-1, 64, 4, 4, 4]               0
        Conv3dAuto-8          [-1, 64, 4, 4, 4]         110,592
       BatchNorm3d-9          [-1, 64, 4, 4, 4]             128
 ResNetBasicBlock-10          [-1, 64, 4, 4, 4]               0
       Conv3dAuto-11          [-1, 64, 4, 4, 4]         110,592
      BatchNorm3d-12          [-1, 64, 4, 4, 4]             128
             ReLU-13          [-1, 64, 4, 4, 4]               0
       Conv3dAuto-14          [-1, 64, 

In [82]:
model.decoder

ResnetDecoder(
  (avg): AdaptiveAvgPool3d(output_size=(1, 1, 1))
  (decoder): Sequential(
    (0): Conv3d(512, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): ReLU()
    (2): Conv3d(512, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (3): ReLU()
    (4): Conv3d(512, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (5): Sigmoid()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
)

In [89]:
#model = torch.load('./models/ACDC_resnet/best.pt').cpu()
model = resnet(1, 3)
n_classes =4


for i in model.parameters():
    i.requires_grad = False


model.decoder = ResnetDecoder(model.encoder.blocks[-1].blocks[-1].expanded_channels, n_classes)


In [90]:
model

ResNet(
  (encoder): ResNetEncoder(
    (gate): Sequential(
      (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (blocks): ModuleList(
      (0): ResNetLayer(
        (blocks): Sequential(
          (0): ResNetBasicBlock(
            (blocks): Sequential(
              (0): Sequential(
                (conv): Conv3dAuto(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
                (bn): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (1): ReLU()
              (2): Sequential(
                (conv): Conv3dAuto(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
                (bn): BatchNorm3d(64, eps=1e-05, momentum=0