This jupyter notebook replicates the results from Figure 4 in the ISMRM abstract "Extending Scan-specific Artifact reduction in K-space (SPARK) to advanced encoding and reconstruction schemes" (Yamin Arefeen et. al.).  In particular, this notebook performs simulation experiments comparing generalized SENSE based reconstructions of wave and non-wave encoded data with SPARK based corrections.  

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy as np
import numpy.linalg as la
import scipy as sp
import cupy as cp
import matplotlib.pyplot as plt
from bart import bart
from utils import cfl
from utils import signalprocessing as sig
from utils import models
from utils import iterative

# Loading the dataset and selecting the slices that I want to alias and reconstruct 

In [None]:
#-Load the fully sampled, multi-channel kspace dataset and normalize 
allkspace = np.transpose(cfl.readcfl('data/kspaceFullFig2'),(3,2,0,1))
allkspace = allkspace / np.max(abs(allkspace.ravel())) * 100  

[C,P,M,N] = allkspace.shape

In [None]:
#-Parameters for the sms acquisition
beginningSliceIndex = 20                     #slice at which we want to "start" sms acquisition
numslices           = 5                      #multi-band factor, number of slices we simultaneously excite
slicedistance       = P // numslices         #distance between slices in z index 
fovshift            = 3                      #FOV shift factor in the spirit of 'blippied' caipi

#-acquisition parameters 
Ry                  = 2                      #in-plane acceleration factor
snr                 = 150                    #noise-level to add to the experiments

#-Iterative method parameters
senseIterations = 20
cudaflag        = 1

#-ACS sizes for SPARK parameters.  Acs size in readout direction(x) is M * numslices since will be using 
#-readout concatenation to modle sms
acsx = M * numslices                    
acsy = 24

#-Learning parameters for the SPARK models we will be training 
learningRate      = .0075
sparkIterations   = 200
normalizationflag = 1

slices = np.linspace(beginningSliceIndex,beginningSliceIndex + slicedistance * (numslices-1),numslices).astype(int)

#-Take fft along 'slice dimension and select kspace of just the slices we care about
allkspaceslices = sig.fft(allkspace,-3)
kspaceslices = allkspaceslices[:,slices,:,:]

# Viewing the slices that will be used to model sms

In [None]:
sig.mosaic(sig.nor(sig.rsos(sig.ifft2c(kspaceslices),-4)),1,numslices)

# Callibrate the coil sensitivity maps

In [None]:
coilslices = np.zeros((C,numslices,M,N),dtype = complex)

for ss in range(0,numslices):
    print('Callibrating coils for slice %d/%d' % (ss + 1,numslices))
    curksp  = np.expand_dims(np.transpose(kspaceslices[:,ss,:,:],axes = (1,2,0)),2)
    coilslices[:,ss,:,:] = np.squeeze(np.transpose(bart(1,'ecalib -m 1',curksp),(3,2,0,1)))
print('done.')

# Shifting the slices/sensitivities and concat along the readout direction 

In [None]:
#-Define the shift amounts
if(fovshift > 0):
    shifts = np.round(np.linspace(-(numslices / 2 - 1),numslices/2,numslices) * N / fovshift ).astype(int)
else:
    shifts = np.zeros((numslices)).astype(int)

#-Define the function which performs the shifting
def performshift(x,shift,direction = 1):
    out = np.zeros(x.shape,dtype=complex)
    
    for ss in range(0,out.shape[-3]):
        out[:,ss,:,:] = np.roll(x[:,ss,:,:],direction*shift[ss])
    return out
      
#-Compute shifted slices in image space (as well as the shifted coils)
slicesShiftedCoils = performshift(sig.ifft2c(kspaceslices),shifts)
shiftedCoils       = performshift(coilslices,shifts)

#-Concatonate the slices along the readout dimension, so that we can perform readout undersampling to model sms
slicesCoils = slicesShiftedCoils[:,0,:,:]
coils       = shiftedCoils[:,0,:,:]

for ss in range(1,numslices):
    slicesCoils = np.concatenate((slicesCoils,slicesShiftedCoils[:,ss,:,:]),axis = -2)
    coils       = np.concatenate((coils,shiftedCoils[:,ss,:,:]),axis = -2)

# Generating undersampling mask and kspace which undersampls readout dimension (to models sms) and undersamples phase-encode for in-plance acceleration

In [None]:
#-Generate the undersampling mask
mask = np.zeros(slicesCoils.shape)
mask[:,::numslices,::Ry] = 1

