Joint MAPLE: Accelerated joint T1 and T2* mapping with scan-specific self-supervised networks

MRM link: https://onlinelibrary.wiley.com/doi/10.1002/mrm.29989

In [None]:
# imports
import torch
import torch.optim as optim
import numpy as np
import scipy.io as sio
import sys
import os
import copy

# change the current working directory to the Main.ipynb file directory
%cd /content/drive/MyDrive/Maple/02-Second_Project_JMAPLE/GitHub_Pack

# import py-modules
import utils
import parser_ops
import unrollednet
import relaxsig

# define the device
if torch.cuda.is_available():
    dev = "cuda:0"
    use_cuda = True
    print('cuda is available')
else:
   dev = "cpu"
   use_cuda = False
parser = parser_ops.get_parser()
args = parser

/content/drive/MyDrive/Maple/02-Second_Project_JMAPLE/GitHub_Pack
cuda is available


# MEMFA Reconstruction Block  : Joint ZS-SSL

In [None]:
# Loading data

print('Loading data ...')

sens_maps =  sio.loadmat(args.data_dir+'cmap_slice_25.mat')['maps'] # multi_channel sensitivity maps -> (nrow, ncol, ncoil)
kData = sio.loadmat(args.data_dir+'raw_slice_25.mat')['kData'] # k-space MEMFA data for single slice -> (nrow, ncol, ncoil, nte, nfa)

scaler = np.max(np.abs(kData))
kspace_train = kData / scaler # scaled k-space data for joint zs_ssl

args.nrow_GLOB, args.ncol_GLOB, args.ncoil_GLOB, args.nte_GLOB, args.nfa_GLOB = kspace_train.shape
args.ncont_GLOB = args.nte_GLOB * args.nfa_GLOB
kspace_train = np.reshape(kspace_train,(args.nrow_GLOB, args.ncol_GLOB, args.ncoil_GLOB, args.ncont_GLOB)) # all contrast in one channel

kMask = np.single(np.load (args.data_dir+'kMask_16x_Uniform_complementary.npy')) # uniform 4x4 k-space mask covering all k-space locations
original_mask = np.reshape(kMask,(args.nrow_GLOB,args.ncol_GLOB,args.ncont_GLOB)) # k-space mask for joint zs-ssl and contrast-specific

# Generate masks
cv_trn_mask, cv_val_mask = utils.uniform_selection(kspace_train, original_mask, rho=args.rho_val)

remainder_mask, cv_val_mask = np.copy(cv_trn_mask), np.copy(np.complex64(cv_val_mask))

print('size of kspace: ', kspace_train[np.newaxis,...].shape, ', sensitivity maps: ', sens_maps.shape, ', masks: ', original_mask.shape)

trn_mask, loss_mask = np.empty((args.num_reps, args.nrow_GLOB, args.ncol_GLOB, args.ncont_GLOB), dtype=np.complex64), \
                                np.empty((args.num_reps, args.nrow_GLOB, args.ncol_GLOB, args.ncont_GLOB), dtype=np.complex64)
# train data empty arrays
nw_input = np.empty((args.num_reps, args.nrow_GLOB, args.ncol_GLOB, args.ncont_GLOB), dtype=np.complex64)
ref_kspace = np.empty((args.num_reps, args.nrow_GLOB, args.ncol_GLOB, args.ncoil_GLOB, args.ncont_GLOB), dtype=np.complex64)

# validation data empty arrays
ref_kspace_val = np.empty((args.num_reps,args.nrow_GLOB, args.ncol_GLOB, args.ncoil_GLOB, args.ncont_GLOB), dtype=np.complex64)
nw_input_val = np.empty((args.num_reps, args.nrow_GLOB, args.ncol_GLOB, args.ncont_GLOB), dtype=np.complex64)

print('create training & loss masks and generate network inputs... ')
# train data
for jj in range(args.num_reps):
    trn_mask[jj, ...], loss_mask[jj, ...] = utils.uniform_selection(kspace_train,remainder_mask, rho=args.rho_train)

    sub_kspace = kspace_train * np.tile(trn_mask[jj][..., np.newaxis,:], (1, 1, args.ncoil_GLOB,1))
    ref_kspace[jj, ...] = kspace_train * np.tile(loss_mask[jj][..., np.newaxis,:], (1, 1, args.ncoil_GLOB, 1))
    nw_input[jj, ...] = utils.sense1(sub_kspace,sens_maps)

