In [1]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_sens_TTT, center_crop
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7fbd0a0c16f0>

### Load the data

In [2]:
# data path
# path_test = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/brain_test_AXT1POST_Skyra_5-8.yaml'
# path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/sensmap_test/'
path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_test_100.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_sens_TTT('multicoil', mask_func = mask_function, use_seed=True, mode='adapt')

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


In [3]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E11.6_joint(l1_1e-5_P)_T300_150epoch/E11.6_joint(l1_1e-5_P)_T300_150epoch_E150_best.pth'
model = Unet(in_chans=2, out_chans=2, chans=32, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

### without TTT

In [4]:
loss_l1_history_=[]
loss_ssim_history_=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    model_output = model(input_image)
    model_output = model_output * std + mean
    # supervised loss [x, fθ(A†y)]
    # L1
    loss_l1 = (l1_loss(model_output, target_image) / torch.sum(torch.abs(target_image))).item()
    # SSIM = 1 - loss
    output_image_1c = complex_abs(torch.moveaxis(model_output , 1, -1 ))
    # center crop
    min_size = min(ground_truth_image.shape[-2:])
    crop = torch.Size([min_size, min_size])
    ground_truth_image = center_crop( ground_truth_image, crop )
    output_image_1c = center_crop( output_image_1c, crop )
    loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image, data_range = ground_truth_image.max().unsqueeze(0)).item()
    #print('Test SSIM loss: ',loss_ssim)

    loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

100it [00:20,  4.88it/s]


In [6]:
print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average L1 loss:  0.28124732211232184
Testing average SSIM loss:  0.7954672402143479


### Test-time training

In [7]:
best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.0001
TTT_epoch = 500

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    # model re-init
    model.load_state_dict(torch.load(checkpoint_path))

    # each data point TTT
    optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch):
        # fθ(A†y)
        model_output = model(input_image)
        model_output = model_output * std + mean
        # supervised loss [x, fθ(A†y)]
        # L1
        loss_l1 = (l1_loss(model_output, target_image) / torch.sum(torch.abs(target_image))).item()
        #print('Test L1 loss: ',loss_l1)
        loss_l1_history.append(loss_l1)
        # SSIM = 1 - loss
        output_image_1c = complex_abs(torch.moveaxis(model_output , 1, -1 ))

        # center crop
        min_size = min(ground_truth_image.shape[-2:])
        crop = torch.Size([min_size, min_size])
        ground_truth_image = center_crop( ground_truth_image, crop )
        output_image_1c = center_crop( output_image_1c, crop )

        loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image, data_range = ground_truth_image.max().unsqueeze(0)).item()
        #print('Test SSIM loss: ',loss_ssim)
        loss_ssim_history.append(loss_ssim)

        # self-supervised loss
        # fθ(A†y)
        model_output = torch.moveaxis(model_output, 1, -1 )
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps)
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * input_mask
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, input_kspace) / torch.sum(torch.abs(input_kspace))

        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())



    # 
    best_loss_l1 = min(loss_l1_history)
    print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

0it [00:00, ?it/s]

1it [01:00, 60.85s/it]

best L1:  0.2099648416042328
best L1 epoch:  62
best SSIM: 0.8269193172454834
best SSIM epoch:  55


2it [02:01, 60.69s/it]

best L1:  0.1558871567249298
best L1 epoch:  81
best SSIM: 0.8308753371238708
best SSIM epoch:  81


3it [03:02, 60.64s/it]

best L1:  0.15853410959243774
best L1 epoch:  82
best SSIM: 0.8681472539901733
best SSIM epoch:  60


4it [04:02, 60.65s/it]

best L1:  0.22414475679397583
best L1 epoch:  78
best SSIM: 0.8117135167121887
best SSIM epoch:  63


5it [05:03, 60.68s/it]

best L1:  0.2320133000612259
best L1 epoch:  43
best SSIM: 0.8102740049362183
best SSIM epoch:  43


6it [05:55, 57.83s/it]

best L1:  0.18415094912052155
best L1 epoch:  105
best SSIM: 0.8515225052833557
best SSIM epoch:  71


7it [06:56, 58.73s/it]

best L1:  0.18991127610206604
best L1 epoch:  71
best SSIM: 0.8867068290710449
best SSIM epoch:  49


8it [07:56, 59.34s/it]

best L1:  0.20700393617153168
best L1 epoch:  55
best SSIM: 0.8538332581520081
best SSIM epoch:  41


9it [08:57, 59.58s/it]

best L1:  0.21448153257369995
best L1 epoch:  70
best SSIM: 0.8302619457244873
best SSIM epoch:  50


10it [09:57, 59.74s/it]

best L1:  0.1679583191871643
best L1 epoch:  86
best SSIM: 0.8673034310340881
best SSIM epoch:  65


11it [10:57, 59.85s/it]

best L1:  0.14868514239788055
best L1 epoch:  78
best SSIM: 0.8620362281799316
best SSIM epoch:  62


