In [1]:
DEV_MODE = False

if DEV_MODE:
    class TestArg():
        def __init__(self, file, factor):
            self.file = file
            self.factor = factor
            self.model_folder = "dev_mode"
            self.input_img = False
            self.kspace_mse = False
            self.kspace_boundary = False
            self.kbound_weight = 0.0001*5
            self.kbound_lower = 0.95
            self.kspace_mse_shape = 'u'
            self.kbound_outer_layer = 'False' 
            
    args = TestArg(file="/work/users/c/c/cctsai/data/BCP_sample/357_T1w_MPR_NORM_3.npy", factor=2)
else:
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument("--file")
    parser.add_argument("--factor")
    
    #basic config
    parser.add_argument("--model_folder")
    parser.add_argument("--input_img", action="store_true", default=False)
    parser.add_argument("--double_arm", default="False")
    
    #loss config
    parser.add_argument("--SSIM", action="store_true", default=False)
    parser.add_argument("--kspace_mse", action="store_true", default=False)
    parser.add_argument("--kspace_mse_shape", default='v')
    
    #kspace boundary
    parser.add_argument("--kspace_boundary", action="store_true", default=False)
    parser.add_argument("--kbound_weight", default = 0.0001*5)
    parser.add_argument("--kbound_lower", default = 0.95)
    parser.add_argument("--kbound_outer_layer", default = "True")
    
    # tv_weight
    parser.add_argument("--tv_weight", default = 0.001)
    
    args = parser.parse_args()
    
_name_ = args.file.split("/")[-1].split(".")[0]
model_folder = args.model_folder

factor = float(args.factor)  #1.25, 1.5, 1.75, 2  
tv_weight = float(args.tv_weight) #0.001
    
import os, sys, glob
sys.path.append("./utils")
res_dir = f"/work/users/c/c/cctsai/res/{model_folder}/{_name_}/{int(factor*100)}"

import numpy as np
import torch
import torch.optim
import torch.nn.functional as F
from utils_unet3D_ver3 import *
from utils_kspace_torch import *
from utils.models import *
from utils.prep import *
from utils.sr_common import *

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

for new_dir in [res_dir, f"{res_dir}/model", f"{res_dir}/log"]:
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

import datetime
current_time = datetime.datetime.now().time()
print_log = lambda msg: write_log(msg, f'{res_dir}/log/age{_name_}_factor{factor*100}_{current_time}.txt') # from prep.py

print_log(f"Task name: {_name_}, factor={factor}")

# save the argument in the output txt
arguments_dict = vars(args)
arguments_string = "\n".join([f"{key}: {value}" for key, value in arguments_dict.items()])
print_log(arguments_string)

usage: ipykernel_launcher.py [-h] [--file FILE] [--factor FACTOR]
                             [--model_folder MODEL_FOLDER] [--input_img]
                             [--double_arm DOUBLE_ARM] [--SSIM] [--kspace_mse]
                             [--kspace_mse_shape KSPACE_MSE_SHAPE]
                             [--kspace_boundary]
                             [--kbound_weight KBOUND_WEIGHT]
                             [--kbound_lower KBOUND_LOWER]
                             [--kbound_outer_layer KBOUND_OUTER_LAYER]
ipykernel_launcher.py: error: unrecognized arguments: -f /nas/longleaf/home/cctsai/.local/share/jupyter/runtime/kernel-b1113fed-11a7-4788-9a19-626110e87df5.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [2]:
modality = 'T1'
twoarm_ratio = 1/5

input_depth = 1
INPUT =     'noise'
OPT_OVER =  'net'

LR = 0.0001
kspace_weight = 0.0001 #stronger from 0.0001
kspace_boundary_weight = float(args.kbound_weight)
KBOUND_LOWER = float(args.kbound_lower)
KBOUND_OUTER_LAYER = args.kbound_outer_layer


lambda_reg = 0.00001
OPTIMIZER = 'adam'

num_iter = 10000
reg_noise_std = 0.03

if args.input_img:
    reg_noise_std = 0

downsampler = lambda img: torch_sinc_downsampler_3D(img, factor)

