In [1]:
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--file")
parser.add_argument("--factor")
parser.add_argument("--model_folder") #-> if not specified, create a folder for it
parser.add_argument("--method", default='kcdip') # -> can be chosen

# can be customized
parser.add_argument("--kspace_mse_shape", default='v') # -> can be chosen
parser.add_argument("--kspace_boundary", action="store_true", default=False) #can be chosen
parser.add_argument("--double_arm", default="True") # -> can be chosen
parser.add_argument("--kspace_mse", action="store_true", default=False) #-> fixed

args = parser.parse_args()
    

######  set up ######
configs = dict()

configs['file'] = args.file
configs['model_folder'] = args.model_folder
configs['factor'] = float(args.factor)  #1.25, 1.5, 1.75, 2    
configs['method'] = args.method

import json
with open("default_settings.json", "r") as json_file:
    default_settings = json.load(json_file)

if configs['method'] in ['kcdip','dip','diptv']:
    configs.update(default_settings[configs['method']])

import os, sys, glob
sys.path.append("./utils")

import datetime
import numpy as np
import torch
import torch.optim
import torch.nn.functional as F
from skimage.metrics import structural_similarity as compare_ssim
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

###
from utils_unet3D_ver3 import *
from utils_kspace_torch import *
from utils.models import *
from utils.prep import *
from utils.sr_common import *


current_time = datetime.datetime.now().strftime("%H%M%S")
res_dir = f"{configs['model_folder']}/{current_time}"

# save the argument in the output txt
print_log = lambda msg: write_log(msg, f'{res_dir}/log/{current_time}.txt')

args_dict = vars(args)
args_string = "\n".join([f"{key}: {value}" for key, value in args_dict.items()])
print_log(args_string)
print_log("Start!")


for new_dir in [res_dir, f"{res_dir}/model", f"{res_dir}/log"]:
    os.makedirs(new_dir, exist_ok=True)

Task name: 357_T1w_MPR_NORM_3, factor=1.75
file: /work/users/c/c/cctsai/data/BCP_sample/357_T1w_MPR_NORM_3.npy
factor: 1.75
model_folder: dev_mode
input_img: True
double_arm: True
kspace_mse: True
kspace_boundary: False
kbound_weight: 0.0005
kbound_lower: 0.95
kspace_mse_shape: u
kspace_mse_weight: 1.0
kbound_outer_layer: True


In [2]:
TWOARM_RATIO = 1/5 # ratio between unsupervsied and self-supervised
OPT_OVER =  'net'

LR = 0.0001 #learning rate
KSPACE_WEIGHT = 0.0001
KBOUND_WEIGHT = 0.0001*5
KBOUND_LOWER = 0.95
LAMBDA_REG = 0.00001 # net weight regulation
OPTIMIZER = 'adam'
reg_noise_std = 0.03 # random noise regulation


downsampler = lambda img: torch_sinc_downsampler_3D(img, factor)

In [3]:
img_HR_np = np.load(configs["file"])
img_HR_np = nor(img_HR_np)

img_LR_tensor = downsampler(img_HR_np)
img_LR_tensor = torch.clamp(img_LR_tensor, min=0.0, max=1.0)

# given data
imgs = {'orig_np':img_HR_np,
        'LR_np': img_LR_tensor.numpy()
       }

# inputs
img_LR_var =  torch.clone(img_LR_tensor).type(dtype)
img_LR_kspace = to_k_space(img_LR_var)
HR_size = img_HR_np.shape[-1]
LR_size = img_LR_var.shape[-1]

net_input = get_noise_3d(input_depth, INPUT, (HR_size,HR_size,HR_size)).type(dtype).detach()
net_input2 = downsampler(net_input[0][0]).unsqueeze(0).unsqueeze(1)

print_log(f'input1 shape: {net_input.shape}')
print_log(f'input2 shape: {net_input2.shape}')

#net
net = UNet3D(in_channels=1, out_channels=1, trilinear=True, conv_residual=True).cuda()

torch.Size([36, 36, 36])

