In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import scipy as sp
from utils import signalprocessing as sig
from utils import models

# Load the data 

In [None]:
forspark   = sp.io.loadmat('forspark/forspark_loraks.mat')['forspark_loraks']

kspace_full             = np.expand_dims(np.transpose(forspark[0][0][0],(2,0,1)),axis = 0)
kspace_loraks           = np.expand_dims(np.transpose(forspark[0][0][1],(2,0,1)),axis = 0)
kspace_loraks_replaced  = np.expand_dims(np.transpose(forspark[0][0][2],(2,0,1)),axis = 0)

Rx    = forspark[0][0][3][0][0]
Ry    = forspark[0][0][4][0][0]
acsx  = forspark[0][0][5][0][0]
acsy  = forspark[0][0][6][0][0]

[E,C,M,N] = kspace_full.shape

iterations        = 200
learningRate      = .0075
normalizationflag = 1
normalizeAll      = 0

kspace        = np.copy(kspace_full)
kspaceGrappa  = np.copy(kspace_loraks)

# Generate the ACS region 

In [None]:
#Generate zero-filled ACS
acsregionX = np.arange(M//2 - acsx // 2,M//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 

kspaceAcsZerofilled = np.zeros((E,C,M,N),dtype = complex)
kspaceAcsZerofilled[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] = kspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1]

# Defining SPARK helper functions 

In [None]:
def reformattingKspaceForSpark(inputKspace,kspaceOriginal,acsregionX,acsregionY,acsx,acsy,normalizationflag):
    [E,C,_,_] = inputKspace.shape
    kspaceAcsCrop     = kspaceOriginal[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #Ground truth measured ACS data, will be used as the ground truth to compute kspace error we want learn
    kspaceAcsGrappa   = inputKspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] 
    #GRAPPA reconstructed ACS region.  kspaceAcsCrop - kspaceAcsGrappa = d will be the supervised error we try to learn
    kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

    #Splitting the difference into the real and imaginary part for the network
    acs_difference_real = np.real(kspaceAcsDifference)
    acs_difference_imag = np.imag(kspaceAcsDifference)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    #Adding the batch dimension
    kspace_grappa = np.copy(inputKspace)
    kspace_grappa_real  = np.real(kspace_grappa)
    kspace_grappa_imag  = np.imag(kspace_grappa)
    kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=1)

    #print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    #Let's do some normalization
    chan_scale_factors_real = np.zeros((E,C),dtype = 'float')
    chan_scale_factors_imag = np.zeros((E,C),dtype = 'float')

    for e in range(E):
        if(normalizationflag):
            scale_factor_input = 1/np.amax(np.abs(kspace_grappa_split[e,:,:,:]))
            kspace_grappa_split[e,:,:,:] *= scale_factor_input

        for c in range(C):
            if(normalizationflag):
                scale_factor_real = 1/np.amax(np.abs(acs_difference_real[e,c,:,:]))
                scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[e,c,:,:]))
            else:
                scale_factor_real = 1
                scale_factor_imag = 1

            chan_scale_factors_real[e,c] = scale_factor_real
            chan_scale_factors_imag[e,c] = scale_factor_imag

            acs_difference_real[e,c,:,:] *= scale_factor_real
            acs_difference_imag[e,c,:,:] *= scale_factor_imag

    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_real = np.expand_dims(acs_difference_real, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)
    acs_difference_imag = np.expand_dims(acs_difference_imag, axis=2)

    #print('acs_difference_real shape: ' + str(acs_difference_real.shape))
    #print('acs_difference_imag shape: ' + str(acs_difference_imag.shape))

    kspace_grappa_split = torch.from_numpy(kspace_grappa_split)
    kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)
    print('kspace_grappa_split shape: ' + str(kspace_grappa_split.shape))

    acs_difference_real = torch.from_numpy(acs_difference_real)
    acs_difference_real = acs_difference_real.to(device, dtype=torch.float)
    print('acs_difference_real shape: ' + str(acs_difference_real.shape))

    acs_difference_imag = torch.from_numpy(acs_difference_imag)
    acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)
    print('acs_target_imag shape: ' + str(acs_difference_imag.shape))
    
    return kspace_grappa_split, acs_difference_real, acs_difference_imag, chan_scale_factors_real, chan_scale_factors_imag

