In [1]:
# pyright: reportGeneralTypeIssues=false
import torch
import torch.nn as nn
from einops import rearrange

# Implementation of the MedFormer model

Para ver todos los detalles podéis descargar el  [paper](https://arxiv.org/pdf/2203.00131.pdf) 

El diagrama de arquitectura del modelo es el siguiente:

![Arquitectura del modelo](figs/medformer.png)

# Convolutional stem 

El convolutional stem consiste en una convolución seguida de un bloque residual como se muestra en la siguiente figura:

![Convolutional stem](figs/convolutional_stem.png)

Para el bloque residual hay varias opciones, en este caso se ha elegido el siguiente:

![Residual block](figs/residual_block1.png)

Comenzaremos implementando el bloque residual que luego tambien emplearemos en varios puntos, y antes de este la secuencia basica norm->activation->conv:



In [None]:
#pyright: reportGeneralTypeIssues=false
from typing import Union, List

class NormActConv(nn.Module):
    """
    Normalization, activation and convolution layer 
    
    Args:
        spatial_dims (int): number of spatial dimensions
        in_ch (int): number of input channels
        out_ch (int): number of output channels
        kernel_size (int): kernel size for convolution
        stride (int): stride for convolution
        padding (int): padding for convolution
        groups (int): groups for convolution
        dilation (int): dilation for convolution
        bias (bool): bias for convolution
        norm (nn.Module): normalization layer (nn.BatchNorm3d, nn.InstanceNorm3d, nn.Identity) also 2d equivelents
        act (nn.Module): activation layer (nn.ReLU,nn.GELU,nn.Identity)
    """

    def __init__(self, spatial_dims,  in_ch, out_ch, padding = 1, kernel_size=3, stride = 1 ,
        groups=1, dilation=1, bias=False, norm=nn.InstanceNorm3d, act =nn.GELU):

        super().__init__()

        self.norm = norm(in_ch) if norm!=nn.Identity else nn.Identity()
        self.act = act() 

        if spatial_dims == 2:
            self.conv = nn.Conv2d(
                in_channels=in_ch, 
                out_channels=out_ch, 
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                dilation=dilation,
                bias=bias
            )
        else:
            self.conv = nn.Conv3d(
                in_channels=in_ch, 
                out_channels=out_ch, 
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                dilation=dilation,
                bias=bias
            )
        

    def forward(self, x): 
    
        return  self.conv(self.act(self.norm(x)))
        


Implementación del bloque residual:



In [2]:
#right: reportGeneralTypeIssues=false

class ResidualBlock(nn.Module):
    def __init__(self,spatial_dims,  in_ch, out_ch, kernel_size=[3,3,3], stride=1, norm=nn.InstanceNorm3d, act=nn.GELU, preact=True):
        super().__init__()
        
        pad_size = [i//2 for i in kernel_size]

        self.conv1 = NormActConv(spatial_dims, in_ch, out_ch, padding=pad_size, kernel_size = kernel_size, stride=stride,  norm=norm, act=act)
        self.conv2 = NormActConv(spatial_dims, out_ch, out_ch,  padding=pad_size, kernel_size = kernel_size, stride=1, norm=norm, act=act)

        
        if stride != 1 or in_ch != out_ch:
            # ESTO ES LA VERSION ORIGINAL NO ES MUY ORTODOXO POR LA ACTIVACION
            #self.shortcut = NormActConv(in_ch, out_ch, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act)
            
            # Hay que poner activacion identity
            self.shortcut = NormActConv(spatial_dims, in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False, act = nn.Identity)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)

        out += self.shortcut(residual)

        return out


Ahora ya tenemos los bloques necesarios para el convolutional stem, que consiste en una convolución seguida de un bloque residual, como en la figura de arriba.

In [None]:
class ConvStem(nn.Module):
    """ Convolutional stem: 
            sequential of conv and redidual block"""
    def __init__(self, spatial_dims,  in_ch, out_ch, kernel_size=3,  norm=nn.InstanceNorm3d, act=nn.GELU):
        super().__init__()

        pad_size = kernel_size // 2 
        #inicializar modulos
        

    def forward(self, x):
        #completar

        return out

# Down block

El down block es el encargado de reducir la dimensión en el encoder y generar los semantic maps:

<img src="figs/down_block.png" width="40%"/>

El primer bloque, patch_merging fusiona parches de la siguiente forma:

<img src="figs/patch_merging.png" width="40%"/>


In [21]:
class DepthwiseSeparableConv(nn.Sequential):
    def __init__(self, spatial_dims, in_ch, out_ch, stride=1, kernel_size=3, bias=False):
        
        
        if isinstance(kernel_size, list):
            padding = [i//2 for i in kernel_size]
        else:
            padding = kernel_size // 2
        if spatial_dims == 2:
            super().__init__(nn.Conv2d(
                in_channels=in_ch,
                out_channels=in_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=in_ch,
                bias=bias),
            nn.Conv2d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=1,
                stride=1,
                padding=0,
                groups=1,
                bias=bias)
            )
        else:
            super().__init__(nn.Conv3d(
                in_channels=in_ch,
                out_channels=in_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=in_ch,
                bias=bias),
            nn.Conv3d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=1,
                stride=1,
                padding=0,
                groups=1,
                bias=bias)
            )
    

