In [31]:
import sigpy as sp
import os
import torch
import matplotlib.pyplot as plt
import h5py
import numpy as np
import glob
from tqdm import tqdm
import json
import sys
os.environ["OMP_NUM_THREADS"] = "1"
os.environ['TOOLBOX_PATH'] = '/home/asad/bart'
sys.path.append('/home/asad/bart/python')
from bart import bart
from multiprocessing import Pool
import random
import zipfile

# Preprocess

In [None]:
def create_masks(R, delta_R, acs_lines, LENGTH):
    total_lines = LENGTH
    num_sampled_lines = np.floor(total_lines / R)
    center_line_idx = np.arange((total_lines - acs_lines) // 2,(total_lines + acs_lines) // 2)
    outer_line_idx = np.setdiff1d(np.arange(total_lines), center_line_idx)
    random_line_idx = np.random.choice(outer_line_idx,size=int(num_sampled_lines - acs_lines), replace=False)
    mask = np.zeros((total_lines, total_lines))
    mask[:,center_line_idx] = 1.
    mask[:,random_line_idx] = 1.
    
    random.shuffle(random_line_idx)
    further_mask = mask.copy()
    further_mask[:, random_line_idx[0:delta_R]] = 0.

    mask = sp.resize(mask, [384, 320])
    further_mask = sp.resize(further_mask, [384, 320])
    mask[0:32] = mask[32:64]
    mask[352:384] = mask[32:64]
    further_mask[0:32] = further_mask[32:64]
    further_mask[352:384] = further_mask[32:64]

    return mask, further_mask

def task(i):
    ksp_white = np.array(torch.load(ksp_list[i])['ksp']).transpose(1, 2, 0)
    ref_white = sp.resize(ksp_white, [ksp_white.shape[0], 33, ksp_white.shape[2]])
    
    ksp_white = ksp_white[:,:,None,...]
    ref_white = ref_white[:,:,None,...]
    
    cc_mat = bart(1, 'cc', ref_white)
    if (ksp_white.shape[-1] >= 4):
        ksp_white = bart(1, 'ccapply -p 4', ksp_white, cc_mat)
    s_maps = bart(1, 'ecalib -m 1 -c0', ksp_white)
    
    gt = bart(1, 'pics -S -i 30', ksp_white, s_maps)
    gt = gt / float(torch.load(gt_list[i])['norm_const_acs_99'])
    gt = gt[None]
    maps = s_maps.squeeze().transpose(2, 0, 1)
    if (ksp_white.shape[-1] < 4):
        maps_zeros = np.zeros([4, 384, 320], dtype=np.complex64)
        maps_zeros[0:2] = maps
        maps = maps_zeros
    
    mask_2, mask_delta_3 = create_masks(R=2, delta_R=54, acs_lines=20, LENGTH=320)
    mask_4, mask_delta_5 = create_masks(R=4, delta_R=16, acs_lines=20, LENGTH=320)
    mask_6, mask_delta_7 = create_masks(R=6, delta_R=8, acs_lines=20, LENGTH=320)
    mask_8, mask_delta_9 = create_masks(R=8, delta_R=5, acs_lines=20, LENGTH=320)
    
    # print("\n\nMask R=2: " + str((384*320)/np.sum(mask_2)))
    # print("\nDelta Mask R=3: " + str((384*320)/np.sum(mask_delta_3)))
    # print("\n\nMask R=4: " + str((384*320)/np.sum(mask_4)))
    # print("\nDelta Mask R=5: " + str((384*320)/np.sum(mask_delta_5)))
    # print("\n\nMask R=6: " + str((384*320)/np.sum(mask_6)))
    # print("\nDelta Mask R=7: " + str((384*320)/np.sum(mask_delta_7)))
    # print("\n\nMask R=8: " + str((384*320)/np.sum(mask_8)))
    # print("\nDelta Mask R=9: " + str((384*320)/np.sum(mask_delta_9)))
    
    normalised_slices = np.stack((gt.real, gt.imag), axis=1)
    
    cur_path = save_root
    cur_path = os.path.join(cur_path, str(i))
    if not os.path.exists(cur_path):
        os.makedirs(cur_path)
    
    slice_path_gt = os.path.join(cur_path, "gt.npy")
    slice_path_maps = os.path.join(cur_path, "maps.npy")
    
    slice_path_mask_2 = os.path.join(cur_path, "mask_2.npy")
    slice_path_mask_delta_3 = os.path.join(cur_path, "mask_delta_3.npy")
    slice_path_mask_4 = os.path.join(cur_path, "mask_4.npy")
    slice_path_mask_delta_5 = os.path.join(cur_path, "mask_delta_5.npy")
    slice_path_mask_6 = os.path.join(cur_path, "mask_6.npy")
    slice_path_mask_delta_7 = os.path.join(cur_path, "mask_delta_7.npy")
    slice_path_mask_8 = os.path.join(cur_path, "mask_8.npy")
    slice_path_mask_delta_9 = os.path.join(cur_path, "mask_delta_9.npy")
    
    np.save(slice_path_gt, normalised_slices[0])
    relative_path_gt = slice_path_gt.split(save_root)[-1][0:]
    
    np.save(slice_path_maps, maps)
    relative_path_maps = slice_path_maps.split(save_root)[-1][0:]

    np.save(slice_path_mask_2, mask_2)
    relative_path_mask_2 = slice_path_mask_2.split(save_root)[-1][0:]
    np.save(slice_path_mask_delta_3, mask_delta_3)
    relative_path_mask_delta_3 = slice_path_mask_delta_3.split(save_root)[-1][0:]

    np.save(slice_path_mask_4, mask_4)
    relative_path_mask_4 = slice_path_mask_4.split(save_root)[-1][0:]
    np.save(slice_path_mask_delta_5, mask_delta_5)
    relative_path_mask_delta_5 = slice_path_mask_delta_5.split(save_root)[-1][0:]

    np.save(slice_path_mask_6, mask_6)
    relative_path_mask_6 = slice_path_mask_6.split(save_root)[-1][0:]
    np.save(slice_path_mask_delta_7, mask_delta_7)
    relative_path_mask_delta_7 = slice_path_mask_delta_7.split(save_root)[-1][0:]

    np.save(slice_path_mask_8, mask_8)
    relative_path_mask_8 = slice_path_mask_8.split(save_root)[-1][0:]
    np.save(slice_path_mask_delta_9, mask_delta_9)
    relative_path_mask_delta_9 = slice_path_mask_delta_9.split(save_root)[-1][0:]

In [None]:
LENGTH = 384
device           = sp.cpu_device
n_proc           = 30 # number of cpu cores to use, when possible

train = "brain"
ksp_list = glob.glob("/csiNAS/mridata/fastmri_" + train + "_white/ksp/*.pt")
gt_list = glob.glob("/csiNAS/mridata/fastmri_" + train + "_white/gt/*.pt")
class_dict = {}
save_root = "/home/asad/ambient-diffusion-mri/data/fastMRI/numpy/ksp_" + train + "MRI_" + str(LENGTH) + "/"
count = 0

In [None]:
with Pool(n_proc) as p:
    for i in tqdm(p.imap(task, range(10000))):
        continue

json_output = {"labels": [[k, v] for k, v in class_dict.items()]}

j = json.dumps(json_output, indent=4)
with open(os.path.join(save_root, "dataset.json"), "w") as f:
    print(j, file=f) 

In [None]:
ksp_white = np.array(torch.load(ksp_list[651])['ksp']).transpose(1, 2, 0)
ref_white = sp.resize(ksp_white, [ksp_white.shape[0], 33, ksp_white.shape[2]])

ksp_white = ksp_white[:,:,None,...]
ref_white = ref_white[:,:,None,...]

cc_mat = bart(1, 'cc', ref_white)

if (ksp_white.shape[-1] >= 4):
    ksp_white = bart(1, 'ccapply -p 4', ksp_white, cc_mat)

s_maps = bart(1, 'ecalib -m 1 -c0', ksp_white)

gt = bart(1, 'pics -S -i 30', ksp_white, s_maps)
maps = s_maps.squeeze().transpose(2, 0, 1)
if (ksp_white.shape[-1] < 4):
    maps_zeros = np.zeros([4, 384, 320], dtype=np.complex64)
    maps_zeros[0:2] = maps
    maps = maps_zeros

# BART CGSENSE

In [27]:
file = "/csiNAS/asad/data/brain_fastMRI/val_samples_ambient/sample_0.pt"
data = torch.load(file)

ambient_recon = np.array(torch.load("/home/asad/ambient-diffusion-mri/results/One-Step/brainMRI_384_R=4-4/0/000000.pt")['recon'])
mask = np.array(data['mask_8'])
maps = np.array(data['s_map'])
ksp = np.array(data['ksp'])
img = np.array(data['gt'])

sampled_ksp = mask*ksp
coil_imgs = sp.ifft(sampled_ksp, axes=(1, 2))
img_out = np.sum(np.conj(maps)*coil_imgs, axis=0)

maps_bart = maps.transpose(1, 2, 0)[...,None,:]
coil_ksp_bart = sampled_ksp.transpose(1, 2, 0)[...,None,:]
mask_bart = mask[0]

sampled_ksp_bart = mask_bart[...,None,None]*coil_ksp_bart
recon_bart_l1 = bart(1, 'pics -l1 -r 0.001 -i 100 -S', sampled_ksp_bart, maps_bart)
# recon_bart_l2 = bart(1, 'pics -l2 -r 0.001 -i 100 -S', sampled_ksp_bart, maps_bart)

print(ssim(abs(img), abs(recon_bart_l1), data_range=abs(img).max() - abs(img).min()))
print(psnr(gt=abs(img), est=abs(recon_bart_l1), max_pixel=np.amax(abs(img))))
print(nrmse(abs(img), abs(recon_bart_l1)))

nrmse_adj = round(np.linalg.norm(img_out - img) / np.linalg.norm(img), 2)
nrmse_l1 = round(np.linalg.norm(recon_bart_l1 - img) / np.linalg.norm(img), 2)
nrmse_l2 = round(np.linalg.norm(recon_bart_l2 - img) / np.linalg.norm(img), 2)
nrmse_recon = round(np.linalg.norm(ambient_recon - img) / np.linalg.norm(img), 2)

In [110]:
plt.figure(figsize=(16, 6))

plt.subplot(2, 6, 1)
plt.title('Original')
plt.imshow(np.flipud(np.abs(img)), cmap='gray', vmax=1)
plt.axis('off')

plt.subplot(2, 6, 2)
plt.title('Mask, R=8')
plt.imshow(mask[0], cmap='gray')
plt.axis('off')

plt.subplot(2, 6, 3)
plt.title('Adjoint, NRMSE: ' + str(nrmse_adj))
plt.imshow(np.flipud(np.abs(img_out)), cmap='gray', vmax=1)
plt.axis('off')

plt.subplot(2, 6, 9)
plt.title('Diff')
plt.imshow(np.flipud(np.abs(img_out - img)), cmap='gray', vmax=0.1)
plt.axis('off')

plt.subplot(2, 6, 4)
plt.title('L1 FISTA, NRMSE: ' + str(nrmse_l1))
plt.imshow(np.flipud(np.abs(recon_bart_l1)), cmap='gray', vmax=1)
plt.axis('off')

plt.subplot(2, 6, 10)
plt.title('Diff')
plt.imshow(np.flipud(np.abs(recon_bart_l1 - img)), cmap='gray', vmax=0.1)
plt.axis('off')

plt.subplot(2, 6, 5)
plt.title('L2, NRMSE: ' + str(nrmse_l2))
plt.imshow(np.flipud(np.abs(recon_bart_l2)), cmap='gray', vmax=1)
plt.axis('off')

plt.subplot(2, 6, 11)
plt.title('Diff')
plt.imshow(np.flipud(np.abs(recon_bart_l2 - img)), cmap='gray', vmax=0.1)
plt.axis('off')

plt.subplot(2, 6, 6)
plt.title('Ambient, NRMSE: ' + str(nrmse_recon))
plt.imshow(np.flipud(np.abs(ambient_recon)), cmap='gray', vmax=1)
plt.axis('off')

plt.subplot(2, 6, 12)
plt.title('Diff')
plt.imshow(np.flipud(np.abs(ambient_recon - img)), cmap='gray', vmax=0.1)
plt.axis('off')

In [109]:
ssim_list = []
nrmse_list = []
psnr_list = []

for i in range(100):
    file = "/csiNAS/asad/data/brain_fastMRI/val_samples_ambient/sample_" + str(i) + ".pt"
    data = torch.load(file)
    
    ksp = np.array(data['ksp'])
    mask = np.array(data['mask_8'])
    ksp = ksp * mask
    
    gt_img = np.array(data['gt'])[None,None]
    maps = np.array(data['s_map'])

    ksp = ksp.transpose(1, 2, 0)[:,:,None,...]
    maps = maps.transpose(1, 2, 0)[:,:,None,...]
    cplx_recon = bart(1, 'pics -l1 -r 0.001 -i 100 -S', ksp, maps)[None,None]

    ssim_list.append(ssim(abs(gt_img[0,0]), abs(cplx_recon[0,0]), data_range=abs(gt_img[0,0]).max() - abs(gt_img[0,0]).min()))
    psnr_list.append(psnr(gt=abs(gt_img[0,0]), est=abs(cplx_recon[0]), max_pixel=np.amax(abs(gt_img))))
    nrmse_list.append(nrmse(abs(gt_img), abs(cplx_recon)))

ssim_list = np.array(ssim_list)
nrmse_list = np.array(nrmse_list)
psnr_list = np.array(psnr_list)

print("Average SSIM: " + str(np.mean(ssim_list)))
print("Average NRMSE: " + str(np.mean(nrmse_list)))
print("Average PSNR: " + str(np.mean(psnr_list)))

print("Dev SSIM: " + str(np.std(ssim_list)))
print("Dev NRMSE: " + str(np.std(nrmse_list)))
print("Dev PSNR: " + str(np.std(psnr_list)))

# Ambient Metrics

In [37]:
from skimage.metrics import structural_similarity as ssim

def nrmse(x, y):
    return np.linalg.norm(x-y) / np.linalg.norm(x)

def psnr(gt, est, max_pixel): 
    mse = np.mean((gt - est) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = max_pixel
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) 
    return psnr

In [42]:
ssim_list = []
nrmse_list = []
psnr_list = []

for i in range(100):
    gt_img = np.array(torch.load("/csiNAS/asad/data/brain_fastMRI/val_samples_ambient/sample_" + str(i) + ".pt")['gt'])[None,None]
    if i < 10:
        cplx_recon = ambient_recon = np.array(torch.load("/home/asad/ambient-diffusion-mri/results/brainMRI_384_R=6-9/0/00000" + str(i) + ".pt")['recon'])[None,None]
    else:
        cplx_recon = ambient_recon = np.array(torch.load("/home/asad/ambient-diffusion-mri/results/brainMRI_384_R=6-9/0/0000" + str(i) + ".pt")['recon'])[None,None]
    ssim_list.append(ssim(abs(gt_img[0,0]), abs(cplx_recon[0,0]), data_range=abs(gt_img[0,0]).max() - abs(gt_img[0,0]).min()))
    psnr_list.append(psnr(gt=abs(gt_img[0,0]), est=abs(cplx_recon[0]), max_pixel=np.amax(abs(gt_img))))
    nrmse_list.append(nrmse(abs(gt_img), abs(cplx_recon)))

ssim_list = np.array(ssim_list)
nrmse_list = np.array(nrmse_list)
psnr_list = np.array(psnr_list)

In [44]:
print("Average SSIM: " + str(np.mean(ssim_list)))
print("Average NRMSE: " + str(np.mean(nrmse_list)))
print("Average PSNR: " + str(np.mean(psnr_list)))

Average SSIM: 0.8877061980631862
Average NRMSE: 0.14436772479382706
Average PSNR: 30.94051787476672


In [25]:
print("Dev SSIM: " + str(np.std(ssim_list)))
print("Dev NRMSE: " + str(np.std(nrmse_list)))
print("Dev PSNR: " + str(np.std(psnr_list)))

Dev SSIM: 0.02110860561367658
Dev NRMSE: 0.018735356852049117
Dev PSNR: 1.3976956427806702


# Hyperparam Search

In [None]:
nrmse = []
lambdas=[0.000001, 0.000004, 0.000007, 0.00001, 0.00004, 0.00007, 0.0001, 0.0004, 0.0007, 0.001, 0.004, 0.007, 0.01, 0.04, 0.07, 0.1]

for i in range(len(lambdas)):
    recon_bart_l1 = bart(1, 'pics -l1 -r ' + str(lambdas[i]) + ' -i 100 -d5 -S', sampled_ksp_bart, maps_bart)
    nrmse.append(np.linalg.norm((img[0] - recon_bart_l1)) / np.linalg.norm(img[0]))

In [None]:
plt.figure(figsize=(20, 3))
plt.plot([str(x) for x in lambdas], nrmse)
plt.xlabel("lambda")
plt.ylabel("nrmse")