#-Generate the noise
noise = np.random.normal(0,1/snr,mask.shape) + 1j*np.random.normal(0,1/snr,mask.shape)
kspace = mask * (sig.fft2c(slicesCoils) + noise)

# Generating sense operators

In [None]:
def senseForward(x,maps,mask):
    return mask * (sig.fft2c(maps*x))
def senseAdjoint(x,maps,mask):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(maps)*(sig.ifft2c(x)),-3)

# Performing the SENSE reconstruction 

In [None]:
#-Compute the adjoint of the kspace data
kadj = senseAdjoint(kspace,coils,mask)

if(cudaflag):
    coils   = cp.asarray(coils)
    mask    = cp.asarray(mask)
    kadj    = cp.asarray(kadj)
    
#-Defining the normal operator and performing the reconstruction
normal = lambda x: senseAdjoint(senseForward(x.reshape(M*numslices,N),coils,mask),\
                                          coils,mask).ravel()

print('SENSE reconstruction ...',end='')
smsSense = cp.asnumpy(iterative.conjgrad(normal,kadj.ravel(),kadj.ravel(),\
                                         ite = 20)).reshape(M*numslices,N)
print(' Done.')

coils = cp.asnumpy(coils)
mask  = cp.asnumpy(mask)
kadj  = cp.asnumpy(kadj)

# Compare SENSE and ground truth 

In [None]:
truth = np.squeeze(performshift(np.expand_dims(np.reshape(np.sum(np.conj(coils) * slicesCoils,-3),\
                                                          (numslices,M,N)),axis=0),shifts,-1))

sense = np.squeeze(performshift(np.expand_dims(np.reshape(smsSense,(numslices,M,N)),axis=0),shifts,-1))

display = sig.nor(np.concatenate((truth,sense),axis = 0))
sig.mosaic(display,2,numslices)

# Computing sense kspace and acs kspace for SPARK 

In [None]:
kspaceSense = senseForward(smsSense,coils,1)
kspaceAcs   = sig.fft2c(slicesCoils)