In [3]:
img_HR_np = np.load(args.file)
if DEV_MODE:
    img_HR_np = img_HR_np[:64, :64, :64]

img_HR_np = nor(img_HR_np)
img_LR_tensor = downsampler(img_HR_np)
img_LR_tensor = torch.clamp(img_LR_tensor, min=0.0, max=1.0)

imgs = {'orig_np':img_HR_np,
        'LR_np': img_LR_tensor.numpy()
       }

# image kspace
img_LR_var =  torch.clone(img_LR_tensor).type(dtype)
img_LR_kspace = to_k_space(img_LR_var)

# shape
img_LR_kspace.shape

torch.Size([32, 32, 32])

In [4]:
#net
net = UNet3D(in_channels=input_depth, out_channels=1, trilinear=True, conv_residual=True).cuda()

# inputs
side=img_HR_np.shape[-1]
if not args.input_img:
    net_input = get_noise_3d(input_depth, INPUT, (side,side,side)).type(dtype).detach()
else:
    upsampled = sinc_upsampler(img_LR_tensor, side, factor)
    net_input = upsampled.unsqueeze(0).unsqueeze(0).type(dtype)

net_input2 = downsampler(net_input[0][0]).unsqueeze(0).unsqueeze(1)
print_log(f'input1 shape: {net_input.shape}')
print_log(f'input2 shape: {net_input2.shape}')

input1 shape: torch.Size([1, 1, 64, 64, 64])
input2 shape: torch.Size([1, 1, 32, 32, 32])


