In [3]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_sens_TTT, center_crop, scale
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7f1cbe2796d0>

### Load the data

In [4]:
# data path
# path_test = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/brain_test_AXT1POST_Skyra_5-8.yaml'
# path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/E11.1/Q/sensmap_test/'
path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_test_100.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_sens_TTT('multicoil', mask_func = mask_function, use_seed=True, mode='adapt')

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


In [5]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E11.7_joint(l1_CA-1e-3-4_P)_T300_150epoch/E11.7_joint(l1_CA-1e-3-4_P)_T300_150epoch_E87_best.pth'
model = Unet(in_chans=2, out_chans=2, chans=32, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

### without TTT

In [4]:
loss_l1_history_=[]
loss_ssim_history_=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    model_output = model(input_image)
    model_output = model_output * std + mean
    # supervised loss [x, fθ(A†y)]
    # L1
    loss_l1 = (l1_loss(model_output, target_image) / torch.sum(torch.abs(target_image))).item()
    # SSIM = 1 - loss
    output_image_1c = complex_abs(torch.moveaxis(model_output , 1, -1 ))
    # center crop
    min_size = min(ground_truth_image.shape[-2:])
    crop = torch.Size([min_size, min_size])
    ground_truth_image = center_crop( ground_truth_image, crop )
    output_image_1c = center_crop( output_image_1c, crop )
    loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image, data_range = ground_truth_image.max().unsqueeze(0)).item()
    #print('Test SSIM loss: ',loss_ssim)

    loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

100it [00:20,  4.88it/s]


In [6]:
print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average L1 loss:  0.28124732211232184
Testing average SSIM loss:  0.7954672402143479


### Test-time training

In [6]:
best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.0001
TTT_epoch = 500

for iter, batch in tqdm(enumerate(test_dataloader)): 
    input_image, target_image, ground_truth_image, mean, std, fname, slice_num, input_kspace, input_mask, target_kspace, target_mask, sens_maps, binary_background_mask = batch
    input_image = input_image.to(device)
    target_image = target_image.to(device)
    input_kspace = input_kspace.to(device)
    input_mask = input_mask.to(device)
    sens_maps = sens_maps.to(device)
    std = std.to(device)
    mean = mean.to(device)
    ground_truth_image = ground_truth_image.to(device)
    binary_background_mask = binary_background_mask.to(device)

    # model re-init
    model.load_state_dict(torch.load(checkpoint_path))

    # each data point TTT
    optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch):
        # fθ(A†y)
        model_output = model(input_image)
        model_output = model_output * std + mean
        # supervised loss [x, fθ(A†y)]
        # L1
        loss_l1 = (l1_loss(model_output, target_image) / torch.sum(torch.abs(target_image))).item()
        #print('Test L1 loss: ',loss_l1)
        loss_l1_history.append(loss_l1)
        # SSIM = 1 - loss
        output_image_1c = complex_abs(torch.moveaxis(model_output , 1, -1 ))

        # center crop
        min_size = min(ground_truth_image.shape[-2:])
        crop = torch.Size([min_size, min_size])
        ground_truth_image = center_crop( ground_truth_image, crop )
        output_image_1c = center_crop( output_image_1c, crop )

        loss_ssim = 1 - ssim_fct(output_image_1c, ground_truth_image, data_range = ground_truth_image.max().unsqueeze(0)).item()
        #print('Test SSIM loss: ',loss_ssim)
        loss_ssim_history.append(loss_ssim)

        # self-supervised loss
        # scale normalization
        scale_factor = scale(input_kspace.squeeze(0), model)
        input_kspace = scale_factor * input_kspace
        
        # fθ(A†y)
        model_output = torch.moveaxis(model_output, 1, -1 )
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps)
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * input_mask
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, input_kspace) / torch.sum(torch.abs(input_kspace))

        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())



    # 
    best_loss_l1 = min(loss_l1_history)
    print('best L1: ', best_loss_l1)
    best_loss_l1_epoch = np.argmin(loss_l1_history)
    print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

1it [01:09, 69.02s/it]

best L1:  0.2031061202287674
best L1 epoch:  0
best SSIM: 0.8499593138694763
best SSIM epoch:  6


2it [03:22, 106.67s/it]

best L1:  0.16225752234458923
best L1 epoch:  0
best SSIM: 0.8566301465034485
best SSIM epoch:  2


3it [05:48, 124.95s/it]

best L1:  0.17915204167366028
best L1 epoch:  0
best SSIM: 0.8703818321228027
best SSIM epoch:  7


4it [08:15, 133.70s/it]

best L1:  0.20116591453552246
best L1 epoch:  0
best SSIM: 0.846823513507843
best SSIM epoch:  11


5it [10:43, 138.59s/it]

best L1:  0.21164408326148987
best L1 epoch:  0
best SSIM: 0.845305323600769
best SSIM epoch:  15


6it [12:54, 136.09s/it]

best L1:  0.22145023941993713
best L1 epoch:  0
best SSIM: 0.8213953971862793
best SSIM epoch:  11


7it [15:21, 139.57s/it]

best L1:  0.21510517597198486
best L1 epoch:  0
best SSIM: 0.8729645013809204
best SSIM epoch:  4


8it [17:48, 141.99s/it]

best L1:  0.21809324622154236
best L1 epoch:  0
best SSIM: 0.8461397886276245
best SSIM epoch:  10


9it [20:15, 143.65s/it]

best L1:  0.21747466921806335
best L1 epoch:  0
best SSIM: 0.8439318537712097
best SSIM epoch:  13


10it [22:42, 144.77s/it]

best L1:  0.1904328465461731
best L1 epoch:  0
best SSIM: 0.8541547656059265
best SSIM epoch:  7


11it [25:10, 145.55s/it]

best L1:  0.17379258573055267
best L1 epoch:  0
best SSIM: 0.8653651475906372
best SSIM epoch:  2


12it [27:37, 146.13s/it]

best L1:  0.2131778746843338
best L1 epoch:  1
best SSIM: 0.8267037272453308
best SSIM epoch:  19


13it [30:05, 146.57s/it]

best L1:  0.21087972819805145
best L1 epoch:  0
best SSIM: 0.85866779088974
best SSIM epoch:  10


14it [32:32, 146.89s/it]

best L1:  0.1916617900133133
best L1 epoch:  0
best SSIM: 0.8700779676437378
best SSIM epoch:  7


15it [35:00, 147.10s/it]

best L1:  0.190070241689682
best L1 epoch:  0
best SSIM: 0.8410977125167847
best SSIM epoch:  11


16it [37:48, 153.39s/it]

best L1:  0.21558724343776703
best L1 epoch:  0
best SSIM: 0.8964644074440002
best SSIM epoch:  13


17it [40:27, 155.10s/it]

best L1:  0.2441447675228119
best L1 epoch:  0
best SSIM: 0.8810756802558899
best SSIM epoch:  13
