In [None]:
import importlib 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy as np
import numpy.linalg as la
import scipy as sp
import cupy as cp
import matplotlib.pyplot as plt
from bart import bart
from utils import cfl
from utils import signalprocessing as sig
from utils import models
from utils import iterative

### defining spark helper functions

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Initial Loss: %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            print('Final Loss:   %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                        print('Initial Loss: %.10f' % (running_loss))
                        
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Final Loss:   %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

### loading dataset and setting parameters 

In [None]:
#Loading the fully sampled wave encoded dataset, psf, and coils
kspace = np.expand_dims(np.transpose(cfl.readcfl('data/kspaceWaveFull2d'),(2,0,1)),\
                        axis=0)
coils  = np.transpose(cfl.readcfl('data/coils2d'),(3,2,0,1))
psf    = np.expand_dims(np.expand_dims(cfl.readcfl('data/psf2d'),axis = 0),axis = 0)

[nMaps,nCoils,Nro,Npe] = coils.shape

#Defining acquisition parameters (acs size and acceleration)
Ry   = 6
acsx = Nro
acsy = 30

#-Iterative method parameters
senseIterations = 20      #20
cudaflag        = 1

#-Some SPARK parameters
learningRate      = .0075  #.0075
sparkIterations   = 200    #200
normalizationflag = 1

### generating cartesian k-space

In [None]:
acsregionX = np.arange(Nro//2 - acsx // 2,Nro//2 + acsx//2) 
acsregionY = np.arange(Npe//2 - acsy // 2,Npe//2 + acsy//2) 

kspace_cartesian = sig.fft(np.conj(psf)*sig.ifft(kspace,-1),-1)

mask = np.zeros((1,nCoils,Nro,Npe),dtype = complex)
mask[:,:,:,::Ry] = 1

maskAcs = np.zeros((1,nCoils,Nro,Npe),dtype = complex)
maskAcs[:,:,:,::Ry] = 1
maskAcs[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] = 1

maskFull = 1

kspaceUndersampledCart      = mask * kspace_cartesian
kspaceUndersampledAcsCart   = maskAcs * kspace_cartesian

### defining sense operators

In [None]:
def senseForward(x,maps,mask):
    xp = cp.get_array_module(x)
    return mask * sig.fft2c(xp.sum(maps*x,-4,keepdims=True))

def senseAdjoint(x,maps,mask):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(maps)*sig.ifft2c(mask*x),-3,keepdims = True)

### performing sense reconstruction

In [None]:
#-Compute the adjoint of the kspace data
kadj    = senseAdjoint(kspaceUndersampledCart,coils,mask)
kadjAcs = senseAdjoint(kspaceUndersampledAcsCart,coils,maskAcs)

if(cudaflag):
    coils   = cp.asarray(coils)
    mask    = cp.asarray(mask)
    kadj    = cp.asarray(kadj)
    kadjAcs = cp.asarray(kadjAcs)
    maskAcs = cp.asarray(maskAcs)
    
#-Defining the normal operator and performing the reconstruction
normal = lambda x: senseAdjoint(senseForward(x.reshape(nMaps,1,Nro,Npe),coils,mask),\
                                          coils,mask).ravel()
normalAcs = lambda x: senseAdjoint(senseForward(x.reshape(nMaps,1,Nro,Npe),coils,maskAcs),\
                                          coils,maskAcs).ravel()

#print('SENSE reconstruction ...',end='')
sense = cp.asnumpy(iterative.conjgrad(normal,kadj.ravel(),kadj.ravel(),\
                                         ite = 20)).reshape(nMaps,1,Nro,Npe)
senseAcs = cp.asnumpy(iterative.conjgrad(normalAcs,kadjAcs.ravel(),kadjAcs.ravel(),\
                                         ite = 20)).reshape(nMaps,1,Nro,Npe)
print(' Done.')

coils = cp.asnumpy(coils)
mask  = cp.asnumpy(mask)
kadj  = cp.asnumpy(kadj)
kadjAcs = cp.asnumpy(kadjAcs)
maskAcs = cp.asnumpy(maskAcs)

#### quick comparisons between sense recons with and without acs 

In [None]:
cropreg  = np.arange(Nro//2 - 128,Nro//2 + 128) 
sensed    = sense[0:1,:,cropreg,:]
senseAcsd = senseAcs[0:1,:,cropreg,:]

display = sig.nor(np.concatenate((sensed[0,:,:,:],senseAcsd[0,:,:,:]),axis = 0))
sig.mosaic(display,1,2,clim = [0,.8])

### computing k-space to learn with spark

In [None]:
kspaceSenseCart = senseForward(sense,coils,1)

### performing spark training

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSenseCart,kspace_cartesian,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

### generating k-space to which we apply correction

In [None]:
kspaceSenseToCorrect = senseForward(senseAcs,coils,1)

#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSenseToCorrect,kspace_cartesian,acsregionX,acsregionY,acsx,acsy,normalizationflag)

### applying correction and acs replacement

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((1,nCoils,Nro,Npe),dtype = complex)

for c in range(0,nCoils):
    #Perform reconstruction coil by coil
    model_namer = 'model' + 'E' + str(0) + 'C' + str(c) + 'r'
    model_namei = 'model' + 'E' + str(0) + 'C' + str(c) + 'i'

    real_model = realSparkGrappaModels[model_namer]
    imag_model = imagSparkGrappaModels[model_namei]

    kspaceToCorrect   = kspaceSenseToCorrect[0,c,:,:]
    kspaceGrappaSplit = kspace_grappa_split[0,:,:,:]

    currentCorrected = \
            applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                chan_scale_factors_real[0,c], chan_scale_factors_imag[0,c])

    kspaceCorrected[:,c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplaced    = np.copy(kspaceCorrected)


kspaceCorrectedReplaced[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspace_cartesian[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 

### computing sense by resolving sense problem 

In [None]:
#-Compute the adjoint of the kspace data
kadjSparkCart  = senseAdjoint(kspaceCorrectedReplaced,coils,1)

if(cudaflag):
    coils           = cp.asarray(coils)
    kadjSparkCart   = cp.asarray(kadjSparkCart)
    
#-Defining the normal operator and performing the reconstruction
normalSpark = lambda x: senseAdjoint(senseForward(x.reshape(nMaps,1,Nro,Npe),coils,1),\
                                          coils,1).ravel()

print('SENSE reconstructions ...',end='')
sparkCart      = cp.asnumpy(iterative.conjgrad(normalSpark,kadjSparkCart.ravel(),kadjSparkCart.ravel(),\
                                         ite = senseIterations)).reshape(nMaps,1,Nro,Npe)
print(' Done.')

coils       = cp.asnumpy(coils)
kadjSpark   = cp.asnumpy(kadjSparkCart)

#### quick displaying spark comparison

In [None]:
sparkcartd = sparkCart[0:1,:,cropreg,:]

display = sig.nor(np.concatenate((senseAcsd[0,:,:,:],sparkcartd[0,:,:,:]),axis = 0))
sig.mosaic(display,1,2,clim = [0,.8])

### defining wave-encoding operators 

In [None]:
def Sf(x,coils):
    xp = cp.get_array_module(x)
    return xp.sum(coils * x,-4,keepdims=True)

def Sa(x,coils):
    xp = cp.get_array_module(x)
    return xp.sum(xp.conj(coils)*x,-3,keepdims = True)

Fxf    = lambda x: sig.fft(x,ax = -2)    #Perform fft in the readout direction
Fyf    = lambda x: sig.fft(x,ax = -1)    #Perform fft in the phaseencode direction
Fxa    = lambda x: sig.ifft(x,ax = -2)   #Perform ifft in the readout direction
Fya    = lambda x: sig.ifft(x,ax = -1)   #Perform ifft in the phaseencode direction

def Pf(x,psf): #Perform the forward wave operation through psf modeling
    return psf * x

def Pa(x,psf): #Perform the adjoint wave operation through psf modeling
    xp = cp.get_array_module(x)
    return xp.conj(psf) * x 

def senseWaveForward(x,maps,psf,mask):
    return mask*Fyf(Pf(Fxf(Sf(x,maps)),psf))

def senseWaveAdjoint(x,maps,psf,mask):
    return Sa(Fxa(Pa(Fya(mask*x),psf)),maps)

### generating under-sampled wave-encoded k-space with/without ACS region 

In [None]:
acsregionX = np.arange(Nro//2 - acsx // 2,Nro//2 + acsx//2) 
acsregionY = np.arange(Npe//2 - acsy // 2,Npe//2 + acsy//2) 

maskFull = 1

#-Generate the undersampling masks with and without acs region
maskNoacs = np.zeros((1,1,Nro,Npe))
maskNoacs[:,:,:,::Ry] = 1

maskAcs = np.zeros((1,1,Nro,Npe))
maskAcs[:,:,:,::Ry] = 1
maskAcs[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] = 1

kspaceUndersampled      = maskNoacs * kspace
kspaceUndersampledAcs   = maskAcs * kspace

### performing wave-encoded reconstructions 

In [None]:
#-Compute the adjoint of the kspace data
kadjFull  = senseWaveAdjoint(kspace,coils,psf,maskFull)
kadj      = senseWaveAdjoint(kspaceUndersampled,coils,psf,maskNoacs)
kadjAcs   = senseWaveAdjoint(kspaceUndersampledAcs,coils,psf,maskAcs)

if(cudaflag):
    coils       = cp.asarray(coils)
    psf         = cp.asarray(psf)
    maskNoacs   = cp.asarray(maskNoacs)
    kadj        = cp.asarray(kadj)
    maskAcs     = cp.asarray(maskAcs)
    kadjAcs     = cp.asarray(kadjAcs)
    kadjFull    = cp.asarray(kadjFull)
    
#-Defining the normal operator and performing the reconstruction
normalWaveFull= lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(nMaps,1,Nro,Npe),coils,psf,maskFull),\
                                          coils,psf,maskFull).ravel()
normalWave    = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(nMaps,1,Nro,Npe),coils,psf,maskNoacs),\
                                          coils,psf,maskNoacs).ravel()
normalWaveAcs = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(nMaps,1,Nro,Npe),coils,psf,maskAcs),\
                                          coils,psf,maskAcs).ravel()

print('WAVE SENSE reconstructions ...',end='')
wave      = cp.asnumpy(iterative.conjgrad(normalWave,kadj.ravel(),kadj.ravel(),\
                                         ite = senseIterations)).reshape(nMaps,1,Nro,Npe)

waveAcs   = cp.asnumpy(iterative.conjgrad(normalWaveAcs,kadjAcs.ravel(),kadjAcs.ravel(),\
                                         ite = senseIterations)).reshape(nMaps,1,Nro,Npe)

full      = cp.asnumpy(iterative.conjgrad(normalWaveFull,kadjFull.ravel(),kadjFull.ravel(),\
                                         ite = senseIterations)).reshape(nMaps,1,Nro,Npe)
print(' Done.')

coils       = cp.asnumpy(coils)
psf         = cp.asnumpy(psf)
maskNoacs   = cp.asnumpy(maskNoacs)
kadj        = cp.asnumpy(kadj)
maskAcs     = cp.asnumpy(maskAcs)
kadjAcs     = cp.asnumpy(kadjAcs)
kadjFull    = cp.asnumpy(kadjFull)

### displaying wave-encoded reconstructions

In [None]:
cropreg  = np.arange(Nro//2 - 128,Nro//2 + 128) 
fulld    = full[0:1,:,cropreg,:]
waved    = wave[0:1,:,cropreg,:]
waveAcsd = waveAcs[0:1,:,cropreg,:]

display = sig.nor(np.concatenate((fulld[0,:,:,:],waved[0,:,:,:],waveAcsd[0,:,:,:]),axis = 0))
sig.mosaic(display,1,3,clim = [0,.8])

### quantifying reconstructions 

In [None]:
print('Wave rmse:         %.2f' % (sig.rmse(fulld,waved)*100))
print('Waveacs rmse:      %.2f' % (sig.rmse(fulld,waveAcsd)*100))
print('Senseacs rmse:     %.2f' % (sig.rmse(fulld,senseAcsd)*100))
print('senseSpark rmse:   %.2f' % (sig.rmse(fulld,sparkcartd)*100))

### computing wave-encoded k-space to learn with SPARK 

In [None]:
kspaceSense = senseWaveForward(wave,coils,psf,maskFull)

### performing spark training 

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSense,kspace,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,sparkIterations)

### generating k-space to which we apply correction

In [None]:
kspaceSenseToCorrect = np.copy(kspaceSense) #what I did originally
#kspaceSenseToCorrect = senseWaveForward(waveAcs,coils,psf,maskFull)

#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceSenseToCorrect,kspace,acsregionX,acsregionY,acsx,acsy,normalizationflag)

### performing correction with ACS replacement 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((1,nCoils,Nro,Npe),dtype = complex)

for c in range(0,nCoils):
    #Perform reconstruction coil by coil
    model_namer = 'model' + 'E' + str(0) + 'C' + str(c) + 'r'
    model_namei = 'model' + 'E' + str(0) + 'C' + str(c) + 'i'

    real_model = realSparkGrappaModels[model_namer]
    imag_model = imagSparkGrappaModels[model_namei]

    kspaceToCorrect   = kspaceSenseToCorrect[0,c,:,:]
    kspaceGrappaSplit = kspace_grappa_split[0,:,:,:]

    currentCorrected = \
            applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                chan_scale_factors_real[0,c], chan_scale_factors_imag[0,c])

    kspaceCorrected[:,c,:,:] = currentCorrected       
            
