In [1]:
from  fine_network import *

In [2]:
model = get_fine_model(path='fine_model.pth')

<All keys matched successfully> ------- fine_model load successful
model time: 8.537387371063232


In [1]:
from ptflops import get_model_complexity_info

In [2]:
import numpy as np
import torch
from torch import nn
from torch.nn import init
# from model.attention.SelfAttention import ScaledDotProductAttention
# from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention


import numpy as np
import torch
from torch import nn
from torch.nn import init

class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h



    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out


class SimplifiedScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(SimplifiedScaledDotProductAttention, self).__init__()

        self.d_model = d_model
        self.d_k = d_model//h
        self.d_v = d_model//h
        self.h = h

        self.fc_o = nn.Linear(h * self.d_v, d_model)
        self.dropout=nn.Dropout(dropout)





    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out




############-----------------------------------------------------------------------------############


class PositionAttentionModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7,D=7):
        super().__init__()
        self.cnn=nn.Conv3d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)
    
    def forward(self,x):
        b,c,h,w,d=x.shape
        y=self.cnn(x)
        y=y.view(b,c,-1).permute(0,2,1) #bs,h*w,c
        y=self.pa(y,y,y) #bs,h*w,c
        return y


class ChannelAttentionModule(nn.Module):
    
    def __init__(self,d_model=512,kernel_size=3,H=7,W=7,D=7):
        super().__init__()
        self.cnn=nn.Conv3d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=SimplifiedScaledDotProductAttention(H*W*D,h=1)
    
    def forward(self,x):
        b,c,h,w,d=x.shape
        y=self.cnn(x)
        y=y.view(b,c,-1) #bs,c,h*w
        y=self.pa(y,y,y) #bs,c,h*w
        return y




class DAModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7,D=7):
        super().__init__()
        self.position_attention_module=PositionAttentionModule(d_model=d_model,kernel_size=3,H=H,W=W,D=D)
        self.channel_attention_module=ChannelAttentionModule(d_model=d_model,kernel_size=3,H=H,W=W,D=D)
    
    def forward(self,input):
        bs,c,h,w,d=input.shape
        p_out=self.position_attention_module(input)
        c_out=self.channel_attention_module(input)
        p_out=p_out.permute(0,2,1).view(bs,c,h,w,d)
        c_out=c_out.view(bs,c,h,w,d)
        return p_out+c_out


from copy import deepcopy
from nnunet.utilities.nd_softmax import softmax_helper
from torch import nn
import torch
import numpy as np
from nnunet.network_architecture.initialization import InitWeights_He
from nnunet.network_architecture.neural_network import SegmentationNetwork
import torch.nn.functional     
       