In [5]:
def gen_ksapce_mask(shape, size):
    if shape.lower()=='v':
        kspace_mask = torch.zeros((size,size,size))
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = half_size//2
            else:
                value = half_size
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
    
    elif shape.lower()=='u':
        kspace_mask = torch.zeros((size,size,size))
        mag = np.linspace(0,5,size//2)
        kweight = np.power(2, mag)  
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = kweight[half_size-1]/2
            else:
                value = kweight[half_size-1]
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
        
    elif shape.lower()=='i':
        kspace_mask = torch.ones((size,size,size))
    return kspace_mask

In [6]:
size = img_LR_var.shape[-1]

if args.kspace_mse or args.kspace_boundary:
    assert args.kspace_mse_shape.lower() in ['u','v','i']
    
    kspace_mask = gen_ksapce_mask(args.kspace_mse_shape, size).type(dtype)
    kspace_mask2 = central_crop_3D(kspace_mask, factor)
    #kspace_mask2 = kspace_mask2/kspace_mask2.max()
    kspace_mask2 = kspace_mask2.type(dtype)

    def kspace_loss(z_pred, z_true, kspace_mask=kspace_mask):
        # = out_LR_kspace, img_LR_kspace
        z_diff = z_pred - z_true

        # Calculate the absolute value of the difference (real, imag) and square each element
        z_abs_sq = torch.square(torch.abs(z_diff))
        z_abs_sq = z_abs_sq*kspace_mask

        # Calculate the mean squared error (MSE) loss in complex number space
        mse_loss = torch.mean(z_abs_sq)
        return mse_loss

# def sum_abs(tensor):
#     return torch.sum(torch.abs(tensor))


def charbonnier_loss(prediction, target, epsilon=1e-6):
    error = torch.sqrt((prediction - target)**2 + epsilon**2)
    loss = torch.mean(error)
    return loss


def kboundary_loss_fn(LR_kspace, HR_kspace, factor, kbound_lower, kbound_outer_layer):
    bd_idx = 16 if DEV_MODE else 48
    
    HR_kspace = HR_kspace/(factor**3)
    HR_kspace_abs = torch.abs(HR_kspace)
    LR_kspace_abs = torch.abs(LR_kspace)

    HR_shell_mean = get_3D_shell(HR_kspace_abs, bd_idx).mean()
    LR_shell_mean = get_3D_shell(LR_kspace_abs, bd_idx-1).mean()

    # CHANGED TO MSE, FROM L1 LOSS
    diff = torch.pow(HR_shell_mean - LR_shell_mean, 2)
    first_loss = diff if((HR_shell_mean > 0.99*LR_shell_mean) or (HR_shell_mean < kbound_lower*LR_shell_mean)) else 0
    loss = first_loss

    if kbound_outer_layer=='True':
        #loss_weight_decay_rate = 0.95
        for layer in range(bd_idx+1, bd_idx+2):
            previous_shell_mean = get_3D_shell(HR_kspace_abs, layer-1).mean()
            current_shell_mean = get_3D_shell(HR_kspace_abs, layer).mean()
            diff = torch.pow(previous_shell_mean - current_shell_mean, 2)
            loss += diff if ((current_shell_mean > 0.99*previous_shell_mean) or (current_shell_mean < kbound_lower*previous_shell_mean)) else 0

    return loss, first_loss

In [7]:
#total loss
#total_loss = main_loss1 + kspace_mse1 + twoarm_ratio*(main_loss2 + kspace_mse2) + kspace_boundary_loss + lambda_reg*l2_reg

def closure():
    global i, net_input, net_input2

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        net_input2 = net_input2_saved + (noise2.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_HR_kspace = to_k_space(out_HR[0][0])
    out_LR_kspace = central_crop_3D(out_HR_kspace, factor)/(factor**3)
    out_LR = inv_fft(out_LR_kspace)
    out_LR = torch.clamp(out_LR, min=0.0, max=1.0)

    main_loss1 = charbonnier_loss(out_LR, img_LR_var)
    total_loss = main_loss1

#     if args.double_arm=="True":
#         out_HR2 = net(net_input2)
#         out_HR2 = torch.clamp(out_HR2, min=0.0, max=1.0)
#         out_HR2_kspace = to_k_space(out_HR2[0][0])
#         main_loss2 = charbonnier_loss(out_HR2, img_LR_var)
#         total_loss = total_loss + (twoarm_ratio*main_loss2)

    # Compute L2 regularization term for the last layer
    l2_reg = 0
    for param in net.parameters():
        l2_reg += torch.norm(param)
    total_loss = total_loss+lambda_reg*l2_reg
    
    tvloss = tv_weight*TVLoss3D(out_HR)
    total_loss += tvloss
    
    """kspace loss"""
    #kspace_mse
#     if args.kspace_mse:
#         kspace_mse1 = kspace_weight*kspace_loss(out_LR_kspace, img_LR_kspace)
#         total_loss = total_loss + kspace_mse1 
        
#         if args.double_arm=="True":
#             kspace_mse2 = kspace_weight*kspace_loss(out_HR2_kspace, img_LR_kspace)
#             total_loss = total_loss + (twoarm_ratio*kspace_mse2)
        
#     if args.kspace_boundary:
#         kboundary_total_loss, kboundary_first_loss = kboundary_loss_fn(img_LR_kspace, out_HR_kspace, factor, KBOUND_LOWER, KBOUND_OUTER_LAYER)
#         kspace_boundary_loss = kspace_boundary_weight*kboundary_total_loss
#         total_loss = total_loss + kspace_boundary_loss
    total_loss.backward()

    if i % 100 == 0:
        # kspace_replacement
        dip_img = out_HR[0][0].detach().cpu().numpy()
        dip_img_central_replacement = central_replacement_3d(imgs['orig_np'], dip_img, factor=factor)

        # psnr
        psnr_LR = volumetric_psnr(imgs['LR_np'], out_LR)
        psnr_HR = volumetric_psnr(imgs['orig_np'], dip_img)
        psnr_kr = volumetric_psnr(imgs['orig_np'], dip_img_central_replacement)
        ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
        ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)

       #1001 newly add 
        item_list = [("Iteration %05d", i), ("PSNR_LR %.3f", psnr_LR), ("PSNR_HR %.3f", psnr_HR),
             ("PSNR_KR %.3f", psnr_kr), ("SSIM %.3f", ssim_index), ("SSIM_KR %.3f", ssim_index_kr),
             ("Loss %.5f", total_loss), ("img_mse %.5f", main_loss1), 
             ("reg_term %.5f", lambda_reg*l2_reg)]
        
        item_list += [("tvloss %.5f", tvloss)]
#         if args.double_arm=="True":
#             psnr_LR2 = volumetric_psnr(imgs['LR_np'], out_HR2[0][0]) #  HR2 is actually at the dimension of LR
#             item_list += [("sec_img_mse %.5f", twoarm_ratio*main_loss2), ("PSNR_LR2 %.3f", psnr_LR2)]
        
        if args.kspace_mse:
            item_list += [("kspace_mse %.5f", kspace_mse1)]
#             if args.double_arm=="True":
#                 item_list += [("sec_kspace_mse %.5f", twoarm_ratio*kspace_mse2)]
                
        if args.kspace_boundary:
            item_list += [("kspace_boundary %.5f", kspace_boundary_loss)]

        output_line = save_info_dict(item_list, info_path=f'{res_dir}/log/info_dict_factor{factor*100}')
        print_log(output_line)
        
#         print_log('Iteration %05d  PSNR_LR %.3f  PSNR_HR %.3f  PSNR_KR %.3f | SSIM %.3f  SSIM_KR %.3f | Loss %.5f  Img_mse %.5f  kspace_mse %.5f  kspace_boundary %.5f' % 
#                   (i, psnr_LR, psnr_HR, psnr_kr, ssim_index, ssim_index_kr, total_loss, main_loss1, kspace_mse1, kspace_boundary_loss))
 
        if i>2000:
            out_HR_np = torch_to_np(out_HR)
            np.save(f'{res_dir}/HR_volume_{i}.npy', out_HR_np[0])
        if i%1000 ==0:
            torch.save(net.state_dict(), f'{res_dir}/model/epoch{i}_model_weights.pth')
    i += 1
    return total_loss

In [None]:
#epo_cont_ =
#model_weights = torch.load(f'{res_dir}/epo{epo_cont_}model_weights.pth')
#net.load_state_dict(model_weights)

net_input_saved = net_input.detach().clone()
net_input2_saved = net_input2.detach().clone()
noise = net_input.detach().clone()
noise2 = net_input2.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

Starting optimization with ADAM


  ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
  ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)


