**Attention U-Net: Learning Where to Look for the Pancreas**    
*Ozan Oktay, et al.*   
[[paper](https://arxiv.org/abs/1804.03999)]   
MIDL 2018   

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None) -> None:
        super(ConvLayer, self).__init__()

        if not hidden_dim:
            hidden_dim = out_dim

        self.conv = nn.Sequential(
                nn.Conv3d(in_dim, hidden_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(),
                nn.Conv3d(hidden_dim, out_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(out_dim),
                nn.ReLU()
            )

    def forward(self, x):
        return self.conv(x)

In [None]:
# Encoder
class AttnUNet_Encoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Encoder, self).__init__()

        self.hidden_dim = 16

        self.conv1 = ConvLayer(1, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv2 = ConvLayer(self.hidden_dim/2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv3 = ConvLayer(self.hidden_dim/2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv4 = ConvLayer(self.hidden_dim/2, self.hidden_dim)

        self.pool = nn.MaxPool3d(2, 2)

    def forward(self, x):

        h1  = self.conv1(x)
        p1  = self.pool(h1)

        h2  = self.conv2(p1)
        p2  = self.pool(h2)

        h3  = self.conv3(p2)
        p3  = self.pool(h3)

        h4  = self.conv4(p3)
        
        stage_outputs = [h1, h2, h3]

        return h4, stage_outputs


In [None]:
class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True) -> None:
        super().__init__()

        if is_deconv:
            self.upsampler = nn.ConvTranspose3d(in_dim, out_dim, kernel_size=)

In [None]:
# Decoder
class AttnUNet_Decoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Decoder, self).__init__()
        