In [5]:
def gen_ksapce_mask(shape, size):
    if shape.lower()=='v':
        kspace_mask = torch.zeros((size,size,size))
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = half_size//2
            else:
                value = half_size
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
    
    elif shape.lower()=='u':
        kspace_mask = torch.zeros((size,size,size))
        mag = np.linspace(0,5,size//2)
        kweight = np.power(2, mag)  
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = kweight[half_size-1]/2
            else:
                value = kweight[half_size-1]
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
        
    elif shape.lower()=='i':
        kspace_mask = torch.ones((size,size,size))
    return kspace_mask

In [6]:
if configs["kspace_mse"] or configs["kspace_boundary"]:
    assert configs["kspace_mse_shape"].lower() in ['u','v','i']
    
    kspace_mask = gen_ksapce_mask(configs["kspace_mse_shape"], LR_size).type(dtype)
    kspace_mask2 = central_crop_3D(kspace_mask, factor)
    kspace_mask2 = kspace_mask2.type(dtype)

    def kspace_loss(z_pred, z_true, kspace_mask=kspace_mask):
        z_diff = z_pred - z_true

        # Calculate the absolute value of the difference (real, imag) and square each element
        z_abs_sq = torch.square(torch.abs(z_diff))
        z_abs_sq = z_abs_sq*kspace_mask

        # Calculate the mean squared error (MSE) loss in complex number space
        mse_loss = torch.mean(z_abs_sq)
        return mse_loss


def charbonnier_loss(prediction, target, epsilon=1e-6):
    error = torch.sqrt((prediction - target)**2 + epsilon**2)
    loss = torch.mean(error)
    return loss


def kboundary_loss_fn(LR_kspace, HR_kspace, factor, kbound_lower):
    bd_idx = LR_kspace.shape[-1]
    
    HR_kspace = HR_kspace/(factor**3)
    HR_kspace_abs = torch.abs(HR_kspace)
    LR_kspace_abs = torch.abs(LR_kspace)

    HR_shell_mean = get_3D_shell(HR_kspace_abs, bd_idx).mean()
    LR_shell_mean = get_3D_shell(LR_kspace_abs, bd_idx-1).mean()

    # CHANGED TO MSE, FROM L1 LOSS
    diff = torch.pow(HR_shell_mean - LR_shell_mean, 2)
    first_loss = diff if((HR_shell_mean > 0.99*LR_shell_mean) or (HR_shell_mean < kbound_lower*LR_shell_mean)) else 0
    loss = first_loss

    return loss, first_loss

In [7]:
def closure():
    global i, net_input, net_input2

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        net_input2 = net_input2_saved + (noise2.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_HR_kspace = to_k_space(out_HR[0][0])
    out_LR_kspace = central_crop_3D(out_HR_kspace, factor)/(factor**3)
    out_LR = inv_fft(out_LR_kspace)
    out_LR = torch.clamp(out_LR, min=0.0, max=1.0)

    main_loss1 = charbonnier_loss(out_LR, img_LR_var)
    total_loss = main_loss1

    if configs['double_arm'] == "True":
        out_HR2 = net(net_input2)
        out_HR2 = torch.clamp(out_HR2, min=0.0, max=1.0)
        out_HR2_kspace = to_k_space(out_HR2[0][0])
        main_loss2 = charbonnier_loss(out_HR2, img_LR_var)
        total_loss = total_loss + (TWOARM_RATIO*main_loss2)

    # TV loss
    if configs['method'] == 'diptv':
        tvloss = tv_weight*TVLoss3D(out_HR)
        total_loss += tvloss
    
    # Compute L2 regularization term for the last layer
    l2_reg = 0
    for param in net.parameters():
        l2_reg += torch.norm(param)        
    total_loss = total_loss+LAMBDA_REG*l2_reg
    
    """kspace loss"""
    #kspace_mse
    if configs["kspace_mse"]:
        kspace_mse1 = KSPACE_WEIGHT*kspace_loss(out_LR_kspace, img_LR_kspace)
        total_loss = total_loss + kspace_mse1 
        
        if configs["double_arm"] =="True":
            kspace_mse2 = KSPACE_WEIGHT*kspace_loss(out_HR2_kspace, img_LR_kspace)
            total_loss = total_loss + (TWOARM_RATIO*kspace_mse2)
        
    if configs["kspace_boundary"]:
        kboundary_total_loss, kboundary_first_loss = kboundary_loss_fn(img_LR_kspace, out_HR_kspace, factor, KBOUND_LOWER)
        kspace_boundary_loss = KBOUND_WEIGHT*kboundary_total_loss
        total_loss = total_loss + kspace_boundary_loss
    total_loss.backward()

    if i % 100 == 0:
        # kspace_replacement
        dip_img = out_HR[0][0].detach().cpu().numpy()
        dip_img_central_replacement = central_replacement_3d(imgs['orig_np'], dip_img, factor=factor)

        # psnr
        psnr_LR = volumetric_psnr(imgs['LR_np'], out_LR)
        psnr_HR = volumetric_psnr(imgs['orig_np'], dip_img)
        psnr_kr = volumetric_psnr(imgs['orig_np'], dip_img_central_replacement)
        ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
        ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)

       #1001 newly add 
        item_list = [("Iteration %05d", i), ("PSNR_LR %.3f", psnr_LR), ("PSNR_HR %.3f", psnr_HR),
             ("PSNR_KR %.3f", psnr_kr), ("SSIM %.3f", ssim_index), ("SSIM_KR %.3f", ssim_index_kr),
             ("Loss %.5f", total_loss), ("img_mse %.5f", main_loss1), 
             ("reg_term %.5f", LAMBDA_REG*l2_reg)]
        
        if configs["double_arm"] =="True":
            psnr_LR2 = volumetric_psnr(imgs['LR_np'], out_HR2[0][0]) #  HR2 is actually at the dimension of LR
            item_list += [("sec_img_mse %.5f", TWOARM_RATIO*main_loss2), ("PSNR_LR2 %.3f", psnr_LR2)]
        
        if configs["kspace_mse"]:
            item_list += [("kspace_mse %.5f", kspace_mse1)]
            if configs["double_arm"]=="True":
                item_list += [("sec_kspace_mse %.5f", TWOARM_RATIO*kspace_mse2)]
                
        if configs["kspace_boundary"]:
            item_list += [("kspace_boundary %.5f", kspace_boundary_loss)]

        output_line = save_info_dict(item_list, info_path=f'{res_dir}/log/info_dict_factor{factor*100}')
        print_log(output_line)
        
 
        if i>2000:
            out_HR_np = torch_to_np(out_HR)
            np.save(f'{res_dir}/HR_volume_{i}.npy', out_HR_np[0])
        if i%1000 ==0:
            torch.save(net.state_dict(), f'{res_dir}/model/epoch{i}_model_weights.pth')
    i += 1
    return total_loss

In [None]:
net_input_saved = net_input.detach().clone()
net_input2_saved = net_input2.detach().clone()
noise = net_input.detach().clone()
noise2 = net_input2.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, configs['num_iter'])

