In [1]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_sens_TTT, complex_center_crop, center_crop_to_smallest
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7f797c7e96f0>

### Load the data

In [2]:
# data path
path_test = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/brain_test_AXT1POST_Skyra_5-8.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/sensmap_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_sens_TTT('multicoil', mask_func = mask_function, use_seed=True, mode='adapt')

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


### Test-time training

In [3]:
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

def TTT(model, TTT_epoch, adapt_lr): 
    optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch):
        # fθ(A†y)
        model_output = model(input_image)
        model_output = model_output * std + mean
        # supervised loss [x, fθ(A†y)]
        # L1
        loss_l1 = (l1_loss(model_output, target_image) / torch.sum(torch.abs(target_image))).item()
        #print('Test L1 loss: ',loss_l1)
        loss_l1_history.append(loss_l1)
        # SSIM = 1 - loss
        output_image_1c = complex_abs(torch.moveaxis(model_output , 1, -1 ))
        loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image, data_range = ground_truth_image.max().unsqueeze(0)).item()
        #print('Test SSIM loss: ',loss_ssim)
        loss_ssim_history.append(loss_ssim)

        # self-supervised loss
        # fθ(A†y)
        model_output = torch.moveaxis(model_output, 1, -1 )
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps)
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * input_mask
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, input_kspace) / torch.sum(torch.abs(input_kspace))

        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())

    return loss_l1_history, loss_ssim_history, self_loss_history

In [4]:
checkpoint_path_self = '/cheng/metaMRI/metaMRI/save/E11.3_joint(l1_CA-1e-3-4_P)_T300_120epoch/E11.3_joint(l1_CA-1e-3-4_P)_T300_120epoch_E66_best.pth'
# '/cheng/metaMRI/metaMRI/save/E11.3_joint(l1_CA-1e-3-4_P)_T300_120epoch/E11.3_joint(l1_CA-1e-3-4_P)_T300_120epoch_E66_best.pth'

model_self = Unet(in_chans=2, out_chans=2, chans=32, num_pool_layers=4, drop_prob=0.0)
model_self.load_state_dict(torch.load(checkpoint_path_self))
model_self = model_self.to(device)

best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    # model re-init
    model_self.load_state_dict(torch.load(checkpoint_path_self))

    # each data point TTT
    loss_l1_history, loss_ssim_history, self_loss_history = TTT(model_self, TTT_epoch=500, adapt_lr=0.0001)
    best_loss_l1 = min(loss_l1_history)
    #print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    #print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

1it [00:21, 21.16s/it]

best L1 epoch:  163
best SSIM epoch:  108


2it [00:41, 20.55s/it]

best L1 epoch:  88
best SSIM epoch:  50


3it [01:01, 20.51s/it]

best L1 epoch:  68
best SSIM epoch:  65


4it [01:22, 20.64s/it]

best L1 epoch:  123
best SSIM epoch:  104


5it [01:43, 20.75s/it]

best L1 epoch:  339
best SSIM epoch:  5


6it [02:04, 20.84s/it]

best L1 epoch:  446
best SSIM epoch:  5


7it [02:25, 20.93s/it]

best L1 epoch:  483
best SSIM epoch:  6


8it [02:48, 21.59s/it]

best L1 epoch:  349
best SSIM epoch:  4


9it [03:10, 21.55s/it]

best L1 epoch:  82
best SSIM epoch:  65


10it [03:31, 21.45s/it]

best L1 epoch:  120
best SSIM epoch:  76


11it [03:52, 21.41s/it]

best L1 epoch:  139
best SSIM epoch:  40


12it [04:13, 21.36s/it]

best L1 epoch:  73
best SSIM epoch:  67


13it [04:35, 21.29s/it]

best L1 epoch:  484
best SSIM epoch:  0


14it [04:56, 21.22s/it]

best L1 epoch:  479
best SSIM epoch:  5


15it [05:17, 21.16s/it]

best L1 epoch:  495
best SSIM epoch:  8


16it [05:38, 21.10s/it]

best L1 epoch:  402
best SSIM epoch:  8


17it [05:59, 21.21s/it]

best L1 epoch:  301
best SSIM epoch:  59


18it [06:20, 21.27s/it]

best L1 epoch:  296
best SSIM epoch:  63


19it [06:42, 21.32s/it]

best L1 epoch:  237
best SSIM epoch:  64


