In [None]:
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--file")
parser.add_argument("--factor")
parser.add_argument("--model_folder") #-> if not specified, create a folder for it
parser.add_argument("--method", default='kcdip') # -> can be chosen

# can be customized
parser.add_argument("--kspace_mse_shape", default='v') # -> can be chosen
parser.add_argument("--kspace_boundary", action="store_true", default=False) #can be chosen
parser.add_argument("--double_arm", default="True") # -> can be chosen
parser.add_argument("--kspace_mse", action="store_true", default=False) #-> fixed

args = parser.parse_args()
    

######  set up ######
configs = dict()

configs['file'] = args.file
configs['model_folder'] = args.model_folder
configs['factor'] = float(args.factor)  #1.25, 1.5, 1.75, 2    
configs['method'] = args.method

import os, sys, glob
sys.path.append("./utils")

import datetime
import numpy as np
import torch
import torch.optim
import torch.nn.functional as F
from skimage.metrics import structural_similarity as compare_ssim
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

###
from utils_unet3D_ver3 import *
from utils_kspace_torch import *
from utils.models import *
from utils.prep import *
from utils.sr_common import *


current_time = datetime.datetime.now().strftime("%H%M%S")
res_dir = f"{configs['model_folder']}/{current_time}"

# save the argument in the output txt
print_log = lambda msg: write_log(msg, f'{res_dir}/log/{current_time}.txt')

args_dict = vars(args)
args_string = "\n".join([f"{key}: {value}" for key, value in args_dict.items()])
print_log(args_string)
print_log("Start!")


for new_dir in [res_dir, f"{res_dir}/model", f"{res_dir}/log"]:
    os.makedirs(new_dir, exist_ok=True)

In [None]:
OPT_OVER =  'net'

LR = 0.0001
LAMBDA_REG = 0.00001
OPTIMIZER = 'adam'

num_iter = 8200
reg_noise_std = 0

downsampler = lambda img: torch_sinc_downsampler_3D(img, factor)
downsampler2 = lambda img: torch_sinc_downsampler_3D(img, factor*factor)

In [None]:
img_HR_np = np.load(configs['file'])
img_HR_np = nor(img_HR_np)

img_LR_tensor = downsampler(img_HR_np)
img_LR_tensor = torch.clamp(img_LR_tensor, min=0.0, max=1.0)

img_LR2_tensor = downsampler2(img_HR_np)
img_LR2_tensor = torch.clamp(img_LR2_tensor, min=0.0, max=1.0)

imgs = {'orig_np':img_HR_np,
        'LR_np': img_LR_tensor.numpy()
       }

# image kspace
img_LR_var =  torch.clone(img_LR_tensor).type(dtype)
net_input = torch.clone(img_LR2_tensor).unsqueeze(0).unsqueeze(0).type(dtype)

# asymmetrical UNet for upsampling 
net = asym_UNet3D(in_channels=1, out_channels=1, trilinear=True, factor=factor).cuda()
MSE = torch.nn.MSELoss()

print_log(f'input shape: {net_input.shape}')

In [None]:
def closure():
    global i, net_input
    net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    gen_LR = net(net_input)
    main_loss1 = MSE(gen_LR[0,0], img_LR_var)
    total_loss = main_loss1

    l2_reg = 0
    for param in net.parameters():
        l2_reg += torch.norm(param)
    total_loss = total_loss + LAMBDA_REG*l2_reg
    total_loss.backward()

    
    if i % 100 == 0:
        # kspace_replacement
        out_HR = net(gen_LR)
        
        dip_img = torch2np(out_HR[0][0])
        dip_img_central_replacement = central_replacement_3d(imgs['orig_np'], dip_img, factor=factor)

        # psnr
        psnr_SS = volumetric_psnr(imgs['LR_np'], gen_LR[0,0]) #SS: self-supervised
        psnr_HR = volumetric_psnr(imgs['orig_np'], dip_img)
        psnr_kr = volumetric_psnr(imgs['orig_np'], dip_img_central_replacement)
        ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
        ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)

       #1001 newly add 
        item_list = [("Iteration %05d", i), ("PSNR_SS %.3f", psnr_SS), ("PSNR_HR %.3f", psnr_HR),
             ("PSNR_KR %.3f", psnr_kr), ("SSIM %.3f", ssim_index), ("SSIM_KR %.3f", ssim_index_kr),
             ("Loss %.5f", total_loss), ("img_mse %.5f", main_loss1), 
             ("reg_term %.5f", LAMBDA_REG*l2_reg)]

        output_line = save_info_dict(item_list, info_path=f'{res_dir}/log/info_dict_factor{factor*100}')
        print_log(output_line)

        if i>2000:
            out_HR_np = torch_to_np(out_HR)
            np.save(f'{res_dir}/HR_volume_{i}.npy', out_HR_np[0])
        if i%1000 ==0:
            torch.save(net.state_dict(), f'{res_dir}/model/epoch{i}_model_weights.pth')
    i += 1
    return total_loss

In [None]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)