In [19]:
import numpy as np

class PatchMerging(nn.Module):
    """
    Modified patch merging layer that works as down-sampling

    Args:
        dim (int): number of input channels
        out_dim (int): number of output channels
        norm (nn.Module): normalization layer (nn.BatchNorm3d, nn.InstanceNorm3d, nn.Identity) also 2d equivelents
        proj_type (str): projection type ('linear' or 'depthwise')
        down_scale (list): down-sampling scale for each dimension
        kernel_size (list): kernel size for 
        wise separable convolution
    """
    def __init__(self, dim, out_dim, norm=nn.InstanceNorm3d, proj_type='linear', down_scale=[2,2,2], kernel_size=[3,3,3]):
        super().__init__()
        self.dim = dim
        assert proj_type in ['linear', 'depthwise']

        spatial_dims = len(down_scale)
        self.down_scale = down_scale

        merged_dim = int(np.prod(down_scale) * dim)

        if proj_type == 'linear':
            self.reduction = nn.Conv3d(merged_dim, out_dim, kernel_size=1, bias=False)
        else:
            self.reduction = DepthwiseSeparableConv(spatial_dims, merged_dim, out_dim, kernel_size=kernel_size)

        self.norm = norm(merged_dim)

    def forward(self, x):
        """
        x: B, C, D, H, W
        """
        #merged_x = []
        #for i in range(self.down_scale[0]):
        #    for j in range(self.down_scale[1]):
        #        for k in range(self.down_scale[2]):
        #            tmp_x = x[:, :, i::self.down_scale[0], j::self.down_scale[1], k::self.down_scale[2]]
        #            merged_x.append(tmp_x)
        
        #x = torch.cat(merged_x, 1)

        #reimplementar con rearrange

        x = self.norm(x)
        x = self.reduction(x)

        return x


In [13]:
x = torch.arange(0, 4*4*4).reshape(1, 1, 4, 4, 4)

#completar para comprobar
x1 = rearrange(x, '', d1=2, d2=2, h1=2, h2=2, w1=2, w2=2)
print(x1[0,:,0,0,0])

merged_x = []
for i in range(2):
    for j in range(2):
        for k in range(2):
            tmp_x = x[:, :, i::2, j::2, k::2]
            merged_x.append(tmp_x)
        
x2 = torch.cat(merged_x, 1)

print(x2[0,:,0,0,0])


tensor([ 0,  1,  4,  5, 16, 17, 20, 21])
tensor([ 0,  1,  4,  5, 16, 17, 20, 21])


## Bidirectional Transformer Block

Este bloque tiene Multihead Attention comenzaremos por implementar con single head attention y luego lo extenderemos a multihead attention.

![Bidirectional Transformer Block](figs/bidirectional_transformer.png)


Notas para la implementacion:

* En lugar de hacer dos convoluciones (para Q/K y V) hacemos una con doble numero de canales y separamos (torch.split, torch.chunk)

