In [1]:
import argparse, os
import sigpy as sp
import scipy.ndimage as ndimage_c
import numpy as np
import scipy.io as sio
import sys
sys.path.append("./sigpy_e/")
import sigpy_e.cfl as cfl 

import sigpy_e.ext as ext
import sigpy_e.reg as reg
from sigpy_e.linop_e import NFTs,Diags,DLD,Vstacks
import sigpy.mri as mr

In [None]:
res_scale = 0.75
fname = '/data/Jiahao/lung_mri/MRI_Raw'
lambda_TV = 0.05
device = 0
outer_iter = 20
fov_scale = (1, 1, 1)

res_scale = 1
outer_iter = 20
iner_iter = 15
n_ref = -1
reg_flag = 1

## data loading
data = cfl.read_cfl(fname+'_datam')
traj = np.real(cfl.read_cfl(fname+'_trajm'))
dcf = cfl.read_cfl(fname+'_dcf2m')

nf_scale = 0.75
nf_arr = np.sqrt(np.sum(traj[0,0,0,0,:,:]**2,axis = 1)) 
nf_e = np.sum(nf_arr<np.max(nf_arr)*nf_scale)
scale = fov_scale
traj[...,0] = traj[...,0]*scale[0]
traj[...,1] = traj[...,1]*scale[1]
traj[...,2] = traj[...,2]*scale[2]

traj = traj[...,:nf_e,:]
data = data[...,:nf_e,:]
dcf = dcf[...,:nf_e,:]

nphase,nEcalib,nCoil,npe,nfe,_ = data.shape
tshape = (380, 256, 256)

ksp = np.reshape(np.transpose(data,(2,1,0,3,4,5)),(nCoil,nphase*npe,nfe))
dcf2 = np.reshape(np.transpose(dcf**2,(2,1,0,3,4,5)),(nphase*npe,nfe))
coord = np.reshape(np.transpose(traj,(2,1,0,3,4,5)),(nphase*npe,nfe,3))

## sensitivity maps
mps = ext.jsens_calib(ksp,coord,dcf2,device = sp.Device(device),ishape = tshape)
S = sp.linop.Multiply(tshape, mps)

In [2]:
## XD-GRASP recon start point
reg_flag = 1
fname = '/data/Jiahao/spiral/fb-data/110422-6/ex17087-P37376'
imgLs = cfl.read_cfl(os.path.join(fname, 'recon_l1_tv_5'))
imgL = np.squeeze(imgLs[0,0,...])

## data loading
data_all = cfl.read_cfl(os.path.join(fname,'ksp'))
traj_all = np.real(cfl.read_cfl(os.path.join(fname,'traj')))
mps = cfl.read_cfl(os.path.join(fname, 'sens'))
kdens = np.squeeze(sio.loadmat(os.path.join(fname, 'kdens.mat'))['kdens'])
nresp, ncard, ne, nCoil, npe, nfe = np.squeeze(data_all).shape
tshape = imgL.shape[1::]
S = sp.linop.Multiply(tshape, mps)

In [3]:
# test on a fixed resp phase and all card phases
nphase = ncard
traj = np.squeeze(traj_all)[0,:,np.newaxis,np.newaxis,...]
data = np.squeeze(data_all)[0,:,np.newaxis,1,...,np.newaxis]
dcf = np.tile(np.sqrt(kdens), (nphase, npe, 1))[:,np.newaxis,np.newaxis,...,np.newaxis]
traj[...,[0,2]] = traj[...,[2,0]]
# imoco recon params
lambda_TV = 0.05
outer_iter = 20
reg_flag = 0

In [4]:
## registration
nphase = 12
n_ref = 11
print('Registration...')
M_fields = []
iM_fields = []
if reg_flag == 1:
    for i in range(nphase):
        M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]), np.abs(imgL[i]))
        M_fields.append(M_field)
        iM_fields.append(iM_field)
    M_fields = np.asarray(M_fields)
    iM_fields = np.asarray(iM_fields)
    np.save(os.path.join(fname, 'M_mr.npy'),M_fields)
    np.save(os.path.join(fname, 'iM_mr.npy'),iM_fields)
else:
    M_fields = np.load(os.path.join(fname,'M_mr.npy'))
    iM_fields = np.load(os.path.join(fname,'iM_mr.npy'))

# numpy array to list
iM_fields = [iM_fields[i] for i in range(iM_fields.shape[0])]
M_fields = [M_fields[i] for i in range(M_fields.shape[0])]

