# Implementacion UNETR

![alt](figs/unetr.png)


Todas las convoluciones tienen bias = False (logico ya que hay batchnorm)

Diferencias entre dibujo e implementacion:

* Emplea LeakyReLU en lugar de ReLU
* El primer bloque azul de cada encoder simplemente es un TransposeConv para reducir de 768 al numero de canales de la capa
* El resto de las cajas azules tienen el TransposeConv y dos bloques amarillos

## Secuencia bloques amarillos:

In [1]:
from typing import Sequence, Tuple, Union, Optional
from torch import nn
import torch

class BloquesAmarillos(nn.Sequential):
    """
    # Implementacion de la secuencia de dos bloques amarillos. Cada bloque amarillo es Conv3x3x3 BN LeakyRelu

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        kernel_size: convolution kernel size.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,    
    ):
        padding = kernel_size // 2
        if spatial_dims == 2:
            # completar por alumno los bloques
            module_list = []
        elif spatial_dims == 3:
            module_list = []
        else:
            raise NotImplementedError("Unsupported spatial_dims: {}".format(spatial_dims))
        super().__init__(*module_list)

    # no hace falta definir forward porque nn.Sequential ya lo tiene definido


In [2]:
bloque_amarillo = BloquesAmarillos(3, 16, 32)

print(bloque_amarillo)

BloquesAmarillos(
  (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.01)
  (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.01)
)


## Encoder

Corresponde con los bloques azules de la figura.

In [3]:

class Encoder(nn.Module):
    """
    Implementa la secuencia bloques azules. El primer bloque azul es unicamente un convtranspose2d 
    Luego los siguientes cada bloque azul es un convtranspose2d y dos bloques amarillos.
    ES DIFERENTE A LO QUE PONE EN LA FIGURA!
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        num_bloques_azules: int,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            num_layer: number of upsampling blocks.
            kernel_size: convolution kernel size.
            stride: convolution stride.
            upsample_kernel_size: convolution kernel size for transposed convolution layers.
        """

        super().__init__()


        if spatial_dims == 2:
            self.primer_bloque = 

            #este sera un nn.ModuleList (ver forward)
            self.blocks = 
                    
            
        else:
            # completar por alumno para 3d

    def forward(self, x):
        x = self.primer_bloque(x)
        for blk in self.blocks:
            x = blk(x)
        return x




## Comprobacion
El número de parametros tiene que salir 2819072

In [4]:
encoder2 = Encoder(3,768, 128, 3)
print(encoder2)

num_params = sum(p.numel() for p in encoder2.parameters() if p.requires_grad)
print("num params = ", num_params)

Encoder(
  (primer_bloque): ConvTranspose3d(768, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
  (blocks): ModuleList(
    (0): Sequential(
      (0): ConvTranspose3d(128, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
      (1): BloquesAmarillos(
        (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): LeakyReLU(negative_slope=0.01)
      )
    )
    (1): Sequential(
      (0): ConvTranspose3d(128, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
      (1): BloquesAmarillos(
        (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=

## Decoder

![alt](figs/decoder.png)

* Esta formado por un bloque verde que interpola y reduce canales
* En la interpolacion se reduce el numero de canales para que coincida con el numero de canales de la capa anterior
* Dos bloques amarillos que hacen lo mismo que los de antes
* El número de canales en la entrada de los bloques amarillos es el doble que en el encoder




In [5]:
class Decoder(nn.Module):
    """
    Implementa el decoder. Primero hace un upsampling del nivel inferior, luego lo combina con el nivel superior y finalmente
    dos bloques amarillos
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        in_channels_down: int,
        out_channels: int,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            in_cnannels_down: number of input channels from the lower level.
            out_channels: number of output channels.
        """

        super().__init__()
        if spatial_dims == 2:
            self.upsample = nn.ConvTranspose2d(in_channels_down, in_channels, 2, 2, bias=False)
            self.decoder = BloquesAmarillos(spatial_dims, 2*in_channels, out_channels)
        else:
            self.upsample = nn.ConvTranspose3d(in_channels_down, in_channels, 2, 2, bias=False)
            self.decoder = BloquesAmarillos(spatial_dims, 2*in_channels, out_channels)

    def forward(self, x_down, x):
        x2 = self.upsample(x_down)
        x3 = out = torch.cat((x2,x), dim=1)
        out = self.decoder(x3)
        return out

In [6]:
decoder5 = Decoder(3, 256, 512, 256)
print(decoder5)

num_params = sum(p.numel() for p in decoder5.parameters() if p.requires_grad)
print(num_params)



Decoder(
  (upsample): ConvTranspose3d(512, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
  (decoder): BloquesAmarillos(
    (0): Conv3d(512, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
)
6358016


In [37]:
from typing import Sequence, Tuple, Union

import torch.nn as nn


from monai.networks.nets.vit import ViT


# para que funcion tanto si spatial_dims es 2 como 3
from monai.utils.misc import ensure_tuple_rep


class myUNETR(nn.Module):
    """
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        img_size: Union[Sequence[int], int],
        feature_size: int = 64,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        pos_embed: str = "conv",
        norm_name: Union[Tuple, str] = "instance",
        conv_block: bool = True,
        res_block: bool = True,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        qkv_bias: bool = False,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dims.
            qkv_bias: apply the bias term for the qkv linear layer in self attention block

        Examples::

            # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

             # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)

            # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')

        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.num_layers = 12
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.patch_size = ensure_tuple_rep(16, spatial_dims)
        self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
        self.hidden_size = hidden_size
        self.classification = False
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            classification=self.classification,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
            qkv_bias=qkv_bias,
        )
        self.encoder1 = BloquesAmarillos(  # type: ignore
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
        )
        self.encoder2 = Encoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_bloques_azules=3,
        )
        self.encoder3 = Encoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_bloques_azules=2,
        )
        self.encoder4 = Encoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_bloques_azules=1,
        )
        
        self.decoder5 = Decoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            in_channels_down=hidden_size,
            out_channels=feature_size * 8,
        )
        self.decoder4 = Decoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            in_channels_down=feature_size * 8,
            out_channels=feature_size * 4,
        )
        self.decoder3 = Decoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            in_channels_down=feature_size * 4,
            out_channels=feature_size * 2,
        )

        self.decoder2 = Decoder( # type: ignore
            spatial_dims=spatial_dims,
            in_channels=feature_size ,
            in_channels_down=feature_size * 2,
            out_channels=feature_size ,
        )
        if spatial_dims == 2:
            self.out = nn.Conv2d( # type: ignore
                in_channels=feature_size ,
                out_channels=out_channels,
                kernel_size=1,
                bias=True,
            )
        else:
            self.out = nn.Conv3d( # type: ignore
                in_channels=feature_size ,
                out_channels=out_channels,
                kernel_size=1,
                bias=True,
            )

        
        self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
        self.proj_view_shape = list(self.feat_size) + [self.hidden_size]

    def proj_feat(self, x):
        new_view = [x.size(0)] + self.proj_view_shape
        x = x.view(new_view)
        x = x.permute(self.proj_axes).contiguous()
        return x

    def forward(self, x_in):
        x, hidden_states_out = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4))
        dec4 = self.proj_feat(x)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        return self.out(out)


In [38]:
model = myUNETR(in_channels=4, out_channels=3, img_size=(128,128,128), feature_size=64, norm_name='batch')

In [39]:
blocks = [model, model.encoder1, model.encoder2, model.encoder3, model.encoder4]

for block in blocks:
    num_params = sum(p.numel() for p in block.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_params}")

Number of trainable parameters: 142451907
Number of trainable parameters: 117760
Number of trainable parameters: 2819072
Number of trainable parameters: 5637120
Number of trainable parameters: 3145728


# Variaciones

* Cambiar el tipo de normalización
* Cambiar el tipo de activación
* Bloques amarillos con conexión residual
* Bloques azules con conexión residual