* Añade dropout en dos lugares: a la matriz de atencion de M_out y a la salida de X_out

In [4]:
from typing import Sequence

class BiDirectionalAtt2D(nn.Module):
    """
    Single Head BiDirectional Attention block in 2D.

    Args:
        feat_dim: Number of input X feature channels.
        map_dim: number of M feature channels. (equal input and output)
        out_channels: number of output X feature channels.
        kernel_size: convolution kernel size.
        attn_drop: dropout rate for M attention map.
        proj_drop: dropout rate for X output.

    """

    def __init__(self, feat_dim, map_dim, head_dim, out_dim,  attn_drop=0., proj_drop=0.,
                     kernel_size=[3,3]):
        super().__init__()

        self.head_dim = head_dim
        self.feat_dim = feat_dim
        self.map_dim = map_dim
        self.scale = feat_dim ** (-0.5)
        self.dim_head = feat_dim
        spatial_dims = 2

        self.X_qv = DepthwiseSeparableConv(spatial_dims,feat_dim, self.head_dim*2, kernel_size=kernel_size)
        self.X_out = DepthwiseSeparableConv(spatial_dims, self.head_dim, out_dim, kernel_size=kernel_size)

        self.M_qv = nn.Conv2d(map_dim, self.head_dim*2, kernel_size=1, stride=1, padding=0, bias=False)
        self.M_out = nn.Conv2d(self.head_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)


    def forward(self, X, M):

        B, C, H, W = X.shape
        b, c, h, w = M.shape
        
        X_qv = self.X_qv(X)
        M_qv = self.M_qv(M)
        

        X_qk, X_v = torch.chunk(X_qv, 2, dim = 1)  # B, inner_dim, H, W
        M_qk, M_v = torch.split(M_qv,self.head_dim, dim = 1)  # B, inner_dim, H, W

        
        # completar

        return X_out, M_out



        

Comprobacion de la implementacion:

In [5]:
X = torch.rand(5,20, 16,16)
M = torch.rand(5,10,4,4)

biAttn = BiDirectionalAtt2D(20,10,15, 30)

X_out, M_out = biAttn(X,M)

print(X_out.shape, M_out.shape)

torch.Size([5, 256, 15]) torch.Size([5, 16, 15])
torch.Size([5, 30, 16, 16]) torch.Size([5, 10, 4, 4])


Bidirectional attention en 2D y 3D

In [15]:
from typing import Sequence

class BiDirectionalAtt(nn.Module):
    """
    Single Head BiDirectional Attention block in 2D and 3D.

    Args:
        spatial_dims: Number of spatial dimensions.
        feat_dim: Number of input X feature channels.
        map_dim: number of M feature channels. (equal input and output)
        head_dim: dimension of head features
        out_dim: number of output X feature channels.
        kernel_size: convolution kernel size.
        attn_drop: dropout rate for M attention map.
        proj_drop: dropout rate for X output.

    """

    def __init__(self,spatial_dims,  feat_dim, map_dim, head_dim, out_dim,  attn_drop=0., proj_drop=0.,
                     kernel_size=3):
        super().__init__()

        self.spatial_dims = spatial_dims
        self.head_dim = head_dim
        self.feat_dim = feat_dim
        self.map_dim = map_dim
        self.scale = head_dim ** (-0.5)
        self.dim_head = head_dim
        
        self.X_qv = DepthwiseSeparableConv(spatial_dims,feat_dim, self.head_dim*2, kernel_size=kernel_size)
        self.X_out = DepthwiseSeparableConv(spatial_dims, self.head_dim, out_dim, kernel_size=kernel_size)

        if self.spatial_dims ==2:
            self.M_qv = nn.Conv2d(map_dim, self.head_dim*2, kernel_size=1, stride=1, padding=0, bias=False)
            self.M_out = nn.Conv2d(self.head_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False)
        else:
            self.M_qv = nn.Conv3d(map_dim, self.head_dim*2, kernel_size=1, stride=1, padding=0, bias=False)
            self.M_out = nn.Conv3d(self.head_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)


    def forward(self, X, M):

        if self.spatial_dims ==2:
            B, C, H, W = X.shape
            b, c, h, w = M.shape
            d = D = None
        else:
            B, C, D, H, W = X.shape
            b, c, d, h, w = M.shape

        X_qv = self.X_qv(X)
        M_qv = self.M_qv(M)
        

        X_qk, X_v = torch.chunk(X_qv, 2, dim = 1)  # B, inner_dim, H, W
        M_qk, M_v = torch.split(M_qv,self.head_dim, dim = 1)  # B, inner_dim, H, W

       #Copiar del anterior y modificar

        return X_out, M_out