Starting optimization with ADAM


  ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
  ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)


Iteration 00000  PSNR_LR 6.387  PSNR_HR 6.917  PSNR_KR 33.519  SSIM 0.071  SSIM_KR 0.920  Loss 0.58058  img_mse 0.44578  reg_term 0.00146  sec_img_mse 0.08530  PSNR_LR2 6.874  kspace_mse 0.04034  sec_kspace_mse 0.00770  
Iteration 00100  PSNR_LR 12.359  PSNR_HR 12.900  PSNR_KR 37.406  SSIM 0.196  SSIM_KR 0.975  Loss 0.27173  img_mse 0.21721  reg_term 0.00146  sec_img_mse 0.04251  PSNR_LR2 12.758  kspace_mse 0.00869  sec_kspace_mse 0.00186  
Iteration 00200  PSNR_LR 13.939  PSNR_HR 14.459  PSNR_KR 37.650  SSIM 0.218  SSIM_KR 0.977  Loss 0.22370  img_mse 0.18002  reg_term 0.00145  sec_img_mse 0.03506  PSNR_LR2 14.315  kspace_mse 0.00592  sec_kspace_mse 0.00124  
Iteration 00300  PSNR_LR 15.439  PSNR_HR 15.940  PSNR_KR 37.482  SSIM 0.238  SSIM_KR 0.978  Loss 0.18703  img_mse 0.15104  reg_term 0.00145  sec_img_mse 0.02938  PSNR_LR2 15.771  kspace_mse 0.00425  sec_kspace_mse 0.00091  
Iteration 00400  PSNR_LR 16.819  PSNR_HR 17.298  PSNR_KR 37.202  SSIM 0.258  SSIM_KR 0.977  Loss 0.15763  i