In [8]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_TTTpaper, center_crop, scale, normalize_separate_over_ch
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7fcf6ce6d710>

### Load the data

In [9]:
# data path
# path_test = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/brain_test_AXT1POST_Skyra_5-8.yaml'
# path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/sensmap_test/'
path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_test_100.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_TTTpaper('multicoil', mask_func = mask_function, use_seed=True, mode='adapt')

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


In [10]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E11.7_sup(l1_CA-1e-3-4_P)_T300_150epoch/E11.7_sup(l1_CA-1e-3-4_P)_T300_150epoch_E85_best.pth'
model = Unet(in_chans=2, out_chans=2, chans=32, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

### without TTT

In [11]:
def evaluate(model, k_input, k_gt, conj_sensmaps): 
    # k_input, k_gt: [1, coils, height, width, channel]

    ssim_fct = SSIMLoss()
    l1_loss = torch.nn.L1Loss(reduction='sum')
    # gt image: complex
    ground_truth_image = ifft2c(k_gt.squeeze(0))
    ground_truth_image = complex_mul(ground_truth_image.squeeze(0), conj_sensmaps.squeeze(0))
    ground_truth_image = ground_truth_image.sum(dim=0, keepdim=False)
    ground_truth_image = torch.moveaxis( ground_truth_image , -1, 0 ) 

    # A†y
    train_inputs = ifft2c(k_input.squeeze(0)) #shape: coils,height,width,2
    train_inputs = complex_mul(train_inputs, conj_sensmaps.squeeze(0))
    train_inputs = train_inputs.sum(dim=0, keepdim=False) #shape: height,width,2
    train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
    # normalize input to have zero mean and std one
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)

    # fθ(A†y)
    model_output = model(train_inputs.unsqueeze(0))
    model_output = model_output.squeeze(0) * std + mean

    # supervised loss [x, fθ(A†y)]
    # L1
    loss_l1 = (l1_loss(model_output, ground_truth_image) / torch.sum(torch.abs(ground_truth_image))).item()
    # SSIM = 1 - loss
    output_image_1c = complex_abs(torch.moveaxis(model_output , 0, -1 )).unsqueeze(0)
    ground_truth_image_1c = complex_abs(torch.moveaxis(ground_truth_image , 0, -1 )).unsqueeze(0)
    # center crop
    min_size = min(ground_truth_image_1c.shape[-2:])
    crop = torch.Size([min_size, min_size])
    ground_truth_image_1c = center_crop( ground_truth_image_1c, crop )
    output_image_1c = center_crop( output_image_1c, crop )
    loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image_1c, data_range = ground_truth_image_1c.max().unsqueeze(0)).item()
    
    return loss_l1, loss_ssim

In [12]:
loss_l1_history_=[]
loss_ssim_history_=[]


for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_kspace, input_mask, kspace, sens_maps, sens_maps_conj, binary_background_mask, fname, slice_num = batch
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    kspace = kspace.to(device)
    sens_maps = sens_maps.to(device)
    sens_maps_conj = sens_maps_conj.to(device)

    # input_kspace / kspace:  [1, coils, height, width, channel]
    loss_l1, loss_ssim = evaluate(model, input_kspace, kspace, sens_maps_conj)

    loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

100it [00:13,  7.36it/s]


In [13]:
print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average L1 loss:  0.21858883425593376
Testing average SSIM loss:  0.8593406730890274


### Test-time training