In [None]:
def trainingSparkNetwork(kspaceGrappaSplit,acsDifferenceReal,acsDifferenceImag,acsx,acsy,learningRate,iterations):
    '''
    Trains a SPARK networks given some appropriately formatted grappa kspace, acsDifferenceReal, and acsDifferenceImaginary
    Inputs:
        kspaceGrappaSplit: allContrasts x 2 * allChannels x M x N,             Grappa reconstructed kspace which will 
                                                                               be used to learn error
        acsDifferenceReal: allContrasts x allChaannels x 1 x 1 x M x N,        Difference between measured and GRAPPA
                                                                               ACS real portion
        acsDifferenceImag: allContrasts x allChaannels x 1 x 1 M x N,          Difference between measured and GRAPPA
                                                                               ACS imag portion             
        acs:               acss x 1,                                           Indices of ACS region
        learningRate:      scalar,                                             Learaning rate for the networks
        iterations:        scalar,                                             Number of iterations we want to train
    Outputs:
        A network which should reconstruct each contrast and channel        
    '''
    
    [E,C,_,_,_,_] = acsDifferenceReal.shape

    #~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the real models
    #~~~~~~~~~~~~~~~~~~~~~~~~
    real_models      = {}
    real_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'r'
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer = optim.Adam(model.parameters(),lr=learningRate)
            running_loss = 0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceReal[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
            
            real_model_names.append(model_name)
            real_models.update({model_name:model})
            
            print('Training Complete, loss = %.10f' % (running_loss))
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #Training the imaginary model
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_models      = {}
    imag_model_names = []
    
    criterion = nn.MSELoss()
    
    for e in range(0,E):
        for c in range(0,C):
            model_name = 'model' + 'E' + str(e) + 'C' + str(c) + 'i'            
            model = models.SPARK_Netv2(coils = C,kernelsize = 3,acsx = acsx, acsy = acsy)
            
            model.to(device)
            
            kspsplit = torch.unsqueeze(kspaceGrappaSplit[e,:,:,:],axis = 0)
            
            print('Training {}'.format(model_name))
            
            optimizer    = optim.Adam(model.parameters(),lr = learningRate)
            running_loss = 0.0
            
            for epoch in range(iterations):
                optimizer.zero_grad()
                
                _,loss_out = model(kspsplit)
                loss = criterion(loss_out,acsDifferenceImag[e,c,:,:,:,:])
                loss.backward()
                optimizer.step()
                
                running_loss = loss.item()
                if(epoch == 0):
                    print('Training started , loss = %.10f' % (running_loss))
                
            imag_model_names.append(model_name)
            imag_models.update({model_name : model})

            print('Training Complete, loss = %.10f' % (running_loss))

    return real_models,real_model_names,imag_models,imag_model_names

In [None]:
def applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,chanScaleFactorReal,chanScaleFactorImag):
    '''
    Given a set of models trained for a particular contrast, apply SPARK to all of the contrasts
    Inputs:
        kspaceToCorrect   - M x N,       Kspace that we want to correct
        kspaceGrappasplit - allcoils x M x N  Kspace that will be used to reconstuct the particular for this kspace
        real_model      - model          Model for correcting the real component
        imag_model      - model          Model for correcting the imaginary component
        chanScaleFactor - Scalar         Scaling parameter for the particular piece of kspace which is corrected
    outputs:
        kspaceCorrected - M x N       Corrected kspace
        
    '''
    
    correctionr = real_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    correctioni = imag_model(torch.unsqueeze(kspaceGrappaSplit,axis=0))[0].cpu().detach().numpy()
    corrected = correctionr[0,0,:,:]/chanScaleFactorReal + 1j * correctioni[0,0,:,:] / chanScaleFactorImag + kspaceToCorrect
    
    return corrected

# Training models

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Reformatting the data
[kspace_grappa_split, acs_difference_real, acs_difference_imag,chan_scale_factors_real,chan_scale_factors_imag] = \
    reformattingKspaceForSpark(kspaceGrappa,kspaceAcsZerofilled,acsregionX,acsregionY,acsx,acsy,normalizationflag)

realSparkGrappaModels,realSparkGrappaNames,imagSparkGrappaModels,imagSparkGrappaNames = \
    trainingSparkNetwork(kspace_grappa_split,acs_difference_real,acs_difference_imag,acsregionX,acsregionY,learningRate,iterations)

# Perform correction and replacement 

In [None]:
#will use each model contrast to reconstruct each recon contrast
kspaceCorrected    = np.zeros((E,C,M,N),dtype = complex)


for reconContrast in range(0,E):
    for c in range(0,C):
        #Perform reconstruction coil by coil
        model_namer = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'r'
        model_namei = 'model' + 'E' + str(reconContrast) + 'C' + str(c) + 'i'

        real_model = realSparkGrappaModels[model_namer]
        imag_model = imagSparkGrappaModels[model_namei]

        kspaceToCorrect   = kspaceGrappa[reconContrast,c,:,:]
        kspaceGrappaSplit = kspace_grappa_split[reconContrast,:,:,:]

        currentCorrected = \
                applySparkCorrection(kspaceToCorrect,kspaceGrappaSplit,real_model,imag_model,\
                    chan_scale_factors_real[reconContrast,c], chan_scale_factors_imag[reconContrast,c])

        kspaceCorrected[reconContrast,c,:,:] = currentCorrected   
        
kspaceCorrectedReplaced = np.copy(kspaceCorrected)

kspaceCorrectedReplaced[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1] =\
    kspace[:,:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1]


# Generate recons and compare 

In [None]:
truth    = sig.rsos(sig.ifft2c(kspace),-3)
loraks   = sig.rsos(sig.ifft2c(kspace_loraks_replaced),-3)
spark    = sig.rsos(sig.ifft2c(kspaceCorrectedReplaced),-3)

display = np.concatenate((sig.nor(truth),sig.nor(loraks),sig.nor(spark)),axis = 0)
sig.mosaic(display,1,3)

In [None]:
print("RMSE LORAKS:   %.2f" % (sig.rmse(truth,loraks)*100))
print("RMSE SPARK:    %.2f" % (sig.rmse(truth,spark)*100))

# Saving to Display on my Local Matlab machine 

In [None]:
results = {'truth': np.squeeze(truth),
           'loraks': np.squeeze(loraks),
           'spark': np.squeeze(spark),
           'kspace_full' : np.squeeze(kspace_full),
           'kspace_loraks' : np.squeeze(kspace_loraks),
           'kspace_spark' : np.squeeze(kspaceCorrectedReplaced),
           'Rx': Rx,
           'Ry': Ry,
           'acsx': acsx,
           'acsy': acsy,           
           'Iterations': iterations,
           'learningRate': learningRate}

sp.io.savemat('results/loraks_results.mat', results, oned_as='row')