# validation data
nw_input_val = utils.sense1(kspace_train * np.tile(cv_trn_mask[:, :, np.newaxis,:], (1, 1, args.ncoil_GLOB,1)),sens_maps)[np.newaxis]

ref_kspace_val = kspace_train*np.tile(cv_val_mask[:, :, np.newaxis,:], (1, 1, args.ncoil_GLOB,1))[np.newaxis]

# prepare data for the training
sens_maps_original = np.copy(sens_maps) # keep a copy of the original sensitivity maps
sens_maps = np.transpose(sens_maps, (2, 0, 1)) #(nrow,ncol,ncoil) --> (ncoil,nrow,ncol)
sens_maps = torch.from_numpy(sens_maps).to(dev)
ref_kspace = utils.complex2real(np.transpose(ref_kspace, (0, 3, 1, 2, 4))) #(batch, ncoil, nrow, ncol, ncont)

nw_input = utils.complex2real(nw_input)

ref_kspace_val = utils.complex2real(np.transpose(ref_kspace_val, (0, 3, 1, 2, 4)))

nw_input_val = utils.complex2real(nw_input_val)

print('size of ref kspace: ', ref_kspace.shape, ', nw_input: ', nw_input.shape, ', sensitivity maps: ', sens_maps.shape, ', masks: ', trn_mask.shape)

total_batch = int(np.floor(np.float32(nw_input.shape[0]) / (args.batchSize))) # set the batch size

trn_mask = torch.from_numpy(trn_mask)
loss_mask = torch.from_numpy(loss_mask)
cv_val_mask = torch.from_numpy(cv_val_mask)
cv_trn_mask = torch.from_numpy(cv_trn_mask)

epoch, val_loss_tracker = 0, 0
trn_loss = 0
model = unrollednet.UnrolledNet(sens_maps).to(dev)

if args.transfer_learning: # transfer learning
   model.load_state_dict(torch.load(args.model_dir+'saved_weights.pth'))
   model.train()

total_loss = []
total_val_loss = []
scalar = torch.tensor([0.5], dtype=torch.float32).to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr = args.zs_ssl_lr)
lowest_val_loss = np.inf
nw_input = torch.from_numpy(nw_input)

ref_kspace = torch.from_numpy(utils.real2complex(ref_kspace))
ref_kspace_val = torch.from_numpy(utils.real2complex(ref_kspace_val))

nw_input_val = torch.tensor(nw_input_val, dtype=torch.float32)

print("NUmber of trainable parameters is:")
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

# training
while epoch < args.epochs and val_loss_tracker < args.stop_training:
    avg_cost = 0
    epoch_loss = 0
    for j in range(total_batch):
        if nw_input.shape[0] - j * args.batchSize < args.batchSize:
            current_batch = nw_input[j * args.batchSize:].to(dev)
            current_kspace = ref_kspace[j * args.batchSize:].to(dev)
            model.set_trn_mask(trn_mask[j * args.batchSize:].to(dev))
            model.set_loss_mask(loss_mask[j * args.batchSize:].to(dev))
            nw_output_image, nw_output_kspace, *_ = model.forward(current_batch, epoch)

        else:
            current_batch = nw_input[j * args.batchSize : (j+1) * args.batchSize].to(dev)
            current_kspace = ref_kspace[j * args.batchSize : (j+1) * args.batchSize].to(dev)
            model.set_trn_mask(trn_mask[j * args.batchSize : (j+1) * args.batchSize].to(dev))
            model.set_loss_mask(loss_mask[j * args.batchSize : (j+1) * args.batchSize].to(dev))
            nw_output_image, nw_output_kspace, *_ = model.forward(current_batch, epoch)
        trn_loss = torch.mul(scalar, torch.linalg.norm(nw_output_kspace - current_kspace)) / torch.linalg.norm(current_kspace) + torch.mul(scalar, torch.linalg.norm(torch.flatten(current_kspace-nw_output_kspace), ord=1))  / torch.linalg.norm(torch.flatten(current_kspace),ord=1)

        epoch_loss += trn_loss
        optimizer.zero_grad()
        trn_loss.backward()
        optimizer.step()
    print(f"Training loss in epoch {epoch} is {epoch_loss.item()}")

    with torch.no_grad():
        model.set_trn_mask(torch.unsqueeze(cv_trn_mask, 0).to(dev))
        model.set_loss_mask(torch.unsqueeze(cv_val_mask, 0).to(dev))
        nw_input_val = nw_input_val.to(dev)
        val_output_kspace = model.forward(nw_input_val,epoch)[1]
        val_loss = torch.mul(scalar, torch.linalg.norm(val_output_kspace - ref_kspace_val.to(dev)) / torch.linalg.norm(ref_kspace_val.to(dev)) + torch.mul(scalar, torch.linalg.norm(torch.flatten(ref_kspace_val.to(dev)-val_output_kspace), ord=1)) / torch.linalg.norm(torch.flatten(ref_kspace_val.to(dev)),ord=1))
        print(f"Validation loss in epoch {epoch} is {val_loss.item()}")
        total_val_loss.append(val_loss)
        if val_loss <= lowest_val_loss:
            lowest_val_loss = val_loss
            val_loss_tracker = 0
            torch.save(model.state_dict(), args.model_dir+'saved_weights.pth')
        else:
            val_loss_tracker += 1
        epoch += 1

