**Attention U-Net: Learning Where to Look for the Pancreas**    
*Ozan Oktay, et al.*   
[[paper](https://arxiv.org/abs/1804.03999)]   
MIDL 2018   

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None) -> None:
        super(ConvLayer, self).__init__()

        if not hidden_dim:
            hidden_dim = out_dim

        self.conv = nn.Sequential(
                nn.Conv3d(in_dim, hidden_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(),
                nn.Conv3d(hidden_dim, out_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(out_dim),
                nn.ReLU()
            )

    def forward(self, x):
        return self.conv(x)

In [None]:
# Encoder
class AttnUNet_Encoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Encoder, self).__init__()

        self.hidden_dim = 32

        self.conv1 = ConvLayer(1, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv2 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv3 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv4 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.pool = nn.MaxPool3d(2, 2)

    def forward(self, x):

        h1  = self.conv1(x)
        p1  = self.pool(h1)

        h2  = self.conv2(p1)
        p2  = self.pool(h2)

        h3  = self.conv3(p2)
        p3  = self.pool(h3)

        h4  = self.conv4(p3)
        
        stage_outputs = [h1, h2, h3]

        return h4, stage_outputs


In [None]:
class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True) -> None:
        super(UpsamplingLayer, self).__init__()

        if is_deconv:
            self.upsampler = nn.ConvTranspose3d(in_dim, out_dim, kernel_size=2, stride=2, padding=0)
        else:
            self.upsampler = nn.Upsample(size=out_dim, scale_factor=2, mode='trilinear')

    def forward(self, x):
        return self.upsampler(x)

In [None]:
class AttentionGate(nn.Module):
    def __init__(self, in_dim, coarser_dim, hidden_dim) -> None:
        super(AttentionGate, self).__init__()

        self.GridGateSignal_generator = nn.Sequential(
                nn.Conv3d(coarser_dim, coarser_dim, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU()
            )
        
        # input feature x // the gating signal from a coarser scale
        # gating dim == in_dim*2
        self.w_x = nn.Conv3d(in_dim, hidden_dim, kernel_size=(2,2,2), stride=(2,2,2), padding=0, bias=False)
        self.w_g = nn.Conv3d(coarser_dim, hidden_dim, kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = nn.Conv3d(hidden_dim, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, inputs, coarser):
        query = self.GridGateSignal_generator(coarser)

        proj_x = self.w_x(inputs)
        proj_g = self.w_g(query)

        addtive = F.relu(proj_x + proj_g)
        attn_coef = self.psi(addtive)

        attn_coef = F.upsample(attn_coef, inputs.size()[2:], mode='trilinear')

        return attn_coef

In [None]:
# Decoder
class AttnUNet_Decoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Decoder, self).__init__()

        self.hidden_dim = 32
        self.out_dim    = 2

        self.conv1 = ConvLayer(1, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv2 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv3 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv4 = ConvLayer(self.hidden_dim//2, self.hidden_dim)


        self.hidden_dim = 16
        self.attn1  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up1    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv1  = ConvLayer(self.hidden_dim*2, self.out_dim, self.hidden_dim)

        self.hidden_dim *= 2 # 32
        self.attn2  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up2    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv2  = ConvLayer(self.hidden_dim*2, self.hidden_dim, self.hidden_dim)

        self.hidden_dim *= 2 # 64
        self.attn3  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up3    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv3  = ConvLayer(self.hidden_dim*2, self.hidden_dim, self.hidden_dim)

    
    def forward(self, enc_out, stage_outputs):

        attn_g3 = self.attn3(stage_outputs[-1], enc_out) * stage_outputs[-1]
        h3      = self.up3(enc_out)
        h3      = torch.concat([attn_g3, h3], dim=1)
        h3      = self.conv3(h3)

        attn_g2 = self.attn2(stage_outputs[-2], h3) * stage_outputs[-2]
        h2      = self.up2(h3)
        h2      = torch.concat([attn_g2, h2], dim=1)
        h2      = self.conv2(h2)

        attn_g1 = self.attn1(stage_outputs[-3], h2) * stage_outputs[-3]
        h1      = self.up1(h2)
        h1      = torch.concat([attn_g1, h1], dim=1)
        h1      = self.conv1(h1)

        return h1        
        