In [None]:
import time
import importlib 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
import numpy as np
import numpy.linalg as la
import scipy as sp
import cupy as cp
import matplotlib.pyplot as plt
from bart import bart
from utils import cfl
from utils import signalprocessing as sig
from utils import models
from utils import iterative

# Defining helper functions and the SPARK model 

In [None]:
def fft3(x):
    return sig.fft(sig.fft(sig.fft(x,-3),-2),-1)

def ifft3(x):
    return sig.ifft(sig.ifft(sig.ifft(x,-3),-2),-1)

class SPARK_3D_net(nn.Module):
    def __init__(self,coils,kernelsize,acsx,acsy,acsz):
        super().__init__()
        self.acsx = acsx
        self.acsy = acsy
        self.acsz = acsz
        
        self.conv1 = nn.Conv3d(coils*2,coils*2,kernelsize,padding=1,bias = False)
        self.conv2 = nn.Conv3d(coils*2,coils,1,padding=0,bias=False)
        self.conv3 = nn.Conv3d(coils, coils*2, kernelsize, padding=1, bias=False)
        self.conv4 = nn.Conv3d(coils*2,coils*2,kernelsize,padding=1,bias = False)
        self.conv5 = nn.Conv3d(coils*2,coils,1,padding=0,bias=False)
        self.conv6 = nn.Conv3d(coils, coils*2, kernelsize, padding=1, bias=False)
        self.conv7 = nn.Conv3d(coils*2, coils, kernelsize, padding=1, bias=False)
        self.conv8 = nn.Conv3d(coils, coils//4, 1, padding=0, bias=False)
        self.conv9 = nn.Conv3d(coils//4, 1, kernelsize, padding=1, bias=False)  
        
    def naliniRelu(self,inp):
        #An attempt at implementing Nalini's custom nonlinearity, from "Joint Frequency- and Image-Space Learning for Fourier Imaging"
        return inp + F.relu((inp-1)/2) + F.relu((-inp-1)/2)        
        
    def forward(self, x):
        y = self.naliniRelu(self.conv1(x))
        y = self.naliniRelu(self.conv2(y))
        y = self.naliniRelu(self.conv3(y))
        y = x + y
        z = self.naliniRelu(self.conv4(y))
        z = self.naliniRelu(self.conv5(z))
        z = self.naliniRelu(self.conv6(z))
        out = z  + y
        out = self.conv9(self.naliniRelu(self.conv8(self.naliniRelu(self.conv7(out)))))
        
        loss_out = out[:,:,self.acsx[0]:self.acsx[-1]+1,self.acsy[0]:self.acsy[-1]+1,self.acsz[0]:self.acsz[-1]+1]

        return out, loss_out

# Setting the parameters and loading the dataset 

In [None]:
print('Loading the dataset... ',end='')

start = time.time()
#Loading fully sampled kspace, grappa recon, and parmaters for BRAIN dataset
kspace       = np.transpose(cfl.readcfl('data/kspaceFullFor3dsparkAcsrec'),(3,0,1,2))
kspaceGrappa = np.transpose(cfl.readcfl('data/kspaceGrappaFor3dsparkAcsrec'),(3,0,1,2))
for3dspark   = sp.io.loadmat('data/for3DsparkAcsrec.mat')['for3Dspark'][0][0]

[C,M,N,P] = kspace.shape

Rx = for3dspark[0][0][0]
Ry = for3dspark[0][0][1]
Rz = for3dspark[0][0][2]

acsx = for3dspark[1][0][0]
acsy = for3dspark[1][0][2]
acsz = for3dspark[1][0][2]

mask = for3dspark[3]

#Defining some SPARK parameters
normalizationflag = 1
measuredReplace   = 1  #If we want to replace measured data (as well as ACS)
iterations        = 1000 
learningRate      = .002
kernelsize        = 3

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print('Elapsed Time is %.3f seconds' % (time.time()-start))

print('GRAPPA Parameters: ')
print('  Dimensions:   %d x %d x %d x %d' %(C,M,N,P))
print('  Acceleration: %d x %d x %d' % (Rx,Ry,Rz))
print('  ACS Sizes:    %d x %d x %d' % (acsx,acsy,acsz))

print('SPARK Parameters: ')
print('  Iterations: %d' % iterations)
print('  Stepsize:   %.3f' % learningRate)
print('  Kernel:     %d' % kernelsize)

# Generating zero filled acs

In [None]:
#-Generating zerofilled kspace
acsregionX = np.arange(M//2 - acsx // 2,M//2 + acsx//2) 
acsregionY = np.arange(N//2 - acsy // 2,N//2 + acsy//2) 
acsregionZ = np.arange(P//2 - acsz // 2,P//2 + acsz//2) 

kspaceAcsZerofilled = np.zeros(kspace.shape,dtype = complex)
kspaceAcsZerofilled[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1] = \
    kspace[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1]



# Generating ACS replaced GRAPPA reconstruction

In [None]:
#-Generating ACS replaced GRAPPA recon
tmp = np.copy(kspaceGrappa)
tmp[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1] = \
    kspace[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1]
grappa = sig.rsos(ifft3(tmp),-4)

In [None]:
#-Reformatting kspace for SPARK
kspaceAcsCrop   = kspace[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1] 
kspaceAcsGrappa = kspaceGrappa[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1] 
kspaceAcsDifference = kspaceAcsCrop - kspaceAcsGrappa

acs_difference_real = np.real(kspaceAcsDifference)
acs_difference_imag = np.imag(kspaceAcsDifference)

kspace_grappa = np.copy(kspaceGrappa)
kspace_grappa_real  = np.real(kspace_grappa)
kspace_grappa_imag  = np.imag(kspace_grappa)
kspace_grappa_split = np.concatenate((kspace_grappa_real, kspace_grappa_imag), axis=0)

#Let's do some normalization
chan_scale_factors_real = np.zeros(C,dtype = 'float')
chan_scale_factors_imag = np.zeros(C,dtype = 'float')

if(normalizationflag):
    scale_factor_input  = 1/np.amax(np.abs(kspace_grappa_split))
    kspace_grappa_split *= scale_factor_input

for c in range(C):
    if(normalizationflag):
        scale_factor_real = 1/np.amax(np.abs(acs_difference_real[c,:,:,:]))
        scale_factor_imag = 1/np.amax(np.abs(acs_difference_imag[c,:,:,:]))
    else:
        scale_factor_real = 1
        scale_factor_imag = 1

    chan_scale_factors_real[c] = scale_factor_real
    chan_scale_factors_imag[c] = scale_factor_imag

    acs_difference_real[c,:,:,:] *= scale_factor_real
    acs_difference_imag[c,:,:,:] *= scale_factor_imag

acs_difference_real = np.expand_dims(np.expand_dims(acs_difference_real,axis=1), axis=1)
acs_difference_imag = np.expand_dims(np.expand_dims(acs_difference_imag,axis=1), axis=1)

kspace_grappa_split = torch.unsqueeze(torch.from_numpy(kspace_grappa_split),axis = 0)
kspace_grappa_split = kspace_grappa_split.to(device, dtype=torch.float)

acs_difference_real = torch.from_numpy(acs_difference_real)
acs_difference_real = acs_difference_real.to(device, dtype=torch.float)

acs_difference_imag = torch.from_numpy(acs_difference_imag)
acs_difference_imag = acs_difference_imag.to(device, dtype=torch.float)

# Training the SPARK network 

## Training the real spark network

In [None]:
real_models      = {}
real_model_names = []

criterion   = nn.MSELoss()

realLoss = np.zeros((iterations,C)) #Record the loss over epoch of each model to analyze later

for c in range(0,C):
    model_name = 'model'+ 'C' + str(c) + 'r'
    model = SPARK_3D_net(coils=C,kernelsize=kernelsize,acsx=acsregionX,acsy=acsregionY,acsz=acsregionZ)

    model.to(device)

    print('Training {}'.format(model_name))
    start = time.time()
    
    optimizer = optim.Adam(model.parameters(),lr=learningRate)
    running_loss = 0

    for epoch in range(iterations):
        optimizer.zero_grad()

        _,loss_out = model(kspace_grappa_split)
        loss = criterion(loss_out,acs_difference_real[c,:,:,:,:,:]) 
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        realLoss[epoch,c] = running_loss;
        
        if(epoch == 0):
            print('   Starting Loss: %.10f' % running_loss)

    real_model_names.append(model_name)
    real_models.update({model_name:model})
    
    print('   Ending Loss:   %.10f' % (running_loss))
    print('   Training Time: %.3f seconds' % (time.time() - start))

## Training the imaginary SPARK network 

In [None]:
imag_models      = {}
imag_model_names = []

criterion   = nn.MSELoss()

imagLoss = np.zeros((iterations,C)) #Record the loss over epoch of each model to analyze later

for c in range(0,C):
    model_name = 'model'+ 'C' + str(c) + 'i'
    model = SPARK_3D_net(coils=C,kernelsize=kernelsize,acsx=acsregionX,acsy=acsregionY,acsz=acsregionZ)

    model.to(device)

    print('Training {}'.format(model_name))
    start = time.time()
    
    optimizer = optim.Adam(model.parameters(),lr=learningRate)
    running_loss = 0

    for epoch in range(iterations):
        optimizer.zero_grad()

        _,loss_out = model(kspace_grappa_split)
        loss = criterion(loss_out,acs_difference_imag[c,:,:,:,:,:]) 
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        imagLoss[epoch,c] = running_loss;
        
        if(epoch == 0):
            print('   Starting Loss: %.10f' % running_loss)
            
    imag_model_names.append(model_name)
    imag_models.update({model_name:model})
    
    print('   Ending Loss:   %.10f' % (running_loss))
    print('   Training Time: %.3f seconds' % (time.time() - start))


# Performing SPARK correction 

In [None]:
print('Performing coil-by-coil correction... ', end = '')
start = time.time()

kspaceCorrected = np.zeros((C,M,N,P),dtype = complex)

for c in range(0,C):
    #Perform reconstruction coil by coil
    model_namer = 'model' + 'C' + str(c) + 'r'
    model_namei = 'model' + 'C' + str(c) + 'i'

    real_model = real_models[model_namer]
    imag_model = imag_models[model_namei]

    correctionr = real_model(kspace_grappa_split)[0].cpu().detach().numpy()
    correctioni = imag_model(kspace_grappa_split)[0].cpu().detach().numpy() 
    
    kspaceCorrected[c,:,:,:] = correctionr[0,0,:,:,:]/chan_scale_factors_real[c] + \
        1j * correctioni[0,0,:,:,:] / chan_scale_factors_imag[c] + kspaceGrappa[c,:,:,:]
    
print('Elapsed Time is %.3f seconds' % (time.time()-start))


# Performing ACS replacement 

In [None]:
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#Perofrming ACS replacement and ifft/rsos coil combine reconstruction
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
print('Performing ACS replacement, ifft, and rsos coil combination... ', end = '')
start = time.time()

#ACS replaced
kspaceCorrectedReplaced    = np.copy(kspaceCorrected)

kspaceCorrectedReplaced[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1] = \
    kspace[:,acsregionX[0]:acsregionX[acsx-1]+1,acsregionY[0]:acsregionY[acsy-1]+1,acsregionZ[0]:acsregionZ[acsz-1]+1]

#Sampled Data replacement
if(measuredReplace):
    kspaceCorrectedReplaced[:,::Rx,::Ry,::Rz] = kspace[:,::Rx,::Ry,::Rz]
    kspaceCorrectedReplaced *= np.expand_dims(mask,axis=0)
    
#Perform IFFT and coil combine
truth  = for3dspark[2]
grappa = sig.rsos(ifft3(tmp),-4)
spark  = sig.rsos(ifft3(kspaceCorrectedReplaced),-4)
print('Elapsed Time is %.3f seconds' % (time.time()-start))

# Saving the results 

In [None]:
### print('Saving results... ', end = '')
start = time.time()
results = {'groundTruth': np.squeeze(truth),
           'grappa': np.squeeze(grappa),
           'spark': np.squeeze(spark),
           'Ry': Ry,
           'Rz': Rz,
           'acsy': acsy,
           'acsz': acsz,           
           'Iterations': iterations,
           'learningRate': learningRate,
           'realLoss':realLoss,
           'imagLoss':imagLoss}

sp.io.savemat('3dVolumeSparkAcsrecon', results, oned_as='row')
print('Elapsed Time is %.3f seconds' % (time.time()-start))