This notebook can be used to reproduce any of the experiments that were conducted to obtain the empirical scaling law for compressive sensing in the context of accelerated MRI in _Section 4: Empirical scaling laws for compressive sensing_ from the paper **Scaling Laws For Deep Learning Based Image Reconstruction**

In [None]:
import os
import json
import numpy as np

from fastmri.main_functions_helpers import *

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Specify which experiments to run by indicating training set size and network size. Corresponding hyperparameters are loaded automatically.

For all available combinations of training set and network size see the config files in options/

In [None]:
#######################################################
# Adjust the following parameters
#######################################################
# Start or continue training
training = True
# Evaluate last and best checkpoint on validation and test set
testing = True

# Assign an ID to the experiment
exp_nums = ['001','002'] 
# Path to fastMRI brain directory containing both the training and validation set
path_to_fastMRI_brain_dataset = "brain_path: ../../../media/ssd1/fastMRIdata/brain"
# training set size
train_sizes = [50,100]
# network size defined by the number of channels in the first layer
channels = [64,128]

########################################################

# Sanity checks
if len(train_sizes) != len(exp_nums) or len(channels) != len(exp_nums):
    raise ValueError("Specify experiment ID for each experiment") 

Load hyperparameter configurations for each experiment from options/

In [None]:
hps = []
for train_size,channel in zip(train_sizes, channels):
    options_name = "options/trainsize{}_channels{}.txt".format(train_size,channel)

    # Load hyperparameter options
    with open(options_name) as handle:
        hp = json.load(handle)
    hps.append(hp)

Run training/testing

In [None]:
for ee in range(len(exp_nums)):
    
    hp = hps[ee]
    num_runs = list(np.arange(hp['num_runs'][0]))
    for rr in num_runs:
        exp_name =  'E' + exp_nums[ee] + \
                    '_t' + str(hp['num_examples'][0]) + \
                    '_l' + '4' + \
                    'c' + str(hp['chans'][0]) + \
                    '_bs' + '1' +\
                    '_lr' + '001'
        if rr>0:
            exp_name = exp_name + '_run{}'.format(rr+1)
        if not os.path.isdir('./'+exp_name):
            os.mkdir('./'+exp_name)
        create_fastmri_dirs_yaml(path_to_fastMRI_brain_dataset,exp_name)
        
        ########
        # Training
        ########
        if training:  
            print('\n{} - Training\n'.format(exp_name))
            args = build_args(hp,rr)
            cli_main(args)
            print('\n{} - Training finished\n'.format(exp_name))
        
        ########
        # Testing
        ########
        if testing:
            print('\n{} - Testing\n'.format(exp_name))
            test_modes = ["test_on_val","test_on_test"]

            for test_mode in test_modes:
                for resume_from_which_checkpoint in ["last","best"]:

                    args = build_args(hp,rr,test_mode)
                    args.mode = "test"
                    args.logger = False
                    args.test_path=args.data_path/"multicoil_val"
                    cli_main(args)
                    if test_mode == "test_on_test" or test_mode == "test_on_val":
                        tm = test_mode[8:]
                    else:
                        tm = test_mode
                    metrics_filename = './'+exp_name+'/log_files/metrics_'+exp_name+'_{}_{}.pkl'.format(tm,resume_from_which_checkpoint)
                    if resume_from_which_checkpoint=="best":
                        ckpt = args.resume_from_checkpoint
                        ind1 = str(ckpt).find('epoch=')
                        ind2 = str(ckpt).find('-step')
                        epoch = str(ckpt)[ind1+len('epoch='):ind2]
                        metrics_filename = metrics_filename[:-4]+'_'+epoch+'ep'+metrics_filename[-4:]
                    evaluate_reconstructions(test_mode,metrics_filename)

            print('\n{} - Testing finished\n'.format(exp_name)) 