In [1]:
import time
import random
import numpy as np
import pandas as pd
import sys
import pickle
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from functions.fftc import fft2c_new as fft2c
from functions.fftc import ifft2c_new as ifft2c
from functions.math import complex_abs, complex_mul, complex_conj
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_TTTpaper_fixMask, center_crop, scale_rss, normalize_separate_over_ch, rss_torch
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate2c_imagepair
### after you install bart 0.7.00 from https://mrirecon.github.io/bart/, import it as follows
sys.path.insert(0,'/cheng/bart-0.7.00/python/')
os.environ['TOOLBOX_PATH'] = "/cheng/bart-0.7.00/"
import bart


plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})

colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
FONTSIZE = 22

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7f03a00fd6f0>

### Load the data

In [2]:
# data path
path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_val.yaml'
path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_val/'

# path_test = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/TTT_brain_test_100.yaml'
# path_test_sensmaps = '/cheng/metaMRI/metaMRI/data_dict/TTT_paper/sensmap_brain_test/'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)

data_transform = UnetDataTransform_TTTpaper_fixMask('multicoil', mask_func = mask_function, use_seed=True)

# training dataset and data loader
testset = SliceDataset(dataset = path_test, path_to_dataset='', 
                path_to_sensmaps = path_test_sensmaps, provide_senmaps=True, 
                challenge="multicoil", transform = data_transform, use_dataset_cache=True)

test_dataloader = torch.utils.data.DataLoader(dataset = testset, batch_size = 1, 
                shuffle = False, generator = torch.Generator().manual_seed(1), pin_memory = False)


In [5]:
checkpoint_path = '/cheng/metaMRI/metaMRI/save/E11.12_joint(l1_1e-5)P_T300_300epoch/E11.12_joint(l1_1e-5)P_T300_300epoch_E300_best.pth'
# '/cheng/metaMRI/metaMRI/save/E11.12_sup(l1_1e-5)P_T300_150epoch/E11.12_sup(l1_1e-5)P_T300_150epoch_E150_best.pth'
model = Unet(in_chans=2, out_chans=2, chans=64, num_pool_layers=4, drop_prob=0.0)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)

### without TTT

In [31]:
loss_l1_history_=[]
loss_ssim_history_=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')

path_mask = '/cheng/metaMRI/ttt_for_deep_learning_cs/unet/test_data/anatomy_shift/mask2d'
with open(path_mask,'rb') as fn:
    mask2d = pickle.load(fn)
mask = torch.tensor(mask2d[0]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = mask.to(device)

for iter, batch in tqdm(enumerate(test_dataloader)): 
    kspace, sens_maps, sens_maps_conj, _, fname, slice_num = batch
    kspace = kspace.squeeze(0).to(device)
    sens_maps = sens_maps.squeeze(0).to(device)
    sens_maps_conj = sens_maps_conj.squeeze(0).to(device)

    # input k space
    input_kspace = kspace * mask + 0.0
    input_kspace = input_kspace.to(device)

    # gt image: x
    target_image = ifft2c(kspace)
    # rss combine
    target_image = rss_torch(target_image)
    # sensmap combine
    # target_image = complex_mul(target_image, sens_maps_conj).sum(dim=0, keepdim=False)

    # A†y
    train_inputs = ifft2c(input_kspace) #shape: coils,height,width,2
    train_inputs = rss_torch(train_inputs)
    # train_inputs = complex_mul(train_inputs, sens_maps_conj).sum(dim=0, keepdim=False) #shape: height,width,2
    train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
    # normalize input to have zero mean and std one
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)
    # fθ(A†y)
    train_outputs = model(train_inputs.unsqueeze(0))
    train_outputs = train_outputs.squeeze(0) * std + mean
    
    # supervised loss [x, fθ(A†y)]
    # [1, 2, 768, 392] -> [1, 768, 392]
    output_image_1c = complex_abs(torch.moveaxis(train_outputs.squeeze(0), 0, -1 )).unsqueeze(0)
    train_targets_1c = complex_abs(target_image).unsqueeze(0)
    #loss_sup = l1_loss(train_outputs_1c, train_targets_1c) / torch.sum(torch.abs(train_targets_1c))
    
    # center crop
    min_size = min(train_targets_1c.shape[-2:])
    crop = torch.Size([min_size, min_size])
    train_targets_1c = center_crop( train_targets_1c, crop )
    output_image_1c = center_crop( output_image_1c, crop )
    
    # SSIM = 1 - loss
    loss_ssim = 1 - ssim_fct(output_image_1c, train_targets_1c, data_range = train_targets_1c.max().unsqueeze(0)).item()

    #loss_l1_history_.append(loss_l1)
    loss_ssim_history_.append(loss_ssim)

