## U-Net: Convolutional networks for biomedical image segmentation (O. Ronneberger et al., 2015)

In [None]:
from pathlib import Path
import pytorch_lightning as pl
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule

import torch
from torch.nn import functional as F

import numpy as np
import pytorch_lightning as pl
from torch import nn
from torchmetrics.metric import Metric

import fastmri
from fastmri.models.unet import ConvBlock, TransposeConvBlock
from fastmri import evaluate
from fastmri.pl_modules import MriModule
from fastmri.pl_modules.mri_module import DistributedMetricSum

## code below was adapted from the fastmri github in order to modify the UNET

# Unet with Attention Gate

In [None]:
class attention_gate(nn.Module):
    """SOURCES: https://arxiv.org/pdf/1804.03999.pdf
                https://idiotdeveloper.com/attention-unet-in-pytorch/
       
       AG is characterised by a set of parameters Θatt containing: linear transformations Wx ∈ Fl×Fint,
       Wg ∈ Fg×Fint, ψ ∈ Fint×1 and bias terms. The linear transformations are computed using
       channel-wise 1x1x1 convolutions for the input tensors. In other contexts [33], this is referred to as
       vector concatenation-based attention, where the concatenated features x and g are linearly mapped
       to a Fint dimensional intermediate space."""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
 
        self.Wg = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, padding=0),
            nn.InstanceNorm2d(F_int),
        )
        self.Wx = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, padding=0),
            nn.InstanceNorm2d(F_int),
        )
        self.relu = nn.ReLU(inplace=True)
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,padding=0),
            nn.InstanceNorm2d(1),
        )
        
        self.Sigmoid = nn.Sigmoid()


    def forward(self, g, x):
        #g is the output from convolution
        #x is the skip connection with corresponding dimensions in the downsample
        Wg = self.Wg(g)
        Wx = self.Wx(x)
        alpha = self.relu(Wg + Wx)
        alpha = self.psi(alpha)
        alpha = self.Sigmoid(alpha)
        return x * alpha

In [None]:
class AttentionUnet(nn.Module):
    """
    PyTorch implementation of a U-Net model with Attention.
    """

    def __init__(
        self,
        in_chans: int,
        out_chans: int,
        chans: int = 32,
        num_pool_layers: int = 4,
        drop_prob: float = 0.0,
    ):
        """
        Args:
            in_chans: Number of channels in the input to the U-Net model.
            out_chans: Number of channels in the output to the U-Net model.
            chans: Number of output channels of the first convolution layer.
            num_pool_layers: Number of down-sampling and up-sampling layers.
            drop_prob: Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob

        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])

        ch = chans
        
        for _ in range(num_pool_layers - 1):
            self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob))
            ch *= 2
        
        self.conv = ConvBlock(ch, ch * 2, drop_prob)

        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        
        #setup list of attention gates
        self.attention_gates = nn.ModuleList()
        
        for _ in range(num_pool_layers - 1):
            self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
            self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
            
            #append attention_gates into list
            self.attention_gates.append(attention_gate(ch, ch, ch//2))
            ch //= 2

        self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
        self.up_conv.append(
            nn.Sequential(
                ConvBlock(ch * 2, ch, drop_prob),
                nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
            )
        )
        
        #append one more attention_gate into list
        self.attention_gates.append(attention_gate(ch, ch, ch//2))
        
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Args:
            image: Input 4D tensor of shape `(N, in_chans, H, W)`.

        Returns:
            Output tensor of shape `(N, out_chans, H, W)`.
        """
        stack = []
        output = image

        # apply down-sampling layers
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)
            output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)

        output = self.conv(output)

        # apply up-sampling layers
        for transpose_conv, conv, attention in zip(self.up_transpose_conv, self.up_conv, self.attention_gates):
            downsample_layer = stack.pop()
            output = transpose_conv(output)
            
            #calculate attention using the output from previous layer and skip connection layer
            downsample_layer = attention(output, downsample_layer)
            
            # reflect pad on the right/botton if needed to handle odd input dimensions
            padding = [0, 0, 0, 0]
            if output.shape[-1] != downsample_layer.shape[-1]:
                padding[1] = 1  # padding right
            if output.shape[-2] != downsample_layer.shape[-2]:
                padding[3] = 1  # padding bottom
            if torch.sum(torch.tensor(padding)) != 0:
                output = F.pad(output, padding, "reflect")
            
            output = torch.cat([output, downsample_layer], dim=1)
            output = conv(output)

        return output