In [17]:
X = torch.rand(5,20, 16,16)
M = torch.rand(5,10,4,4)


spatial_dims = 2
feat_dim = 20
map_dim = 10
head_dim = 15
out_dim = 25

biAttn2D = BiDirectionalAtt(spatial_dims,  feat_dim, map_dim, head_dim, out_dim)

X_out, M_out = biAttn2D(X,M)

print(X_out.shape, M_out.shape)


X = torch.rand(5,20, 16, 16,16)
M = torch.rand(5,10,4, 4,4)

spatial_dims = 3
feat_dim = 20
map_dim = 10
head_dim = 15
out_dim = 25
biAttn3D = BiDirectionalAtt(spatial_dims,  feat_dim, map_dim, head_dim, out_dim)

X_out, M_out = biAttn3D(X,M)

print(X_out.shape, M_out.shape)

torch.Size([5, 25, 16, 16]) torch.Size([5, 10, 4, 4])
torch.Size([5, 25, 16, 16, 16]) torch.Size([5, 10, 4, 4, 4])


Multihead bidirectional attention

In [4]:
from typing import Sequence

class MultiHeadBiDirectionalAtt(nn.Module):
    """
    Single Head BiDirectional Attention block in 2D and 3D.

    Args:
        spatial_dims: Number of spatial dimensions.
        feat_dim: Number of input X feature channels.
        num_heads: Number of attention heads.
        head_dim: Number of attention head channels.
        map_dim: number of M feature channels. (equal input and output)
        out_dim: number of output X feature channels.
        attn_drop: dropout rate for M attention map.
        proj_drop: dropout rate for X output.

    """

    def __init__(self,spatial_dims,  feat_dim, map_dim, num_heads, head_dim,  out_dim,  attn_drop=0., proj_drop=0.,
                     kernel_size=3):
        super().__init__()

        self.spatial_dims = spatial_dims
        self.feat_dim = feat_dim
        self.map_dim = map_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.inner_dim = head_dim * num_heads
        self.scale = feat_dim ** (-0.5)
        
        
        self.X_qv = DepthwiseSeparableConv(spatial_dims,feat_dim, self.inner_dim*2, kernel_size=kernel_size)
        self.X_out = DepthwiseSeparableConv(spatial_dims, self.inner_dim, out_dim, kernel_size=kernel_size)

        if self.spatial_dims ==2:
            self.M_qv = nn.Conv2d(map_dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=False)
            self.M_out = nn.Conv2d(self.inner_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False)
        else:
            self.M_qv = nn.Conv3d(map_dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=False)
            self.M_out = nn.Conv3d(self.inner_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)


    def forward(self, X, M):

        if self.spatial_dims ==2:
            B, C, H, W = X.shape
            b, c, h, w = M.shape
            D = d = None
        elif self.spatial_dims ==3:
            B, C, D, H, W = X.shape
            b, c, d, h, w = M.shape
        else:
            raise NotImplementedError

        X_qv = self.X_qv(X)
        M_qv = self.M_qv(M)
        

        X_qk, X_v = torch.chunk(X_qv, 2, dim = 1)  # B, inner_dim, (D) H, W
        M_qk, M_v = torch.split(M_qv,self.inner_dim, dim = 1)  # B, inner_dim, (D) H, W

        #copiar del anterior y modificar para incluir las heads

        return X_out, M_out



In [29]:
X = torch.rand(5,20, 16,16)
M = torch.rand(5,10,4,4)


spatial_dims = 2
feat_dim = 20
map_dim = 10
head_dim = 15
out_dim = 25
num_heads = 4


