In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/fastmri-reproducible-benchmark


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os.path as op
from keras.backend.tensorflow_backend import set_session
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau
from keras.models import load_model
from keras.utils.vis_utils import model_to_dot
from keras_tqdm import TQDMNotebookCallback
from keras import backend as K
import tensorflow as tf
from tqdm import tqdm_notebook

from data import MaskShifted2DSequence, MaskShiftedSingleImage2DSequence
from evaluate import psnr, ssim
from pdnet import pdnet, invnet

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
tf.logging.set_verbosity(tf.logging.INFO)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
                                    # (nothing gets printed in Jupyter, only if you run it standalone)
sess = tf.Session(config=config)
set_session(sess)  # set this TensorFlow session as the default session for Keras

In [4]:
# paths
train_path = '/media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/'
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
test_path = '/media/Zaccharie/UHRes/singlecoil_test/'

In [5]:
n_samples_train = 34742
n_samples_val = 7135

n_volumes_train = 973
n_volumes_val = 199

In [6]:
# generators
AF = 4
train_gen = MaskShiftedSingleImage2DSequence(train_path, af=AF)
val_gen = MaskShiftedSingleImage2DSequence(val_path, af=AF)

In [21]:
run_params = {
    'n_primal': 5,
    'n_dual': 5,
    'n_iter': 10,
    'n_filters': 32,
#     'n_primal': 2,
#     'n_dual': 2,
#     'n_iter': 2,
#     'n_filters': 8,
}

n_epochs = 2
run_id = f'pdnet_af{AF}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

In [28]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=50)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
)
lr_on_plat_cback = ReduceLROnPlateau(monitor='val_loss', min_lr=5*1e-5, mode='auto', patience=3)
# fix from https://github.com/bstriner/keras-tqdm/issues/31#issuecomment-516325065
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
# tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
# tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
# tqdm_cb.on_test_begin = lambda x,y:None
# tqdm_cb.on_test_end = lambda x,y:None

In [31]:
model = pdnet(lr=1e-3, **run_params)
simple_model = invnet(lr=1e-2, **run_params)
print(model.summary(line_length=150))

______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
kspace_input (InputLayer)                        (None, 640, None, 1)             0                                                                   
______________________________________________________________________________________________________________________________________________________
buffer_primal (Lambda)                           (None, 640, None, 5)             0                 kspace_input[0][0]                                
______________________________________________________________________________________________________________________________________________________
mask_input (InputLayer)                          (None, 640, None)                0           

In [None]:
model.fit_generator(
    train_gen, 
    steps_per_epoch=n_samples_train, 
    epochs=n_epochs,
    validation_data=val_gen,
    validation_steps=n_samples_val,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, lr_on_plat_cback],
    max_queue_size=100,
    use_multiprocessing=True,
    workers=35,
)

HBox(children=(IntProgress(value=0, description='Training', max=2, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=34742, style=ProgressStyle(description_width='i…

In [19]:
# # simple overfit trials

# data = train_gen[0]
# val_data = val_gen[0]
# simple_model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0][0].shape[0], 
# #     batch_size=1, 
#     epochs=1, 
#     verbose=2, 
#     shuffle=False,
# )

In [20]:
# # overfitting trials

# data = train_gen[0]
# val_data = val_gen[0]
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0][0].shape[0], 
# #     batch_size=1, 
#     epochs=200, 
#     verbose=2, 
#     shuffle=False,
# )

In [13]:
# def crop_center_mc(img, cropx, cropy=None):
#     # taken from https://stackoverflow.com/questions/39382412/crop-center-portion-of-a-numpy-image/39382475
#     if cropy is None:
#         cropy = cropx
#     y, x = img.shape[1:]
#     startx = x//2 - (cropx//2)
#     starty = y//2 - (cropy//2)
#     return img[:, starty:starty+cropy, startx:startx+cropx]

# def from_pd_out_to_im_crop(imgs, crop):
#     res = np.fft.fftshift(imgs, axes=[1, 2])
#     res = crop_center_mc(res, 320)
#     return res

In [14]:
# im_recos = model.predict_on_batch(data[0])
# im_simple = simple_model.predict_on_batch(data[0])

# import matplotlib.pyplot as plt
# import numpy as np
# focus_idx = 10
# fig, axs = plt.subplots(1, 3, sharex=True, sharey=True)
# axs[0].imshow(from_pd_out_to_im_crop(np.squeeze(data[1]), 320)[focus_idx])
# axs[1].imshow(from_pd_out_to_im_crop(np.squeeze(im_recos), 320)[focus_idx])
# axs[2].imshow(from_pd_out_to_im_crop(np.squeeze(im_simple), 320)[focus_idx])

In [15]:
# gt = np.squeeze(data[1])
# pred = np.squeeze(im_recos)
# pred_simple = np.squeeze(im_simple)

# psnr(from_pd_out_to_im_crop(gt, 320), from_pd_out_to_im_crop(pred, 320)), psnr(from_pd_out_to_im_crop(gt, 320), from_pd_out_to_im_crop(pred_simple, 320))

In [16]:
# ssim(from_pd_out_to_im_crop(gt, 320), from_pd_out_to_im_crop(pred, 320)), ssim(from_pd_out_to_im_crop(gt, 320), from_pd_out_to_im_crop(pred_simple, 320))

In [17]:
# psnr(gt, pred), psnr(gt, pred_simple)

In [18]:
# plt.figure()
# _ = plt.hist(
#     [
#         np.abs(np.squeeze(data[1][11])).flatten(),
#         np.abs(np.squeeze(im_recos[11])).flatten()
#     ], 
#     bins=100,
# )