20it [07:03, 21.35s/it]

best L1 epoch:  315
best SSIM epoch:  13


Meta

In [22]:
checkpoint_path_meta = '/cheng/metaMRI/metaMRI/save/E11.3_maml(l1_CA-1e-3-4_P)_T300_200epoch/E11.3_maml(l1_CA-1e-3-4_P)_T300_200epoch_E79_best.pth'
# '/cheng/metaMRI/metaMRI/save/E11.3_maml(l1_CA-1e-3-4_P)_T300_200epoch/E11.3_maml(l1_CA-1e-3-4_P)_T300_200epoch_E79_best.pth'


model_meta = Unet(in_chans=2, out_chans=2, chans=32, num_pool_layers=4, drop_prob=0.0)
model_meta.load_state_dict(torch.load(checkpoint_path_meta))
model_meta = model_meta.to(device)


best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    # model re-init
    model_meta.load_state_dict(torch.load(checkpoint_path_meta))

    # each data point TTT
    loss_l1_history, loss_ssim_history, self_loss_history = TTT(model_meta, TTT_epoch=500, adapt_lr=0.0001)
    best_loss_l1 = min(loss_l1_history)
    #print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    #print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

1it [00:19, 19.96s/it]

best L1 epoch:  199
best SSIM epoch:  126


2it [00:40, 20.07s/it]

best L1 epoch:  221
best SSIM epoch:  112


3it [01:00, 20.23s/it]

best L1 epoch:  262
best SSIM epoch:  106


4it [01:21, 20.40s/it]

best L1 epoch:  162
best SSIM epoch:  113


5it [01:41, 20.52s/it]

best L1 epoch:  341
best SSIM epoch:  143


6it [02:02, 20.61s/it]

best L1 epoch:  358
best SSIM epoch:  6


7it [02:23, 20.68s/it]

best L1 epoch:  365
best SSIM epoch:  5


8it [02:44, 20.71s/it]

best L1 epoch:  407
best SSIM epoch:  6


9it [03:05, 20.75s/it]

best L1 epoch:  195
best SSIM epoch:  66


10it [03:26, 20.80s/it]

best L1 epoch:  177
best SSIM epoch:  56


11it [03:46, 20.82s/it]

best L1 epoch:  132
best SSIM epoch:  131


12it [04:07, 20.82s/it]

best L1 epoch:  145
best SSIM epoch:  67


13it [04:28, 20.83s/it]

best L1 epoch:  469
best SSIM epoch:  3


14it [04:49, 20.84s/it]

best L1 epoch:  336
best SSIM epoch:  2


15it [05:10, 20.84s/it]

best L1 epoch:  302
best SSIM epoch:  6


16it [05:31, 20.83s/it]

best L1 epoch:  314
best SSIM epoch:  7


17it [05:52, 20.97s/it]

best L1 epoch:  308
best SSIM epoch:  15


18it [06:13, 21.06s/it]

best L1 epoch:  246
best SSIM epoch:  29


19it [06:34, 21.13s/it]

best L1 epoch:  209
best SSIM epoch:  31


20it [06:56, 21.16s/it]

best L1 epoch:  322
best SSIM epoch:  36


21it [07:16, 21.02s/it]

best L1 epoch:  424
best SSIM epoch:  157


22it [07:37, 20.91s/it]

best L1 epoch:  338
best SSIM epoch:  181


23it [07:58, 20.84s/it]

best L1 epoch:  321
best SSIM epoch:  200


24it [08:18, 20.78s/it]

best L1 epoch:  331
best SSIM epoch:  41


25it [08:39, 20.88s/it]

best L1 epoch:  371
best SSIM epoch:  156


26it [09:01, 20.96s/it]

best L1 epoch:  324
best SSIM epoch:  39


27it [09:22, 21.02s/it]

best L1 epoch:  366
best SSIM epoch:  38


28it [09:43, 21.05s/it]

best L1 epoch:  299
best SSIM epoch:  98


29it [10:04, 21.07s/it]

best L1 epoch:  323
best SSIM epoch:  6


30it [10:25, 21.09s/it]

best L1 epoch:  376
best SSIM epoch:  5


31it [10:46, 21.11s/it]

best L1 epoch:  336
best SSIM epoch:  4


32it [11:07, 21.11s/it]

best L1 epoch:  309
best SSIM epoch:  10