biAttn2D = MultiHeadBiDirectionalAtt(spatial_dims,  feat_dim, map_dim, num_heads, head_dim, out_dim)

X_out, M_out = biAttn2D(X,M)

print(X_out.shape, M_out.shape)


X = torch.rand(5,20, 16, 16,16)
M = torch.rand(5,10,4, 4,4)

spatial_dims = 3
feat_dim = 20
map_dim = 10
head_dim = 15
out_dim = 25
num_heads = 5

biAttn3D = MultiHeadBiDirectionalAtt(spatial_dims,  feat_dim, map_dim, num_heads, head_dim, out_dim)

X_out, M_out = biAttn3D(X,M)

print(X_out.shape, M_out.shape)

torch.Size([5, 25, 16, 16]) torch.Size([5, 10, 4, 4])
torch.Size([5, 25, 16, 16, 16]) torch.Size([5, 10, 4, 4, 4])


## Bidirectional Transformer Block

Una vez tenemos la atención bidireccional el transformer bidireccional añade las partes de atencion y Feed Forward como en un transformer normal.

<img src="figs/bidirectional_transformer_layer1.png" width="40%"/>

In [None]:
class BidirectionAttentionBlock(nn.Module):
    """Bidirectional Attention Block
    Args:
        feat_dim (int): number of channels of feature map
        map_dim (int): number of channels of semantic map
        out_dim (int): number of channels of output feature map
        heads (int): number of heads
        dim_head (int): dimension of each head
        norm (nn.Module, optional): normalization layer. Defaults to nn.InstanceNorm3d.
        act (nn.Module, optional): activation layer. Defaults to nn.GELU.
        expansion (int, optional): expansion ratio of MLP. Defaults to 4.
    """
    def __init__(self, feat_dim, map_dim, out_dim, heads, dim_head, norm=nn.InstanceNorm3d, act=nn.GELU,
                expansion=4, attn_drop=0., proj_drop=0., map_size=[8, 8, 8], 
                proj_type='depthwise', kernel_size=[3,3,3]):
        super().__init__()

        spatial_dims = len(map_size)
        
        self.norm1 = nn.LayerNorm(feat_dim)  # norm layer for feature map
        self.norm2 = nn.LayerNorm(map_dim) if norm else nn.Identity() # norm layer for semantic map
        
        self.attn = MultiHeadBiDirectionalAtt(spatial_dims, feat_dim, map_dim, out_dim, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, kernel_size=kernel_size)

        
        if feat_dim != out_dim:
            self.shortcut = NormActConv(spatial_dims, feat_dim, out_dim, 1, padding=0, norm=norm, act=nn.Identity)
        else:
            self.shortcut = nn.Identity()

        if proj_type == 'linear':
            self.feedforward = FusedMBConv(out_dim, out_dim, expansion=expansion, kernel_size=1, act=act, norm=norm)
        else:
            self.feedforward = MBConv(out_dim, out_dim, expansion=expansion, kernel_size=kernel_size, act=act, norm=norm)

    def forward(self, x, semantic_map):
        
        feat = self.norm1(x)
        mapp = self.norm2(semantic_map)

        out, mapp = self.attn(feat, mapp)

        out += self.shortcut(x)
        out = self.feedforward(out)

        mapp += semantic_map

        return out, mapp


In [None]:
class MultiHeadBiDirectionalAttLayer(nn.Module):
    """
    Multi head bi-directional attention layer (repeats MultiHeadBiDirectionalAtt for num_blocks times)

    Args:
        spatial_dims: Number of spatial dimensions.
        feat_dim: Number of input X feature channels.
        num_heads: Number of attention heads.
        head_dim: Number of attention head channels.
        map_dim: number of M feature channels. (equal input and output)
        out_dim: number of output X feature channels.
        attn_drop: dropout rate for M attention map.
        proj_drop: dropout rate for X output.

    """


    def __init__(self, spatial_dims, feat_dim, map_dim, out_dim, num_blocks, heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0., map_size=[8,8,8], proj_type='depthwise', act=nn.GELU, kernel_size=[3,3,3]):
        super().__init__()

        self.spatial_dims = spatial_dims
        dim1 = feat_dim
        dim2 = out_dim

        self.blocks = nn.ModuleList([])
        for i in range(num_blocks):
            self.blocks.append(MultiHeadBiDirectionalAtt(spatial_dims, 
                        feat_dim, map_dim, out_dim, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, kernel_size=kernel_size))
            dim1 = out_dim

    def forward(self, x, semantic_map):
        for block in self.blocks:
            x, semantic_map = block(x, semantic_map)

        
        return x, semantic_map