class Stem(nn.Module):
    """
    fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
    """

    def __init__(self, input_channels, output_channels,
                 conv_op=nn.Conv3d, conv_kwargs=None,
                 norm_op=nn.InstanceNorm3d, norm_op_kwargs=None,
                 dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super().__init__()
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op

        self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
            'p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
        self.lrelu = self.nonlin(**self.nonlin_kwargs)

    def forward(self, x):
        x = self.conv(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return self.lrelu(self.instnorm(x))
    
class Block(nn.Module):
    """
    fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
    """

    def __init__(self, input_channels, output_channels,
                 conv_op=nn.Conv3d, conv_kwargs=None,
                 norm_op=nn.InstanceNorm3d, norm_op_kwargs=None,
                 dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super().__init__()
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op

        self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
            'p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
        self.lrelu = self.nonlin(**self.nonlin_kwargs)
        
                
        self.stride = conv_kwargs['stride']
        conv_kwargs = {'kernel_size': 1, 'stride':self.stride, 'padding': 0, 'dilation': 1, 'bias': True}
        
        self.convbn1x1 = nn.Sequential(nn.Conv3d(input_channels,output_channels,**conv_kwargs),
                                self.norm_op(output_channels, **self.norm_op_kwargs))
        if self.stride == 1:
            self.identity_bn = self.norm_op(output_channels, **self.norm_op_kwargs)

    def forward(self, x):
        indentity = x
        x = self.conv(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.instnorm(x)
        
                    
        if self.stride == 1:
            indentity = self.convbn1x1(indentity)+self.identity_bn(indentity)
        else:
            indentity = self.convbn1x1(indentity)
        x = x+indentity
        return self.lrelu(x)
        
class ResVGG_Backbone(nn.Module):
    def __init__(self, in_channels = 1, image_size = [96,96,96], channels = [32,64,128,256,320,320]):
        super().__init__()

                


        #feature block
        self.blocks = []
        self.blocks.append(self.make_stage(in_channels,channels[0],Stem,stride=1))
        in_channels = channels[0]
        for index in range(len(channels)-1):
            self.blocks.append(self.make_stage(in_channels,channels[index+1],Block,stride=2))
            in_channels = channels[index+1]
            
            
            
        self.blocks = nn.ModuleList(self.blocks)
        
#         self.atten = DAModule(self.channels[index],H=self.sizes[-2][0],W=self.sizes[-2][1],D=self.sizes[-2][2])
    def make_stage(self,in_channels,output_channels,block,stride=1):
        conv_kwargs1 = {'kernel_size': 3, 'stride': stride, 'padding': 1, 'dilation': 1, 'bias': True}
        return nn.Sequential(
                    block( #extract features
                        input_channels=in_channels,output_channels=output_channels,
                    conv_kwargs = conv_kwargs1),
                    block( #dowm sampling
                        input_channels=output_channels,output_channels=output_channels,
                        ))
        
        
    
    def forward(self, x):
        out = []
        for block in self.blocks:
            x = block(x)
            out.append(x)
        out.pop()
        return out[::-1],x
    
class ConvDropoutNormNonlin(nn.Module):
    """
    fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
    """

    def __init__(self, input_channels, output_channels,
                 conv_op=nn.Conv3d, conv_kwargs=None,
                 norm_op=nn.InstanceNorm3d, norm_op_kwargs=None,
                 dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super(ConvDropoutNormNonlin, self).__init__()
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op

        self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
            'p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
        self.lrelu = self.nonlin(**self.nonlin_kwargs)


    def forward(self, x):

        x = self.conv(x)
        if self.dropout is not None:
            x = self.dropout(x)

        return self.lrelu(self.instnorm(x))
class Decoder(nn.Module):

    def __init__(self,in_channels = 1, image_size = [96,96,96], channels = [32,64,128,256,320,320]):
        super().__init__()
        
        self.tu = []
        for channel in range(1,len(channels)):
            in_channel,out_channel = channels[-(channel)],channels[-(channel+1)]
            self.tu.append(nn.ConvTranspose3d(in_channel,out_channel,kernel_size=2, stride=2
        ))
        self.decoders = []
        decoders_channels = channels[:-1]
        for channel in range(1,len(decoders_channels)+1):
            in_channel = decoders_channels[-(channel)]
            self.decoders.append(nn.Sequential(
            ConvDropoutNormNonlin( #extract features
                        input_channels=in_channel*2,output_channels=in_channel),
            ConvDropoutNormNonlin( #extract features
                        input_channels=in_channel,output_channels=in_channel)
            ))
        self.decoders = nn.ModuleList(self.decoders)
        self.tu = nn.ModuleList(self.tu)
    def forward(self,x,skips):
        index = [0,1,2,3,4]
        for i,up,decode in zip(index,self.tu,self.decoders):
            x = up(x)
            x = decode(torch.cat((x,skips[i]),1))
        return x
class UNet(nn.Module):

    def __init__(self,in_channels = 4,num_classes=3, image_size = [96,96,96], channels = [32,64,128,256,320,320]):
        super().__init__()
        self.backbone = ResVGG_Backbone(in_channels = in_channels, image_size = image_size, channels = channels)
        self.decoder = Decoder(in_channels = in_channels, image_size = image_size, channels = channels)
        self.logits = nn.Conv3d(channels[0],num_classes,1,1)
        
        
    def forward(self,x):
        skips,feature = self.backbone(x)
        x = self.decoder(feature,skips)
        x = self.logits(x)
        return x
        



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [6]:
model = UNet().cuda()

In [7]:
flops, params = get_model_complexity_info(model, (4,128,128,128), as_strings=True, print_per_layer_stat=True)
print(flops, params)

UNet(
  31.72 M, 100.000% Params, 542.69 GMac, 100.000% MACs, 
  (backbone): ResVGG_Backbone(
    14.55 M, 45.879% Params, 148.05 GMac, 27.281% MACs, 
    (blocks): ModuleList(
      14.55 M, 45.879% Params, 148.05 GMac, 27.281% MACs, 
      (0): Sequential(
        31.3 k, 0.099% Params, 65.77 GMac, 12.119% MACs, 
        (0): Stem(
          3.55 k, 0.011% Params, 7.52 GMac, 1.385% MACs, 
          (conv): Conv3d(3.49 k, 0.011% Params, 7.31 GMac, 1.348% MACs, 4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (instnorm): InstanceNorm3d(64, 0.000% Params, 134.22 MMac, 0.025% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (lrelu): LeakyReLU(0, 0.000% Params, 67.11 MMac, 0.012% MACs, negative_slope=0.01, inplace=True)
        )
        (1): Stem(
          27.74 k, 0.087% Params, 58.25 GMac, 10.734% MACs, 
          (conv): Conv3d(27.68 k, 0.087% Params, 58.05 GMac, 10.697% MACs, 32, 32, kernel_size=(3, 3, 3), stride=(1, 1,

In [8]:
model

UNet(
  (backbone): ResVGG_Backbone(
    (blocks): ModuleList(
      (0): Sequential(
        (0): Stem(
          (conv): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (instnorm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
        (1): Stem(
          (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (instnorm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): Sequential(
        (0): Block(
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (instnorm): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          (convbn