33it [11:29, 21.13s/it]

best L1 epoch:  498
best SSIM epoch:  3


34it [11:50, 21.13s/it]

best L1 epoch:  499
best SSIM epoch:  4


35it [12:11, 21.13s/it]

best L1 epoch:  496
best SSIM epoch:  4


36it [12:32, 21.14s/it]

best L1 epoch:  497
best SSIM epoch:  2


37it [12:53, 21.14s/it]

best L1 epoch:  498
best SSIM epoch:  210


38it [13:14, 21.14s/it]

best L1 epoch:  494
best SSIM epoch:  128


39it [13:36, 21.15s/it]

best L1 epoch:  487
best SSIM epoch:  137


40it [13:57, 21.16s/it]

best L1 epoch:  378
best SSIM epoch:  193


41it [14:18, 21.22s/it]

best L1 epoch:  437
best SSIM epoch:  286


42it [14:40, 21.49s/it]

best L1 epoch:  384
best SSIM epoch:  252


43it [15:01, 21.21s/it]

best L1 epoch:  338
best SSIM epoch:  208


44it [15:21, 21.02s/it]

best L1 epoch:  274
best SSIM epoch:  147


45it [15:42, 21.05s/it]

best L1 epoch:  246
best SSIM epoch:  186


46it [16:04, 21.06s/it]

best L1 epoch:  287
best SSIM epoch:  181


47it [16:25, 21.09s/it]

best L1 epoch:  179
best SSIM epoch:  129


48it [16:46, 21.10s/it]

best L1 epoch:  141
best SSIM epoch:  81


49it [17:06, 20.94s/it]

best L1 epoch:  189
best SSIM epoch:  58


50it [17:27, 20.83s/it]

best L1 epoch:  194
best SSIM epoch:  63


51it [17:47, 20.76s/it]

best L1 epoch:  189
best SSIM epoch:  47


52it [18:08, 20.70s/it]

best L1 epoch:  224
best SSIM epoch:  9


53it [18:29, 20.84s/it]

best L1 epoch:  307
best SSIM epoch:  146


54it [18:50, 20.92s/it]

best L1 epoch:  317
best SSIM epoch:  64


55it [19:11, 20.99s/it]

best L1 epoch:  331
best SSIM epoch:  158


56it [19:33, 21.02s/it]

best L1 epoch:  351
best SSIM epoch:  51


57it [19:53, 20.89s/it]

best L1 epoch:  292
best SSIM epoch:  4


58it [20:14, 20.79s/it]

best L1 epoch:  375
best SSIM epoch:  5


59it [20:34, 20.71s/it]

best L1 epoch:  431
best SSIM epoch:  5


60it [20:55, 20.68s/it]

best L1 epoch:  369
best SSIM epoch:  6


61it [21:16, 20.82s/it]

best L1 epoch:  158
best SSIM epoch:  85


62it [21:37, 20.91s/it]

best L1 epoch:  169
best SSIM epoch:  75


63it [21:58, 20.97s/it]

best L1 epoch:  194
best SSIM epoch:  92


64it [22:19, 21.02s/it]

best L1 epoch:  126
best SSIM epoch:  8


65it [22:40, 20.89s/it]

best L1 epoch:  280
best SSIM epoch:  168


66it [23:01, 20.81s/it]

best L1 epoch:  252
best SSIM epoch:  188


67it [23:21, 20.74s/it]

best L1 epoch:  259
best SSIM epoch:  181


68it [23:42, 20.70s/it]

best L1 epoch:  177
best SSIM epoch:  112


69it [24:03, 20.84s/it]

best L1 epoch:  168
best SSIM epoch:  100


70it [24:24, 20.93s/it]

best L1 epoch:  114
best SSIM epoch:  69


71it [24:45, 20.99s/it]

best L1 epoch:  116
best SSIM epoch:  84


72it [25:06, 21.02s/it]

best L1 epoch:  157
best SSIM epoch:  91


73it [25:27, 21.05s/it]

best L1 epoch:  164
best SSIM epoch:  35


74it [25:49, 21.07s/it]

best L1 epoch:  141
best SSIM epoch:  30


75it [26:10, 21.09s/it]

best L1 epoch:  188
best SSIM epoch:  32


76it [26:31, 21.10s/it]

best L1 epoch:  86
best SSIM epoch:  47


77it [26:51, 20.77s/it]