## Semantic map generation

Seguiremos con el bloque de la Figura 8.

<img src="figs/semantic_map_generation.png" width="30%"/>

In [52]:
class SemanticMapGeneration(nn.Module):
    """
    Semantic Map Generation Module

    Args:
        spatial_dims: Number of spatial dimensions.
        feat_dim: Number of input X feature channels.
        w: Semantic map width.
        h: Semantic map height.
    """

    def __init__(self,spatial_dims,  feat_dim, w, h, d = None):
        super().__init__()

        self.spatial_dims = spatial_dims
        self.feat_dim = feat_dim
        self.w = w
        self.h = h
        self.d = d


        if self.spatial_dims ==2:
            self.weight_map_size = w * h 
            self.conv = nn.Conv2d(self.feat_dim, self.feat_dim + self.w * self.h , kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.weight_map_size = w * h * d
            self.conv = nn.Conv3d(self.feat_dim, self.feat_dim +  self.w * self.h * self.d, kernel_size=3, stride=1, padding=1, bias=False)
        
        
    def forward(self, x):


        tmp  = self.conv(X)
        
        token_map, weight_map  = torch.split(tmp, [self.feat_dim, self.weight_map_size], dim=1)

        if self.spatial_dims ==2:
            weight_map = rearrange(weight_map, 'b c h w -> b c (h w)')
            token_map = rearrange(token_map, 'b c h w -> b c (h w)')
        else:
            weight_map = rearrange(weight_map, 'b c d w h -> b c (d w h)')
            token_map = rearrange(token_map, 'b c d w h  -> b c (d w h)')
        
        weight_map = torch.softmax(weight_map, dim=-1)

        print(weight_map.shape, token_map.shape)
        out = torch.einsum('bsS, bcS->bcs', weight_map, token_map)
        if self.spatial_dims ==2:
            out = rearrange(out, 'b c (h w) -> b c h w', h=self.h, w=self.w)
        else:
            out = rearrange(out, 'b c (d h w) -> b c d h w', h=self.h, w=self.w, d=self.d)
        return out


        

In [53]:
x = torch.rand(5, 64, 4096)
y = torch.rand(5, 20, 4096)
print(x.shape, y.shape)
torch.einsum('b s S,b c S->b c s', y, x).shape

torch.Size([5, 64, 4096]) torch.Size([5, 20, 4096])


torch.Size([5, 64, 20])

In [56]:
spatial_dims = 2
feat_dim = 20
w = 4
h = 4
d = 4

semanticMap = SemanticMapGeneration(spatial_dims,  feat_dim, w, h, d)

X = torch.rand(5,20, 16, 16)

y = semanticMap(X)
print(y.shape)

torch.Size([5, 16, 256]) torch.Size([5, 20, 256])
torch.Size([5, 20, 4, 4])


# Semantic map fusion

Seguiremos con el bloque de la Figura 8.C


<img src="figs/semantic_map_fusion.png" width="30%"/>

La explicacion de este bloque es la siguiente:
Multi-scale fusion plays a vital role in dense prediction tasks
to combine the high-level semantic and low-level detailed information. The semantic map in BMHA is naturally suitable for multi-scale fusion with a minimal computation overhead. See in Fig.
1 (C), given 2D semantic maps from multiple scales: M1, M2, . . . , Mn, we first flatten them and
concatenate them together into a long 1D token sequence MF . The sequence MF contains all tokens from all scales and is then fed into conventional Transformer blocks for multi-scale semantic
fusion. The fused sequence is then chunked and reshaped back to 2D semantic maps. Unlike previous approaches fuse multi-scale features locally, such as fusing with resized multi-scale feature9 or
with atrous spatial pyramid pooling47, the proposed approach propagates information across all tokens at every scale via the all-to-all attention to form a semantically and spatially global multi-scale
fusion.

Algun comentario:

* Teoricamente cada semantic map de cada resolucion tiene una dimension diferente
* Antes de concatenar los tokens de cada resolucion se proyectan a una dimension comun "dim"
* Tras el transformer cada bloque se lleva a la dimension original de cada resolucion
* en la implementacion original distingue entre dropout de attencion y de FF, aunque luego solo usa el de FF

In [83]:
class SemanticMapFusion(nn.Module):
    """semantic map fusion module
    Args:
        spatial_dims: Number of spatial dimensions.
        in_dim_list: number of input feature channels for each resolution
        dim: Number of feature channels 
        heads: Number of attention heads.
        depth: Number of transformer blocks.

    """
    
    def __init__(self, spatial_dims, in_dim_list, dim, heads, depth=1,  dropout=0.):
        super().__init__()

        self.spatial_dims = spatial_dims
        self.dim = dim

        # project all maps to the same channel num
        self.in_proj = nn.ModuleList([]) # importante cuando guardemos listas de bloques
        for i in range(len(in_dim_list)):
            if self.spatial_dims == 2:
                self.in_proj.append(nn.Conv2d(in_dim_list[i], dim, kernel_size=1, bias=False))
            else:
                self.in_proj.append(nn.Conv3d(in_dim_list[i], dim, kernel_size=1, bias=False))


        self.transformer =  nn.Sequential(
            *[torch.nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim, dropout=dropout, 
                activation='gelu', batch_first=True, norm_first=True) for k in range(4)])

        # project all maps back to their origin channel num
        self.out_proj = nn.ModuleList([])
        for i in range(len(in_dim_list)):
            if self.spatial_dims == 2:
                self.out_proj.append(nn.Conv2d(dim, in_dim_list[i], kernel_size=1, bias=False))
            else:
                self.out_proj.append(nn.Conv3d(dim, in_dim_list[i], kernel_size=1, bias=False))

    def forward(self, map_list):
        if self.spatial_dims == 2:
            b, c, h, w = map_list[0].shape
            d = None
        else:
            b, c, d, h, w = map_list[0].shape

        #proj_maps = [self.in_proj[i](map_list[i]).view(B, self.dim, -1).permute(0, 2, 1) for i in range(len(map_list))]
        #B, L, C where L=DHW
        proj_maps = [in_proj(map) for map, in_proj in  zip(map_list, self.in_proj)]
        if self.spatial_dims == 2:
            proj_maps = [rearrange(map, 'b c h w -> b (h w) c', h=h, w=w) for map in proj_maps]
        else:
            proj_maps = [rearrange(map, 'b c d h w -> b (d h w) c', d=d, h=h, w=w) for map in proj_maps]

        proj_maps = torch.cat(proj_maps, dim=1) #all tokens as a 1D sequence
        attned_maps = self.transformer(proj_maps)

        attend_maps = attned_maps.chunk(len(map_list), dim=1) # split the tensor into maps of different resolutions

        if self.spatial_dims == 2:
            maps_out = [rearrange(attend_map, 'b (h w) c -> b c h w', h=h, w=w) for attend_map in attend_maps]
        else:
            maps_out = [rearrange(attend_map, 'b (d h w) c -> b c d h w', d=d, h=h, w=w) for attend_map in attend_maps]
        
        
        
        maps_out = [out_proj(attend_map) for out_proj, attend_map in zip(self.out_proj, maps_out)]

        return maps_out


