Here, I will try applying SPARK to residual RAKI for a variety of acceleration rates and ACS sizes, to see which combinations of ACS size + acceleration would be best to show

In [1]:
#rraki imports
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import scipy.io as sio
import numpy as np
import numpy.matlib
import time
import os

from utils import signalprocessing as sig

#additional spark imports which may be needed
import importlib 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy.linalg as la
import cupy as cp
import matplotlib.pyplot as plt
from bart import bart
from utils import cfl
from utils import signalprocessing as sig
from utils import models
from utils import iterative


Instructions for updating:
non-resource variables are not supported in the long term


### defining rraki helper functions

In [2]:
def weight_variable(shape,vari_name):                   
    initial = tf.truncated_normal(shape, stddev=0.1,dtype=tf.float32)
    return tf.Variable(initial,name = vari_name)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape,dtype=tf.float32)
    return tf.Variable(initial)

def conv2d_same(x,W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def conv2d_dilate_same(x, W,dilate_rate):
    return tf.nn.convolution(x, W,padding='SAME',dilation_rate = [1,dilate_rate])

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')

def conv2d_dilate(x, W,dilate_rate):
    return tf.nn.convolution(x, W,padding='VALID',dilation_rate = [1,dilate_rate])

def learning_residual_raki(ACS_input,target_input,accrate_input,sess,\
    ACS_dim_X,ACS_dim_Y,ACS_dim_Z,target_dim_X,target_dim_Y,target_dim_Z,\
    target,target_x_start,target_x_end,target_y_start,target_y_end,\
    ACS):
    input_ACS    = tf.placeholder(tf.float32, [1, ACS_dim_X,ACS_dim_Y,ACS_dim_Z])                                  
    input_Target = tf.placeholder(tf.float32, [1, target_dim_X,target_dim_Y,target_dim_Z]) 
    
    Input = tf.reshape(input_ACS, [1, ACS_dim_X, ACS_dim_Y, ACS_dim_Z])         

    [target_dim0,target_dim1,target_dim2,target_dim3] = np.shape(target)
    
    W_conv1 = weight_variable([kernel_x_1, kernel_y_1, ACS_dim_Z, layer1_channels],'W1') 
    h_conv1 = tf.nn.relu(conv2d_dilate(Input, W_conv1,accrate_input)) 

    W_conv2 = weight_variable([kernel_x_2, kernel_y_2, layer1_channels, layer2_channels],'W2')
    h_conv2 = tf.nn.relu(conv2d_dilate(h_conv1, W_conv2,accrate_input))

    W_conv3 = weight_variable([kernel_last_x, kernel_last_y, layer2_channels, target_dim3],'W3')
    h_conv3 = conv2d_dilate(h_conv2, W_conv3,accrate_input)
    
    W_conv_linear = weight_variable([kernel_x_linear,kernel_y_linear,ACS_dim_Z,target_dim3],'W_lin')
    h_linear = conv2d_dilate(Input,W_conv_linear,accrate_input)
    x_length = h_conv3.shape[1]
    y_length = h_conv3.shape[2]
    
    if(y_length % 2 == 0):
        h_linear = h_linear[:,x_length//2 - x_length//2:x_length//2 + x_length//2,\
            y_length//2 - y_length//2:y_length//2 + y_length//2,:]
    else:
        h_linear = h_linear[:,x_length//2 - x_length//2:x_length//2 + x_length//2,\
            y_length//2 - y_length//2:y_length//2 + y_length//2+1,:]
        
    error_norm = tf.norm(input_Target - h_linear) + tf.norm(input_Target - h_linear - h_conv3)
    train_step = tf.train.AdamOptimizer(LearningRate).minimize(error_norm)
    
     
    if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
        init = tf.initialize_all_variables()
    else:
        init = tf.global_variables_initializer()
    sess.run(init)

    error_prev = 1 
    for i in range(MaxIteration):
        
        sess.run(train_step, feed_dict={input_ACS: ACS, input_Target: target})
        if i % 100 == 0:                                                                      
            error_now=sess.run(error_norm,feed_dict={input_ACS: ACS, input_Target: target})    
            print('The',i,'th iteration gives an error',error_now)  
            
    error = sess.run(error_norm,feed_dict={input_ACS: ACS, input_Target: target})
    return [sess.run(W_conv1),sess.run(W_conv2),sess.run(W_conv3),sess.run(W_conv_linear),error]  
    
def cnn_linear(input_kspace,w_linear,acc_rate,sess):
    return sess.run(conv2d_dilate(input_kspace,w_linear,acc_rate))

def cnn_3layer(input_kspace,w1,w2,w3,acc_rate,sess):                
    h_conv1 = tf.nn.relu(conv2d_dilate(input_kspace, w1,acc_rate)) 
    h_conv2 = tf.nn.relu(conv2d_dilate(h_conv1, w2,acc_rate))
    h_conv3 = conv2d_dilate(h_conv2, w3,acc_rate) 
    return sess.run(h_conv3)           

### loading k-space and defining fft operators 

In [3]:
fft2c_raki  = lambda x: sig.fft(sig.fft(x,0),1)
ifft2c_raki = lambda x: sig.ifft(sig.ifft(x,0),1)

image_coils_truth  = sio.loadmat('data/img_grappa_32chan.mat')['IMG']

kspace_truth_raki  = fft2c_raki(image_coils_truth)
[M,N,C] = kspace_truth_raki.shape

### setting non-changing residual-raki parameters

In [4]:
GPU_FRAC = 1/4

Rx     = 1
acsx   = M

#### Linear Network Parameters ####
kernel_x_linear = 5
kernel_y_linear = 2

#### RAKI Network Parameters ####
kernel_x_1 = 5
kernel_y_1 = 2

kernel_x_2 = 1
kernel_y_2 = 1

kernel_last_x = 3
kernel_last_y = 2

layer1_channels = 32 
layer2_channels = 8

MaxIteration = 1000
LearningRate = 3e-3  

### setting ablation parameters

In [5]:
all_Ry    = [5,6]
all_acsy  = [20,24,30,36,40]
             
all_parameters = []

for ry in all_Ry:
    for acsy in all_acsy:
        all_parameters.append({'Ry':ry, 'acsy':acsy})

### designing residual raki-reconstruction function

In [6]:
def residual_raki(MaxIteration,LearningRate,Rx,Ry,acsx,acsy,GPU_FRAC,\
    kspace_truth_raki=kspace_truth_raki,kernel_x_linear=kernel_x_linear,\
    kernel_y_linear=kernel_y_linear,kernel_x_1=kernel_x_1,kernel_y_1=kernel_y_1,kernel_x_2=kernel_x_2,\
    kernel_y_2=kernel_y_2,kernel_last_x=kernel_last_x,kernel_last_y=kernel_last_y,
    layer1_channels=layer1_channels,layer2_channels=layer2_channels):
    
    acsregionX = np.arange(M//2 - acsx // 2,M//2 + acsx//2) 
    acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

    kspace_raki_undersampled_withacs = np.zeros((M,N,C),dtype = complex)
    kspace_raki_undersampled_withacs[::Rx,::Ry,:] = kspace_truth_raki[::Rx,::Ry,:]
    kspace_raki_undersampled_withacs[acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,:]\
        = kspace_truth_raki[acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,:]

    kspace = np.copy(kspace_raki_undersampled_withacs)
    no_ACS_flag = 0;
    normalize = 0.015/np.max(abs(kspace[:]))
    kspace = np.multiply(kspace,normalize) 
    cur_kspace_truth_raki = np.multiply(kspace_truth_raki,normalize)
    
    [m1,n1,no_ch] = np.shape(kspace)
    no_inds = 1

    kspace_all = np.copy(kspace);
    kx = np.transpose(np.int32([(range(1,m1+1))]))                  
    ky = np.int32([(range(1,n1+1))])

    kspace = np.copy(kspace_all)
    mask = np.squeeze(np.matlib.sum(np.matlib.sum(np.abs(kspace),0),1))>0; 
    picks = np.where(mask == 1);                                  
    kspace = kspace[:,np.int32(picks[0][0]):n1+1,:]
    kspace_all = kspace_all[:,np.int32(picks[0][0]):n1+1,:]  

    kspace_NEVER_TOUCH = np.copy(kspace_all)

    mask = np.squeeze(np.matlib.sum(np.matlib.sum(np.abs(kspace),0),1))>0;  
    picks = np.where(mask == 1);                                  
    d_picks = np.diff(picks,1)  
    indic = np.where(d_picks == 1);

    mask_x = np.squeeze(np.matlib.sum(np.matlib.sum(np.abs(kspace),2),1))>0;
    picks_x = np.where(mask_x == 1);
    x_start = picks_x[0][0]
    x_end = picks_x[0][-1]


    no_ACS_flag=0;
    print('ACS signal found in the input data')
    indic = indic[1][:]
    center_start = picks[0][indic[0]];
    center_end = picks[0][indic[-1]+1];
    ACS = kspace[x_start:x_end+1,center_start:center_end+1,:]
    [ACS_dim_X, ACS_dim_Y, ACS_dim_Z] = np.shape(ACS)
    ACS_re = np.zeros([ACS_dim_X,ACS_dim_Y,ACS_dim_Z*2])
    ACS_re[:,:,0:no_ch] = np.real(ACS)
    ACS_re[:,:,no_ch:no_ch*2] = np.imag(ACS)

    acc_rate = d_picks[0][0]
    no_channels = ACS_dim_Z*2

    time_ALL_start = time.time()

    [ACS_dim_X, ACS_dim_Y, ACS_dim_Z] = np.shape(ACS_re)
    ACS = np.reshape(ACS_re, [1,ACS_dim_X, ACS_dim_Y, ACS_dim_Z]) 
    ACS = np.float32(ACS)  
    
    w_linear_all = \
        np.zeros([kernel_x_linear, kernel_y_linear, no_channels, acc_rate - 1, no_channels],dtype=np.float32)
        
    w1_all = np.zeros([kernel_x_1, kernel_y_1, no_channels, layer1_channels, no_channels],dtype=np.float32)
    w2_all = np.zeros([kernel_x_2, kernel_y_2, layer1_channels,layer2_channels,no_channels],dtype=np.float32)
    w3_all = np.zeros([kernel_last_x, kernel_last_y, layer2_channels,acc_rate - 1, no_channels],dtype=np.float32)    

    target_x_start = np.int32(np.ceil(kernel_x_1/2) + np.floor(kernel_x_2/2) + np.floor(kernel_last_x/2) -1); 
    target_x_end = np.int32(ACS_dim_X - target_x_start -1); 

    target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + (np.ceil(kernel_y_2/2)-1) + (np.ceil(kernel_last_y/2)-1)) * acc_rate;     
    target_y_end = ACS_dim_Y  - np.int32((np.floor(kernel_y_1/2) + np.floor(kernel_y_2/2) + np.floor(kernel_last_y/2))) * acc_rate -1;

    target_dim_X = target_x_end - target_x_start + 1
    target_dim_Y = target_y_end - target_y_start + 1
    target_dim_Z = acc_rate - 1
    
    print('go!')
    time_Learn_start = time.time() 

    errorSum = 0;
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = GPU_FRAC ; 

    for ind_c in range(ACS_dim_Z):

        sess = tf.Session(config=config)
        # set target lines
        target = np.zeros([1,target_dim_X,target_dim_Y,target_dim_Z])
        print('learning channel #',ind_c+1)
        time_channel_start = time.time()

        for ind_acc in range(acc_rate-1):
            target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + (np.ceil(kernel_y_2/2)-1) + (np.ceil(kernel_last_y/2)-1)) * acc_rate + ind_acc + 1 
            target_y_end = ACS_dim_Y  - np.int32((np.floor(kernel_y_1/2) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2))) * acc_rate + ind_acc
            target[0,:,:,ind_acc] = ACS[0,target_x_start:target_x_end + 1, target_y_start:target_y_end +1,ind_c];

        # learning

        [w1,w2,w3,w_linear,error]=learning_residual_raki(ACS,target,acc_rate,sess,\
            ACS_dim_X,ACS_dim_Y,ACS_dim_Z,target_dim_X,target_dim_Y,target_dim_Z,target,\
            target_x_start,target_x_end,target_y_start,target_y_end,\
            ACS) 
        w1_all[:,:,:,:,ind_c] = w1
        w2_all[:,:,:,:,ind_c] = w2
        w3_all[:,:,:,:,ind_c] = w3     
        w_linear_all[:,:,:,:,ind_c] = w_linear
        
        time_channel_end = time.time()
        print('Time Cost:',time_channel_end-time_channel_start,'s')
        print('Norm of Error = ',error)
        errorSum = errorSum + error

        sess.close()
        tf.reset_default_graph()

    time_Learn_end = time.time();

    print('lerning step costs:',(time_Learn_end - time_Learn_start)/60,'min')    
    kspace_recon_all = np.copy(kspace_all)
    kspace_recon_all_nocenter = np.copy(kspace_all)

    kspace = np.copy(kspace_all)

    over_samp = np.setdiff1d(picks,np.int32([range(0, n1,acc_rate)]))
    kspace_und = kspace
    kspace_und[:,over_samp,:] = 0;
    [dim_kspaceUnd_X,dim_kspaceUnd_Y,dim_kspaceUnd_Z] = np.shape(kspace_und)

    kspace_und_re = np.zeros([dim_kspaceUnd_X,dim_kspaceUnd_Y,dim_kspaceUnd_Z*2])
    kspace_und_re[:,:,0:dim_kspaceUnd_Z] = np.real(kspace_und)
    kspace_und_re[:,:,dim_kspaceUnd_Z:dim_kspaceUnd_Z*2] = np.imag(kspace_und)
    kspace_und_re = np.float32(kspace_und_re)
    kspace_und_re = np.reshape(kspace_und_re,[1,dim_kspaceUnd_X,dim_kspaceUnd_Y,dim_kspaceUnd_Z*2])
    kspace_recon = kspace_und_re

    kspace_recon_linear       = np.copy(kspace_recon)
    kspace_recon_residual     = np.zeros(kspace_recon.shape,dtype = complex)
        
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = GPU_FRAC ; 

    for ind_c in range(0,no_channels):
        print('Reconstruting Channel #',ind_c+1)

        sess = tf.Session(config=config) 
        if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
            init = tf.initialize_all_variables()
        else:
            init = tf.global_variables_initializer()
        sess.run(init)

        # grab w and b
        w1 = np.float32(w1_all[:,:,:,:,ind_c])
        w2 = np.float32(w2_all[:,:,:,:,ind_c])     
        w3 = np.float32(w3_all[:,:,:,:,ind_c])

        w_linear = np.float32(w_linear_all[:,:,:,:,ind_c])
        
        residual_recon = cnn_3layer(kspace_und_re,w1,w2,w3,acc_rate,sess) 
        linear_recon   = cnn_linear(kspace_und_re,w_linear,acc_rate,sess)  
        x_length = residual_recon.shape[1]
        y_length = residual_recon.shape[2]
        linear_recon = linear_recon[:,x_length//2 - x_length//2:x_length//2 + x_length//2,\
                y_length//2 - y_length//2:y_length//2 + y_length//2,:]
        #linear_recon = np.copy(residual_recon)
        
        target_x_end_kspace = dim_kspaceUnd_X - target_x_start;            
        
        for ind_acc in range(0,acc_rate-1):
            target_y_start = \
                np.int32((np.ceil(kernel_y_1/2)-1) + np.int32((np.ceil(kernel_y_2/2)-1)) + \
                np.int32(np.ceil(kernel_last_y/2)-1)) * acc_rate + ind_acc + 1; 
            
            target_y_end_kspace = \
                dim_kspaceUnd_Y - np.int32((np.floor(kernel_y_1/2)) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2)) * acc_rate + ind_acc;
               
            kspace_recon[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c] = \
                linear_recon[0,:,::acc_rate,ind_acc] + residual_recon[0,:,::acc_rate,ind_acc];
            kspace_recon_linear[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c] \
                = linear_recon[0,:,::acc_rate,ind_acc];
            kspace_recon_residual[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c] = \
                residual_recon[0,:,::acc_rate,ind_acc]
            
    kspace_recon          = np.squeeze(kspace_recon)
    kspace_recon_linear   = np.squeeze(kspace_recon_linear)
    kspace_recon_residual = np.squeeze(kspace_recon_residual)
    
    kspace_recon_complex = (kspace_recon[:,:,0:np.int32(no_channels/2)] + \
                np.multiply(kspace_recon[:,:,np.int32(no_channels/2):no_channels],1j))

    kspace_recon_complex_linear = (kspace_recon_linear[:,:,0:np.int32(no_channels/2)] + \
                np.multiply(kspace_recon_linear[:,:,np.int32(no_channels/2):no_channels],1j))

    kspace_recon_complex_residual = (kspace_recon_residual[:,:,0:np.int32(no_channels/2)] + \
                np.multiply(kspace_recon_residual[:,:,np.int32(no_channels/2):no_channels],1j))
    
    kspace_recon_all_nocenter[:,:,:] = np.copy(kspace_recon_complex); 
    
    kspace_recon_complex_acs = np.copy(kspace_recon_complex);
    kspace_recon_complex_acs[:,center_start:center_end,:] = kspace_NEVER_TOUCH[:,center_start:center_end,:]

    kspace_recon_complex_linear_acs = np.copy(kspace_recon_complex_linear)
    kspace_recon_complex_linear_acs[:,center_start:center_end,:] = \
        kspace_NEVER_TOUCH[:,center_start:center_end,:]

        
    return [kspace_recon_complex,kspace_recon_complex_acs,\
            kspace_recon_complex_linear,kspace_recon_complex_linear_acs,\
            kspace_recon_complex_residual,cur_kspace_truth_raki]

### perforing rraki reconstructions 

In [None]:
kspace_rraki_recon_all      = np.zeros((M,N,C,len(all_parameters)),dtype = complex)
kspace_rraki_acs_recon_all  = np.zeros((M,N,C,len(all_parameters)),dtype = complex)
kspace_linear_recon_all     = np.zeros((M,N,C,len(all_parameters)),dtype = complex)
kspace_linear_acs_recon_all = np.zeros((M,N,C,len(all_parameters)),dtype = complex)
kspace_residual_est_all     = np.zeros((M,N,C,len(all_parameters)),dtype = complex)
kspace_raki_truth_all       = np.zeros((M,N,C,len(all_parameters)),dtype = complex)

for index,parameter in enumerate(all_parameters):
    Ry   = parameter['Ry']
    acsy = parameter['acsy']
    
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print('Training %d/%d || Ry %d || Acsy %d' % (index+1,len(all_parameters),Ry,acsy))
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    
    [kspace_rraki,kspace_rraki_acs,kspace_linear,kspace_linear_acs,kspace_residual,kspace_truth] = \
        residual_raki(MaxIteration,LearningRate,Rx,Ry,acsx,acsy,GPU_FRAC)
    
    kspace_rraki_recon_all[:,:,:,index]      = kspace_rraki
    kspace_rraki_acs_recon_all[:,:,:,index]  = kspace_rraki_acs
    kspace_linear_recon_all[:,:,:,index]     = kspace_linear
    kspace_linear_acs_recon_all[:,:,:,index] = kspace_linear_acs
    kspace_residual_est_all[:,:,:,index]     = kspace_residual
    kspace_raki_truth_all[:,:,:,index]       = kspace_truth

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Training 1/10 || Ry 5 || Acsy 20
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ACS signal found in the input data
go!
learning channel # 1
The 0 th iteration gives an error 0.09506537
The 100 th iteration gives an error 0.0044419398
The 200 th iteration gives an error 0.003182366
The 300 th iteration gives an error 0.0027491986
The 400 th iteration gives an error 0.0024614115
The 500 th iteration gives an error 0.002278151
The 600 th iteration gives an error 0.0021487987
The 700 th iteration gives an error 0.0020432384
The 800 th iteration gives an error 0.002031209
The 900 th iteration gives an error 0.0019480456
Time Cost: 41.30243277549744 s
Norm of Error =  0.0018904072
learning channel # 2
The 0 th iteration gives an error 0.10017571
The 100 th iteration gives an error 0.0043185577
The 200 th iteration gives an error 0.0031654616
The 300 th iteration gives an error 0.0027962942
The 400 th iteration gives an error 0.0025299033
The 500 th iteration gives an e

The 300 th iteration gives an error 0.0027335868
The 400 th iteration gives an error 0.0023968956
The 500 th iteration gives an error 0.002249312
The 600 th iteration gives an error 0.0021704487
The 700 th iteration gives an error 0.0020970602
The 800 th iteration gives an error 0.0019862982
The 900 th iteration gives an error 0.001969071
Time Cost: 3.749126672744751 s
Norm of Error =  0.0019006375
learning channel # 16
The 0 th iteration gives an error 0.09148891
The 100 th iteration gives an error 0.0042047543
The 200 th iteration gives an error 0.0029646151
The 300 th iteration gives an error 0.0025388168
The 400 th iteration gives an error 0.0022457084
The 500 th iteration gives an error 0.0021381073
The 600 th iteration gives an error 0.0019668434
The 700 th iteration gives an error 0.001909819
The 800 th iteration gives an error 0.0018042708
The 900 th iteration gives an error 0.0017215719
Time Cost: 3.586442708969116 s
Norm of Error =  0.0017382996
learning channel # 17
The 0 th

The 900 th iteration gives an error 0.0018251757
Time Cost: 3.543794631958008 s
Norm of Error =  0.0017361259
learning channel # 30
The 0 th iteration gives an error 0.09187326
The 100 th iteration gives an error 0.003571893
The 200 th iteration gives an error 0.0026381682
The 300 th iteration gives an error 0.0022333637
The 400 th iteration gives an error 0.0020251446
The 500 th iteration gives an error 0.0018520684
The 600 th iteration gives an error 0.0017769295
The 700 th iteration gives an error 0.0016298769
The 800 th iteration gives an error 0.0015820565
The 900 th iteration gives an error 0.0015028722
Time Cost: 3.533731698989868 s
Norm of Error =  0.0014648675
learning channel # 31
The 0 th iteration gives an error 0.10765435
The 100 th iteration gives an error 0.0050955005
The 200 th iteration gives an error 0.0037243045
The 300 th iteration gives an error 0.0032175342
The 400 th iteration gives an error 0.0029493095
The 500 th iteration gives an error 0.0027533344
The 600 th

The 300 th iteration gives an error 0.0025440224
The 400 th iteration gives an error 0.0022768746
The 500 th iteration gives an error 0.0021622719
The 600 th iteration gives an error 0.0020269514
The 700 th iteration gives an error 0.0018883445
The 800 th iteration gives an error 0.0018332351
The 900 th iteration gives an error 0.0017343468
Time Cost: 3.5506982803344727 s
Norm of Error =  0.001707549
learning channel # 45
The 0 th iteration gives an error 0.08406578
The 100 th iteration gives an error 0.003683675
The 200 th iteration gives an error 0.002882029
The 300 th iteration gives an error 0.0024672365
The 400 th iteration gives an error 0.002264986
The 500 th iteration gives an error 0.0021205484
The 600 th iteration gives an error 0.0019787848
The 700 th iteration gives an error 0.0019693403
The 800 th iteration gives an error 0.0018673411
The 900 th iteration gives an error 0.0017544356
Time Cost: 3.4994637966156006 s
Norm of Error =  0.001837212
learning channel # 46
The 0 th

The 900 th iteration gives an error 0.0017281404
Time Cost: 3.603980302810669 s
Norm of Error =  0.0017500864
learning channel # 59
The 0 th iteration gives an error 0.10536118
The 100 th iteration gives an error 0.007503024
The 200 th iteration gives an error 0.004752017
The 300 th iteration gives an error 0.0039311093
The 400 th iteration gives an error 0.0034560403
The 500 th iteration gives an error 0.003143086
The 600 th iteration gives an error 0.0029760585
The 700 th iteration gives an error 0.0028857654
The 800 th iteration gives an error 0.0027485043
The 900 th iteration gives an error 0.0026949013
Time Cost: 3.57247257232666 s
Norm of Error =  0.0025838874
learning channel # 60
The 0 th iteration gives an error 0.0828723
The 100 th iteration gives an error 0.003994931
The 200 th iteration gives an error 0.002887839
The 300 th iteration gives an error 0.002430635
The 400 th iteration gives an error 0.0022209855
The 500 th iteration gives an error 0.0020505688
The 600 th iterat

The 0 th iteration gives an error 0.10781421
The 100 th iteration gives an error 0.005294746
The 200 th iteration gives an error 0.003911422
The 300 th iteration gives an error 0.0034094146
The 400 th iteration gives an error 0.0031438959
The 500 th iteration gives an error 0.0029358864
The 600 th iteration gives an error 0.002795894
The 700 th iteration gives an error 0.0026132152
The 800 th iteration gives an error 0.0025739204
The 900 th iteration gives an error 0.0024653277
Time Cost: 3.720940351486206 s
Norm of Error =  0.002443985
learning channel # 7
The 0 th iteration gives an error 0.09211948
The 100 th iteration gives an error 0.004518167
The 200 th iteration gives an error 0.0036569545
The 300 th iteration gives an error 0.0032835216
The 400 th iteration gives an error 0.0029590542
The 500 th iteration gives an error 0.0028105816
The 600 th iteration gives an error 0.0026464309
The 700 th iteration gives an error 0.0026107118
The 800 th iteration gives an error 0.0024808953


### defining spark helper functions

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    #print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            
            print('Training Complete, loss = %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
                
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Training Complete, loss = %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

### defining non-changing parameters for SPARK 

In [None]:
E = 1
normalizationflag = 1
normalizeAll      = 0
iterations        = 200
learningRate      = .0075

### spark reconstruction loop

In [None]:
kspace_spark_all = np.zeros((len(all_parameters),C,M,N),dtype = complex)

for index,parameter in enumerate(all_parameters):
    Ry   = parameter['Ry']
    acsy = parameter['acsy']
    
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print('Training %d/%d || Ry %d || Acsy %d' % (index+1,len(all_parameters),Ry,acsy))
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    
    ### Defining k-space to be corrected in SPARK framework ###
    kspaceGrappa = np.transpose(np.expand_dims(kspace_rraki_recon_all[:,:,:,index],axis = 0),(0,3,1,2))
    kspace       = np.transpose(np.expand_dims(kspace_raki_truth_all[:,:,:,index],axis = 0),(0,3,1,2))
    
    ### Defining acs region in SPARK framework ###
    acsregionX = np.arange(M//2 - acsx // 2,M//2 + acsx//2) 
    acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

    kspaceAcsZerofilled = np.zeros((E,C,M,N),dtype = complex)
    kspaceAcsZerofilled[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] = kspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1]
    
    ### Training Network ###
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    [kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
        reformattingKspaceForSpark(kspaceGrappa,kspaceAcsZerofilled,acsregionX,acsregionY,acsx,acsy,normalizationflag)

    realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
        trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,iterations)

    ### Applying SPARK correction with ACS replacement ###
    kspaceCorrected    = np.zeros((E,C,M,N),dtype = complex)

    for reconContrast in range(0,E):
        for c in range(0,C):
            #Perform reconstruction coil by coil
            model_namer = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'r'
            model_namei = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'i'

            real_model = realSparkGrappaModels[model_namer]
            imag_model = imagSparkGrappaModels[model_namei]

            kspaceToCorrect   = kspaceGrappa[reconContrast,c,:,:]
            kspaceGrappaSplit = kspace_grappa_split[reconContrast,:,:,:]

            currentCorrected = \
                    applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                        chan_scale_factors_real[reconContrast,c], chan_scale_factors_imag[reconContrast,c])

            kspaceCorrected[reconContrast,c,:,:] = currentCorrected  

    kspaceCorrectedReplaced    = np.copy(kspaceCorrected)
    kspaceCorrectedReplaced[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
        kspace[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 
    
    kspace_spark_all[index,:,:,:] = kspaceCorrectedReplaced[0,:,:,:]

### Computing rraki and spark reconstructions

In [None]:
rraki_acs_recons = sig.rsos(ifft2c_raki(kspace_rraki_acs_recon_all),2)
rraki_recons     = sig.rsos(ifft2c_raki(kspace_rraki_recon_all),2)
spark_recons     = np.transpose(sig.rsos(sig.ifft2c(kspace_spark_all),-3),(1,2,0))

truth = sig.rsos(sig.ifft2c(kspace),-3)

### evaluate a particular parameter set 

In [None]:
pp = 1

print('Ry: %d || Acs y: %d' % (all_parameters[pp]['Ry'],all_parameters[pp]['acsy']))
Ry   = parameter['Ry']
acsy = parameter['acsy']

print('RMSE:')
print('   rraki no acs: %.2f' % (sig.rmse(np.squeeze(truth),np.squeeze(rraki_recons[:,:,pp]))*100))
print('   rraki w/ acs: %.2f' % (sig.rmse(np.squeeze(truth),np.squeeze(rraki_acs_recons[:,:,pp]))*100))
print('   rraki spark:  %.2f' % (sig.rmse(np.squeeze(truth),np.squeeze(spark_recons[:,:,pp]))*100))

display = np.concatenate((np.expand_dims(rraki_acs_recons[:,:,pp],axis=0),\
                         np.expand_dims(spark_recons[:,:,pp],axis=0)),axis=0)
sig.mosaic(sig.nor(display),1,2)

###  saving

In [None]:
results = {'truth':          np.squeeze(truth),
           'all_rraki' :     np.squeeze(rraki_recons),
           'all_rraki_acs':  np.squeeze(rraki_acs_recons),
           'all_spark':      np.squeeze(spark_recons),
           'acs_sizes':      np.squeeze(all_acsy),
           'accelerations':  np.squeeze(all_Ry),
           'all_parameters': np.squeeze(all_parameters)}

sio.savemat('results/residual_raki_with_spark_ablation.mat', results, oned_as='row')