In [14]:
best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.0001
TTT_epoch = 500

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_kspace, input_mask, kspace, sens_maps, sens_maps_conj, binary_background_mask, fname, slice_num = batch
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    kspace = kspace.to(device)
    sens_maps = sens_maps.to(device)
    sens_maps_conj = sens_maps_conj.to(device)

    # model re-init
    model.load_state_dict(torch.load(checkpoint_path))

    # each data point TTT
    optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch): 
        ####### training ######
        # scale normalization
        scale_factor = scale(input_kspace.squeeze(0), model)

        # A†y
        scaled_input_kspace = scale_factor * input_kspace.squeeze(0)
        train_inputs = ifft2c(scaled_input_kspace) #shape: coils,height,width,2
        train_inputs = complex_mul(train_inputs, sens_maps_conj.squeeze(0))
        train_inputs = train_inputs.sum(dim=0, keepdim=False) #shape: height,width,2
        train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
        # normalize input to have zero mean and std one
        train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)

        # fθ(A†y)
        model_output = model(train_inputs.unsqueeze(0))
        model_output = model_output.squeeze(0) * std + mean
        model_output = torch.moveaxis(model_output.unsqueeze(0), 1, -1 )
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps.squeeze(0))
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * input_mask.squeeze(0)
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, scaled_input_kspace) / torch.sum(torch.abs(scaled_input_kspace))
    
        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())


        ###### evaluation ######
        loss_l1, loss_ssim = evaluate(model, input_kspace, kspace, sens_maps_conj)

        loss_l1_history.append(loss_l1)
        loss_ssim_history.append(loss_ssim)


    # best results for an example
    best_loss_l1 = min(loss_l1_history)
    print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

0it [00:00, ?it/s]

1it [01:15, 75.81s/it]

best L1:  0.15656623244285583
best L1 epoch:  232
best SSIM: 0.9002406597137451
best SSIM epoch:  205


2it [02:32, 76.61s/it]

best L1:  0.1155308336019516
best L1 epoch:  219
best SSIM: 0.9078854322433472
best SSIM epoch:  162


3it [03:50, 76.99s/it]

best L1:  0.12585926055908203
best L1 epoch:  288
best SSIM: 0.9176030158996582
best SSIM epoch:  276


4it [05:08, 77.29s/it]

best L1:  0.17101293802261353
best L1 epoch:  120
best SSIM: 0.8881748914718628
best SSIM epoch:  104


5it [06:26, 77.49s/it]

best L1:  0.18637390434741974
best L1 epoch:  22
best SSIM: 0.87989741563797
best SSIM epoch:  31


6it [07:36, 75.24s/it]

best L1:  0.16253876686096191
best L1 epoch:  361
best SSIM: 0.8839223980903625
best SSIM epoch:  228


7it [08:54, 76.09s/it]

best L1:  0.16250966489315033
best L1 epoch:  320
best SSIM: 0.9141367673873901
best SSIM epoch:  120


8it [10:12, 76.65s/it]

best L1:  0.18459104001522064
best L1 epoch:  69
best SSIM: 0.8862533569335938
best SSIM epoch:  72


9it [11:30, 77.01s/it]

best L1:  0.17429804801940918
best L1 epoch:  221
best SSIM: 0.8886793851852417
best SSIM epoch:  176


10it [12:48, 77.25s/it]

best L1:  0.13453441858291626
best L1 epoch:  322
best SSIM: 0.9155253171920776
best SSIM epoch:  209


11it [14:05, 77.43s/it]

best L1:  0.10814477503299713
best L1 epoch:  381
best SSIM: 0.9269947409629822
best SSIM epoch:  189


12it [15:23, 77.55s/it]

best L1:  0.1534079611301422
best L1 epoch:  313
best SSIM: 0.8939396142959595
best SSIM epoch:  230


13it [16:41, 77.63s/it]

best L1:  0.13932207226753235
best L1 epoch:  179
best SSIM: 0.923108696937561
best SSIM epoch:  162


14it [17:59, 77.69s/it]

best L1:  0.12940387427806854
best L1 epoch:  291
best SSIM: 0.9256773591041565
best SSIM epoch:  207


15it [19:17, 77.74s/it]

best L1:  0.14469029009342194
best L1 epoch:  246
best SSIM: 0.8972741961479187
best SSIM epoch:  193


16it [20:47, 81.60s/it]

best L1:  0.1494869887828827
best L1 epoch:  332
best SSIM: 0.9454710483551025
best SSIM epoch:  280


17it [22:13, 82.91s/it]

best L1:  0.15201601386070251
best L1 epoch:  437
best SSIM: 0.9445827603340149
best SSIM epoch:  313


18it [23:39, 83.83s/it]

best L1:  0.15429428219795227
best L1 epoch:  265
best SSIM: 0.933418333530426
best SSIM epoch:  210


19it [25:10, 85.85s/it]

best L1:  0.11366625875234604
best L1 epoch:  486
best SSIM: 0.9388250112533569
best SSIM epoch:  325