## Training

### K-Space Mask for transforming the input data

In [None]:
mask_types = [
    "random",
    "equispaced",
    "equispaced_fraction",
    "magic",
    "magic_fraction"
]
mask_type = mask_types[0]

In [None]:
# Number of center lines to use in mask
center_fractions = [0.09]

In [None]:
# acceleration rates to use for masks
accelerations = [4]

In [None]:
mask = create_mask_for_mask_type(
    mask_type, center_fractions, accelerations
)
type(mask)

## Datasets

In [None]:
# Data specific Parameters
data_path = Path('../data/')
test_path = Path('../data/singlecoil_test')
challenge = "singlecoil"
test_split = "test"

In [None]:
# Fraction of slices in the dataset to use (train split only). 
# If not given all will be used. Cannot set together with volume_sample_rate.
sample_rate = None
val_sample_rate = None
test_sample_rate = None
volume_sample_rate = None
val_volume_sample_rate = None
test_volume_sample_rate = None
use_dataset_cache_file = True
combine_train_val = False

# data loader arguments
batch_size = 1
num_workers = 0

### use random masks for train transform, fixed masks for val transform

In [None]:
train_transform = UnetDataTransform(challenge, mask_func=mask, use_seed=False)
train_transform

In [None]:
val_transform = UnetDataTransform(challenge, mask_func=mask)

In [None]:
test_transform = UnetDataTransform(challenge)

In [None]:
data_module = FastMriDataModule(
        data_path=data_path,
        challenge=challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=test_split,
        test_path=test_path,
        sample_rate=sample_rate,
        batch_size=batch_size,
        num_workers=num_workers,
        distributed_sampler=None,
)
data_module.challenge

In [None]:
# Verify access to datasets is ready...
data_module.prepare_data()

## UNet Model

In [None]:
##############################
# UNet Model Hyperparameters #
##############################
in_chans=1          # number of input channels to U-Net
out_chans=1         # number of output chanenls to U-Net
chans=32            # number of top-level U-Net channels
num_pool_layers=4   # number of U-Net pooling layers
drop_prob=0.0       # dropout probability
lr=0.001            # RMSProp learning rate
lr_step_size=40     # epoch at which to decrease learning rate
lr_gamma=0.1        # extent to which to decrease learning rate
weight_decay=0.0    # weight decay regularization strength

In [None]:
unet_model = UnetModule(
        in_chans=in_chans,
        out_chans=out_chans,
        chans=chans,
        num_pool_layers=num_pool_layers,
        drop_prob=drop_prob,
        lr=lr,
        lr_step_size=lr_step_size,
        lr_gamma=lr_gamma,
        weight_decay=weight_decay,
)
att_unet_model = 

## Trainer

In [None]:
trainer_config = dict(
    #replace_sampler_ddp=False,    # this is necessary for volume dispatch during val
    #strategy="ddp",               # what distributed version to use
    #seed=42,                      # random seed
    accelerator = "cpu",
    devices=1,                     # number of gpus to use
    deterministic=True,            # makes things slower, but deterministic
    default_root_dir='../logs',    # directory for logs and checkpoints
    max_epochs=10,                 # max number of epochs
)

In [None]:
trainer = pl.Trainer(**trainer_config)

## Run Training

In [None]:
trainer.fit(model, datamodule=data_module)