######## TODO scale M_field
print('Motion Field scaling...')
M_fields = [reg.M_scale(M,tshape) for M in M_fields]
iM_fields = [reg.M_scale(M,tshape) for M in iM_fields]

Registration...
Motion Field scaling...


In [5]:
## low rank
device = 0


print('Prep...')
Ms = []
M0s = []
for i in range(nphase):
    # M = reg.interp_op(tshape,iM_fields[i],M_fields[i])
    M = reg.interp_op(tshape,M_fields[i])
    M0 = reg.interp_op(tshape,np.zeros(tshape+(3,)))
    M = DLD(M,device=sp.Device(device))
    M0 = DLD(M0,device=sp.Device(device))
    Ms.append(M)
    M0s.append(M0)
Ms = Diags(Ms,oshape=(nphase,)+tshape,ishape=(nphase,)+tshape)
M0s = Diags(M0s,oshape=(nphase,)+tshape,ishape=(nphase,)+tshape)


Prep...


In [6]:
PFTSMs = []
Is = []
for i in range(nphase):
    Is.append(sp.linop.Identity(tshape))
    FTs = NFTs((nCoil,)+tshape,traj[i,0,0,...],device=sp.Device(device))
    M = reg.interp_op(tshape,M_fields[i])
    M = DLD(M,device=sp.Device(device))
    W = sp.linop.Multiply((nCoil,npe,nfe,),dcf[i,0,0,:,:,0]) 
    FTSM = W*FTs*S*M
    PFTSMs.append(FTSM)
PFTSMs = Diags(PFTSMs,oshape=(nphase,nCoil,npe,nfe,),ishape=(nphase,)+tshape)*Vstacks(Is,ishape=tshape,oshape=(nphase,)+tshape)

## precondition
print('Preconditioner calculation...')
tmp = PFTSMs.H*PFTSMs*np.complex64(np.ones(tshape))
L=np.mean(np.abs(tmp))
wdata = data[:,0,:,:,:,0]*dcf[:,0,:,:,:,0]*1e4

TV = sp.linop.FiniteDifference(PFTSMs.ishape,axes = (0,1,2))

Preconditioner calculation...


In [7]:
####### debug
print('TV dim:{}'.format(TV.oshape))
proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(TV.oshape, lambda_TV), TV)

# ADMM
print('Recon...')
alpha = np.max(np.abs(PFTSMs.H*wdata))
###### debug
print('alpha:{}'.format(alpha))
sigma = 0.4
tau = 0.4
X = np.zeros(tshape,dtype=np.complex64)
p = np.zeros_like(wdata)
X0 = np.zeros_like(X)
q = np.zeros((3,)+tshape,dtype=np.complex64)
for i in range(outer_iter):
    p = (p + sigma*(PFTSMs*X-wdata))/(1+sigma)
    q = (q + sigma*TV*X)
    q = q/(np.maximum(np.abs(q),alpha)/alpha)
    X0 = X
    X = X-tau*(1/L*PFTSMs.H*p + lambda_TV*TV.H*q)
    print('outer iter:{}, res:{}'.format(i,np.linalg.norm(X-X0)/np.linalg.norm(X)))

    cfl.write_cfl(os.path.join(fname, 'imoco_test'), X)

TV dim:[3, 28, 256, 256]
Recon...
alpha:7237.257580239187
outer iter:0, res:1.0
outer iter:1, res:0.6088739426550389
outer iter:2, res:0.4014321874683403
outer iter:3, res:0.2694029279751769
outer iter:4, res:0.17768674081268618
outer iter:5, res:0.11276571037934194
outer iter:6, res:0.07049511088836011
outer iter:7, res:0.05138533672341825
outer iter:8, res:0.04987250497471724
outer iter:9, res:0.052204013020284046
outer iter:10, res:0.05111733399744803
outer iter:11, res:0.04593337897661696
outer iter:12, res:0.038520562168982204
outer iter:13, res:0.03126929965094747
outer iter:14, res:0.025822314392938483
outer iter:15, res:0.022332723427159754
outer iter:16, res:0.01988433819464816
outer iter:17, res:0.017644932049827807
outer iter:18, res:0.01537155437665405
outer iter:19, res:0.013205140656093527


In [52]:
M_fields_ants = np.load(os.path.join(fname, 'M_mr.npy'))
iM_fields_ants = np.load(os.path.join(fname,'iM_mr.npy'))
cfl.write_cfl(os.path.join(fname,'M_ants'), M_fields_ants)
cfl.write_cfl(os.path.join(fname,'iM_ants'), iM_fields_ants)