In [None]:
DEV_MODE = False

if DEV_MODE:
    class TestArg():
        def __init__(self, factor):
            self.file = "/work/users/c/c/cctsai/data/BCP_sample/357_T1w_MPR_NORM_3.npy"
            self.factor = factor
            self.model_folder = "dev_mode"
            self.input_img = False
            
    args = TestArg(factor=2)
else:
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument("--file")
    parser.add_argument("--factor")
    
    #basic config
    parser.add_argument("--model_folder")
    parser.add_argument("--input_img", action="store_true", default=False)
    parser.add_argument("--double_arm", default="False")
        
    args = parser.parse_args()
    
_name_ = args.file.split("/")[-1].split(".")[0]
model_folder = args.model_folder

factor = float(args.factor)  #1.25, 1.5, 1.75, 2    
    
import os, sys, glob
sys.path.append("./utils")
res_dir = f"/work/users/c/c/cctsai/res/{model_folder}/{_name_}/{int(factor*100)}"

import numpy as np
import torch
import torch.optim
import torch.nn.functional as F
from utils_unet3D_ver3 import *
from utils_kspace_torch import *
from utils.models import *
from utils.prep import *
from utils.sr_common import *

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

for new_dir in [res_dir, f"{res_dir}/model", f"{res_dir}/log"]:
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

import datetime
current_time = datetime.datetime.now().time()
print_log = lambda msg: write_log(msg, f'{res_dir}/log/age{_name_}_factor{factor*100}_{current_time}.txt') # from prep.py

print_log(f"Task name: {_name_}, factor={factor}")

# save the argument in the output txt
arguments_dict = vars(args)
arguments_string = "\n".join([f"{key}: {value}" for key, value in arguments_dict.items()])
print_log(arguments_string)

In [None]:
modality = 'T1'

input_depth = 1
INPUT =     'noise'
OPT_OVER =  'net'

LR = 0.0001
lambda_reg = 0.00001
OPTIMIZER = 'adam'

num_iter = 8200
reg_noise_std = 0

downsampler = lambda img: torch_sinc_downsampler_3D(img, factor)
downsampler2 = lambda img: torch_sinc_downsampler_3D(img, factor*factor)

In [None]:
img_HR_np = np.load(args.file)
if DEV_MODE:
    img_HR_np = img_HR_np[:160, :160, :160]

img_HR_np = nor(img_HR_np)
img_LR_tensor = downsampler(img_HR_np)
img_LR_tensor = torch.clamp(img_LR_tensor, min=0.0, max=1.0)

img_LR2_tensor = downsampler2(img_HR_np)
img_LR2_tensor = torch.clamp(img_LR2_tensor, min=0.0, max=1.0)

imgs = {'orig_np':img_HR_np,
        'LR_np': img_LR_tensor.numpy()
       }

# image kspace
img_LR_var =  torch.clone(img_LR_tensor).type(dtype)
net_input = torch.clone(img_LR2_tensor).unsqueeze(0).unsqueeze(0).type(dtype)

print_log(f'input shape: {net_input.shape}')

In [None]:
net = asym_UNet3D(in_channels=input_depth, out_channels=1, trilinear=True, factor=factor).cuda()
MSE = torch.nn.MSELoss()

In [None]:
def closure():
    global i, net_input
    net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    gen_LR = net(net_input)
    main_loss1 = MSE(gen_LR[0,0], img_LR_var)
    
    total_loss = main_loss1

    l2_reg = 0
    for param in net.parameters():
        l2_reg += torch.norm(param)
    total_loss = total_loss+lambda_reg*l2_reg
    total_loss.backward()

    
    if i % 100 == 0:
        # kspace_replacement
        out_HR = net(gen_LR)
        
        dip_img = torch2np(out_HR[0][0])
        dip_img_central_replacement = central_replacement_3d(imgs['orig_np'], dip_img, factor=factor)

        # psnr
        psnr_SS = volumetric_psnr(imgs['LR_np'], gen_LR[0,0]) #SS: self-supervised
        psnr_HR = volumetric_psnr(imgs['orig_np'], dip_img)
        psnr_kr = volumetric_psnr(imgs['orig_np'], dip_img_central_replacement)
        ssim_index = compare_ssim(imgs['orig_np'], dip_img, channel_axis=None)
        ssim_index_kr = compare_ssim(imgs['orig_np'], dip_img_central_replacement, channel_axis=None)

       #1001 newly add 
        item_list = [("Iteration %05d", i), ("PSNR_SS %.3f", psnr_SS), ("PSNR_HR %.3f", psnr_HR),
             ("PSNR_KR %.3f", psnr_kr), ("SSIM %.3f", ssim_index), ("SSIM_KR %.3f", ssim_index_kr),
             ("Loss %.5f", total_loss), ("img_mse %.5f", main_loss1), 
             ("reg_term %.5f", lambda_reg*l2_reg)]

        output_line = save_info_dict(item_list, info_path=f'{res_dir}/log/info_dict_factor{factor*100}')
        print_log(output_line)
        
#         print_log('Iteration %05d  PSNR_LR %.3f  PSNR_HR %.3f  PSNR_KR %.3f | SSIM %.3f  SSIM_KR %.3f | Loss %.5f  Img_mse %.5f  kspace_mse %.5f  kspace_boundary %.5f' % 
#                   (i, psnr_LR, psnr_HR, psnr_kr, ssim_index, ssim_index_kr, total_loss, main_loss1, kspace_mse1, kspace_boundary_loss))
 
        if i>2000:
            out_HR_np = torch_to_np(out_HR)
            np.save(f'{res_dir}/HR_volume_{i}.npy', out_HR_np[0])
        if i%1000 ==0:
            torch.save(net.state_dict(), f'{res_dir}/model/epoch{i}_model_weights.pth')
    i += 1
    return total_loss

In [None]:
#epo_cont_ =
#model_weights = torch.load(f'{res_dir}/epo{epo_cont_}model_weights.pth')
#net.load_state_dict(model_weights)

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)