In [None]:
import importlib 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy as np
import numpy.linalg as la
import scipy as sp
import cupy as cp
import matplotlib.pyplot as plt
from bart import bart
from utils import cfl
from utils import signalprocessing as sig
from utils import models
from utils import iterative

### loading dataset and selecting slices/psf

In [None]:
#-Load Fx^H * Fy^H * Fz^H * wave_data
img_yz = np.transpose(cfl.readcfl('data/img_yz'),(3,2,0,1))

#-Load y/z wave psf's
PsfY_fit = np.expand_dims(np.expand_dims(cfl.readcfl('data/PsfY_fit'),0),0)
PsfZ_fit = np.expand_dims(np.expand_dims(cfl.readcfl('data/PsfZ_fit'),0),0)
PsfZ_fit = np.transpose(PsfZ_fit,(0,3,2,1))

In [None]:
[C,P,Nro,N] = img_yz.shape

#-Some parameters on the sms acquisition 
beginningSliceIndex = 10
numslices_all       = 4
slicedistance       = P // numslices_all
fovshift            = 3 #FOV shift factor 'blipped caipi' esque stuff

#-acquisition parameters 
Ry                  = 5
os                  = 3   #How much wave encoding was oversampled by

#-Iterative method parameters
senseIterations = 20
cudaflag        = 1

learningRate      = .0075
sparkIterations   = 200
normalizationflag = 1

slices_all = np.linspace(beginningSliceIndex,beginningSliceIndex + slicedistance * (numslices_all-1),numslices_all).astype(int)
slices     = slices_all[1::] #Remove frist empty slice
numslices  = numslices_all - 1

#-Some SPARK parameters
acsx = Nro 
acsy = 30

### selecting slices and associated psf 

In [None]:
img_yz_slices = img_yz[:,slices,:,:]
psf_slices    = PsfY_fit * PsfZ_fit[:,slices,:,:]

### visualizing slices to be aliased

In [None]:
slices_to_alias_coils = sig.ifft(np.conj(psf_slices) * sig.fft(img_yz_slices,-2),-2)
slices_to_alias = sig.rsos(slices_to_alias_coils,-4)
sig.mosaic(sig.nor(slices_to_alias),1,numslices)

### generating cartesian k-space and coil profiles 

In [None]:
kspace_slices_cartesian = sig.fft2c(slices_to_alias_coils)

coils_slices = np.zeros((C,numslices,Nro,N),dtype = complex)

for ss in range(0,numslices):
    print('Callibrating coils for slice %d/%d' % (ss + 1,numslices))
    curksp  = np.expand_dims(np.transpose(kspace_slices_cartesian[:,ss,:,:],axes = (1,2,0)),2)
    coils_slices[:,ss,:,:] = np.squeeze(np.transpose(bart(1,'ecalib -m 1 -c .5',curksp),(3,2,0,1)))
print('done.')

### applying fov shift to slices and sensitivities 

In [None]:
#-Define the shift amounts
if(fovshift > 0):
    shifts = np.round(np.linspace(-(numslices / 2 - 1),numslices/2,numslices) * N / fovshift ).astype(int)
else:
    shifts = np.zeros((numslices)).astype(int)

#-Define the function which performs the shifting
def performshift(x,shift,direction = 1):
    out = np.zeros(x.shape,dtype=complex)
    
    for ss in range(0,out.shape[-3]):
        out[:,ss,:,:] = np.roll(x[:,ss,:,:],direction*shift[ss])
    return out
      
#-Compute shifted slices in image space (as well as the shifted coils)
slicesShiftedCoils  = performshift(slices_to_alias_coils,shifts)
coils               = performshift(coils_slices,shifts)
img_yz_slices_shift = performshift(img_yz_slices,shifts)


### visualizing shifted slices and wave aliasing

In [None]:
display = np.squeeze(np.concatenate((np.expand_dims(sig.nor(sig.rsos(slicesShiftedCoils,-4)),0),\
                          np.expand_dims(sig.nor(sig.rsos(img_yz_slices_shift,-4)),0)),axis = 1))
sig.mosaic(display,1,2*numslices)

### generating cartesian k-space slice group 

In [None]:
col = lambda x: np.sum(x,axis = - 3, keepdims = True)
exp = lambda x: np.repeat(x,repeats = numslices,axis = -3)
acsregionX = np.arange((Nro*numslices)//2 - acsx // 2,(Nro*numslices)//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

#-Generate the undersampling mask
mask = np.zeros((C,1,Nro,N),dtype = complex)
mask[:,:,:,::Ry] = 1

maskAcs = np.zeros((C,1,Nro,N),dtype = complex)
maskAcs[:,:,:,::Ry]       = 1
maskAcs[:,:,:,acsregionY[0]:acsregionY[acsy-1]] = 1

kspace    = col(mask * (sig.fft2c(slicesShiftedCoils)))
kspaceAcs = col(maskAcs * (sig.fft2c(slicesShiftedCoils)))

### generating slice group sense operators

In [None]:
def senseForward(x,maps,mask):
    return mask * col(sig.fft2c(maps*x))
def senseAdjoint(x,maps,mask):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(maps)*(sig.ifft2c(exp(x))),-4,keepdims = True)