100it [04:48,  2.88s/it]


In [32]:
#print("Testing average L1 loss: ", sum(loss_l1_history_) / len(loss_l1_history_))
print("Testing average SSIM loss: ", sum(loss_ssim_history_) / len(loss_ssim_history_))

Testing average SSIM loss:  0.8495896404981613


### Test-time training

In [3]:
def normalize(im1,im2):
    # im1: ground truth
    # im2: reconstruction
    im2 = (im2-im2.mean()) / im2.std()
    im2 *= im1.std()
    im2 += im1.mean()
    return im1,im2

In [8]:
best_loss_l1_history=[]
best_loss_l1_index_history=[]
best_loss_ssim_history=[]
best_loss_ssim_index_history=[]
ssim_fct = SSIMLoss()
l1_loss = torch.nn.L1Loss(reduction='sum')
adapt_lr = 0.00001
TTT_epoch = 1500
# each data point TTT
optimizer = torch.optim.Adam(model.parameters(),lr=adapt_lr)


path_mask = '/cheng/metaMRI/ttt_for_deep_learning_cs/unet/test_data/anatomy_shift/mask2d'
with open(path_mask,'rb') as fn:
    mask2d = pickle.load(fn)
mask = torch.tensor(mask2d[0]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = mask.to(device)

for iter, batch in tqdm(enumerate(test_dataloader)): 
    print('Date point: ', iter+1)
    kspace, sens_maps, sens_maps_conj, _, fname, slice_num = batch
    kspace = kspace.squeeze(0).to(device)
    sens_maps = sens_maps.squeeze(0).to(device)
    sens_maps_conj = sens_maps_conj.squeeze(0).to(device)

    # input k space
    input_kspace = kspace * mask + 0.0
    input_kspace = input_kspace.to(device)

    # gt image: x
    target_image = ifft2c(kspace)
    # rss combine
    target_image = rss_torch(target_image)
    # sensmap combine
    # target_image = complex_mul(target_image, sens_maps_conj).sum(dim=0, keepdim=False)
    train_targets_1c = complex_abs(target_image).unsqueeze(0)
    # center crop
    min_size = min(train_targets_1c.shape[-2:])
    crop = torch.Size([min_size, min_size])
    train_targets_1c = center_crop( train_targets_1c, crop )

    # A†y
    train_inputs = ifft2c(input_kspace) #shape: coils,height,width,2
    train_inputs = rss_torch(train_inputs)
    # train_inputs = complex_mul(train_inputs, sens_maps_conj).sum(dim=0, keepdim=False) #shape: height,width,2
    train_inputs = torch.moveaxis( train_inputs , -1, 0 ) # move complex channels to channel dimension
    # normalize input to have zero mean and std one
    train_inputs, mean, std = normalize_separate_over_ch(train_inputs, eps=1e-11)

    # model re-init
    model.load_state_dict(torch.load(checkpoint_path))
    
    loss_l1_history = []
    loss_ssim_history = []
    self_loss_history = []

    for iteration in range(TTT_epoch): 
        ####### inference ######
        # fθ(A†y)
        model_output = model(train_inputs.unsqueeze(0))
        model_output = model_output.squeeze(0) * std + mean
        model_output = torch.moveaxis(model_output.unsqueeze(0), 1, -1 )
        ###### training ######
        # S fθ(A†y)
        output_sens_image = complex_mul(model_output, sens_maps.squeeze(0))
        # FS fθ(A†y)
        Fimg = fft2c(output_sens_image)
        # MFS fθ(A†y) = A fθ(A†y)
        Fimg_forward = Fimg * mask
        # consistency loss [y, Afθ(A†y)]
        loss_self = l1_loss(Fimg_forward, input_kspace) / torch.sum(torch.abs(input_kspace))
    
        optimizer.zero_grad()
        loss_self.backward()
        optimizer.step()
        #train_loss += loss.item()
        #print('TTT loss: ',loss_self.item())
        self_loss_history.append(loss_self.item())

        ###### evaluation ######
        output_image_1c = complex_abs(model_output.squeeze(0)).unsqueeze(0)
        #loss_sup = l1_loss(train_outputs_1c, train_targets_1c) / torch.sum(torch.abs(train_targets_1c))
        # center crop
        output_image_1c = center_crop( output_image_1c, crop )
    
        # SSIM = 1 - loss
        train_targets_1c,output_image_1c = normalize(train_targets_1c,output_image_1c)
        loss_ssim = 1 - ssim_fct(output_image_1c, train_targets_1c, data_range = train_targets_1c.max().unsqueeze(0)).item()
        print(f"Epoch {iteration+1}/{TTT_epoch}, Loss: {loss_ssim:.4f}", end='\r')

        loss_ssim_history.append(loss_ssim)



    # best results for an example
    # best_loss_l1 = min(loss_l1_history)
    # print('best L1: ', best_loss_l1)
    # best_loss_l1_epoch = np.argmin(loss_l1_history)
    # print('best L1 epoch: ', best_loss_l1_epoch)
    best_loss_ssim = max(loss_ssim_history)
    print('best SSIM:', best_loss_ssim)
    best_loss_ssim_epoch = np.argmax(loss_ssim_history)
    print('best SSIM epoch: ', best_loss_ssim_epoch)

    # best_loss_l1_index_history.append(best_loss_l1_epoch)
    best_loss_ssim_index_history.append(best_loss_ssim_epoch)
    # best_loss_l1_history.append(best_loss_l1)
    best_loss_ssim_history.append(best_loss_ssim)


# print("Testing average L1 loss: ", sum(best_loss_l1_history) / len(best_loss_l1_history))
# print("Testing average L1 loss epoch: ", sum(best_loss_l1_index_history) / len(best_loss_l1_index_history))
print("Testing average SSIM loss: ", sum(best_loss_ssim_history) / len(best_loss_ssim_history))
print("Testing average SSIM loss epoch: ", sum(best_loss_ssim_index_history) / len(best_loss_ssim_index_history))

0it [00:00, ?it/s]

Date point:  1


  return compare_ssim(


Epoch 1498/1500, Loss: 0.9030

1it [03:29, 209.23s/it]

best SSIM: 0.90819734: 0.9032
best SSIM epoch:  858
Date point:  2
Epoch 1498/1500, Loss: 0.8936

2it [07:01, 211.29s/it]

best SSIM: 0.90278095: 0.8929
best SSIM epoch:  522
Date point:  3
Epoch 1499/1500, Loss: 0.8591

3it [12:24, 262.18s/it]

best SSIM: 0.8691576s: 0.8580
best SSIM epoch:  241
Date point:  4
Epoch 1499/1500, Loss: 0.7146

4it [19:26, 325.05s/it]

best SSIM: 0.7366221s: 0.7139
best SSIM epoch:  426
Date point:  5
Epoch 1260/1500, Loss: 0.9092

4it [25:03, 375.82s/it]

Epoch 1262/1500, Loss: 0.9097




KeyboardInterrupt: 