This notebook contains a PyTorch implementation and code for running pretrained models based on the paper:

U-Net: Convolutional networks for biomedical image segmentation (O. Ronneberger et al., 2015)

In [None]:
from pathlib import Path
import pytorch_lightning as pl
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule

## K-Space Mask for transforming the input data

In [None]:
mask_types = [
    "random",
    "equispaced",
    "equispaced_fraction",
    "magic",
    "magic_fraction"
]
mask_type = mask_types[0]

In [None]:
# Number of center lines to use in mask
center_fractions = [0.09]

In [None]:
# acceleration rates to use for masks
accelerations = [4]

In [None]:
mask = create_mask_for_mask_type(
    mask_type, center_fractions, accelerations
)
type(mask)

### use random masks for train transform, fixed masks for val transform

In [None]:
# Data specific Parameters
data_path = Path('../data/singlecoil_datasets/')
test_path = Path('../data/singlecoil_datasets/singlecoil_test/')
challenge = "singlecoil"
test_split = "test"
# Fraction of slices in the dataset to use (train split only). 
# If not given all will be used. Cannot set together with volume_sample_rate.
sample_rate = None
val_sample_rate = None
test_sample_rate = None
volume_sample_rate = None
val_volume_sample_rate = None
test_volume_sample_rate = None
use_dataset_cache_file = True
combine_train_val = False

# data loader arguments
batch_size = 1
num_workers = 4

In [None]:
train_transform = UnetDataTransform(challenge, mask_func=mask, use_seed=False)
train_transform

In [None]:
val_transform = UnetDataTransform(challenge, mask_func=mask)

In [None]:
test_transform = UnetDataTransform(challenge)

In [None]:
data_module = FastMriDataModule(
        data_path=data_path,
        challenge=challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=test_split,
        test_path=test_path,
        sample_rate=sample_rate,
        batch_size=batch_size,
        num_workers=num_workers,
        distributed_sampler=False,
)
data_module.challenge

In [None]:
# Verify access to datasets is ready...
data_module.prepare_data()

## UNet Model

In [None]:
##############################
# UNet Model Hyperparameters #
##############################
in_chans=1          # number of input channels to U-Net
out_chans=1         # number of output chanenls to U-Net
chans=32            # number of top-level U-Net channels
num_pool_layers=4   # number of U-Net pooling layers
drop_prob=0.0       # dropout probability
lr=0.001            # RMSProp learning rate
lr_step_size=40     # epoch at which to decrease learning rate
lr_gamma=0.1        # extent to which to decrease learning rate
weight_decay=0.0    # weight decay regularization strength

In [None]:
model = UnetModule(
        in_chans=in_chans,
        out_chans=out_chans,
        chans=chans,
        num_pool_layers=num_pool_layers,
        drop_prob=drop_prob,
        lr=lr,
        lr_step_size=lr_step_size,
        lr_gamma=lr_gamma,
        weight_decay=weight_decay,
)

## Trainer

In [None]:
trainer_config = dict(
    #gpus=1,                     # number of gpus to use
    replace_sampler_ddp=False,  # this is necessary for volume dispatch during val
    strategy=None,               # what distributed version to use
    #seed=42,                    # random seed
    deterministic=True,         # makes things slower, but deterministic
    default_root_dir='../logs',  # directory for logs and checkpoints
    max_epochs=50,              # max number of epochs
)

In [None]:
trainer = pl.Trainer(**trainer_config)

## Run Training

In [None]:
trainer.fit(model, datamodule=data_module)