12it [11:57, 59.91s/it]

best L1:  0.23380343616008759
best L1 epoch:  75
best SSIM: 0.7255272269248962
best SSIM epoch:  35


13it [12:57, 59.92s/it]

best L1:  0.18946310877799988
best L1 epoch:  55
best SSIM: 0.8540046811103821
best SSIM epoch:  52


14it [13:56, 59.84s/it]

best L1:  0.19039054214954376
best L1 epoch:  124
best SSIM: 0.8191182017326355
best SSIM epoch:  109


15it [14:56, 59.83s/it]

best L1:  0.17936040461063385
best L1 epoch:  75
best SSIM: 0.8463544249534607
best SSIM epoch:  59


16it [16:02, 61.75s/it]

best L1:  0.1949545294046402
best L1 epoch:  138
best SSIM: 0.9105905294418335
best SSIM epoch:  93


17it [17:06, 62.25s/it]

best L1:  0.18903107941150665
best L1 epoch:  118
best SSIM: 0.9123648405075073
best SSIM epoch:  84


18it [18:09, 62.57s/it]

best L1:  0.15404941141605377
best L1 epoch:  79
best SSIM: 0.9260188937187195
best SSIM epoch:  70


19it [19:15, 63.64s/it]

best L1:  0.17913246154785156
best L1 epoch:  210
best SSIM: 0.838662326335907
best SSIM epoch:  172


20it [20:22, 64.52s/it]

best L1:  0.2617347836494446
best L1 epoch:  107
best SSIM: 0.8112187385559082
best SSIM epoch:  75


21it [21:27, 64.71s/it]

best L1:  0.22814726829528809
best L1 epoch:  190
best SSIM: 0.8276022672653198
best SSIM epoch:  115


22it [22:33, 65.15s/it]

best L1:  0.20765039324760437
best L1 epoch:  165
best SSIM: 0.896100640296936
best SSIM epoch:  64


23it [23:38, 65.12s/it]

best L1:  0.195331871509552
best L1 epoch:  98
best SSIM: 0.8735427260398865
best SSIM epoch:  61


24it [24:44, 65.20s/it]

best L1:  0.18004119396209717
best L1 epoch:  165
best SSIM: 0.857361376285553
best SSIM epoch:  130


25it [25:49, 65.30s/it]

best L1:  0.22384338080883026
best L1 epoch:  124
best SSIM: 0.836162269115448
best SSIM epoch:  63


26it [26:49, 63.77s/it]

best L1:  0.2332524210214615
best L1 epoch:  224
best SSIM: 0.7369975447654724
best SSIM epoch:  104


27it [27:55, 64.30s/it]

best L1:  0.17542509734630585
best L1 epoch:  143
best SSIM: 0.8781123161315918
best SSIM epoch:  99


28it [29:01, 64.71s/it]

best L1:  0.22579443454742432
best L1 epoch:  168
best SSIM: 0.8022077083587646
best SSIM epoch:  92


29it [30:07, 65.28s/it]

best L1:  0.2221979945898056
best L1 epoch:  158
best SSIM: 0.8503038883209229
best SSIM epoch:  119


30it [31:12, 65.24s/it]

best L1:  0.1954353302717209
best L1 epoch:  173
best SSIM: 0.8356055021286011
best SSIM epoch:  76


31it [32:13, 63.89s/it]

best L1:  0.1626785397529602
best L1 epoch:  59
best SSIM: 0.8599189519882202
best SSIM epoch:  49


32it [33:19, 64.51s/it]

best L1:  0.14683789014816284
best L1 epoch:  69
best SSIM: 0.9276238083839417
best SSIM epoch:  59


33it [34:20, 63.34s/it]

best L1:  0.17497548460960388
best L1 epoch:  85
best SSIM: 0.8774202466011047
best SSIM epoch:  54


34it [35:26, 64.16s/it]

best L1:  0.22621920704841614
best L1 epoch:  142
best SSIM: 0.8513139486312866
best SSIM epoch:  3


35it [36:31, 64.62s/it]

best L1:  0.19791394472122192
best L1 epoch:  124
best SSIM: 0.9026313424110413
best SSIM epoch:  77


36it [37:37, 64.97s/it]

best L1:  0.15918700397014618
best L1 epoch:  94
best SSIM: 0.9099578857421875
best SSIM epoch:  81


37it [38:43, 65.21s/it]

best L1:  0.21316008269786835
best L1 epoch:  117
best SSIM: 0.8454534411430359
best SSIM epoch:  110


38it [39:49, 65.52s/it]

best L1:  0.19320228695869446
best L1 epoch:  220
best SSIM: 0.8514959812164307
best SSIM epoch:  140


39it [40:54, 65.39s/it]

best L1:  0.14604486525058746
best L1 epoch:  128
best SSIM: 0.8920755386352539
best SSIM epoch:  81
