In [2]:
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--file")
parser.add_argument("--factor")
parser.add_argument("--model_folder") #-> if not specified, create a folder for it
parser.add_argument("--method", default='kcdip') # -> can be chosen

# can be customized
parser.add_argument("--kspace_mse_shape", default='v') # -> can be chosen
parser.add_argument("--kspace_boundary", action="store_true", default=False) #can be chosen
parser.add_argument("--double_arm", default="True") # -> can be chosen
parser.add_argument("--kspace_mse", action="store_true", default=False) #-> fixed

args = parser.parse_args()
    

######  set up ######
configs = dict()

configs['file'] = args.file
configs['model_folder'] = args.model_folder
configs['factor'] = float(args.factor)  #1.25, 1.5, 1.75, 2    
configs['method'] = args.method

import json
with open("default_settings.json", "r") as json_file:
    default_settings = json.load(json_file)

if configs['method'] in ['kcdip','dip','diptv']:
    configs.update(default_settings[configs['method']])

import os, sys, glob
sys.path.append("./utils")

import datetime
import numpy as np
import torch
import torch.optim
import torch.nn.functional as F
from skimage.metrics import structural_similarity as compare_ssim
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

###
from utils_unet3D_ver3 import *
from utils_kspace_torch import *
from utils.models import *
from utils.prep import *
from utils.sr_common import *


current_time = datetime.datetime.now().strftime("%H%M%S")
res_dir = f"{configs['model_folder']}/{current_time}"

# save the argument in the output txt
print_log = lambda msg: write_log(msg, f'{res_dir}/log/{current_time}.txt')

args_dict = vars(args)
args_string = "\n".join([f"{key}: {value}" for key, value in args_dict.items()])
print_log(args_string)
print_log("Start!")


for new_dir in [res_dir, f"{res_dir}/model", f"{res_dir}/log"]:
    os.makedirs(new_dir, exist_ok=True)

Task name: 357_T1w_MPR_NORM_3, factor=2.0
file: /work/users/c/c/cctsai/data/BCP_sample/357_T1w_MPR_NORM_3.npy
factor: 2
model_folder: dev_mode
input_img: True
double_arm: True
kspace_mse: True
kspace_boundary: True
kbound_weight: 0.0005
kbound_lower: 0.95
kspace_mse_shape: u
kspace_mse_weight: 1.0
kbound_outer_layer: True


In [None]:
TWOARM_RATIO = 1/5 # ratio between unsupervsied and self-supervised
INPUT =     'noise'
OPT_OVER =  'net'

LR = 0.0001 #learning rate
KSPACE_WEIGHT = 0.0001
KBOUND_WEIGHT = 0.0001*5
KBOUND_LOWER = 0.95
LAMBDA_REG = 0.00001 # net weight regulation
OPTIMIZER = 'adam'
reg_noise_std = 0.03 # random noise regulation


downsampler = lambda img: torch_sinc_downsampler_3D(img, factor)

In [4]:
def chop_cube(arr):
    off=48
    return arr[off:-off, off:-off, off:-off]

img_HR_np = np.load(configs["file"])
img_HR_np = chop_cube(img_HR_np)
img_HR_np = nor(img_HR_np)

# image kspace
img_HR_var =  torch.from_numpy(img_HR_np).type(dtype)
img_HR_kspace = to_k_space(img_HR_var)

# inputs
HR_size = img_HR_np.shape[-1]
SR_size = int(img_HR_np.shape[-1]*factor)

net_input = get_noise_3d(input_depth, INPUT, (SR_size,SR_size,SR_size)).type(dtype).detach()
net_input2 = downsampler(net_input[0][0]).unsqueeze(0).unsqueeze(1)

print_log(f'input1 shape: {net_input.shape}')
print_log(f'input2 shape: {net_input2.shape}')

#net
filter_n = 64
net = UNet3D(in_channels=1, out_channels=1, filter_n=filter_n, trilinear=True, conv_residual=True).cuda()

torch.Size([96, 96, 96])