print("training of joint zs-ssl network is finished")
print("MEMFA reconstruction using trained network ...")

# MEMFA reconstruction using trained network

test_maskc = torch.tensor(original_mask[None]).to(dev)
model = unrollednet.UnrolledNet(sens_maps, test_maskc, test_maskc).to(dev)
model.load_state_dict(torch.load(args.model_dir+'saved_weights.pth'))
model.eval()

nw_input = utils.sense1(kspace_train * np.tile(original_mask[..., np.newaxis,:], (1, 1, args.ncoil_GLOB,1)),sens_maps_original)
nw_input = torch.from_numpy(utils.complex2real(nw_input)).type(torch.FloatTensor)
nw_input = torch.unsqueeze(nw_input, 0).to(dev)
zs_ssl_recon, _, _ = model.forward(nw_input.to(dev))
zs_ssl_recon = (utils.real2complex(zs_ssl_recon.squeeze())*scaler).detach().cpu().numpy()
zs_ssl_recon_final = np.reshape(zs_ssl_recon,(args.nrow_GLOB,args.ncol_GLOB,args.nte_GLOB,args.nfa_GLOB))
np.save(args.recon_dir+'joint_zs_ssl_recon_slice_25.npy',zs_ssl_recon_final)

print("joint zs-ssl reconstruction is finished")

#utils.display_all(abs(zs_ssl_recon_final))

# Parameter Mapping

In [None]:
# define necessary in-script functions
def T1_Mapping(lr,epochs,target, datc, map_init):

    model_param = relaxsig.T1_Signal(map_init, 1e-1)
    model_param.set_constants(ang_use, TR, B1, phase_T1)
    model_param = model_param.to(dev)

    optimizer = optim.Adam(model_param.parameters(), lr=lr)

    for k in range(epochs):
        optimizer.zero_grad()
        a = model_param()
        loss = utils.MSEc(a, target)

        loss.backward()
        optimizer.step()
    print('T1 initialization finished after {} epochs.'.format(epochs))

    return model_param

def T2_Mapping(lr,epochs,target, datc, map_init):

    model_param = relaxsig.T2_Signal(map_init, 1e-10)
    model_param.set_te(TE_use)
    model_param = model_param.to(dev)

    optimizer = optim.Adam(model_param.parameters(), lr=lr)

    for k in range(epochs):
        optimizer.zero_grad()
        a = model_param()
        loss = utils.MSEc(a, target)

        loss.backward()
        optimizer.step()
    print('T2 initialization finished after {} epochs.'.format(epochs))

    return model_param

def Joint_Mapping_Init(lr,epochs,target, map_init):

    model_param = relaxsig.Joint_Signal(map_init, 1e0)
    model_param.set_constants(ang_use, TE_use, TR, B1, phase)
    model_param = model_param.to(dev)

    optimizer = optim.Adam(model_param.parameters(), lr=lr)

    k = 0
    while k < epochs:
        optimizer.zero_grad()
        a = model_param()
        loss = utils.MSEc(a, target)

        if(k+1)%1000 == 0:

          print('iter', k+1, 'train loss: ' + str(loss.cpu().detach().numpy()))

        k += 1
        loss.backward()
        optimizer.step()

    print('joint initialization finished after {} epochs.'.format(epochs))

    return model_param

