In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from utils import cfl
from utils import signalprocessing as sig
from utils import models

In [None]:
forspark   = sp.io.loadmat('forspark/forspark_Rx6Ry1.mat')['forspark']

kspace_orig    = np.transpose(cfl.readcfl('forspark/kspace_orig_Rx6Ry1'),axes = (2,0,1))
kspace_grappa  = np.transpose(cfl.readcfl('forspark/kspace_grappa_noisy_Rx6Ry1'),axes = (3,2,0,1))

Rx    = forspark[0][0][0][0][0]
Ry    = forspark[0][0][1][0][0]
acsx  = forspark[0][0][2][0][0]
acsy  = forspark[0][0][3][0][0]

baseline_coils = np.expand_dims(np.transpose(forspark[0][0][4],(2,0,1)),axis = 0)
coils          = np.expand_dims(np.transpose(forspark[0][0][5],(2,0,1)),axis = 0)

kspace       = np.expand_dims(kspace_orig,axis = 0)
kspaceGrappa = sig.fft2c(baseline_coils);

[E,C,M,N] = kspace.shape

#-Spark parameters
iterations        = 200
learningRate      = .0075
normalizationflag = 1
normalizeAll      = 0

# Generate ACS for training network from just base-line 

In [None]:
#Generate zero-filled ACS
acsregionX = np.arange(M//2 - acsx // 2,M//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

kspaceAcsZerofilled = np.zeros((E,C,M,N),dtype = complex)
kspaceAcsZerofilled[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] = kspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1]

# SPARK helper functions

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            
            print('Training Complete, loss = %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
                
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Training Complete, loss = %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

# Training network on just baseline 

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceGrappa,kspaceAcsZerofilled,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,iterations)

# Correcting just the baseline 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((E,C,M,N),dtype = complex)


for reconContrast in range(0,E):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceGrappa[reconContrast,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[reconContrast,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[reconContrast,c], chan_scale_factors_imag[reconContrast,c])

        kspaceCorrected[reconContrast,c,:,:] = currentCorrected  
        
kspace_baseline_spark = np.copy(kspaceCorrected)

# Coil-combining grappa_baseline, spark_baseline, and truth 

In [None]:
kspaceCorrected    = np.zeros((E,C,M,N),dtype = complex)

for reconContrast in range(0,E):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceGrappa[reconContrast,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[reconContrast,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[reconContrast,c], chan_scale_factors_imag[reconContrast,c])

        kspaceCorrected[reconContrast,c,:,:] = currentCorrected   
kspaces_spark = np.copy(kspaceCorrected)

# compare just baselines

In [None]:
cc = lambda x: np.sum(np.conj(coils) * x, axis = -3)/(1e-12 +np.sum(coils * np.conj(coils),-3))

truth             = cc(sig.ifft2c(kspace))
baseline_grappa   = cc(baseline_coils)
baseline_spark    = cc(sig.ifft2c(kspace_baseline_spark))

print("BASELINE RMSES")
print("  grappa: %.2f" % (sig.rmse(truth,baseline_grappa)*100))
print("  spark:  %.2f" % (sig.rmse(truth,baseline_spark)*100))

In [None]:
display = np.concatenate((truth,baseline_grappa,baseline_spark),axis = 0)
sig.mosaic(sig.nor(display),1,3)

# Apply SPARK correction to all other grappa reconstructed k-spaces 

In [None]:
kspaceGrappa = np.copy(kspace_grappa)
[E,C,M,N] = kspaceGrappa.shape
#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceGrappa,kspaceAcsZerofilled,acsregionX,acsregionY,acsx,acsy,normalizationflag)

In [None]:
kspaceCorrectedReplica    = np.zeros((E,C,M,N),dtype = complex)

for reconContrast in range(0,E):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(0) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(0) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceGrappa[reconContrast,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[reconContrast,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[0,c], chan_scale_factors_imag[0,c])

        kspaceCorrectedReplica[reconContrast,c,:,:] = currentCorrected   
kspaces_spark_montecarlo = np.copy(kspaceCorrectedReplica)

# Compute spark and GRAPPA monte-carlo replicas and compute RMSE's acrooss all 

In [None]:
grappa_montecarlo = cc(sig.ifft2c(kspace_grappa))
spark_montecarlo  = cc(sig.ifft2c(kspaces_spark_montecarlo))

In [None]:
rmsespark  = np.zeros((E,1))
rmsegrappa = np.zeros((E,1))

for ee in range(0,E):
    rmsespark[ee]  = sig.rmse(truth,grappa_montecarlo[ee,:,:])
    rmsegrappa[ee] = sig.rmse(truth,spark_montecarlo[ee,:,:])

plt.plot(rmsespark)
plt.plot(rmsegrappa)
plt.show()


## Display a random example 

In [None]:
randint = np.random.randint(0,E)
display = np.concatenate((grappa_montecarlo[randint:randint+1,:,:],spark_montecarlo[randint:randint+1,:,:]),axis =0)
sig.mosaic(sig.nor(display),1,2)

# Saving Results 

In [None]:
results = {'baseline_grappa':   np.squeeze(baseline_grappa),
           'baseline_spark' :   np.squeeze(baseline_spark),
           'grappa_montecarlo': np.squeeze(grappa_montecarlo),
           'spark_montecarlo' : np.squeeze(spark_montecarlo),
           'Rx': Rx,
           'Ry': Ry,
           'acsx': acsx,
           'acsy': acsy,           
           'Iterations': iterations,
           'learningRate': learningRate}

sp.io.savemat('results/results_Rx%dRy%d.mat' % (Rx,Ry), results, oned_as='row')