In [1]:
import argparse 
import sigpy as sp
import scipy.ndimage as ndimage_c
import numpy as np

import sys
sys.path.append("./sigpy_e/")
import sigpy_e.cfl as cfl 

import sigpy_e.ext as ext
import sigpy_e.reg as reg
from sigpy_e.linop_e import NFTs,Diags,DLD,Vstacks
import sigpy.mri as mr

In [4]:
res_scale = 0.75
fname = '/data/Jiahao/lung_mri/MRI_Raw'
lambda_TV = 0.05
device = 0
outer_iter = 20
fov_scale = (1, 1, 1)

In [35]:
## data loading
data = cfl.read_cfl(fname+'_datam')
traj = np.real(cfl.read_cfl(fname+'_trajm'))
dcf = cfl.read_cfl(fname+'_dcf2m')


In [36]:
nf_scale = 0.75
nf_arr = np.sqrt(np.sum(traj[0,0,0,0,:,:]**2,axis = 1)) 
nf_e = np.sum(nf_arr<np.max(nf_arr)*nf_scale)
scale = fov_scale
traj[...,0] = traj[...,0]*scale[0]
traj[...,1] = traj[...,1]*scale[1]
traj[...,2] = traj[...,2]*scale[2]

traj = traj[...,:nf_e,:]
data = data[...,:nf_e,:]
dcf = dcf[...,:nf_e,:]

nphase,nEcalib,nCoil,npe,nfe,_ = data.shape
tshape = (380, 256, 256)

In [37]:
ksp = np.reshape(np.transpose(data,(2,1,0,3,4,5)),(nCoil,nphase*npe,nfe))
dcf2 = np.reshape(np.transpose(dcf**2,(2,1,0,3,4,5)),(nphase*npe,nfe))
coord = np.reshape(np.transpose(traj,(2,1,0,3,4,5)),(nphase*npe,nfe,3))

mps = ext.jsens_calib(ksp,coord,dcf2,device = sp.Device(device),ishape = tshape)
S = sp.linop.Multiply(tshape, mps)

JsenseRecon:   0%|          | 0/10 [00:00<?, ?it/s]

In [38]:
### recon
PFTSs = []
for i in range(nphase):
    FTs = NFTs((nCoil,)+tshape,traj[i,0,0,...],device=sp.Device(device))
    W = sp.linop.Multiply((nCoil,npe,nfe,),dcf[i,0,0,:,:,0]) 
    FTSs = W*FTs*S
    PFTSs.append(FTSs)
PFTSs = Diags(PFTSs,oshape=(nphase,nCoil,npe,nfe,),ishape=(nphase,)+tshape)

## preconditioner
wdata = data[:,0,:,:,:,0]*dcf[:,0,:,:,:,0]
tmp = PFTSs.H*PFTSs*np.complex64(np.ones((nphase,)+tshape))
L=np.mean(np.abs(tmp))

In [39]:
## reconstruction
q2 = np.zeros((nphase,)+tshape,dtype=np.complex64)
Y = np.zeros_like(wdata)
q20 = np.zeros_like(q2)

sigma = 0.4
tau = 0.4
for i in range(outer_iter):
    Y = (Y + sigma*(1/L*PFTSs*q2-wdata))/(1+sigma)
    
    q20 = q2
    q2 = np.complex64(ext.TVt_prox(q2-tau*PFTSs.H*Y,lambda_TV))
    print('outer iter:{}, res:{}'.format(i,np.linalg.norm(q2-q20)/np.linalg.norm(q2)))

    cfl.write_cfl(fname+'_mrL', q2)