def Joint_Mapping(lr,epochs,target,mu, datc, map_init, scaler):

    model_param = relaxsig.Joint_Signal(map_init, 1e0)
    model_param.set_constants(ang_use, TE_use, TR, B1, phase)
    model_param = model_param.to(dev)

    optimizer = optim.Adam(model_param.parameters(), lr=lr)

    k=0
    Gen_1 = np.inf
    T1_nrmse_min, T2_nrmse_min = np.inf, np.inf
    while k < epochs:

        optimizer.zero_grad()
        a=model_param()

        loss1 = utils.MSEc(a, target)
        loss2 = utils.MSEc(utils.Ac_Combined(a.permute(3,2,0,1)[None],maskc,coilc),datc)
        loss = loss1 + mu*loss2

        if(k+1)%100 == 0:

          T1 = 1/model_param.R1.detach().squeeze().cpu().numpy()
          T2 = 1/model_param.R2.detach().squeeze().cpu().numpy()
          delB = model_param.delB.detach().squeeze().cpu().numpy()
          Gen = model_param().detach().squeeze().cpu().numpy()*scaler

          T1_nrmse = utils.NRMSE(T1,T1_ideal,brain_mask,up_sat = 2500,dw_sat = 0)
          T2_nrmse = utils.NRMSE(T2,T2_ideal,brain_mask,up_sat = 100, dw_sat = 0)
          delB_nrmse = utils.NRMSE(delB,delB_ideal,brain_mask)
          Gen_NRMSE = utils.NRMSE(Gen,ideal,brain_mask)

          if T1_nrmse <= T1_nrmse_min:
             T1_nrmse_min = T1_nrmse
             T1_best_model = copy.deepcopy(model_param)

          if T2_nrmse <= T2_nrmse_min:
             T2_nrmse_min = T2_nrmse
             T2_best_model = copy.deepcopy(model_param)

          print('------------------------------')
          print('iter', k+1, 'train loss: ' + str(loss.cpu().detach().numpy()))
          print('T1 NRMSE= {}'.format(T1_nrmse))
          print('T2 NRMSE= {}'.format(T2_nrmse))
          print('delB NRMSE= {}'.format(delB_nrmse))
          print('Gen NRMSE= {}'.format(Gen_NRMSE))
          print('Best T1 NRMSE:{}'.format(T1_nrmse_min))
          print('Best T2 NRMSE:{}'.format(T2_nrmse_min))

          if Gen_1 - Gen_NRMSE <= args.Tol:
             break
          else:
               Gen_1 = Gen_NRMSE
        k += 1
        loss.backward()
        optimizer.step()

    return T1_best_model, T2_best_model, model_param


# data preparation for parameter mapping
coil_sens = sens_maps_original[:,:,:,None,None]
coilc = torch.tensor(coil_sens.transpose(2,4,3,0,1),dtype=torch.cfloat)

kMask = kMask[:,:,None,:,:] # contrast-specific k-space mask

datc=torch.tensor((kMask*kData).transpose(2,4,3,0,1),dtype=torch.cfloat) # sub-sampled k-space data (ncoil, nfa, nte, nrow, ncol)

maskc = torch.tensor(kMask.transpose(2,4,3,0,1))

coilc, datc, maskc = coilc.to(dev), datc.to(dev), maskc.to(dev)

brain_mask = np.load(args.data_dir+'brain_mask_slice_25.npy') # for error calculations
skull_mask = np.load(args.data_dir+'skull_mask_slice_25.npy') # for maps show

ang = np.arange(4,17,6) # sequence parameters: three flip angles
ang = (ang*np.pi)/180
ang_use = torch.tensor(ang[None,None,None,:]).float()

TE = np.arange(3.6,29,5) # sequence parameters: six echo times
TE_use = torch.tensor(TE[None,None,:,None]).float()
TR = 34
TE_use, ang_use = TE_use.to(dev), ang_use.to(dev)

# reconstruction of ideal image for calculation of the error for generated synthetic image
ideal = (coil_sens.conj()*utils.ift2_np(kData)).sum(2) / ((abs(coil_sens)**2).sum(2)+np.finfo(float).eps);

B1 = sio.loadmat(args.data_dir+'b1_slc_25.mat')['Output'] # B1 map
b1 = np.flipud(B1)
C1 = np.copy(b1) # handling the negative stride
B1 = torch.tensor(C1[...,None,None],dtype=torch.float).to(dev)