Iteration 00000  PSNR_LR 9.305  PSNR_HR 9.316  PSNR_KR 24.334  SSIM 0.039  SSIM_KR 0.387  Loss 0.31295  img_mse 0.31126  reg_term 0.00146  tvloss 0.00022  
Iteration 00100  PSNR_LR 17.509  PSNR_HR 17.402  PSNR_KR 34.847  SSIM 0.213  SSIM_KR 0.965  Loss 0.12126  img_mse 0.11978  reg_term 0.00146  tvloss 0.00002  
Iteration 00200  PSNR_LR 19.557  PSNR_HR 19.369  PSNR_KR 35.275  SSIM 0.268  SSIM_KR 0.971  Loss 0.09687  img_mse 0.09539  reg_term 0.00145  tvloss 0.00002  
Iteration 00300  PSNR_LR 20.886  PSNR_HR 20.554  PSNR_KR 34.761  SSIM 0.305  SSIM_KR 0.971  Loss 0.08062  img_mse 0.07915  reg_term 0.00145  tvloss 0.00002  
Iteration 00400  PSNR_LR 22.701  PSNR_HR 22.283  PSNR_KR 36.163  SSIM 0.357  SSIM_KR 0.976  Loss 0.06600  img_mse 0.06452  reg_term 0.00145  tvloss 0.00002  
Iteration 00500  PSNR_LR 24.162  PSNR_HR 23.490  PSNR_KR 35.979  SSIM 0.401  SSIM_KR 0.976  Loss 0.05571  img_mse 0.05424  reg_term 0.00145  tvloss 0.00002  
Iteration 00600  PSNR_LR 25.587  PSNR_HR 24.684  PSNR_

In [None]:
out_HR = net(net_input)
out_LR = downsampler(out_HR[0][0])
dip_img = out_HR[0][0].detach().cpu().numpy()
dip_img_central_replacement = central_replacement_3d(img_HR_np, dip_img, factor=factor)

print_log('Final iter performance')
print_log(str(volumetric_psnr(img_HR_np, out_HR[0][0])))
print_log(str(volumetric_psnr(img_HR_np, dip_img_central_replacement)))