In [84]:
m = torch.randn(5, 20, 4, 4)
M = [m, m, m]

sF = SemanticMapFusion(2, [20, 20, 20], 20, 4, depth=4)

Mo = sF(M)


print(Mo[0].shape)


torch.Size([5, 20, 4, 4])


# Downsampling and upsampling


In [None]:

class MedFormerEncoder(nn.Module):
    def __init__(self, in_ch, out_ch, conv_num, trans_num, down_scale=[2,2,2], kernel_size=[3,3,3],
                conv_block=BasicBlock, heads=4, dim_head=64, expansion=1, attn_drop=0., 
                proj_drop=0., map_size=[8,8,8], proj_type='depthwise', norm=nn.BatchNorm3d, 
                act=nn.GELU, map_generate=False, map_dim=None):
        super().__init__()
        

        map_dim = out_ch if map_dim is None else map_dim
        self.map_generate = map_generate
        if map_generate:
            self.map_gen = SemanticMapGeneration(out_ch, map_dim, map_size)

        self.patch_merging = PatchMerging(in_ch, out_ch, norm=norm, proj_type=proj_type, down_scale=down_scale, kernel_size=kernel_size)

        block_list = []
        for i in range(conv_num):
            block_list.append(conv_block(out_ch, out_ch, norm=norm, act=act, kernel_size=kernel_size))
        self.conv_blocks = nn.Sequential(*block_list)

        self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, heads=heads, \
                dim_head=dim_head, norm=norm, act=act, expansion=expansion, attn_drop=attn_drop, \
                proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, kernel_size=kernel_size)

    def forward(self, x):
        x = self.patch_merging(x)

        out = self.conv_blocks(x)

        if self.map_generate:
            semantic_map = self.map_gen(out)
        else:
            semantic_map = None

        out, semantic_map = self.trans_blocks(out, semantic_map)
            

        return out, semantic_map


