In [None]:
# default_exp engine.model

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# DECODE Network

> Definition of the classes and modules we use to build our 3D UNet

In [None]:
#export
from decode_fish.imports import *
import torch.nn as nn
import types
from functools import partial
import torch.nn.functional as F

In [None]:
from nbdev.showdoc import *

In [None]:
#export
def _get_conv(ndim: int):
    "Get Convolution Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'Conv{ndim}d')

In [None]:
#export
def init_func(m, func=nn.init.kaiming_normal_):
    "Initialize pytorch model `m` weights with `func`"
    if func and hasattr(m, 'weight'): func(m.weight)
    return m

In [None]:
#export
def layer_types(m):
    "returns list of pytorch models type"
    if isinstance(m, list): return list(map(type, m))
    return list(map(type, m.children()))

In [None]:
#export
def extract_layer(m, name=torch.nn.modules.Conv3d):
    res = []
    for child in m.children():
        for layer in child.modules():
            if(isinstance(layer,name)):
                res.append(layer)
    return res

In [None]:
#export
def number_of_features_per_level(init_channel_number, num_levels):
    return [init_channel_number * 2 ** k for k in range(num_levels)]

In [None]:
#export
class SingleConv(nn.Sequential):
    """
    Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
    of operations can be specified via the `order` parameter
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int or tuple): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
        padding (int or tuple):
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1):
        super(SingleConv, self).__init__()

        for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
            self.add_module(name, module)

In [None]:
#export
class DoubleConv(nn.Sequential):
    """
    A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
    We use (Conv3d+ReLU+GroupNorm3d) by default.
    This can be changed however by providing the 'order' argument, e.g. in order
    to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
    Use padded convolutions to make sure that the output (H_out, W_out) is the same
    as (H_in, W_in), so that you don't have to crop in the decoder path.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
        kernel_size (int or tuple): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
        padding (int or tuple): add zero-padding added to all three sides of the input
    """

    def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1):
        super(DoubleConv, self).__init__()
        if encoder:
            # we're in the encoder path
            conv1_in_channels = in_channels
            conv1_out_channels = out_channels // 2
            if conv1_out_channels < in_channels:
                conv1_out_channels = in_channels
            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
        else:
            # we're in the decoder path, decrease the number of channels in the 1st convolution
            conv1_in_channels, conv1_out_channels = in_channels, out_channels
            conv2_in_channels, conv2_out_channels = out_channels, out_channels

        # conv1
        self.add_module('SingleConv1',
                        SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
                                   padding=padding))
        # conv2
        self.add_module('SingleConv2',
                        SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
                                   padding=padding))

In [None]:
#export
class Upsampling(nn.Module):
    """
    Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
    Args:
        transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
        in_channels (int): number of input channels for transposed conv
            used only if transposed_conv is True
        out_channels (int): number of output channels for transpose conv
            used only if transposed_conv is True
        kernel_size (int or tuple): size of the convolving kernel
            used only if transposed_conv is True
        scale_factor (int or tuple): stride of the convolution
            used only if transposed_conv is True
        mode (str): algorithm used for upsampling:
            'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
            used only if transposed_conv is False
    """

    def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
                 scale_factor=(2, 2, 2), mode='nearest'):
        super(Upsampling, self).__init__()

        if transposed_conv:
            # make sure that the output size reverses the MaxPool3d from the corresponding encoder
            # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0])
            self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
                                               padding=1)
        else:
            self.upsample = partial(self._interpolate, mode=mode)

    def forward(self, encoder_features, x):
        output_size = encoder_features.size()[2:]
        return self.upsample(x, output_size)

    @staticmethod
    def _interpolate(x, size, mode):
        return F.interpolate(x, size=size, mode=mode)

In [None]:
#export
class Encoder(nn.Module):
    """
    A single module from the encoder path consisting of the optional max
    pooling layer (one may specify the MaxPool kernel_size to be different
    than the standard (2,2,2), e.g. if the volumetric data is anisotropic
    (make sure to use complementary scale_factor in the decoder path) followed by
    a DoubleConv module.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        conv_kernel_size (int or tuple): size of the convolving kernel
        apply_pooling (bool): if True use MaxPool3d before DoubleConv
        pool_kernel_size (int or tuple): the size of the window
        pool_type (str): pooling layer: 'max' or 'avg'
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
        padding (int or tuple): add zero-padding added to all three sides of the input
    """

    def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
                 pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
                 num_groups=8, padding=1):
        super(Encoder, self).__init__()
        assert pool_type in ['max', 'avg']
        if apply_pooling:
            if pool_type == 'max':
                self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
            else:
                self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
        else:
            self.pooling = None

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=True,
                                         kernel_size=conv_kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups,
                                         padding=padding)

    def forward(self, x):
        if self.pooling is not None:
            x = self.pooling(x)
        x = self.basic_module(x)
        return x

In [None]:
#export
class Decoder(nn.Module):
    """
    A single module for decoder path consisting of the upsampling layer
    (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        conv_kernel_size (int or tuple): size of the convolving kernel
        scale_factor (tuple): used as the multiplier for the image H/W/D in
            case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
            from the corresponding encoder
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
        padding (int or tuple): add zero-padding added to all three sides of the input
    """

    def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
                 conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1):
        super(Decoder, self).__init__()
        if basic_module == DoubleConv:
            # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
            self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=conv_kernel_size, scale_factor=scale_factor, mode=mode)
            # concat joining
            self.joining = partial(self._joining, concat=True)
        else:
            # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
            self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=conv_kernel_size, scale_factor=scale_factor, mode=mode)
            # sum joining
            self.joining = partial(self._joining, concat=False)
            # adapt the number of in_channels for the ExtResNetBlock
            in_channels = out_channels

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=False,
                                         kernel_size=conv_kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups,
                                         padding=padding)

    def forward(self, encoder_features, x):
        x = self.upsampling(encoder_features=encoder_features, x=x)
        x = self.joining(encoder_features, x)
        x = self.basic_module(x)
        return x

    @staticmethod
    def _joining(encoder_features, x, concat):
        if concat:
            return torch.cat((encoder_features, x), dim=1)
        else:
            return encoder_features + x

In [None]:
#export
def conv3d(in_channels, out_channels, kernel_size, bias, padding):
    return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)

def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding):
    """
    Create a list of modules with together constitute a single conv layer with non-linearity
    and optional batchnorm/groupnorm.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size(int or tuple): size of the convolving kernel
        order (string): order of things, e.g.
            'cr' -> conv + ReLU
            'gcr' -> groupnorm + conv + ReLU
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
            'bcr' -> batchnorm + conv + ReLU
        num_groups (int): number of groups for the GroupNorm
        padding (int or tuple): add zero-padding added to all three sides of the input
    Return:
        list of tuple (name, module)
    """
    assert 'c' in order, "Conv layer MUST be present"
    assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'

    modules = []
    for i, char in enumerate(order):
        if char == 'r':
            modules.append(('ReLU', nn.ReLU(inplace=True)))
        elif char == 'l':
            modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
        elif char == 'e':
            modules.append(('ELU', nn.ELU(inplace=True)))
        elif char == 'c':
            # add learnable bias only in the absence of batchnorm/groupnorm
            bias = not ('g' in order or 'b' in order)
            modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
        elif char == 'g':
            is_before_conv = i < order.index('c')
            if is_before_conv:
                num_channels = in_channels
            else:
                num_channels = out_channels

            # use only one group if the given number of groups is greater than the number of channels
            if num_channels < num_groups:
                num_groups = 1

            assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
            modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
        elif char == 'b':
            is_before_conv = i < order.index('c')
            if is_before_conv:
                modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
            else:
                modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
        else:
            raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")

    return modules

In [None]:
#export
class Abstract3DUNet(nn.Module):
    """
    Base class for standard and residual UNet.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output segmentation masks;
            Note that that the of out_channels might correspond to either
            different semantic classes or to different binary segmentation mask.
            It's up to the user of the class to interpret the out_channels and
            use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
            or BCEWithLogitsLoss (two-class) respectively)
        f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
            of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
        final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
            final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
            to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
        basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
        layer_order (string): determines the order of layers
            in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
            See `SingleConv` for more info
        f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64);
            if tuple: number of feature maps at each level
        num_groups (int): number of groups for the GroupNorm
        num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
        is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied
            after the final convolution; if False (regression problem) the normalization layer is skipped at the end
        testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
            will be applied as the last operation during the forward pass; if False the model is in training mode
            and the `final_activation` (even if present) won't be applied; default: False
        conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
        pool_kernel_size (int or tuple): the size of the window
        conv_padding (int or tuple): add zero-padding added to all three sides of the input
    """

    def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
                 num_groups=8, num_levels=4, is_segmentation=True, testing=False,
                 conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, inp_scale=1, inp_offset=0, **kwargs):
        super(Abstract3DUNet, self).__init__()

        self.testing = testing
        self.inp_scale = inp_scale
        self.inp_offset = inp_offset

        if isinstance(f_maps, int):
            f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)

        # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
        encoders = []
        for i, out_feature_num in enumerate(f_maps):
            if i == 0:
                encoder = Encoder(in_channels, out_feature_num,
                                  apply_pooling=False,  # skip pooling in the firs encoder
                                  basic_module=basic_module,
                                  conv_layer_order=layer_order,
                                  conv_kernel_size=conv_kernel_size,
                                  num_groups=num_groups,
                                  padding=conv_padding)
            else:
                # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
                encoder = Encoder(f_maps[i - 1], out_feature_num,
                                  basic_module=basic_module,
                                  conv_layer_order=layer_order,
                                  conv_kernel_size=conv_kernel_size,
                                  num_groups=num_groups,
                                  pool_kernel_size=pool_kernel_size,
                                  padding=conv_padding)

            encoders.append(encoder)

        self.encoders = nn.ModuleList(encoders)

        # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
        decoders = []
        reversed_f_maps = list(reversed(f_maps))
        for i in range(len(reversed_f_maps) - 1):
            if basic_module == DoubleConv:
                in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
            else:
                in_feature_num = reversed_f_maps[i]

            out_feature_num = reversed_f_maps[i + 1]
            # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
            # currently strides with a constant stride: (2, 2, 2)
            decoder = Decoder(in_feature_num, out_feature_num,
                              basic_module=basic_module,
                              conv_layer_order=layer_order,
                              conv_kernel_size=conv_kernel_size,
                              num_groups=num_groups,
                              padding=conv_padding)
            decoders.append(decoder)

        self.decoders = nn.ModuleList(decoders)

        # in the last layer a 1×1 convolution reduces the number of output
        # channels to the number of labels
#         self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)

#         if is_segmentation:
#             # semantic segmentation problem
#             if final_sigmoid:
#                 self.final_activation = nn.Sigmoid()
#             else:
#                 self.final_activation = nn.ELU()
#         else:
#             # regression problem
#             self.final_activation = None

    def forward(self, x):
        # encoder part
        x = (x-self.inp_offset) / self.inp_scale
        encoders_features = []
        for encoder in self.encoders:
            x = encoder(x)
            # reverse the encoder outputs to be aligned with the decoder
            encoders_features.insert(0, x)

        # remove the last encoder's output from the list
        # !!remember: it's the 1st in the list
        encoders_features = encoders_features[1:]

        # decoder part
        for decoder, encoder_features in zip(self.decoders, encoders_features):
            # pass the output from the corresponding encoder and the output
            # of the previous decoder
            x = decoder(encoder_features, x)

#         x = self.final_conv(x)

#         # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs
#         # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
#         if self.testing and self.final_activation is not None:
#             x = self.final_activation(x)

        return x

In [None]:
#export
class UNet3D(Abstract3DUNet):
    """
    3DUnet model from
    `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
        <https://arxiv.org/pdf/1606.06650.pdf>`.
    Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
    """

    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
                 num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, inp_scale=1, inp_offset=0, **kwargs):
        super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
                                     basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
                                     num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
                                     conv_padding=conv_padding, inp_scale=inp_scale, inp_offset=inp_offset, **kwargs)

In [None]:
#export
class IntensityDist(nn.Module):
    def __init__(self, int_conc, int_rate, int_loc):
        super().__init__()
        self.int_conc = torch.nn.Parameter(torch.tensor(float(int_conc)))
        self.int_rate = torch.nn.Parameter(torch.tensor(float(int_rate)))
        self.int_loc = torch.nn.Parameter(torch.tensor(float(int_loc)))        

class UnetDecodeNoBn(nn.Module):
    def __init__(self, ch_in: int =1, ch_out: int=10, final_sigmoid : bool =False, depth: int =3, inp_scale: float=1., inp_offset: float=0.,  order='bcr', f_maps=64, p_offset=-5.,
                int_conc=5, int_rate=1, int_loc=1):
        super().__init__()
        self.unet = UNet3D(ch_in, ch_out, final_sigmoid=final_sigmoid, num_levels=depth, 
                           layer_order = order, inp_scale=inp_scale, inp_offset=inp_offset, f_maps=f_maps)
        self.p_offset = p_offset
        self.int_dist = IntensityDist(int_conc, int_rate, int_loc)
        
        self.p_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.p_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
        nn.init.constant_(self.p_out2.bias,p_offset)
        
        self.xyzi_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.xyzi_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
        
        self.xyzis_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.xyzis_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
        
        self.bg_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.bg_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
        
        nn.init.kaiming_normal_(self.p_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.p_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.xyzi_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.xyzi_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.xyzis_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.xyzis_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.bg_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.bg_out2.weight, mode='fan_in', nonlinearity='linear')
            
    def forward(self, x):
        out =  self.unet(x)
        
        logit    = F.elu(self.p_out1(out))
        logit    = self.p_out2(logit)
        logit    = torch.clamp(logit, -15., 15)
        
        xyzi = F.elu(self.xyzi_out1(out))
        xyzi = self.xyzi_out2(xyzi)
        
        xyz_mu   = torch.tanh(xyzi[:, :3])
        i_mu     = F.softplus(xyzi[:, 3:]) + self.int_dist.int_loc.detach() + 0.01
        xyzi_mu = torch.cat((xyz_mu, i_mu), dim=1)
        
        xyzis = F.elu(self.xyzis_out1(out))
        xyzis = self.xyzis_out2(xyzis)
        xyzi_sig = F.softplus(xyzis) + 0.01
        
        background = F.elu(self.bg_out1(out))
        background = self.bg_out2(background)
        background = self.unet.inp_scale * F.softplus(background)
        
        return {'logits': logit, 
                'xyzi_mu': xyzi_mu, 
                'xyzi_sigma': xyzi_sig, 
                'background': background}

In [None]:
#export
class InferenceNetwork(nn.Module):
    def __init__(self, ch_in: int =1, ch_out: int=10, final_sigmoid : bool =False, depth: int =3, inp_scale: float=1., inp_offset: float=0.,  order='bcr', f_maps=64, p_offset=-5.,
                int_conc=5, int_rate=1, int_loc=1):
        super().__init__()
        self.unet = UNet3D(ch_in, ch_out, final_sigmoid=final_sigmoid, num_levels=depth, 
                           layer_order = order, inp_scale=inp_scale, inp_offset=inp_offset, f_maps=f_maps)
        self.p_offset = p_offset
        self.int_dist = IntensityDist(int_conc, int_rate, int_loc)
        
        self.p_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.p_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
        nn.init.constant_(self.p_out2.bias,p_offset)
        
        self.xyzi_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.xyzi_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
        
        self.xyzis_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.xyzis_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
        
        self.bg_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
        self.bg_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
        
        nn.init.kaiming_normal_(self.p_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.p_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.xyzi_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.xyzi_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.xyzis_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.xyzis_out2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.kaiming_normal_(self.bg_out1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.bg_out2.weight, mode='fan_in', nonlinearity='linear')
            
    def forward(self, x):
        out =  self.unet(x)
        
        logit    = F.elu(self.p_out1(out))
        logit    = self.p_out2(logit)
        logit    = torch.clamp(logit, -15., 15)
        
        xyzi = F.elu(self.xyzi_out1(out))
        xyzi = self.xyzi_out2(xyzi)
        
        xyz_mu   = torch.tanh(xyzi[:, :3])
        i_mu     = F.softplus(xyzi[:, 3:]) + self.int_dist.int_loc.detach() + 0.01
        xyzi_mu = torch.cat((xyz_mu, i_mu), dim=1)
        
        xyzis = F.elu(self.xyzis_out1(out))
        xyzis = self.xyzis_out2(xyzis)
        xyzi_sig = F.softplus(xyzis) + 0.01
        
        background = F.elu(self.bg_out1(out))
        background = self.bg_out2(background)
        background = self.unet.inp_scale * F.softplus(background)
        
        return torch.cat([logit,xyzi_mu,xyzi_sig,background],1)
    
    def tensor_to_dict(self, x):
    
        return {'logits': x[:,0:1], 
                'xyzi_mu': x[:,1:5], 
                'xyzi_sigma': x[:,5:9], 
                'background': x[:,9:10]}

In [None]:
model = InferenceNetwork(order= 'ce')
output = model.tensor_to_dict(model(torch.randn([10,1,20,20,20])))
for k in output.keys():
    print(k, output[k].shape)

logits torch.Size([10, 1, 20, 20, 20])
xyzi_mu torch.Size([10, 4, 20, 20, 20])
xyzi_sigma torch.Size([10, 4, 20, 20, 20])
background torch.Size([10, 1, 20, 20, 20])


In [None]:
model = UnetDecodeNoBn(order= 'ce')
output = model(torch.randn([10,1,20,20,20]))
for k in output.keys():
    print(k, output[k].shape)

logits torch.Size([10, 1, 20, 20, 20])
xyzi_mu torch.Size([10, 4, 20, 20, 20])
xyzi_sigma torch.Size([10, 4, 20, 20, 20])
background torch.Size([10, 1, 20, 20, 20])


In [None]:
sum(p.numel() for p in model.parameters())

4371213

In [None]:
cfg = OmegaConf.load(default_conf)
model = hydra.utils.instantiate(cfg.model, int_loc=1, inp_scale=1, inp_offset=0)
model(torch.randn([10,1,20,20,20])).keys()

dict_keys(['logits', 'xyzi_mu', 'xyzi_sigma', 'background'])

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

319053

In [None]:
!nbdev_build_lib

Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted 17_eval_routines.ipynb.
Converted index.ipynb.
