In [4]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import copy
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import to_tensor, center_crop, scale_sensmap, normalize_separate_over_ch, rss_torch
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7fc0301196f0>

### Load the data

In [11]:
### data path ###
mypath = '/media/hdd1/stanford_fastmri_format/'
sensmap_path = '/cheng/metaMRI/metaMRI/data_dict_temp/stanford_sensmap/'
### data ###
with open('/cheng/metaMRI/ttt_for_deep_learning_cs/unet/train_data/stanford_val','rb') as fn:
    val_slices = pickle.load(fn)

with open('/cheng/metaMRI/ttt_for_deep_learning_cs/unet/test_data/dataset_shift/mask2d','rb') as fn:
    mask2d = pickle.load(fn)
mask = torch.tensor(mask2d[0]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = mask.to(device)


In [24]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E_sensmap_joint(l1_1e-5)Q_T300_300epoch_stanford/E_sensmap_joint(l1_1e-5)Q_T300_300epoch_stanford_E150.pth'
model = Unet(in_chans=2, out_chans=2, chans=64, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

In [7]:
### mask partitioning for self validation: we define a left-out set for early-stopping
where_ones = np.array(np.where(mask2d==1))
m = len( where_ones[0] ) 
self_val_ids = where_ones[:,np.random.randint(0,m,m//20)] # set 5% of the pixels to zero (leave-outs for self validation)

### Test-time training

In [10]:
background_flip = False
eval_with_binary_mask = True
gt_norm = False

In [25]:
# best_loss_l1_history=[]
# best_loss_l1_index_history=[]
early_loss_ssim_history = []
early_loss_ssim_epoch_history = []
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.00001
TTT_epoch = 2500
window_size = 5
# each data point TTT

for iter, slice_file_name in enumerate(val_slices): 
    print('Testing sample: ', iter+1)
    slice = slice_file_name['slice']  # slice: 69
    file_name = slice_file_name['filename']  # file_name: 'ge9.h5'
    filename_no_extension, file_extension = os.path.splitext(file_name)  # filename: 'ge9'. file_extension: '.h5'
    ### load the training k-space
    f = h5py.File(mypath + file_name, 'r')
    kspace = f['kspace'][slice] # ground truth k-space: [8,320,320]
    kspace = to_tensor(kspace).to(device)   # complex k-space [8,320,320,2]

    ### load the sensmap
    # your name style and path
    smap_fname = str(filename_no_extension) + '_smaps_slice' + str(slice) + '.h5'
    with h5py.File(sensmap_path + smap_fname, "r") as hf:
        sens_maps = hf["sens_maps"][()] #np.array of shape coils,height,width with complex valued entries
    sens_maps = to_tensor(sens_maps)
    sens_maps_conj = complex_conj(sens_maps)
    sens_maps = sens_maps.to(device)
    sens_maps_conj = sens_maps_conj.to(device)

    binary_background_mask = torch.round(torch.sum(complex_mul(sens_maps_conj,sens_maps),0)[:,:,0:1])
    binary_background_mask = torch.moveaxis( binary_background_mask , -1, 0 ) .to(device)

    # input k space
    input_kspace = kspace * mask + 0.0
    # scale factor

    # gt image by normalized kspace: x
    target_image_1c = complex_abs(complex_mul(ifft2c(kspace), sens_maps_conj).sum(dim=0, keepdim=False)).unsqueeze(0)
    crop_size = torch.Size([min(target_image_1c.shape[-2:]), min(target_image_1c.shape[-2:])])
    if eval_with_binary_mask == True: 
        binary_target_image_1c = (target_image_1c.unsqueeze(0) * binary_background_mask).squeeze(0)
        # center crop for SSIM
        crop_target_image = center_crop( binary_target_image_1c, crop_size )
    else: 
        crop_target_image = center_crop( target_image_1c, crop_size )
    
    if gt_norm == True: 
        std_crop_target_image = crop_target_image.std()
        mean_crop_target_image = crop_target_image.mean()
    else:
        pass

    # A†y
    # train_inputs = complex_mul(train_inputs, sens_maps_conj).sum(dim=0, keepdim=False) #shape: height,width,2
    train_inputs = complex_mul(ifft2c(input_kspace), sens_maps_conj).sum(dim=0, keepdim=False)
    train_inputs = torch.moveaxis( train_inputs , -1, 0 )
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []
    val_error_history = []
    
    model_ttt = copy.deepcopy(model)
    optimizer = torch.optim.Adam(model_ttt.parameters(),lr=adapt_lr)

    for iteration in range(TTT_epoch): 
        ###### training ######
        # fθ(A†y)
        model_output = model_ttt(train_inputs.unsqueeze(0))    # [1, 2, height, width]       
        model_output = model_output.squeeze(0) * std + mean
        model_output = model_output.unsqueeze(0)
        model_output = torch.moveaxis(model_output, 1, -1 )    #[1, height, width, 2]

        # S fθ(A†y) [coils, height, width, channel]
        if background_flip == True: 
            output_sens_image = torch.zeros(sens_maps.shape).to(device) 
            for j,s in enumerate(sens_maps):
                ss = s.clone()
                ss[torch.abs(ss)==0.0] = torch.abs(ss).max()
                output_sens_image[j,:,:,0] = model_output[0,:,:,0] * ss[:,:,0] - model_output[0,:,:,1] * ss[:,:,1]
                output_sens_image[j,:,:,1] = model_output[0,:,:,0] * ss[:,:,1] + model_output[0,:,:,1] * ss[:,:,0]

        else: 
            output_sens_image = complex_mul(model_output, sens_maps)
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * mask

        ### separate self-validation and test-time training measurements
        Fimg_forward_train = Fimg_forward.clone()
        Fimg_forward_train[:,self_val_ids[0],self_val_ids[1],:] = 0
        input_kspace_most = input_kspace.clone()
        input_kspace_most[:,self_val_ids[0],self_val_ids[1],:] = 0

        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward_train, input_kspace_most) / torch.sum(torch.abs(input_kspace_most))

        ### compute the self-validation error
        Fimg_forward_val= Fimg_forward[:,self_val_ids[0],self_val_ids[1],:]
        input_kspace_val = input_kspace[:,self_val_ids[0],self_val_ids[1],:]
        val_error = l1_loss(Fimg_forward_val,input_kspace_val).item()
        
        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())

        ###### evaluation ######
        output_image_1c = complex_abs(model_output).unsqueeze(0)    # [2, height, width] -> [1, height, width]
        if eval_with_binary_mask == True: 
            binary_mask_output_image_1c = (output_image_1c.unsqueeze(0) * binary_background_mask).squeeze(0)
            # center crop for SSIM
            crop_output_image = center_crop( binary_mask_output_image_1c, crop_size )
        else: 
            # center crop for SSIM
            crop_output_image = center_crop( output_image_1c, crop_size )

        if gt_norm == True: 
            # normalization
            crop_output_image = (crop_output_image - crop_output_image.mean()) / crop_output_image.std()
            crop_output_image *= std_crop_target_image
            crop_output_image += mean_crop_target_image
        else:
            pass
        # SSIM = 1 - loss
        loss_ssim = 1 - ssim_fct(crop_output_image, crop_target_image, data_range = crop_target_image.max().unsqueeze(0)).item()
        print(f"Epoch {iteration+1}/{TTT_epoch}, Loss: {loss_self.item():.4f}", end='\r')
        print(f"Epoch {iteration+1}/{TTT_epoch}, SSIM: {loss_ssim:.4f}", end='\r')

        # early stop
        if iteration > 3*window_size:
            if np.mean(val_error_history[-window_size:]) > np.mean(val_error_history[-2*window_size:-window_size]): 
                print('\nAutomatic early stopping activated.')
                break

        #loss_l1_history.append(loss_l1)
        val_error_history.append(val_error)


    early_loss_ssim_history.append(loss_ssim)
    early_loss_ssim_epoch_history.append(iteration+1)
    # best results for an example
    # best_loss_l1 = min(loss_l1_history)
    # print('The best L1: ', best_loss_l1)
    # best_loss_l1_epoch = np.argmin(loss_l1_history)
    # print('The best L1 epoch: ', best_loss_l1_epoch)
    # best_loss_ssim = max(loss_ssim_history)
    # print('The best SSIM through training process:', best_loss_ssim)
    # best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    # print('The best SSIM epoch: ', best_loss_ssim_epoch)

    # # best_loss_l1_index_history.append(best_loss_l1_epoch)
    # best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    # # best_loss_l1_history.append(best_loss_l1)
    # best_loss_ssim_history.append(best_loss_ssim)


# print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
# print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(early_loss_ssim_history) / len(early_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(early_loss_ssim_epoch_history) / len(early_loss_ssim_epoch_history))

Testing sample:  1
Epoch 18/2500, SSIM: 0.8678
Automatic early stopping activated.
Testing sample:  2
Epoch 17/2500, SSIM: 0.8027
Automatic early stopping activated.
Testing sample:  3
Epoch 17/2500, SSIM: 0.8466
Automatic early stopping activated.
Testing sample:  4
Epoch 17/2500, SSIM: 0.8836
Automatic early stopping activated.
Testing sample:  5
Epoch 17/2500, SSIM: 0.8504
Automatic early stopping activated.
Testing sample:  6
Epoch 20/2500, SSIM: 0.9257
Automatic early stopping activated.
Testing sample:  7
Epoch 18/2500, SSIM: 0.9267
Automatic early stopping activated.
Testing sample:  8
Epoch 17/2500, SSIM: 0.8796
Automatic early stopping activated.
Testing sample:  9
Epoch 17/2500, SSIM: 0.8624
Automatic early stopping activated.
Testing sample:  10
Epoch 17/2500, SSIM: 0.8744
Automatic early stopping activated.
Testing sample:  11
Epoch 17/2500, SSIM: 0.9008
Automatic early stopping activated.
Testing sample:  12
Epoch 17/2500, SSIM: 0.8888
Automatic early stopping activated.
T

Without TTT

In [28]:
COIL = 'sensmap'
gt_norm = False

loss_ssim_history_=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

for iter, slice_file_name in tqdm(enumerate(val_slices)): 
    slice = slice_file_name['slice']  # slice: 69
    file_name = slice_file_name['filename']  # file_name: 'ge9.h5'
    filename_no_extension, file_extension = os.path.splitext(file_name)  # filename: 'ge9'. file_extension: '.h5'
    ### load the training k-space
    f = h5py.File(mypath + file_name, 'r')
    kspace = f['kspace'][slice] # ground truth k-space: [8,320,320]
    kspace = to_tensor(kspace).to(device)   # complex k-space [8,320,320,2]

    ### load the sensmap
    # your name style and path
    smap_fname = str(filename_no_extension) + '_smaps_slice' + str(slice) + '.h5'
    with h5py.File(sensmap_path + smap_fname, "r") as hf:
        sens_maps = hf["sens_maps"][()] #np.array of shape coils,height,width with complex valued entries
    sens_maps = to_tensor(sens_maps)
    sens_maps_conj = complex_conj(sens_maps)
    sens_maps = sens_maps.to(device)
    sens_maps_conj = sens_maps_conj.to(device)

    binary_background_mask = torch.round(torch.sum(complex_mul(sens_maps_conj,sens_maps),0)[:,:,0:1])
    binary_background_mask = torch.moveaxis( binary_background_mask , -1, 0 ) .to(device)

    # input k space
    input_kspace = kspace * mask + 0.0

    # gt image: x
    if COIL == 'rss':
        target_image_1c = rss_torch(complex_abs(ifft2c(kspace))).unsqueeze(0)
    elif COIL == 'sensmap':
        target_image_1c = complex_abs(complex_mul(ifft2c(kspace), sens_maps_conj).sum(dim=0, keepdim=False)).unsqueeze(0)
    # center crop for SSIM
    crop_size = torch.Size([min(target_image_1c.shape[-2:]), min(target_image_1c.shape[-2:])])
    binary_target_image_1c = (target_image_1c.unsqueeze(0) * binary_background_mask).squeeze(0)
    # center crop for SSIM
    crop_target_image = center_crop( binary_target_image_1c, crop_size )

    if gt_norm: 
        std_crop_target_image = crop_target_image.std()
        mean_crop_target_image = crop_target_image.mean()

    # A†y
    if COIL == 'rss':
        train_inputs = rss_torch(ifft2c(input_kspace))
    elif COIL == 'sensmap':    
        train_inputs = complex_mul(ifft2c(input_kspace), sens_maps_conj).sum(dim=0, keepdim=False)
    # [height, width, 2]
    train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
    # [2, height, width]
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)
    
    # fθ(A†y)
    train_outputs = model(train_inputs.unsqueeze(0)) # [1, 2, height, width]
    train_outputs = train_outputs.squeeze(0) * std + mean
    train_outputs_1c = complex_abs(torch.moveaxis(train_outputs.unsqueeze(0), 1, -1 )).unsqueeze(0) # [1, height, width]

    binary_mask_output_image_1c = (train_outputs_1c.unsqueeze(0) * binary_background_mask).squeeze(0)
    # center crop for SSIM
    crop_output_image = center_crop( binary_mask_output_image_1c, crop_size )
        
    # normalization
    if gt_norm: 
        crop_output_image = (crop_output_image - crop_output_image.mean()) / crop_output_image.std()
        crop_output_image *= std_crop_target_image
        crop_output_image += mean_crop_target_image


    # loss_ssim = 1 - ssim_fct(binary_mask_output_image_1c, crop_target_image, data_range = crop_target_image.max().unsqueeze(0)).item()
    loss_ssim = 1 - ssim_fct(crop_output_image, crop_target_image, data_range = crop_target_image.max().unsqueeze(0)).item()

    # loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

0it [00:00, ?it/s]

88it [00:06, 14.44it/s]


In [29]:
# print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average SSIM loss:  0.8301290984858166