### performing slice-group sense reconstruction

In [None]:
#-Compute the adjoint of the kspace data
kadj = senseAdjoint(kspace,coils,mask)
kadjAcs = senseAdjoint(kspaceAcs,coils,maskAcs)

if(cudaflag):
    coils   = cp.asarray(coils)
    mask    = cp.asarray(mask)
    maskAcs = cp.asarray(maskAcs)
    kadj    = cp.asarray(kadj)
    kadjAcs = cp.asarray(kadjAcs)
    
#-Defining the normal operator and performing the reconstruction
normal = lambda x: senseAdjoint(senseForward(x.reshape(1,numslices,Nro,N),coils,mask),\
                                          coils,mask).ravel()

normalAcs = lambda x: senseAdjoint(senseForward(x.reshape(1,numslices,Nro,N),coils,maskAcs),\
                                          coils,maskAcs).ravel()
print('SENSE reconstruction ...',end='')
smsSense = cp.asnumpy(iterative.conjgrad(normal,kadj.ravel(),kadj.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)
smsSenseAcs = cp.asnumpy(iterative.conjgrad(normalAcs,kadjAcs.ravel(),kadjAcs.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)

print(' Done.')

coils = cp.asnumpy(coils)
mask  = cp.asnumpy(mask)
kadj  = cp.asnumpy(kadj)
kadjAcs = cp.asnumpy(kadjAcs)
maskAcs = cp.asnumpy(maskAcs)

### evaluating slice-group sense reconstruction

In [None]:
viscrop = lambda x: x[:,768//2-128:768//2+128,:]
    
truth = viscrop(np.squeeze(performshift(np.expand_dims(np.sum(np.conj(coils) * slicesShiftedCoils,-4),axis = 0),shifts,-1),axis = 0))
sense = viscrop(np.squeeze(performshift(np.expand_dims(np.reshape(smsSenseAcs,(numslices,Nro,N)),axis=0),shifts,-1),axis=0))

display = sig.nor(np.concatenate((truth,sense),axis = 0))
sig.mosaic(display,2,numslices)

In [None]:
print('Total rmse:   %.2f' % (sig.rmse(truth,sense)*100) )
for ss in range(0,numslices):
    print('Slice %d rmse: %.2f' % (ss+1,sig.rmse(truth[ss,:,:],sense[ss,:,:])*100))


### defining SPARK helper functions

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Initial Loss: %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            print('Final Loss:   %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                        print('Initial Loss: %.10f' % (running_loss))
                        
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Final Loss:   %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

### computing k-space for training spark network 

In [None]:
kspaceSense      = np.transpose(col(sig.fft2c(coils*smsSense)),(1,0,2,3))
kspaceAcsSpark   = np.transpose(col((sig.fft2c(slicesShiftedCoils))),(1,0,2,3))
acsregionX = np.arange((Nro)//2 - acsx // 2,(Nro)//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

### training spark network 

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSense,kspaceAcsSpark,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

### generating k-space to which we apply correction

In [None]:
kspaceSenseToCorrect      = np.transpose(col(sig.fft2c(coils*smsSenseAcs)),(1,0,2,3))

[kspace_grappasplit, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSenseToCorrect,kspaceAcsSpark,acsregionX,acsregionY,acsx,acsy,normalizationflag)


### performing correction and acs replacement

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((1,C,Nro,N),dtype = complex)

for e in range(0,1):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceSenseToCorrect[e,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[e,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[e,c], chan_scale_factors_imag[e,c])

        kspaceCorrected[e,c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplaced    = np.copy(kspaceCorrected)


kspaceCorrectedReplaced[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspaceAcsSpark[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 
    

### re-performing cartesian slice-group reconstruction with corrected k-space

In [None]:
#-Compute the adjoint of the kspace data
kadj = senseAdjoint(np.transpose(kspaceCorrectedReplaced,(1,0,2,3)),coils,1)

if(cudaflag):
    coils   = cp.asarray(coils)
    kadj    = cp.asarray(kadj)
    
#-Defining the normal operator and performing the reconstruction
normal = lambda x: senseAdjoint(senseForward(x.reshape(1,numslices,Nro,N),coils,1),\
                                          coils,1).ravel()

print('SENSE reconstruction ...',end='')
smsSenseSpark = cp.asnumpy(iterative.conjgrad(normal,kadj.ravel(),kadj.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)

print(' Done.')

coils = cp.asnumpy(coils)
kadj  = cp.asnumpy(kadj)

### visualizing cartesian sense with spark

In [None]:
viscrop = lambda x: x[:,768//2-128:768//2+128,:]
    
spark = viscrop(np.squeeze(performshift(np.expand_dims(np.reshape(smsSenseSpark,(numslices,Nro,N)),axis=0),shifts,-1),axis=0))

display = sig.nor(np.concatenate((truth,sense,spark),axis = 0))
sig.mosaic(display,3,numslices)

In [None]:
print('Sense Total rmse:   %.2f' % (sig.rmse(truth,sense)*100) )
print('Spark Total rmse:   %.2f' % (sig.rmse(truth,spark)*100) )

if(numslices > 1):
    for ss in range(0,numslices):
        print('Slice %d:' %(ss+1))
        print('  sense rmse: %.2f' % (sig.rmse(truth[ss,:,:],sense[ss,:,:])*100))
        print('  spark rmse: %.2f' % (sig.rmse(truth[ss,:,:],spark[ss,:,:])*100))


### defining shifted psf 

In [None]:
#-Shift the psf and restructure to match readout oversampled dimensions
psf = np.copy(psf_slices)

### defining wave-encoded operators 

In [None]:
def sforwave(x,coils):
    return coils * x

def sadjwave(x,coils):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(coils)*x,-4,keepdims = True)

Fx    = lambda x: sig.fft(x,ax = -2)    #Perform fft in the readout direction
Fy    = lambda x: sig.fft(x,ax = -1)    #Perform fft in the phaseencode direction
Fxadj = lambda x: sig.ifft(x,ax = -2)   #Perform ifft in the readout direction
Fyadj = lambda x: sig.ifft(x,ax = -1)   #Perform ifft in the phaseencode direction

def waveForward(x,psf): #Perform the forward wave operation through psf modeling
    return psf * x

def waveAdjoint(x,psf): #Perform the adjoint wave operation through psf modeling
    xp = cp.get_array_module(x)
    return xp.conj(psf) * x 

def senseWaveForward(x,maps,psf,mask):
    xp = cp.get_array_module(x)
    return mask*col(Fy(waveForward(Fx(sforwave(x,maps)),psf)))

def senseWaveAdjoint(x,maps,psf,mask):
    xp = cp.get_array_module(x)
    return sadjwave(Fxadj(waveAdjoint(Fyadj(exp(xp.conj(mask) * x)),psf)),maps)

def analyzePsf(x,psf):    
    return Fxadj(psf*Fx(x))

### generating wave-encoded k-space

In [None]:
#-Generate the undersampling mask
maskWave = np.zeros((C,1,Nro,N),dtype = complex)
maskWave[:,:,:,::Ry] = 1

maskWaveAcs = np.zeros((C,1,Nro,N),dtype = complex)
maskWaveAcs[:,:,:,::Ry] = 1
maskWaveAcs[:,:,:,acsregionY[0]:acsregionY[acsy-1]] = 1

kspaceWave = col(maskWave * (sig.fft(psf*sig.fft(slicesShiftedCoils,-2),-1)))
kspaceWaveAcs = col(maskWaveAcs * (sig.fft(psf*sig.fft(slicesShiftedCoils,-2),-1)))

### performing wave-encoded reconstruction

In [None]:
#-Compute the adjoint of the kspace data
kadjWave = senseWaveAdjoint(kspaceWave,coils,psf,maskWave)
kadjWaveAcs = senseWaveAdjoint(kspaceWaveAcs,coils,psf,maskWaveAcs)

if(cudaflag):
    coils       = cp.asarray(coils)
    maskWave    = cp.asarray(maskWave)
    maskWaveAcs = cp.asarray(maskWaveAcs)
    kadjWave    = cp.asarray(kadjWave)
    kadjWaveAcs = cp.asarray(kadjWaveAcs)
    psf         = cp.asarray(psf)
    
#-Defining the normal operator and performing the reconstruction
normalWave = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(1,numslices,Nro,N),coils,psf,maskWave),\
                                          coils,psf,maskWave).ravel()

normalWaveAcs = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(1,numslices,Nro,N),coils,psf,maskWaveAcs),\
                                          coils,psf,maskWaveAcs).ravel()

print('WAVE SENSE reconstruction ...',end='')
smsWave = cp.asnumpy(iterative.conjgrad(normalWave,kadjWave.ravel(),kadjWave.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)

smsWaveAcs = cp.asnumpy(iterative.conjgrad(normalWaveAcs,kadjWaveAcs.ravel(),kadjWaveAcs.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)
print(' Done.')

coils    = cp.asnumpy(coils)
maskWave = cp.asnumpy(maskWave)
kadjWave = cp.asnumpy(kadjWave)
psf      = cp.asnumpy(psf)
maskWaveAcs = cp.asnumpy(maskWaveAcs)
kadjWaveAcs = cp.asnumpy(kadjWaveAcs)

### visualizing wave-encoded reconstruction 

In [None]:
wave = viscrop(np.squeeze(performshift(np.expand_dims(np.reshape(smsWaveAcs,(numslices,Nro,N)),axis=0),shifts,-1)))

display = sig.nor(np.concatenate((truth,sense,wave),axis = 0))
sig.mosaic(display,3,numslices)

In [None]:
print('Sense Total rmse:   %.2f' % (sig.rmse(truth,sense)*100) )
print('Wave  Total rmse:   %.2f' % (sig.rmse(truth,wave)*100) )

for ss in range(0,numslices):
    print('Slice %d:' %(ss+1))
    print('  sense rmse: %.2f' % (sig.rmse(truth[ss,:,:],sense[ss,:,:])*100))
    print('  wave  rmse: %.2f' % (sig.rmse(truth[ss,:,:],wave[ss,:,:])*100))


### setting up k-space to train SPARK

In [None]:
kspaceWaveSpark      = np.transpose(col(sig.fft(psf*sig.fft(coils*smsWave,-2),-1)),(1,0,2,3))
kspaceWaveAcsSpark   = np.transpose(col(sig.fft(psf*sig.fft(slicesShiftedCoils,-2),-1)),(1,0,2,3))

### training wave-spark network 

In [None]:
#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceWaveSpark,kspaceWaveAcsSpark,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

### generating k-space wiith which to apply correction 

In [None]:
#kspaceSenseToCorrect = np.copy(kspaceWaveSpark) #what I did originally
kspaceSenseToCorrect = np.transpose(col(sig.fft(psf*sig.fft(coils*smsWaveAcs,-2),-1)),(1,0,2,3))

[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSenseToCorrect,kspaceWaveAcsSpark,acsregionX,acsregionY,acsx,acsy,normalizationflag)


### applying spark correction 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrectedWave    = np.zeros((1,C,Nro,N),dtype = complex)

for e in range(0,1):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceSenseToCorrect[e,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[e,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[e,c], chan_scale_factors_imag[e,c])

        kspaceCorrectedWave[e,c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplacedWave   = np.copy(kspaceCorrectedWave)


kspaceCorrectedReplacedWave[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspaceWaveAcsSpark[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 

### performing slice-group reconstruction after spark correction

In [None]:
#-Compute the adjoint of the kspace data
kadjWave = senseWaveAdjoint(np.transpose(kspaceCorrectedReplacedWave,(1,0,2,3)),coils,psf,1)

if(cudaflag):
    coils       = cp.asarray(coils)
    kadjWave    = cp.asarray(kadjWave)
    psf         = cp.asarray(psf)
    
#-Defining the normal operator and performing the reconstruction
normalWave = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(1,numslices,Nro,N),coils,psf,1),\
                                          coils,psf,1).ravel()

print('WAVE SENSE reconstruction ...',end='')
smsWaveSpark = cp.asnumpy(iterative.conjgrad(normalWave,kadjWave.ravel(),kadjWave.ravel(),\
                                         ite = 20)).reshape(1,numslices,Nro,N)

print(' Done.')

coils    = cp.asnumpy(coils)
kadjWave = cp.asnumpy(kadjWave)
psf      = cp.asnumpy(psf)

### displaying results 

In [None]:
viscrop = lambda x: x[:,768//2-128:768//2+128,:]
    
sparkWave = viscrop(np.squeeze(performshift(np.expand_dims(np.reshape(smsWaveSpark,(numslices,Nro,N)),axis=0),shifts,-1),axis=0))

display = sig.nor(np.concatenate((truth,wave,sparkWave),axis = 0))
sig.mosaic(display,3,numslices)

In [None]:
print('wave  Total rmse:   %.2f' % (sig.rmse(truth,wave)*100) )
print('spark Total rmse:   %.2f' % (sig.rmse(truth,sparkWave)*100) )

for ss in range(0,numslices):
    print('Slice %d:' %(ss+1))
    print('  wave  rmse: %.2f' % (sig.rmse(truth[ss,:,:],wave[ss,:,:])*100))
    print('  spark rmse: %.2f' % (sig.rmse(truth[ss,:,:],sparkWave[ss,:,:])*100))


### saving results 

In [None]:
results = {'truth': np.squeeze(truth),
           'sense': np.squeeze(sense),
           'spark': np.squeeze(spark),
           'wave':  np.squeeze(wave),
           'sparkwave': np.squeeze(sparkWave),
           'mbfactor': numslices,
           'Ry': Ry,
           'acsy': acsy,           
           'Iterations': sparkIterations,
           'learningRate': learningRate,
           'slices':  slices,
           'fovshift':fovshift}