In [64]:
class MedFormerDecoder(nn.Module):
    def __init__(self, in_ch, out_ch, conv_num, trans_num, up_scale=[2,2,2], kernel_size=[3,3,3], 
                conv_block=BasicBlock, heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0.,
                map_size=[4,8,8], proj_type='depthwise', norm=nn.BatchNorm3d, act=nn.GELU, 
                map_dim=None, map_shortcut=False):
        super().__init__()


        self.map_shortcut = map_shortcut
        map_dim = out_ch if map_dim is None else map_dim
        if map_shortcut:
            self.map_reduction = nn.Conv3d(in_ch+out_ch, map_dim, kernel_size=1, bias=False)
        else:
            self.map_reduction = nn.Conv3d(in_ch, map_dim, kernel_size=1, bias=False)

        self.trans_blocks = BasicLayer(in_ch+out_ch, map_dim, out_ch, num_blocks=trans_num, \

                    heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion, \
                    attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size,\
                    proj_type=proj_type, kernel_size=kernel_size)

        if trans_num == 0:
            dim1 = in_ch+out_ch
        else:
            dim1 = out_ch

        conv_list = []
        for i in range(conv_num):
            conv_list.append(conv_block(dim1, out_ch, kernel_size=kernel_size, norm=norm, act=act))
            dim1 = out_ch
        self.conv_blocks = nn.Sequential(*conv_list)

    def forward(self, x1, x2, map1, map2=None):
        # x1: low-res feature, x2: high-res feature shortcut from encoder
        # map1: semantic map from previous low-res layer
        # map2: semantic map from encoder shortcut, might be none if we don't have the map from encoder
        
        x1 = F.interpolate(x1, size=x2.shape[-3:], mode='trilinear', align_corners=True)
        feat = torch.cat([x1, x2], dim=1)
        
        if self.map_shortcut and map2 is not None:
            semantic_map = torch.cat([map1, map2], dim=1)
        else:
            semantic_map = map1
        
        if semantic_map is not None:
            semantic_map = self.map_reduction(semantic_map)

        out, semantic_map = self.trans_blocks(feat, semantic_map)
        out = self.conv_blocks(out)

        return out, semantic_map
       



Sequential(
  (0): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=20, out_features=20, bias=True)
    )
    (linear1): Linear(in_features=20, out_features=20, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=20, out_features=20, bias=True)
    (norm1): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (1): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=20, out_features=20, bias=True)
    )
    (linear1): Linear(in_features=20, out_features=20, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=20, out_features=20, bias=True)
    (norm1): LayerNorm((20,), eps=1e-05, elementwise_affine=True)
  