In [6]:
def gen_ksapce_mask(shape, size):
    if shape.lower()=='v':
        kspace_mask = torch.zeros((size,size,size))
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = half_size//2
            else:
                value = half_size
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
    
    elif shape.lower()=='u':
        kspace_mask = torch.zeros((size,size,size))
        mag = np.linspace(0,5,size//2)
        kweight = np.power(2, mag)  
        for half_size in range(1, size//2+1):
            if half_size ==size:
                value = kweight[half_size-1]/2
            else:
                value = kweight[half_size-1]
            kspace_mask = fill_3D_shell(kspace_mask, half_size, value)

        kspace_mask = kspace_mask/kspace_mask.max()
        
    elif shape.lower()=='i':
        kspace_mask = torch.ones((size,size,size))
    return kspace_mask

In [7]:
if configs["kspace_mse"] or configs["kspace_boundary"]:
    assert configs["kspace_mse_shape"].lower() in ['u','v','i']
    
    kspace_mask = gen_ksapce_mask(configs["kspace_mse_shape"], LR_size).type(dtype)
    kspace_mask2 = central_crop_3D(kspace_mask, factor)
    kspace_mask2 = kspace_mask2.type(dtype)

    def kspace_loss(z_pred, z_true, kspace_mask=kspace_mask):
        z_diff = z_pred - z_true

        # Calculate the absolute value of the difference (real, imag) and square each element
        z_abs_sq = torch.square(torch.abs(z_diff))
        z_abs_sq = z_abs_sq*kspace_mask

        # Calculate the mean squared error (MSE) loss in complex number space
        mse_loss = torch.mean(z_abs_sq)
        return mse_loss


def charbonnier_loss(prediction, target, epsilon=1e-6):
    error = torch.sqrt((prediction - target)**2 + epsilon**2)
    loss = torch.mean(error)
    return loss


def kboundary_loss_fn(LR_kspace, HR_kspace, factor, kbound_lower):
    bd_idx = LR_kspace.shape[-1]
    
    HR_kspace = HR_kspace/(factor**3)
    HR_kspace_abs = torch.abs(HR_kspace)
    LR_kspace_abs = torch.abs(LR_kspace)

    HR_shell_mean = get_3D_shell(HR_kspace_abs, bd_idx).mean()
    LR_shell_mean = get_3D_shell(LR_kspace_abs, bd_idx-1).mean()

    # CHANGED TO MSE, FROM L1 LOSS
    diff = torch.pow(HR_shell_mean - LR_shell_mean, 2)
    first_loss = diff if((HR_shell_mean > 0.99*LR_shell_mean) or (HR_shell_mean < kbound_lower*LR_shell_mean)) else 0
    loss = first_loss

    return loss, first_loss

In [8]:
def closure():
    global i, net_input, net_input2
    torch.cuda.empty_cache()

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        net_input2 = net_input2_saved + (noise2.normal_() * reg_noise_std)

    out_SR = net(net_input)
    out_SR_kspace = to_k_space(out_SR[0][0])
    out_HR_kspace = central_crop_3D(out_SR_kspace, factor)/(factor**3)
    out_HR = inv_fft(out_HR_kspace)
    out_HR = torch.clamp(out_HR, min=0.0, max=1.0)
    
    main_loss1 = charbonnier_loss(out_HR, img_HR_var)
    total_loss = main_loss1

    if configs["double_arm"]=="True":
        out_HR2 = net(net_input2)
        out_HR2 = torch.clamp(out_HR2, min=0.0, max=1.0)
        out_HR2_kspace = to_k_space(out_HR2[0][0])
        main_loss2 = charbonnier_loss(out_HR2, img_HR_var)
        total_loss = total_loss + (twoarm_ratio*main_loss2)

    # Compute L2 regularization term for the last layer
    l2_reg = 0
    for param in net.parameters():
        l2_reg += torch.norm(param)
    total_loss = total_loss + LAMBDA_REG*l2_reg
    
    """kspace loss"""
    #kspace_mse
    if configs["kspace_mse"]:
        kspace_mse1 = KSPACE_WEIGHT*kspace_loss(out_HR_kspace, img_HR_kspace)
        total_loss = total_loss + kspace_mse1 
        
        if configs["double_arm"]=="True":
            kspace_mse2 = KSPACE_WEIGHT*kspace_loss(out_HR2_kspace, img_HR_kspace)
            total_loss = total_loss + (twoarm_ratio*kspace_mse2)
        
    if configs["kspace_boundary"]:
        kboundary_total_loss, kboundary_first_loss = kboundary_loss_fn(img_HR_kspace, out_SR_kspace, factor, KBOUND_LOWER, KBOUND_OUTER_LAYER)
        kspace_boundary_loss = KBOUND_WEIGHT*kboundary_total_loss
        total_loss = total_loss + kspace_boundary_loss
    total_loss.backward()

    if i % 100 == 0:
        # psnr
        psnr_hR = volumetric_psnr(img_HR_np, out_HR)
        item_list = [("Iteration %05d", i), ("PSNR_hR %.3f", psnr_hR), 
             ("Loss %.5f", total_loss), ("img_mse %.5f", main_loss1), 
             ("reg_term %.5f", LAMBDA_REG*l2_reg)]
        
        if configs["double_arm"]=="True":            
            psnr_hR2 = volumetric_psnr(img_HR_np, out_HR2[0][0]) #  HR2 is actually at the dimension of LR
            item_list += [("sec_img_mse %.5f", twoarm_ratio*main_loss2), ("PSNR_hR2 %.3f", psnr_hR2)]
            
        if configs["kspace_mse"]:
            item_list += [("kspace_mse %.5f", kspace_mse1)]
            if configs["double_arm"]=="True":
                item_list += [("sec_kspace_mse %.5f", twoarm_ratio*kspace_mse2)]
                
        if configs["kspace_boundary"]:
            item_list += [("kspace_boundary %.5f", kspace_boundary_loss)]

        output_line = save_info_dict(item_list, info_path=f'{res_dir}/log/info_dict_factor{factor*100}')
        print_log(output_line)

        if i>2000:            
            out_SR_np = torch_to_np(out_SR)
            np.save(f'{res_dir}/SR_volume_{i}.npy', out_SR_np[0])
        if i%1000 ==0:
            torch.save(net.state_dict(), f'{res_dir}/model/epoch{i}_model_weights.pth')
    i += 1
    return total_loss

In [None]:
net_input_saved = net_input.detach().clone()
net_input2_saved = net_input2.detach().clone()
noise = net_input.detach().clone()
noise2 = net_input2.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

Starting optimization with ADAM
Iteration 00000  PSNR_hR 11.633  Loss 8.71903  img_mse 0.25997  reg_term 0.00146  sec_img_mse 0.05155  PSNR_hR2 11.749  kspace_mse 0.27104  sec_kspace_mse 0.08181  kspace_boundary 8.05321  
Iteration 00100  PSNR_hR 19.170  Loss 0.48350  img_mse 0.10582  reg_term 0.00146  sec_img_mse 0.02304  PSNR_hR2 16.752  kspace_mse 0.14008  sec_kspace_mse 0.03447  kspace_boundary 0.17863  