outer iter:0, res:1.0
outer iter:1, res:0.6152777075767517
outer iter:2, res:0.41512995958328247
outer iter:3, res:0.28948768973350525
outer iter:4, res:0.201780304312706
outer iter:5, res:0.13650698959827423
outer iter:6, res:0.08614236116409302
outer iter:7, res:0.04683448374271393
outer iter:8, res:0.017114846035838127
outer iter:9, res:0.010728595778346062
outer iter:10, res:0.025127502158284187
outer iter:11, res:0.03456598147749901
outer iter:12, res:0.03831390663981438
outer iter:13, res:0.03708263486623764
outer iter:14, res:0.0320313386619091
outer iter:15, res:0.02461172454059124
outer iter:16, res:0.016355693340301514
outer iter:17, res:0.008692963048815727
outer iter:18, res:0.003517814911901951
outer iter:19, res:0.004671511240303516


In [2]:
res_scale = 1
outer_iter = 20
iner_iter = 15
n_ref = -1
reg_flag = 1

In [5]:
## data loading
data = cfl.read_cfl(fname+'_datam')
traj = np.real(cfl.read_cfl(fname+'_trajm'))
dcf = cfl.read_cfl(fname+'_dcf2m')
nf_scale = res_scale
nf_arr = np.sqrt(np.sum(traj[0,0,0,0,:,:]**2,axis = 1)) 
nf_e = np.sum(nf_arr<np.max(nf_arr)*nf_scale)
scale = fov_scale
traj[...,0] = traj[...,0]*scale[0]
traj[...,1] = traj[...,1]*scale[1]
traj[...,2] = traj[...,2]*scale[2]

traj = traj[...,:nf_e,:]
data = data[...,:nf_e,:]
dcf = dcf[...,:nf_e,:]

nphase,nEcalib,nCoil,npe,nfe,_ = data.shape
tshape = (380,256,256)

In [6]:
## calibration
print('Calibration...')
ksp = np.reshape(np.transpose(data,(2,1,0,3,4,5)),(nCoil,nphase*npe,nfe))
dcf2 = np.reshape(np.transpose(dcf**2,(2,1,0,3,4,5)),(nphase*npe,nfe))
coord = np.reshape(np.transpose(traj,(2,1,0,3,4,5)),(nphase*npe,nfe,3))

mps = ext.jsens_calib(ksp,coord,dcf2,device = sp.Device(device),ishape = tshape)
S = sp.linop.Multiply(tshape, mps)

imgL = cfl.read_cfl(fname+'_mrL')
imgL = np.squeeze(imgL)

## registration
print('Registration...')
M_fields = []
iM_fields = []
if reg_flag is 1:
    for i in range(nphase):
        M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]), np.abs(imgL[i]))
        M_fields.append(M_field)
        iM_fields.append(iM_field)
    M_fields = np.asarray(M_fields)
    iM_fields = np.asarray(iM_fields)
    np.save(fname+'_M_mr.npy',M_fields)
    np.save(fname+'_iM_mr.npy',iM_fields)
else:
    M_fields = np.load(fname+'_M_mr.npy')
    iM_fields = np.load(fname+'_iM_mr.npy')

# numpy array to list
iM_fields = [iM_fields[i] for i in range(iM_fields.shape[0])]
M_fields = [M_fields[i] for i in range(M_fields.shape[0])]

######## TODO scale M_field
print('Motion Field scaling...')
M_fields = [reg.M_scale(M,tshape) for M in M_fields]
iM_fields = [reg.M_scale(M,tshape) for M in iM_fields]


"is" with a literal. Did you mean "=="?


"is" with a literal. Did you mean "=="?



Calibration...


JsenseRecon:   0%|          | 0/10 [00:00<?, ?it/s]

Registration...



"is" with a literal. Did you mean "=="?



Motion Field scaling...


In [7]:
## low rank
print('Prep...')
Ms = []
M0s = []
for i in range(nphase):
    # M = reg.interp_op(tshape,iM_fields[i],M_fields[i])
    M = reg.interp_op(tshape,M_fields[i])
    M0 = reg.interp_op(tshape,np.zeros(tshape+(3,)))
    M = DLD(M,device=sp.Device(device))
    M0 = DLD(M0,device=sp.Device(device))
    Ms.append(M)
    M0s.append(M0)