20it [26:37, 86.33s/it]

best L1:  0.17198723554611206
best L1 epoch:  240
best SSIM: 0.9282688498497009
best SSIM epoch:  214


21it [28:05, 86.68s/it]

best L1:  0.16007517278194427
best L1 epoch:  371
best SSIM: 0.9145439267158508
best SSIM epoch:  260


22it [29:35, 87.83s/it]

best L1:  0.15563474595546722
best L1 epoch:  491
best SSIM: 0.933197557926178
best SSIM epoch:  258


23it [31:03, 87.72s/it]

best L1:  0.16020707786083221
best L1 epoch:  333
best SSIM: 0.9211586117744446
best SSIM epoch:  208


24it [32:30, 87.63s/it]

best L1:  0.13470907509326935
best L1 epoch:  499
best SSIM: 0.9240269064903259
best SSIM epoch:  349


25it [33:58, 87.58s/it]

best L1:  0.16163942217826843
best L1 epoch:  325
best SSIM: 0.9231560230255127
best SSIM epoch:  197


26it [35:16, 84.66s/it]

best L1:  0.11768034845590591
best L1 epoch:  488
best SSIM: 0.9375031590461731
best SSIM epoch:  192


27it [36:43, 85.50s/it]

best L1:  0.1336156278848648
best L1 epoch:  215
best SSIM: 0.9490060210227966
best SSIM epoch:  193


28it [38:14, 87.01s/it]

best L1:  0.14107903838157654
best L1 epoch:  494
best SSIM: 0.9387004971504211
best SSIM epoch:  186


29it [39:41, 87.15s/it]

best L1:  0.1591196358203888
best L1 epoch:  298
best SSIM: 0.9206281304359436
best SSIM epoch:  257


30it [41:08, 87.24s/it]

best L1:  0.13218453526496887
best L1 epoch:  498
best SSIM: 0.9343234300613403
best SSIM epoch:  143


31it [42:26, 84.42s/it]

best L1:  0.14713479578495026
best L1 epoch:  202
best SSIM: 0.8869951963424683
best SSIM epoch:  151


32it [43:54, 85.32s/it]

best L1:  0.1392560452222824
best L1 epoch:  258
best SSIM: 0.9357326030731201
best SSIM epoch:  75


33it [45:12, 83.07s/it]

best L1:  0.15384557843208313
best L1 epoch:  197
best SSIM: 0.8997040390968323
best SSIM epoch:  182


34it [46:42, 85.30s/it]

best L1:  0.1711570769548416
best L1 epoch:  363
best SSIM: 0.9207457304000854
best SSIM epoch:  300


35it [48:13, 86.88s/it]

best L1:  0.16085898876190186
best L1 epoch:  472
best SSIM: 0.929293692111969
best SSIM epoch:  267


36it [49:40, 87.05s/it]

best L1:  0.1485922932624817
best L1 epoch:  314
best SSIM: 0.9195379614830017
best SSIM epoch:  71


37it [51:11, 88.09s/it]

best L1:  0.1462310552597046
best L1 epoch:  277
best SSIM: 0.921417772769928
best SSIM epoch:  274


38it [52:41, 88.82s/it]

best L1:  0.1427035927772522
best L1 epoch:  479
best SSIM: 0.9132251739501953
best SSIM epoch:  201


39it [54:09, 88.41s/it]

best L1:  0.11034882068634033
best L1 epoch:  472
best SSIM: 0.9352681636810303
best SSIM epoch:  353


40it [55:31, 86.77s/it]

best L1:  0.14645756781101227
best L1 epoch:  91
best SSIM: 0.9248132705688477
best SSIM epoch:  92


41it [56:40, 81.40s/it]

best L1:  0.1696406602859497
best L1 epoch:  248
best SSIM: 0.8603676557540894
best SSIM epoch:  234


42it [58:11, 84.14s/it]

best L1:  0.20748218894004822
best L1 epoch:  205
best SSIM: 0.8907877206802368
best SSIM epoch:  198


43it [59:29, 82.26s/it]

best L1:  0.156113862991333
best L1 epoch:  348
best SSIM: 0.9179937243461609
best SSIM epoch:  160


44it [1:00:56, 83.82s/it]