best L1 epoch:  331
best SSIM epoch:  211


78it [27:11, 20.55s/it]

best L1 epoch:  314
best SSIM epoch:  171


79it [27:31, 20.39s/it]

best L1 epoch:  353
best SSIM epoch:  13


80it [27:51, 20.30s/it]

best L1 epoch:  350
best SSIM epoch:  18


81it [28:12, 20.39s/it]

best L1 epoch:  322
best SSIM epoch:  117


82it [28:32, 20.45s/it]

best L1 epoch:  297
best SSIM epoch:  170


83it [28:53, 20.48s/it]

best L1 epoch:  318
best SSIM epoch:  144


84it [29:13, 20.52s/it]

best L1 epoch:  329
best SSIM epoch:  109


85it [29:34, 20.70s/it]

best L1 epoch:  179
best SSIM epoch:  90


86it [29:56, 20.82s/it]

best L1 epoch:  112
best SSIM epoch:  80


87it [30:18, 21.40s/it]

best L1 epoch:  154
best SSIM epoch:  100


88it [30:40, 21.48s/it]

best L1 epoch:  219
best SSIM epoch:  15


89it [31:01, 21.21s/it]

best L1 epoch:  159
best SSIM epoch:  53


90it [31:21, 21.02s/it]

best L1 epoch:  232
best SSIM epoch:  36


91it [31:42, 20.89s/it]

best L1 epoch:  220
best SSIM epoch:  75


92it [32:02, 20.79s/it]

best L1 epoch:  242
best SSIM epoch:  137


93it [32:23, 20.89s/it]

best L1 epoch:  202
best SSIM epoch:  197


94it [32:45, 20.96s/it]

best L1 epoch:  207
best SSIM epoch:  194


95it [33:06, 21.01s/it]

best L1 epoch:  178
best SSIM epoch:  146


96it [33:27, 21.04s/it]

best L1 epoch:  173
best SSIM epoch:  80


97it [33:47, 20.91s/it]

best L1 epoch:  287
best SSIM epoch:  198


98it [34:08, 20.81s/it]

best L1 epoch:  284
best SSIM epoch:  138


99it [34:29, 20.74s/it]

best L1 epoch:  216
best SSIM epoch:  56


100it [34:49, 20.69s/it]

best L1 epoch:  245
best SSIM epoch:  75


101it [35:10, 20.82s/it]

best L1 epoch:  367
best SSIM epoch:  254


102it [35:31, 20.91s/it]

best L1 epoch:  346
best SSIM epoch:  8


103it [35:52, 20.97s/it]

best L1 epoch:  339
best SSIM epoch:  13


104it [36:14, 21.01s/it]

best L1 epoch:  308
best SSIM epoch:  140


105it [36:34, 20.87s/it]

best L1 epoch:  381
best SSIM epoch:  137


106it [36:55, 20.78s/it]

best L1 epoch:  305
best SSIM epoch:  153


107it [37:15, 20.72s/it]

best L1 epoch:  336
best SSIM epoch:  137


108it [37:36, 20.68s/it]

best L1 epoch:  268
best SSIM epoch:  112


109it [37:57, 20.82s/it]

best L1 epoch:  199
best SSIM epoch:  98


110it [38:18, 20.92s/it]

best L1 epoch:  165
best SSIM epoch:  83


111it [38:39, 20.99s/it]

best L1 epoch:  60
best SSIM epoch:  66


112it [39:00, 21.03s/it]

best L1 epoch:  50
best SSIM epoch:  38


113it [39:22, 21.06s/it]

best L1 epoch:  196
best SSIM epoch:  156


114it [39:43, 21.07s/it]

best L1 epoch:  219
best SSIM epoch:  157


115it [40:04, 21.08s/it]

best L1 epoch:  165
best SSIM epoch:  106


116it [40:25, 21.09s/it]

best L1 epoch:  127
best SSIM epoch:  68


117it [40:46, 21.10s/it]

best L1 epoch:  386
best SSIM epoch:  325


118it [41:07, 21.10s/it]

best L1 epoch:  330
best SSIM epoch:  180


119it [41:28, 21.12s/it]

best L1 epoch:  297
best SSIM epoch:  148


120it [41:49, 21.12s/it]

best L1 epoch:  272
best SSIM epoch:  108


121it [42:10, 20.96s/it]

best L1 epoch:  318
best SSIM epoch:  21