Ms = Diags(Ms,oshape=(nphase,)+tshape,ishape=(nphase,)+tshape)
M0s = Diags(M0s,oshape=(nphase,)+tshape,ishape=(nphase,)+tshape)

PFTSMs = []
Is = []
for i in range(nphase):
    Is.append(sp.linop.Identity(tshape))
    FTs = NFTs((nCoil,)+tshape,traj[i,0,0,...],device=sp.Device(device))
    M = reg.interp_op(tshape,M_fields[i])
    M = DLD(M,device=sp.Device(device))
    W = sp.linop.Multiply((nCoil,npe,nfe,),dcf[i,0,0,:,:,0]) 
    FTSM = W*FTs*S*M
    PFTSMs.append(FTSM)
PFTSMs = Diags(PFTSMs,oshape=(nphase,nCoil,npe,nfe,),ishape=(nphase,)+tshape)*Vstacks(Is,ishape=tshape,oshape=(nphase,)+tshape)

## precondition
print('Preconditioner calculation...')
tmp = PFTSMs.H*PFTSMs*np.complex64(np.ones(tshape))
L=np.mean(np.abs(tmp))
wdata = data[:,0,:,:,:,0]*dcf[:,0,:,:,:,0]*1e4

TV = sp.linop.FiniteDifference(PFTSMs.ishape,axes = (0,1,2))

Prep...
Preconditioner calculation...


In [8]:
####### debug
print('TV dim:{}'.format(TV.oshape))
proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(TV.oshape, lambda_TV), TV)

# ADMM
print('Recon...')
alpha = np.max(np.abs(PFTSMs.H*wdata))
###### debug
print('alpha:{}'.format(alpha))
sigma = 0.4
tau = 0.4
X = np.zeros(tshape,dtype=np.complex64)
p = np.zeros_like(wdata)
X0 = np.zeros_like(X)
q = np.zeros((3,)+tshape,dtype=np.complex64)
for i in range(outer_iter):
    p = (p + sigma*(PFTSMs*X-wdata))/(1+sigma)
    q = (q + sigma*TV*X)
    q = q/(np.maximum(np.abs(q),alpha)/alpha)
    X0 = X
    X = X-tau*(1/L*PFTSMs.H*p + lambda_TV*TV.H*q)
    print('outer iter:{}, res:{}'.format(i,np.linalg.norm(X-X0)/np.linalg.norm(X)))

    cfl.write_cfl(fname+'_imoco_new', X)

TV dim:[3, 380, 256, 256]
Recon...
alpha:8.979140281677246
outer iter:0, res:1.0
outer iter:1, res:0.6154823899269104
outer iter:2, res:0.41538119316101074
outer iter:3, res:0.28970643877983093
outer iter:4, res:0.20190787315368652
outer iter:5, res:0.13648439943790436
outer iter:6, res:0.08590278029441833
outer iter:7, res:0.04626091569662094
outer iter:8, res:0.015832984820008278
outer iter:9, res:0.010160350240767002
outer iter:10, res:0.02571268565952778
outer iter:11, res:0.035586945712566376
outer iter:12, res:0.03962282836437225
outer iter:13, res:0.03857368603348732
outer iter:14, res:0.033607207238674164
outer iter:15, res:0.02619597315788269
outer iter:16, res:0.01792795956134796
outer iter:17, res:0.010354213416576385
outer iter:18, res:0.005410945042967796
outer iter:19, res:0.005533785559237003


In [47]:
M_fields[0].shape

(380, 256, 256, 3)

In [48]:
wdata.shape

(2, 8, 38514, 199)

In [49]:
data.shape

(2, 1, 8, 38514, 199, 1)

In [52]:
PFTSMs.oshape

[2, 8, 38514, 199]

In [16]:
M_fields_ants = np.load(fname+'_M_mr.npy')
iM_fields_ants = np.load(fname+'_iM_mr.npy')
cfl.write_cfl(fname+'M_ants', M_fields_ants)
cfl.write_cfl(fname+'iM_ants', iM_fields_ants)