best L1:  0.14013414084911346
best L1 epoch:  498
best SSIM: 0.9349681735038757
best SSIM epoch:  389


45it [1:02:24, 84.92s/it]

best L1:  0.1556912064552307
best L1 epoch:  249
best SSIM: 0.9504179358482361
best SSIM epoch:  248


46it [1:03:50, 85.24s/it]

best L1:  0.1553000509738922
best L1 epoch:  184
best SSIM: 0.9343366026878357
best SSIM epoch:  152


47it [1:05:20, 86.84s/it]

best L1:  0.15103453397750854
best L1 epoch:  77
best SSIM: 0.9214394688606262
best SSIM epoch:  81


48it [1:06:38, 84.14s/it]

best L1:  0.18310458958148956
best L1 epoch:  126
best SSIM: 0.8591356873512268
best SSIM epoch:  112


49it [1:08:06, 85.14s/it]

best L1:  0.12948082387447357
best L1 epoch:  461
best SSIM: 0.9298696517944336
best SSIM epoch:  252


50it [1:09:36, 86.77s/it]

best L1:  0.17948895692825317
best L1 epoch:  137
best SSIM: 0.8464381098747253
best SSIM epoch:  124


51it [1:10:54, 84.09s/it]

best L1:  0.13182635605335236
best L1 epoch:  190
best SSIM: 0.9040126800537109
best SSIM epoch:  165


52it [1:12:21, 85.10s/it]

best L1:  0.14341306686401367
best L1 epoch:  252
best SSIM: 0.931383490562439
best SSIM epoch:  222


53it [1:13:49, 85.80s/it]

best L1:  0.1434725970029831
best L1 epoch:  475
best SSIM: 0.9316346645355225
best SSIM epoch:  257


54it [1:15:16, 86.28s/it]

best L1:  0.13599349558353424
best L1 epoch:  311
best SSIM: 0.9335466027259827
best SSIM epoch:  191


55it [1:16:44, 86.64s/it]

best L1:  0.16729936003684998
best L1 epoch:  313
best SSIM: 0.9062685370445251
best SSIM epoch:  262


56it [1:18:11, 86.88s/it]

best L1:  0.1597057729959488
best L1 epoch:  325
best SSIM: 0.9029427766799927
best SSIM epoch:  229


57it [1:19:37, 86.61s/it]

best L1:  0.12829375267028809
best L1 epoch:  496
best SSIM: 0.9480673670768738
best SSIM epoch:  177


58it [1:21:05, 86.86s/it]

best L1:  0.17200276255607605
best L1 epoch:  275
best SSIM: 0.9034411311149597
best SSIM epoch:  112


59it [1:22:32, 87.04s/it]

best L1:  0.1568334996700287
best L1 epoch:  318
best SSIM: 0.9243959784507751
best SSIM epoch:  213


60it [1:24:00, 87.17s/it]

best L1:  0.14688006043434143
best L1 epoch:  491
best SSIM: 0.9259443879127502
best SSIM epoch:  276


61it [1:25:27, 87.26s/it]

best L1:  0.15750569105148315
best L1 epoch:  416
best SSIM: 0.9373766779899597
best SSIM epoch:  332


62it [1:26:50, 85.97s/it]

best L1:  0.13959167897701263
best L1 epoch:  398
best SSIM: 0.9423765540122986
best SSIM epoch:  312


63it [1:28:16, 85.96s/it]

best L1:  0.17734327912330627
best L1 epoch:  244
best SSIM: 0.9134417176246643
best SSIM epoch:  244


64it [1:29:40, 85.52s/it]

best L1:  0.12993702292442322
best L1 epoch:  364
best SSIM: 0.9228552579879761
best SSIM epoch:  245


65it [1:31:08, 86.09s/it]

best L1:  0.13109123706817627
best L1 epoch:  288
best SSIM: 0.9590017199516296
best SSIM epoch:  246


66it [1:32:35, 86.51s/it]

best L1:  0.14107780158519745
best L1 epoch:  366
best SSIM: 0.92348313331604
best SSIM epoch:  313


67it [1:34:03, 86.83s/it]

best L1:  0.1575208455324173
best L1 epoch:  201
best SSIM: 0.9149660468101501
best SSIM epoch:  178