#ACS replaced
kspaceCorrectedReplaced    = np.copy(kspaceCorrected)


kspaceCorrectedReplaced[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] = \
    kspace[:,:,acsregionX[0]:acsregionX[acsx-1],acsregionY[0]:acsregionY[acsy-1]] 
    

### computing final spark reconstruction by resolving sense problem 

In [None]:
#-Compute the adjoint of the kspace data
kadjSpark  = senseWaveAdjoint(kspaceCorrectedReplaced,coils,psf,maskFull)

if(cudaflag):
    coils       = cp.asarray(coils)
    psf         = cp.asarray(psf)
    kadjSpark   = cp.asarray(kadjSpark)
    
#-Defining the normal operator and performing the reconstruction
normalWaveSpark = lambda x: senseWaveAdjoint(senseWaveForward(x.reshape(nMaps,1,Nro,Npe),coils,psf,maskFull),\
                                          coils,psf,maskFull).ravel()

print('WAVE SENSE reconstructions ...',end='')
spark      = cp.asnumpy(iterative.conjgrad(normalWaveSpark,kadjSpark.ravel(),kadjSpark.ravel(),\
                                         ite = senseIterations)).reshape(nMaps,1,Nro,Npe)
print(' Done.')

coils       = cp.asnumpy(coils)
psf         = cp.asnumpy(psf)
kadjSpark   = cp.asnumpy(kadjSpark)

### displaying comparisons 

In [None]:
sparkd = spark[0:1,:,cropreg,:]

display = sig.nor(np.concatenate((senseAcsd[0,:,:,:],sparkcartd[0,:,:,:],waveAcsd[0,:,:,:],sparkd[0,:,:,:]),\
                                 axis = 0))
sig.mosaic(display,1,4,clim = [0,.8])

In [None]:
print('Sense rmse:      %.2f' % (sig.rmse(fulld,senseAcsd)*100))
print('senseSpark rmse: %.2f' % (sig.rmse(fulld,sparkcartd)*100))
print('Wave rmse:       %.2f' % (sig.rmse(fulld,waveAcsd)*100))
print('waveSpark rmse:  %.2f' % (sig.rmse(fulld,sparkd)*100))


### Saving results 

In [None]:
results = {'full': np.squeeze(full),
           'wave': np.squeeze(waveAcs),
           'spark': np.squeeze(spark),
           'sense': np.squeeze(senseAcs),
           'sensespark' : np.squeeze(sparkCart),
           'Ry': Ry,
           'acsy': acsy,           
           'Iterations': sparkIterations,
           'learningRate': learningRate,
           'orientation': ori,
           'slice': sli}
