In [1]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_TTTpaper_fixMask, center_crop, scale_rss, normalize_separate_over_ch, rss_torch
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7f5787e7a990>

### Load the data

In [2]:
# data path
path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_val.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_val/'

# path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_test_100.yaml'
# path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_TTTpaper_fixMask('multicoil', mask_func = mask_function, use_seed=True)

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


In [12]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E11.12_sup(l1_1e-5)Q_T300_150epoch/E11.12_sup(l1_1e-5)Q_T300_150epoch_E150_best.pth'
model = Unet(in_chans=2, out_chans=2, chans=64, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

### without TTT

In [13]:
loss_l1_history_=[]
loss_ssim_history_=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

path_mask = '/cheng/metaMRI/ttt_for_deep_learning_cs/unet/test_data/anatomy_shift/mask2d'
with open(path_mask,'rb') as fn:
    mask2d = pickle.load(fn)
mask = torch.tensor(mask2d[0]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = mask.to(device)

for iter, batch in tqdm(enumerate(test_dataloader)): 
    kspace, sens_maps, sens_maps_conj, _, fname, slice_num = batch
    kspace = kspace.squeeze(0).to(device)
    sens_maps = sens_maps.squeeze(0).to(device)
    sens_maps_conj = sens_maps_conj.squeeze(0).to(device)

    # input k space
    input_kspace = kspace * mask + 0.0
    input_kspace = input_kspace.to(device)

    # gt image: x
    target_image = ifft2c(kspace)
    # rss combine
    target_image = rss_torch(target_image)
    # sensmap combine
    # target_image = complex_mul(target_image, sens_maps_conj).sum(dim=0, keepdim=False)

    # A†y
    train_inputs = ifft2c(input_kspace) #shape: coils,height,width,2
    train_inputs = rss_torch(train_inputs)
    # train_inputs = complex_mul(train_inputs, sens_maps_conj).sum(dim=0, keepdim=False) #shape: height,width,2
    train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
    # normalize input to have zero mean and std one
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)
    print(mean)
    # fθ(A†y)
    train_outputs = model(train_inputs.unsqueeze(0))
    train_outputs = train_outputs.squeeze(0) * std + mean
    
    # supervised loss [x, fθ(A†y)]
    # [1, 2, 768, 392] -> [1, 768, 392]
    output_image_1c = complex_abs(torch.moveaxis(train_outputs.squeeze(0), 0, -1 )).unsqueeze(0)
    train_targets_1c = complex_abs(target_image).unsqueeze(0)
    #loss_sup = l1_loss(train_outputs_1c, train_targets_1c) / torch.sum(torch.abs(train_targets_1c))
    
    # center crop
    min_size = min(train_targets_1c.shape[-2:])
    crop = torch.Size([min_size, min_size])
    train_targets_1c = center_crop( train_targets_1c, crop )
    output_image_1c = center_crop( output_image_1c, crop )
    
    # SSIM = 1 - loss
    loss_ssim = 1 - ssim_fct(output_image_1c, train_targets_1c, data_range = train_targets_1c.max().unsqueeze(0)).item()

    #loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

0it [00:00, ?it/s]

1it [00:00,  2.00it/s]

tensor([[[2.4544e-05]],

        [[2.5892e-05]]], device='cuda:0')


2it [00:00,  2.34it/s]

tensor([[[2.2407e-05]],

        [[2.4296e-05]]], device='cuda:0')


3it [00:01,  3.10it/s]

tensor([[[2.3728e-05]],

        [[1.9464e-05]]], device='cuda:0')
tensor([[[3.1128e-05]],

        [[3.3436e-05]]], device='cuda:0')


5it [00:01,  3.52it/s]

tensor([[[2.5894e-05]],

        [[2.3860e-05]]], device='cuda:0')


6it [00:01,  3.73it/s]

tensor([[[2.2929e-05]],

        [[2.4232e-05]]], device='cuda:0')


7it [00:02,  3.50it/s]

tensor([[[2.3987e-05]],

        [[2.3277e-05]]], device='cuda:0')
tensor([[[1.5400e-05]],

        [[2.1290e-05]]], device='cuda:0')


9it [00:02,  3.65it/s]

tensor([[[2.6547e-05]],

        [[2.5468e-05]]], device='cuda:0')


10it [00:02,  3.92it/s]

tensor([[[2.5308e-05]],

        [[2.4325e-05]]], device='cuda:0')


11it [00:03,  3.64it/s]

tensor([[[2.8156e-05]],

        [[2.7499e-05]]], device='cuda:0')


12it [00:03,  2.91it/s]

tensor([[[3.2064e-05]],

        [[2.7782e-05]]], device='cuda:0')


13it [00:04,  2.86it/s]

tensor([[[2.1454e-05]],

        [[2.1858e-05]]], device='cuda:0')


14it [00:04,  3.07it/s]

tensor([[[2.8798e-05]],

        [[2.8197e-05]]], device='cuda:0')


15it [00:04,  3.14it/s]

tensor([[[2.6253e-05]],

        [[2.6041e-05]]], device='cuda:0')


16it [00:04,  3.43it/s]

tensor([[[3.5860e-05]],

        [[3.5055e-05]]], device='cuda:0')


17it [00:05,  3.54it/s]

tensor([[[1.6724e-05]],

        [[1.5973e-05]]], device='cuda:0')


18it [00:05,  3.66it/s]

tensor([[[2.5305e-05]],

        [[2.3914e-05]]], device='cuda:0')


19it [00:05,  3.46it/s]

tensor([[[2.1341e-05]],

        [[2.4033e-05]]], device='cuda:0')
tensor([[[2.4893e-05]],

        [[2.5757e-05]]], device='cuda:0')


21it [00:06,  4.02it/s]

tensor([[[2.2489e-05]],

        [[2.1666e-05]]], device='cuda:0')
tensor([[[2.0520e-05]],

        [[1.8339e-05]]], device='cuda:0')


23it [00:06,  4.46it/s]

tensor([[[3.0202e-05]],

        [[2.7851e-05]]], device='cuda:0')


24it [00:06,  4.48it/s]

tensor([[[2.7083e-05]],

        [[2.8162e-05]]], device='cuda:0')


25it [00:06,  4.52it/s]

tensor([[[2.2934e-05]],

        [[2.1481e-05]]], device='cuda:0')


26it [00:07,  4.56it/s]

tensor([[[1.9762e-05]],

        [[2.2637e-05]]], device='cuda:0')


27it [00:07,  4.22it/s]

tensor([[[2.4026e-05]],

        [[2.1864e-05]]], device='cuda:0')


28it [00:07,  4.11it/s]

tensor([[[2.8324e-05]],

        [[2.8431e-05]]], device='cuda:0')


29it [00:07,  4.27it/s]

tensor([[[3.4062e-05]],

        [[3.3027e-05]]], device='cuda:0')


30it [00:08,  4.07it/s]

tensor([[[3.0716e-05]],

        [[3.4193e-05]]], device='cuda:0')


31it [00:08,  4.09it/s]

tensor([[[2.4977e-05]],

        [[2.3801e-05]]], device='cuda:0')


32it [00:08,  4.03it/s]

tensor([[[2.9639e-05]],

        [[3.1704e-05]]], device='cuda:0')


33it [00:08,  3.96it/s]

tensor([[[2.7951e-05]],

        [[2.7026e-05]]], device='cuda:0')
tensor([[[2.8203e-05]],

        [[2.6752e-05]]], device='cuda:0')


35it [00:09,  4.24it/s]

tensor([[[2.5343e-05]],

        [[2.5013e-05]]], device='cuda:0')
tensor([[[2.1994e-05]],

        [[1.8631e-05]]], device='cuda:0')


37it [00:09,  4.34it/s]

tensor([[[1.8391e-05]],

        [[1.7150e-05]]], device='cuda:0')


38it [00:10,  4.47it/s]

tensor([[[2.8771e-05]],

        [[2.7448e-05]]], device='cuda:0')


39it [00:10,  3.60it/s]

tensor([[[3.7722e-05]],

        [[3.5158e-05]]], device='cuda:0')


40it [00:10,  3.56it/s]

tensor([[[2.6789e-05]],

        [[2.2879e-05]]], device='cuda:0')


41it [00:11,  3.45it/s]

tensor([[[2.2359e-05]],

        [[2.2804e-05]]], device='cuda:0')


42it [00:11,  3.17it/s]

tensor([[[2.7459e-05]],

        [[2.7466e-05]]], device='cuda:0')


43it [00:11,  3.55it/s]

tensor([[[2.0467e-05]],

        [[1.7880e-05]]], device='cuda:0')


44it [00:11,  3.60it/s]

tensor([[[2.7222e-05]],

        [[2.7438e-05]]], device='cuda:0')


45it [00:12,  3.39it/s]

tensor([[[2.8288e-05]],

        [[2.9869e-05]]], device='cuda:0')


46it [00:12,  3.61it/s]

tensor([[[1.6752e-05]],

        [[1.7911e-05]]], device='cuda:0')


47it [00:12,  3.81it/s]

tensor([[[3.6628e-05]],

        [[3.4636e-05]]], device='cuda:0')
tensor([[[1.8743e-05]],

        [[1.9075e-05]]], device='cuda:0')


49it [00:13,  3.74it/s]

tensor([[[2.1987e-05]],

        [[2.2822e-05]]], device='cuda:0')


50it [00:13,  3.97it/s]

tensor([[[1.5092e-05]],

        [[1.5801e-05]]], device='cuda:0')


51it [00:13,  3.69it/s]

tensor([[[2.9518e-05]],

        [[2.7896e-05]]], device='cuda:0')


52it [00:14,  3.53it/s]

tensor([[[2.7912e-05]],

        [[2.8781e-05]]], device='cuda:0')


53it [00:14,  3.32it/s]

tensor([[[3.6578e-05]],

        [[3.1964e-05]]], device='cuda:0')


54it [00:14,  3.35it/s]

tensor([[[3.4687e-05]],

        [[3.3790e-05]]], device='cuda:0')


55it [00:14,  3.63it/s]

tensor([[[2.3029e-05]],

        [[2.3699e-05]]], device='cuda:0')


56it [00:15,  3.79it/s]

tensor([[[2.8720e-05]],

        [[3.0031e-05]]], device='cuda:0')


58it [00:15,  4.57it/s]

tensor([[[2.4190e-05]],

        [[2.5199e-05]]], device='cuda:0')
tensor([[[2.0580e-05]],

        [[2.0444e-05]]], device='cuda:0')


59it [00:15,  4.71it/s]

tensor([[[3.0081e-05]],

        [[3.3852e-05]]], device='cuda:0')


60it [00:15,  4.27it/s]

tensor([[[2.7212e-05]],

        [[2.4824e-05]]], device='cuda:0')


61it [00:16,  4.44it/s]

tensor([[[2.8045e-05]],

        [[2.7181e-05]]], device='cuda:0')


62it [00:16,  4.16it/s]

tensor([[[3.6650e-05]],

        [[3.6007e-05]]], device='cuda:0')


63it [00:16,  3.84it/s]

tensor([[[2.1767e-05]],

        [[1.9188e-05]]], device='cuda:0')


64it [00:17,  3.22it/s]

tensor([[[3.0953e-05]],

        [[2.9698e-05]]], device='cuda:0')


65it [00:17,  3.36it/s]

tensor([[[2.6129e-05]],

        [[2.6757e-05]]], device='cuda:0')


66it [00:17,  3.24it/s]

tensor([[[3.4447e-05]],

        [[3.3622e-05]]], device='cuda:0')


67it [00:18,  3.34it/s]

tensor([[[2.2101e-05]],

        [[2.1154e-05]]], device='cuda:0')


68it [00:18,  3.42it/s]

tensor([[[3.1400e-05]],

        [[2.7733e-05]]], device='cuda:0')


69it [00:18,  3.61it/s]

tensor([[[3.2099e-05]],

        [[3.3824e-05]]], device='cuda:0')
tensor([[[2.5905e-05]],

        [[2.5524e-05]]], device='cuda:0')


71it [00:18,  4.28it/s]

tensor([[[2.2001e-05]],

        [[1.9387e-05]]], device='cuda:0')
tensor([[[2.9121e-05]],

        [[2.9865e-05]]], device='cuda:0')


73it [00:19,  3.89it/s]

tensor([[[2.6556e-05]],

        [[3.0259e-05]]], device='cuda:0')


74it [00:19,  3.89it/s]

tensor([[[2.1807e-05]],

        [[2.5479e-05]]], device='cuda:0')


75it [00:20,  3.75it/s]

tensor([[[2.3835e-05]],

        [[2.1595e-05]]], device='cuda:0')


76it [00:20,  3.65it/s]

tensor([[[3.2969e-05]],

        [[2.9667e-05]]], device='cuda:0')


77it [00:20,  3.88it/s]

tensor([[[2.8095e-05]],

        [[2.9341e-05]]], device='cuda:0')


78it [00:20,  3.97it/s]

tensor([[[2.7864e-05]],

        [[2.6393e-05]]], device='cuda:0')


79it [00:20,  4.18it/s]

tensor([[[2.3264e-05]],

        [[2.1031e-05]]], device='cuda:0')


80it [00:21,  4.13it/s]

tensor([[[2.8661e-05]],

        [[2.7013e-05]]], device='cuda:0')


81it [00:21,  4.14it/s]

tensor([[[4.1873e-05]],

        [[3.5499e-05]]], device='cuda:0')


82it [00:21,  4.20it/s]

tensor([[[2.9939e-05]],

        [[2.8765e-05]]], device='cuda:0')


83it [00:21,  4.02it/s]

tensor([[[2.7381e-05]],

        [[2.7366e-05]]], device='cuda:0')
tensor([[[2.2985e-05]],

        [[2.0978e-05]]], device='cuda:0')


85it [00:22,  4.61it/s]

tensor([[[2.5527e-05]],

        [[2.4880e-05]]], device='cuda:0')


86it [00:22,  4.35it/s]

tensor([[[3.2155e-05]],

        [[3.1297e-05]]], device='cuda:0')


87it [00:22,  3.84it/s]

tensor([[[2.6446e-05]],

        [[2.7081e-05]]], device='cuda:0')


88it [00:23,  4.03it/s]

tensor([[[1.9350e-05]],

        [[2.0518e-05]]], device='cuda:0')


89it [00:23,  3.88it/s]

tensor([[[1.9035e-05]],

        [[1.9159e-05]]], device='cuda:0')


90it [00:23,  3.34it/s]

tensor([[[2.8182e-05]],

        [[2.7903e-05]]], device='cuda:0')


91it [00:24,  3.11it/s]

tensor([[[3.0631e-05]],

        [[2.8593e-05]]], device='cuda:0')


92it [00:24,  3.35it/s]

tensor([[[2.2936e-05]],

        [[2.2340e-05]]], device='cuda:0')


93it [00:24,  3.57it/s]

tensor([[[2.2520e-05]],

        [[2.5201e-05]]], device='cuda:0')
tensor([[[1.9952e-05]],

        [[1.8891e-05]]], device='cuda:0')


95it [00:25,  3.82it/s]

tensor([[[2.9902e-05]],

        [[3.2756e-05]]], device='cuda:0')
tensor([[[2.8206e-05]],

        [[2.7420e-05]]], device='cuda:0')


97it [00:25,  4.01it/s]

tensor([[[3.4298e-05]],

        [[3.3938e-05]]], device='cuda:0')


98it [00:25,  3.88it/s]

tensor([[[2.5982e-05]],

        [[2.6098e-05]]], device='cuda:0')


99it [00:26,  3.50it/s]

tensor([[[3.1378e-05]],

        [[3.6152e-05]]], device='cuda:0')


100it [00:26,  3.75it/s]

tensor([[[2.5156e-05]],

        [[2.5939e-05]]], device='cuda:0')





In [14]:
#print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average SSIM loss:  0.9217642635107041


### Test-time training

In [8]:
best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.0001
TTT_epoch = 1500

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_kspace, input_mask, kspace, sens_maps, sens_maps_conj, binary_background_mask, fname, slice_num = batch
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    kspace = kspace.to(device)
    sens_maps = sens_maps.to(device)
    sens_maps_conj = sens_maps_conj.to(device)

    # model re-init
    model.load_state_dict(torch.load(checkpoint_path))

    # each data point TTT
    optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch): 
        ####### training ######
        # scale normalization
        scale_factor = scale(input_kspace.squeeze(0), model)

        # A†y
        scaled_input_kspace = scale_factor * input_kspace.squeeze(0)
        train_inputs = ifft2c(scaled_input_kspace) #shape: coils,height,width,2
        train_inputs = complex_mul(train_inputs, sens_maps_conj.squeeze(0))
        train_inputs = train_inputs.sum(dim=0, keepdim=False) #shape: height,width,2
        train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
        # normalize input to have zero mean and std one
        #train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)

        # fθ(A†y)
        model_output = model(train_inputs.unsqueeze(0))
        model_output = model_output.squeeze(0)# * std + mean
        model_output = torch.moveaxis(model_output.unsqueeze(0), 1, -1 )
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps.squeeze(0))
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * input_mask.squeeze(0)
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, scaled_input_kspace) / torch.sum(torch.abs(scaled_input_kspace))
    
        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())


        ###### evaluation ######
        loss_l1, loss_ssim = evaluate(model, scaled_input_kspace, kspace, sens_maps_conj)

        loss_l1_history.append(loss_l1)
        loss_ssim_history.append(loss_ssim)


    # best results for an example
    best_loss_l1 = min(loss_l1_history)
    print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

1it [04:50, 290.84s/it]

best L1:  0.12027107179164886
best L1 epoch:  1496
best SSIM: 0.8916576504707336
best SSIM epoch:  1494


2it [10:12, 309.06s/it]

best L1:  0.13555407524108887
best L1 epoch:  1476
best SSIM: 0.8540239930152893
best SSIM epoch:  1375


3it [14:34, 287.54s/it]

best L1:  0.1827261745929718
best L1 epoch:  1488
best SSIM: 0.8546640872955322
best SSIM epoch:  1477


4it [19:25, 289.03s/it]

best L1:  0.19028446078300476
best L1 epoch:  1497
best SSIM: 0.8113265633583069
best SSIM epoch:  0


4it [23:26, 351.62s/it]


KeyboardInterrupt: 