acsregionX = np.arange((M*numslices)//2 - acsx // 2,(M*numslices)//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

# Defining SPARK helper functions 

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    #Calling things kspaceGrappa are remnants of old code.  Here, kspace grappa can be the reconstructed space
    #of whatever reconstruction technique we chose to use
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)


    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Initial Loss: %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            print('Final Loss:   %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                        print('Initial Loss: %.10f' % (running_loss))
                        
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Final Loss:   %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

# Perform SPARK training 

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

kspaceSense = np.expand_dims(np.squeeze(kspaceSense),axis = 0)
kspaceAcs   = np.expand_dims(np.squeeze(kspaceAcs),axis = 0)
#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSense,kspaceAcs,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

# Perform correction and ACS replacement 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((C,M*numslices,N),dtype = complex)

for c in range(0,C):
    #Perform reconstruction coil by coil
    model_namer = 'model' + 'E' + str(0) + 'C' + str(c) + 'r'
    model_namei = 'model' + 'E' + str(0) + 'C' + str(c) + 'i'

    real_model = realSparkGrappaModels[model_namer]
    imag_model = imagSparkGrappaModels[model_namei]

    kspaceToCorrect   = kspaceSense[0,c,:,:]
    kspaceGrappaSplit = kspace_grappa_split[0,:,:,:]

    currentCorrected = \
            applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                chan_scale_factors_real[0,c], chan_scale_factors_imag[0,c])

    kspaceCorrected[c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplaced    = np.copy(kspaceCorrected)


kspaceCorrectedReplaced[:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspaceAcs[0,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 
    

# Computing sense image from corrected kspace 

In [None]:
#Perform IFFT and coil combine
spark = np.sum(np.conj(coils) * sig.ifft2c(kspaceCorrectedReplaced),-3)
spark = np.squeeze(performshift(np.expand_dims(np.reshape(spark,(numslices,M,N)),axis=0),shifts,-1))

# Compare truth, sense, and spark 

In [None]:
display = sig.nor(np.concatenate((truth,sense,spark),axis = 0))
sig.mosaic(display,3,numslices)

# Defining wave encoding parameters

In [None]:
slicethickness = 1e-3 #[m]

os       = 3
cycles   = 2
Gymax    = 16 * 1e-3
Gzmax    = 16 * 1e-3     #[T/m]
Tadc     = 1432.7*1e-6   #[s]
gamma    = 42.577*1e6    #[Hz/T]
FOVy     = N * 1e-3      #[m]
FOVz     = (numslices+1) * slicedistance * slicethickness

Nro      = M * os
yind     = np.linspace(-FOVy/2,FOVy/2,N)           #[m]
zind     = np.linspace(-FOVz/2,FOVz/2,numslices)   #[m]
adc      = np.linspace(0,Tadc,Nro)                 #[s]

#-Generating the point spread functions
gradienty = np.sin(cycles * np.pi * adc / Tadc);
gradientz = np.cos(cycles * np.pi * adc / Tadc);

psf = np.zeros((1,numslices,Nro,N),dtype = complex)

for yy in range(0,N):
    for ss in range(0,numslices):
        psf[0,ss,:,yy] = np.exp(1j * gamma * Tadc * (Gymax * gradienty * yind[yy] + Gzmax * gradientz * zind[ss]))


#-Shift the psf and restructure to match readout oversampled dimensions
psfShifted = performshift(psf,shifts)

psfRo       = psfShifted[:,0,:,:]

for ss in range(1,numslices):
    psfRo = np.concatenate((psfRo,psfShifted[:,ss,:,:]),axis = -2)
    

# Defining wave encoding operators 

In [None]:
padreadout = (Nro - M)//2

def sforwave(x,coils):
    return coils * x

def sadjwave(x,coils):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(coils)*x,-3)

Fx    = lambda x: sig.fft(x,ax = -2)    #Perform fft in the readout direction
Fy    = lambda x: sig.fft(x,ax = -1)    #Perform fft in the phaseencode direction
Fxadj = lambda x: sig.ifft(x,ax = -2)   #Perform ifft in the readout direction
Fyadj = lambda x: sig.ifft(x,ax = -1)   #Perform ifft in the phaseencode direction

def waveForward(x,psf): #Perform the forward wave operation through psf modeling
    return psf * x

def waveAdjoint(x,psf): #Perform the adjoint wave operation through psf modeling
    xp = cp.get_array_module(x)
    return xp.conj(psf) * x 

def resize(x,padreadout,inpdims):
    xp = cp.get_array_module(x)
    
    if(inpdims == 3):
        return xp.pad(x,((0,0),(padreadout,padreadout),(0,0)),mode = 'constant',constant_values = 0)
    elif(inpdims == 4):
        return xp.pad(x,((0,0),(0,0),(padreadout,padreadout),(0,0)),mode = 'constant',constant_values = 0)
    
crop = lambda x: x[:,:,Nro//2 - M//2:Nro//2 + M//2,:]

def senseWaveForward(x,maps,psf,padreadout,mask):
    xp = cp.get_array_module(x)
    return mask*Fy(waveForward(Fx(xp.reshape(resize(xp.reshape(sforwave(x,maps),\
        (C,numslices,M,N)),padreadout,inpdims = 4),(C,numslices*Nro,N))),psf))

def senseWaveAdjoint(x,maps,psf,mask):
    xp = cp.get_array_module(x)
    return sadjwave(xp.reshape(crop(np.reshape(Fxadj(waveAdjoint(Fyadj(xp.conj(mask) * x),psf)),\
                                               (C,numslices,Nro,N))),(C,numslices*M,N)),maps)

def analyzePsf(x,psf,padreadout):
    
    tmp = np.reshape(x,(numslices,M,N))
    tmp = resize(tmp,padreadout,inpdims = 3)
    tmp = np.reshape(tmp,(1,numslices*Nro,N))
    
    return Fxadj(psf*Fx(tmp))

# Generating wave encoded kspace 

In [None]:
#-Generate the undersampling mask
maskWave = np.zeros((C,M*os*numslices,N))
maskWave[:,::numslices,::Ry] = 1

#-Generate the noise
noiseWave = np.random.normal(0,1/snr,maskWave.shape) + 1j*np.random.normal(0,1/snr,maskWave.shape)
kspaceWave = maskWave*(Fy(waveForward(Fx(np.reshape(resize(np.reshape(slicesCoils,(C,numslices,M,N)),\
                    padreadout,inpdims = 4),(C,numslices*Nro,N))),psfRo)) + noiseWave)


# Performing wave-encoded reconstruction 

In [None]:
#-Compute the adjoint of the kspace data
kadjWave = senseWaveAdjoint(kspaceWave,coils,psfRo,maskWave)

if(cudaflag):
    coils       = cp.asarray(coils)
    maskWave    = cp.asarray(maskWave)
    kadjWave    = cp.asarray(kadjWave)
    psfRo       = cp.asarray(psfRo)
    
#-Defining the normal operator and performing the reconstruction
normalWave = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(M*numslices,N),coils,psfRo,padreadout,maskWave),\
                                          coils,psfRo,maskWave).ravel()

print('WAVE SENSE reconstruction ...',end='')
smsWave = cp.asnumpy(iterative.conjgrad(normalWave,kadjWave.ravel(),kadjWave.ravel(),\
                                         ite = 20)).reshape(M*numslices,N)
print(' Done.')

coils    = cp.asnumpy(coils)
maskWave = cp.asnumpy(maskWave)
kadjWave = cp.asnumpy(kadjWave)
psfRo    = cp.asnumpy(psfRo)

# Comparing ground truth, sense, wave 

In [None]:
wave =  np.squeeze(performshift(np.expand_dims(np.reshape(smsWave,(numslices,M,N)),axis=0),shifts,-1))

display = sig.nor(np.concatenate((truth,sense,wave),axis = 0))
sig.mosaic(display,3,numslices)

In [None]:
print('Sense Total rmse:   %.2f' % (sig.rmse(truth,sense)*100) )
print('Wave  Total rmse:   %.2f' % (sig.rmse(truth,wave)*100) )

for ss in range(0,numslices):
    print('Slice %d:' %(ss+1))
    print('  sense rmse: %.2f' % (sig.rmse(truth[ss,:,:],sense[ss,:,:])*100))
    print('  wave  rmse: %.2f' % (sig.rmse(truth[ss,:,:],wave[ss,:,:])*100))


# Setting up kspaces for wave-encoded SPARK

In [None]:
kspaceWaveSpark = np.expand_dims(senseWaveForward(smsWave,coils,psfRo,padreadout,mask = 1),axis = 0)
kspaceWaveAcs   = np.expand_dims(Fy(waveForward(Fx(np.reshape(resize(np.reshape(slicesCoils,(C,numslices,M,N)),\
                    padreadout,inpdims = 4),(C,numslices*Nro,N))),psfRo)),axis = 0)

#-Adjust the acs region to be the oversampled size in wave
acsregionX = np.arange(0,M * numslices * os)
acsx = len(acsregionX)

# Performing wave-encoded SPARK reconstruction

In [None]:
#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceWaveSpark,kspaceWaveAcs,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

# Performing the correction 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrectedWave    = np.zeros((C,M*numslices*os,N),dtype = complex)

for c in range(0,C):
    #Perform reconstruction coil by coil
    model_namer = 'model' + 'E' + str(0) + 'C' + str(c) + 'r'
    model_namei = 'model' + 'E' + str(0) + 'C' + str(c) + 'i'

    real_model = realSparkGrappaModels[model_namer]
    imag_model = imagSparkGrappaModels[model_namei]

    kspaceToCorrect   = kspaceWaveSpark[0,c,:,:]
    kspaceGrappaSplit = kspace_grappa_split[0,:,:,:]

    currentCorrected = \
            applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                chan_scale_factors_real[0,c], chan_scale_factors_imag[0,c])

    kspaceCorrectedWave[c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplacedWave   = np.copy(kspaceCorrectedWave)


kspaceCorrectedReplacedWave[:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspaceWaveAcs[0,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 
    
#Perform IFFT and coil combine
sparkWave = senseWaveAdjoint(kspaceCorrectedReplacedWave,coils,psfRo,mask = 1)
sparkWave = np.squeeze(performshift(np.expand_dims(np.reshape(sparkWave,(numslices,M,N)),axis=0),shifts,-1))

# Displaying results 

In [None]:
display = sig.nor(np.concatenate((truth,wave,sparkWave),axis = 0))
sig.mosaic(display,3,numslices)

# Saving results 

In [None]:
results = {'truth': np.squeeze(truth),
           'sense': np.squeeze(sense),
           'spark': np.squeeze(spark),
           'wave':  np.squeeze(wave),
           'sparkwave': np.squeeze(sparkWave),
           'mbfactor': numslices,
           'Ry': Ry,
           'acsy': acsy,           
           'Iterations': sparkIterations,
           'learningRate': learningRate,
           'snr' : snr,
           'slices':  slices,
           'slicedistance': slicedistance,
           'os' : os,
            'cycles' : cycles,
          'Gymax' :Gymax,
          'Gzmax' :Gzmax,
          'Tadc':Tadc,
          'FOVy':FOVy,
          'FOVz':FOVz,
          'fovshift':fovshift}

sp.io.savemat('figure4results.mat', results, oned_as='row')