# Implementacion SwinUnetTR


<img src="figs/swin_unetr.png" width="60%"/>

Algunos detalles:

* Emplea LeakyReLU en lugar de ReLU
* Normaliza con instance normalization


In [None]:
import numpy as np
import torch
from torch import nn
from typing import Union, Sequence, Tuple
from monai.utils import ensure_tuple_rep
from monai.networks.nets.swin_unetr import SwinTransformer

## Implementación resblock

<img  src="figs/res_block.png" width="50%" />


In [None]:
class ResBlock(nn.Module):
    """
    Bloque residual como el de la figura anterior. La normalizacion que emplea es InstanceNorm

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.

    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        ):
        super().__init__()

        kernel_size =  3

        self.adjust_in_channels = in_channels != out_channels

        if spatial_dims == 2:
            # Completar
            if self.adjust_in_channels:
                # Completar
        else:
            #repetir para 3D

        self.lrelu = nn.LeakyReLU()

        
    def forward(self, inp):
        residual = inp
        

        
        return out


### Comprobacion resblock

In [None]:
from monai.networks.nets.swin_unetr import SwinUNETR

model = SwinUNETR(128,4,3,feature_size = 48)

print(model.encoder2)
print("num of parameters: ", sum(p.numel() for p in model.encoder2.parameters() if p.requires_grad))

In [None]:
resblock = ResBlock(3, 48, 48)
print(resblock)
print("num of parameters: ", sum(p.numel() for p in resblock.parameters() if p.requires_grad))

## Decoder



In [None]:
class Decoder(nn.Module):
    """
    Decoder layer for swin unetr. 
        1- Upsamples the x_down image
        2- Cats the upsampled image with x in de channel dimmension
        3- Applies a resblock to the cat image (number of channels of the residual block is 2*in_channels )

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels_down: number of  channels lower level.
        in_channels: number of  channels upperlevel.
        out_channels: number of output channels.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        in_channels_down: int,
        out_channels: int,
        ):
        super().__init__()

        upsample_kernel_size =  2
        stride = 2



        if spatial_dims == 2:
            
        else:
            

    

        
 
    def forward(self, x_down, x):
        
        return out


In [None]:
print(model.decoder5)
print("num of parameters: ", sum(p.numel() for p in model.decoder5.parameters() if p.requires_grad))

In [None]:
decoder5 = Decoder(3, 384, 768, 384)
print(decoder5)
print("num of parameters: ", sum(p.numel() for p in decoder5.parameters() if p.requires_grad))

In [None]:
class mySwinUNETR(nn.Module):
    """
    Swin UNETR based on: "Hatamizadeh et al.,
    Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
    <https://arxiv.org/abs/2201.01266>"
    """

    def __init__(
        self,
        img_size: Union[Sequence[int], int],
        in_channels: int,
        out_channels: int,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        feature_size: int = 24,
        norm_name: Union[Tuple, str] = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = True,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        downsample="merging",
    ) -> None:
        """
        Args:
            img_size: dimension of input image.
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            feature_size: dimension of network feature size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            norm_name: feature normalization type and arguments.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            dropout_path_rate: drop path rate.
            normalize: normalize output intermediate features in each stage.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: number of spatial dims.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).

        Examples::

            # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
            >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)

            # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
            >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))

            # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
            >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

        """

        super().__init__()

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(2, spatial_dims)
        window_size = ensure_tuple_rep(7, spatial_dims)

        if not (spatial_dims == 2 or spatial_dims == 3):
            raise ValueError("spatial dimension should be 2 or 3.")

        for m, p in zip(img_size, patch_size):
            for i in range(5):
                if m % np.power(p, i + 1) != 0:
                    raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")

        if not (0 <= drop_rate <= 1):
            raise ValueError("dropout rate should be between 0 and 1.")

        if not (0 <= attn_drop_rate <= 1):
            raise ValueError("attention dropout rate should be between 0 and 1.")

        if not (0 <= dropout_path_rate <= 1):
            raise ValueError("drop path rate should be between 0 and 1.")

        if feature_size % 12 != 0:
            raise ValueError("feature_size should be divisible by 12.")

        self.normalize = normalize

        self.swinViT = SwinTransformer(
            in_chans=in_channels,
            embed_dim=feature_size,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dropout_path_rate,
            norm_layer=nn.LayerNorm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
            downsample=downsample,
        )

        # rama de arriba aumenta dimensiones de 4 a 48
        self.encoder1 = 

        # lo aplica al resultado del patch encoding resolucion 1/2
        self.encoder2 = 

        # lo aplica al resultado del patch encoding resolucion 1/4
        self.encoder3 = 

        # lo aplica al resultado del patch encoding resolucion 1/8
        self.encoder4 = 

        # NO HAY ENCODER 5!!! LA FIGURA ESTA MAL

        # lo aplica al resultado del patch encoding resolucion 1/32
        self.bottleneck = 

        # Combina salida de bottleneck con la salida del nivel 3 del SwinViT
        self.decoder5 = 


        self.decoder4 = 

        self.decoder3 = 

        self.decoder2 = 

        self.decoder1 = 


        if spatial_dims == 2:
            self.out = nn.Conv2d( # type: ignore
                in_channels=feature_size ,
                out_channels=out_channels,
                kernel_size=1,
                bias=True,
            )
        else:
            self.out = nn.Conv3d( # type: ignore
                in_channels=feature_size ,
                out_channels=out_channels,
                kernel_size=1,
                bias=True,
            )
   
    def forward(self, x_in):
        
        return logits

 


    



In [None]:
my_model = mySwinUNETR(128,4,3,feature_size = 48)

print("num of parameters: ", sum(p.numel() for p in my_model.parameters() if p.requires_grad))

modules = [my_model.swinViT, my_model.encoder1, my_model.encoder2, my_model.encoder3, my_model.encoder4, my_model.bottleneck, my_model.decoder5, my_model.decoder4, my_model.decoder3, my_model.decoder2, my_model.decoder1, my_model.out]

for module in modules:
    print("num of parameters: ", sum(p.numel() for p in module.parameters() if p.requires_grad))


In [None]:
model = SwinUNETR(128,4,3,feature_size = 48)


print("num of parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
modules = [model.swinViT, model.encoder1, model.encoder2, model.encoder3, model.encoder4, model.encoder10, model.decoder5, model.decoder4, model.decoder3, model.decoder2, model.decoder1, model.out]
for module in modules:
    print("num of parameters: ", sum(p.numel() for p in module.parameters() if p.requires_grad))

In [None]:
print(my_model.decoder2)
print(model.decoder2)

In [None]:
print(my_model.encoder1)
print(model.encoder1)

In [None]:
for m1, m2 in zip(model.swinViT.named_parameters(), my_model.swinViT.named_parameters()):
    if m1[1].shape != m2[1].shape:
        print(m1[0], m1[1].shape)
        print(m2[0], m2[1].shape)

print("num of parameters: ", sum(p.numel() for p in model.swinViT.parameters() if p.requires_grad))
print("num of parameters: ", sum(p.numel() for p in my_model.swinViT.parameters() if p.requires_grad))


In [None]:
for m in model.swinViT.named_parameters():
    
    print(m[0], m[1].shape)