# golden fully-sampled maps as the ground truths
T1_ideal = np.load(args.golden_dir+'T1_Golden_slc_25.npy')
T2_ideal = np.load(args.golden_dir+'T2_Golden_slc_25.npy')
delB_ideal = np.load(args.golden_dir+'delB_Golden_slc_25.npy')

# acceleration rate demonstration
acc = np.prod(kMask.shape) / kMask.sum()
acc = int(np.round(acc))
print('acc: ' + str(acc) + 'x')

recon=np.load(args.recon_dir+'joint_zs_ssl_recon_slice_25.npy') # MEMFA reconstructed image from Recon_Block

# parameter estination setting
mapping_scaler = np.max(abs(recon))
mapping_scaler_T1 = np.max(abs(recon[...,0,:])) # TE = 3.6ms selected
mapping_scaler_T2 = np.max(abs(recon[...,:,1])) # FA = 10 degree selsected

target = torch.tensor(recon,dtype=torch.cfloat)
target_T1 = target[...,0,:]/mapping_scaler_T1 # TE = 3.6ms selected
target_T2 = target[...,:,1]/mapping_scaler_T2 # FA = 10 degree selsected
target = target/mapping_scaler

phase = torch.exp(1j*torch.tensor(np.angle(target),dtype = torch.cfloat).to(dev)) # phase estimation for joint signal model
phase_T1 = torch.exp(1j*torch.tensor(np.angle(target_T1),dtype = torch.cfloat).to(dev)) # phase estimation for T1 signal model

target, target_T1, target_T2 = target.to(dev), target_T1.to(dev), target_T2.to(dev)

# first step of the fast initialization with individual signal models and conventional parameter fitting
print('first step of initialization with individual signal models ..')
init_model_T1 = T1_Mapping(1e-3,500, target_T1, datc[:,:,0,...], [torch.rand(args.nrow_GLOB,args.ncol_GLOB,1,2),
                                    torch.rand(args.nrow_GLOB,args.ncol_GLOB,1)])


init_model_T2 = T2_Mapping(1e-3, 3000, target_T2, datc[:,0,...], [torch.rand(args.nrow_GLOB,args.ncol_GLOB,1,2),
                                    torch.rand(args.nrow_GLOB,args.ncol_GLOB,1),
                                    torch.rand(args.nrow_GLOB,args.ncol_GLOB,1)])

# second step of initialization with joint signal model and conventional parameter fitting
print('second step of initialization with joint signal model ..')
init_model_joint = Joint_Mapping_Init(args.init_LR,args.init_Epochs, target,[init_model_T2.M0.unsqueeze(3),
                                    init_model_T1.R1.unsqueeze(-1),
                                    init_model_T2.R2.unsqueeze(-1),
                                    init_model_T2.delB.unsqueeze(-1)])

# Joint MAPLE parameter mapping with loss1 and loss2
print('Joint MAPLE parameter mapping ..')
T1_model, T2_model, model_comb = Joint_Mapping(args.JM_LR,args.JM_Epochs, target, args.MU,
                                    datc/mapping_scaler,[init_model_joint.M0,
                                     init_model_joint.R1,init_model_joint.R2,
                                     init_model_joint.delB], scaler = mapping_scaler)

# output maps
M0_est = torch.view_as_complex(T1_model.M0.detach()).squeeze()
R1_est = T1_model.R1.detach().squeeze()
R2_est = T2_model.R2.detach().squeeze()
delB_est = T2_model.delB.detach().squeeze()
gen_img_Gen = model_comb().detach()*mapping_scaler

if use_cuda:
    M0_est, R1_est, R2_est, delB_est = M0_est.cpu(), R1_est.cpu(), R2_est.cpu(), delB_est.cpu()

M0_est, R1_est, R2_est, delB_est = M0_est.numpy(), R1_est.numpy(), R2_est.numpy(), delB_est.numpy()

T1_est, T2_est = 1/R1_est, 1/R2_est

print('Mapping Process Finished!')

# save estimated parameter maps
np.save(args.param_dir+'M0_slc_25_R16cus',M0_est)
np.save(args.param_dir+'T1_slc_25_R16cus',T1_est)
np.save(args.param_dir+'T2_slc_25_R16cus',T2_est)
np.save(args.param_dir+'delB_slc_25_R16cus',delB_est)
np.save(args.param_dir+'gen_img_slc_25_R16cus',gen_img_Gen.cpu().numpy())

utils.display_maps(M0_est,T1_est,T2_est, delB_est, skull_mask)
