In [None]:
import os
from threading import Thread  # needed since the denoiser is running in parallel
import queue

import numpy as np
import torch
import torch.optim
from models.skip import skip  # our network

from utils.utils import *  # auxiliary functions
from utils.mine_blur_utils2 import *  # blur functions
from utils.data import Data  # class that holds img, psnr, time
from skimage.metrics import structural_similarity as ssim

from skimage.restoration import denoise_nl_means

from scipy.signal import convolve2d

In [None]:
# got GPU? - if you are not getting the exact article results set CUDNN to False
CUDA_FLAG = True
CUDNN = True 
if CUDA_FLAG:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # GPU accelerated functionality for common operations in deep neural nets
    torch.backends.cudnn.enabled = CUDNN
    # benchmark mode is good whenever your input sizes for your network do not vary.
    # This way, cudnn will look for the optimal set of algorithms for that particular 
    # configuration (which takes some time). This usually leads to faster runtime.
    # But if your input sizes changes at each iteration, then cudnn will benchmark every
    # time a new size appears, possibly leading to worse runtime performances.
    torch.backends.cudnn.benchmark = CUDNN
    # torch.backends.cudnn.deterministic = True
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

In [None]:
NOISE_SIGMA = 5
STD_BLUR    = 1.6
DIM_FILTER  = 21
BLUR_TYPE = 'gauss_blur'  # 'gauss_blur' or 'uniform_blur' that the two only options
GRAY_SCALE = False  # if gray scale is False means we have rgb image, the psnr will be compared on Y. ch.
                    # if gray scale is True it will turn rgb to gray scale
USE_FOURIER = False

# graphs labels:
X_LABELS = ['Iterations']*3
Y_LABELS = ['PSNR between x and net (db)', 'PSNR with original image (db)', 'loss']

# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)
# for example use data_dict['Clean'].img to get the clean image
ORIGINAL  = 'Clean'
CORRUPTED = 'Blurred'
DIP_NLM   = 'DIP-TTGV'

In [None]:
def rgb2yuv(rgb):
    """
    将RGB图像转换为YUV颜色空间
    
    参数:
    rgb: RGB图像，形状为(3, H, W)或(H, W, 3)的numpy数组
    
    返回:
    yuv: YUV图像，形状与输入相同
    """
    if len(rgb.shape) == 3 and rgb.shape[0] == 3:
        # 形状为(3, H, W)
        r, g, b = rgb[0, :, :], rgb[1, :, :], rgb[2, :, :]
    elif len(rgb.shape) == 3 and rgb.shape[2] == 3:
        # 形状为(H, W, 3)
        r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    else:
        raise ValueError("不支持的图像形状")
    
    # 转换公式
    y = 0.299 * r + 0.587 * g + 0.114 * b
    u = -0.147 * r - 0.289 * g + 0.436 * b + 0.5
    v = 0.615 * r - 0.515 * g - 0.100 * b + 0.5
    
    # 确保值在[0, 1]范围内
    y = np.clip(y, 0, 1)
    u = np.clip(u, 0, 1)
    v = np.clip(v, 0, 1)
    
    if len(rgb.shape) == 3 and rgb.shape[0] == 3:
        # 返回形状为(3, H, W)
        return np.stack([y, u, v], axis=0)
    else:
        # 返回形状为(H, W, 3)
        return np.stack([y, u, v], axis=2)

def compare_SSIM(img1, img2, on_y=False, gray_scale=True, data_range=1.0):
    """
    计算两幅图像之间的结构相似性指数(SSIM)
    
    参数:
    img1, img2: 输入图像
    on_y: 是否只计算Y通道(适用于彩色图像)
    gray_scale: 是否为灰度图像
    data_range: 图像数据的范围(通常为1.0或255)
    """
    if gray_scale:
        # 灰度图像直接计算SSIM
        return ssim(img1, img2, data_range=data_range)
    else:
        if on_y:
            # 彩色图像但只计算Y通道
            if len(img1.shape) == 3 and img1.shape[0] == 3:
                # 转换到YUV颜色空间并提取Y通道
                img1_y = rgb2yuv(img1)[0, :, :]
                img2_y = rgb2yuv(img2)[0, :, :]
                return ssim(img1_y, img2_y, data_range=data_range)
            else:
                # 已经是单通道图像
                return ssim(img1, img2, data_range=data_range)
        else:
            # 计算多通道SSIM
            if len(img1.shape) == 3 and img1.shape[0] == 3:
                # 对于彩色图像，计算每个通道的SSIM然后取平均
                ssim_values = []
                for i in range(3):
                    ssim_val = ssim(img1[i, :, :], img2[i, :, :], data_range=data_range)
                    ssim_values.append(ssim_val)
                return np.mean(ssim_values)
            else:
                # 单通道图像
                return ssim(img1, img2, data_range=data_range)

In [None]:
def load_imgs_deblurring(fname, blur_type, noise_sigma,STD_BLUR, DIM_FILTER,plot=False):
    """  Loads an image, and add gaussian blur
    Args: 
         fname: path to the image
         blur_type: 'uniform' or 'gauss'
         noise_sigma: noise added after blur
         covert2gray: should we convert to gray scale image?
         plot: will plot the images
    Out:
         dictionary of images and dictionary of psnrs
    """
    img_pil, img_np = load_and_crop_image(fname)        
    if GRAY_SCALE:
        img_np = rgb2gray(img_pil)
    kernel = get_h(blur_type,STD_BLUR,DIM_FILTER)
    kernel_torch = np_to_torch(kernel)  
    blurred = torch_to_np(blur_th(np_to_torch(img_np), kernel_torch))
    blurred = np.clip(blurred + np.random.normal(scale=noise_sigma/255., size=blurred.shape), 0, 1).astype(np.float32)
    ssim_val = compare_SSIM(img_np, blurred, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
    
    data_dict = { ORIGINAL: Data(img_np), 
                 CORRUPTED: Data(blurred, 
                                 compare_PSNR(img_np, blurred, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE),
                                 ssim_val) }  # 添加SSIM值
    # data_dict = { ORIGINAL: Data(img_np), 
    #              CORRUPTED: Data(blurred, compare_PSNR(img_np, blurred,   on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)) }
    if plot:
        plot_dict(data_dict)
    return data_dict,kernel_torch

In [None]:
# Get the LR and HR images
data_dict,kernel_torch = load_imgs_deblurring('datasets/watercastle.png', BLUR_TYPE, NOISE_SIGMA,STD_BLUR, DIM_FILTER,plot=True)

In [None]:
# 拉普拉斯核（用于边缘检测，对噪声敏感）
lap_kernel = np.array([[1,-2,1], [-2, 4, -2], [1,-2,1]])
# 获取退化图像的高和宽
h = data_dict[CORRUPTED].img.shape[2]
w = data_dict[CORRUPTED].img.shape[1]

def estimate_variance(img):
    # 用拉普拉斯核对图像卷积，得到边缘响应（含噪声）
    out = convolve2d(img, lap_kernel, mode='valid')
    # 通过卷积结果的绝对值之和估计噪声方差
    out = np.sum(np.abs(out))
    out = (out * np.sqrt(0.5 * np.pi) / (6 * (h-2) * (w-2)))  # 归一化
    return out


print(data_dict[CORRUPTED].img[:,:,:].shape)
# 估计噪声标准差（转换为[0,255]范围）
NOISE_SIGMA = estimate_variance(data_dict[CORRUPTED].img[0,:,:])*255
print(NOISE_SIGMA)

In [None]:
def get_network_and_input(img_shape, input_depth=32, pad='reflection',
                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,
                          act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4,
                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'
    """ Getting the relevant network and network input (based on the image shape and input depth)
        We are using the same default params as in DIP article
        img_shape - the image shape (ch, x, y)
    """
    n_channels = img_shape[0]
    net = skip(input_depth, n_channels,
               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,
               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,
               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,
               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,
               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)
    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()
    return net, net_input

In [None]:
size = data_dict['Clean'].img.shape
h = size[-2]
w = size[-1]
Dh_psf = np.array([ [0, 0, 0], [1, -1, 0], [0, 0, 0]])
Dv_psf = np.array([ [0, 1, 0], [0, -1, 0], [0, 0, 0]])
Id_psf = np.array([[1]])

Id_DFT = torch.from_numpy(psf2otf(Id_psf, [h,w])).cuda()
Dh_DFT = torch.from_numpy(psf2otf(Dh_psf, [h,w])).cuda()
Dv_DFT = torch.from_numpy(psf2otf(Dv_psf, [h,w])).cuda()

DhT_DFT = torch.conj(Dh_DFT)
DvT_DFT = torch.conj(Dv_DFT)


#定义散度
def div(zx, zy):
    [zxx, zxy] = D(zx, Dh_DFT, Dv_DFT)
    [zyx, zyy] = D(zy, Dh_DFT, Dv_DFT)
    div = zxx + zyy
    return div

In [None]:
def train_via_admm(net, net_input, kernel_torch,y,  noise_lev,tau, org_img=None,                      # y is the noisy image
                   plot_array={}, algorithm_name="", admm_iter=5000, save_path="",           # path to save params
                   LR=0.001,tao = 0.001,beta = 2,alpha1 = 1,alpha2 = 10,gama = 0.03,w1 = 0.005,sigma1 = 0.03,            # learning rate
                   rou=0.0008, LR_x=None, noise_factor=0.033,        #0.033  LR_x needed only if method!=fixed_point
                   threshold=40, threshold_step=0.01, increase_reg=0.033):                # increase regularization 
    """ training the network using
        ## Must Params ##
        net                 - the network to be trained
        net_input           - the network input
        denoiser_function   - an external denoiser function, used as black box, this function
                              must get numpy noisy image, and return numpy denoised image
        y                   - the noisy image
        sigma               - the noise level (int 0-255)
        
        # optional params #
        org_img             - the original image if exist for psnr compare only, or None (default)
        plot_array          - prints params at the begging of the training and plot images at the required indices
        admm_iter           - total number of admm epoch
        LR                  - the lr of the network in admm (step 2)
        sigma_f             - the sigma to send the denoiser function
        update_iter         - denoised image updated every 'update_iter' iteration
        method              - 'fixed_point' or 'grad' or 'mixed' 
        algorithm_name      - the name that would show up while running, just to know what we are running ;)
                
        # equation params #  
        beta                - regularization parameter (lambda in the article)
        mu                  - ADMM parameter
        LR_x                - learning rate of the parameter x, needed only if method!=fixed point
        # more
        noise_factor       - the amount of noise added to the input of the network
        threshold          - when the image become close to the noisy image at this psnr
        increase_reg       - we going to increase regularization by this amount
        threshold_step     - and keep increasing it every step
    """
    # To print
    list_psnr=[]
    list_ssim=[]
    list_stopping=[]

    # get optimizer and loss function:
    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss
    # additional noise added to the input:
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    if org_img is not None: 
        psnr_y = compare_PSNR(org_img, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)  # get the noisy image psnr
        ssim_y = compare_SSIM(org_img, y, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)  # 计算噪声图像的SSIM
    # optimizer and scheduler
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt
    
    y_torch = np_to_torch(y).type(dtype)
    x = y.copy()
    u   = 0 * y_torch
    v1  = 0 * y_torch
    v2  = 0 * y_torch
    t  = 0 * y_torch
    q11  = 0 * y_torch
    q12  = 0 * y_torch
    q21  = 0 * y_torch
    p1  = 0 * y_torch
    p2  = 0 * y_torch
    # (v1, v2) = np.zeros_like(y), np.zeros_like(y)
    # (p1, p2) = np.zeros_like(y), np.zeros_like(y)
    # q11 =  np.zeros_like(y)
    # q12 = np.zeros_like(y)
    # q21 = np.zeros_like(y)
    # q22 = np.zeros_like(y)
    f_x, avg, avg2, avg3 = x.copy(), np.rint(y), np.rint(y), np.rint(y)
    img_queue = queue.Queue()
    
    #inner_iter=1
    [v11, v12] = D(v1, Dh_DFT, Dv_DFT)
    [v21, v22] = D(v2, Dh_DFT, Dv_DFT)
    for i in range(1, 1 + admm_iter):
        
        rho = tau*noise_lev*np.sqrt(y.shape[0]*y.shape[1]*y.shape[2] - 1)

      # step 1, update network:
        optimizer.zero_grad()
        net_input = net_input_saved + (noise.normal_() * noise_factor)
        out = net(net_input)
        out_np = torch_to_np(out)
              
      # loss:
      #   [Dh_out, Dv_out] = D(out, Dh_DFT, Dv_DFT) #computing the gradient
      #   Dh_out_np        = torch_to_np(Dh_out)
      #   Dv_out_np        = torch_to_np(Dv_out)
        loss_y = mse(blur_th(out, kernel_torch), y_torch)
        loss_x = mse(out, (t - u).type(dtype))
        total_loss = loss_y + rou * loss_x
        total_loss.backward()
        optimizer.step()
          
      # step 2
        u = ((u + tao * div(p1, p2) + tao * rou *(out + t )) / (1 + rou * tao)).detach().clone()
        [u1, u2] = D(u, Dh_DFT, Dv_DFT)#计算梯度
        
        # 更新 p
        
        #计算 梯度u - v
        b1 = u1-v1
        b2 = u2-v2
        mochang1 = torch.sqrt(torch.pow(b1, 2) + torch.pow(b2, 2) )
        #计算ε（v）
        
        
        b11 = v11
        b12 = 1/2 * (v12 + v21)
        b21 = 1/2 * (v12 + v21)
        b22 = v22
        mochang2 = torch.sqrt(torch.pow(b11, 2) + torch.pow(b12, 2) + torch.pow(b21, 2) + torch.pow(b22, 2))
        fenmu = beta+alpha1*mochang1+alpha2*mochang2
        cigema = ((beta+1)*alpha1 )/ fenmu
        
        a1 = p1 + gama * (u1-v1)
        a2 = p2 + gama * (u2-v2)
        fanshu =  torch.sqrt(torch.pow(a1, 2) + torch.pow(a2, 2) )#计算a1 和 a2 的平方，相加，最后取平方根。
        mm = torch.clamp( fanshu / cigema, min=1)
        #fanshu 除以 alpha1 的结果，并使用 torch.clamp 函数将结果限制在最小值为 1 的范围内。
        p1 = (a1 / mm).detach().clone()
        p2 = (a2 / mm).detach().clone()
        
        v1 = (v1 + w1*(p1+div(q11, q12))).detach().clone()
        v2 = (v2 + w1*(p2+div(q21, q22))).detach().clone()

    #  epsilon(v)
    
        [v11, v12] = D(v1, Dh_DFT, Dv_DFT)
        [v21, v22] = D(v2, Dh_DFT, Dv_DFT)
            
    # 更新 q
        # [u1, u2] = D(u, Dh_DFT, Dv_DFT)
        c1 = u1-v1
        c2 = u2-v2
        mochang3 = torch.sqrt(torch.pow(c1, 2) + torch.pow(c2, 2) )
        c11 = v11
        c12 = 1/2 * (v12 + v21)
        c21 = 1/2 * (v12 + v21)
        c22 = v22
        mochang4 = torch.sqrt(torch.pow(c11, 2) + torch.pow(c12, 2) + torch.pow(c21, 2) + torch.pow(c22, 2))
        fenmu1 = beta+alpha1*mochang3+alpha2*mochang4
        yita = ((beta+1)*alpha2) / fenmu1
        a11 = q11 + sigma1 * v11 
        a12 = q12 + sigma1 * 1/2 * (v12 + v21)
        a21 = q21 + sigma1 * 1/2 * (v12 + v21)
        a22 = q22 + sigma1 * v22 

        fanshu =  torch.sqrt(torch.pow(a11, 2) + torch.pow(a12, 2) + torch.pow(a21, 2) + torch.pow(a22, 2))
        mm = torch.clamp( fanshu / yita, min=1)
        q11 = (a11 / mm).detach().clone()
        q12 = (a12 / mm).detach().clone()
        q21 = (a21 / mm).detach().clone()
        q22 = (a22 / mm).detach().clone()

    # t
        t = (t -u + out).detach().clone() 
      

      # Averaging:
        avg = avg * .99 + out_np * .01

        stopping = np.sqrt(np.sum(np.square(torch_to_np(blur_th(out.data, kernel_torch))-y)))/ rho 
        list_stopping.append(stopping)
        
      # show psnrs: 
        psnr_noisy = compare_PSNR(out_np, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
        
        if org_img is not None:
            psnr_net, psnr_avg = (compare_PSNR(org_img, out_np,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE),
                                  compare_PSNR(org_img, avg, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
            ssim_net = compare_SSIM(org_img, out_np, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            ssim_avg = compare_SSIM(org_img, avg, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            list_psnr.append(psnr_avg)
            list_ssim.append(ssim_avg)
            print('\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),
                  'psnrs: y: %.2f psnr_noisy: %.2f net: %.2f avg: %.2f' % (psnr_y, psnr_noisy, psnr_net, psnr_avg), 
                  'ssim: y: %.4f net: %.4f avg: %.4f' % (ssim_y, ssim_net, ssim_avg),  # 添加SSIM输出
                  'params: stopping: %.2f' %(stopping), end='')
            if i in plot_array:  # plot images
                tmp_dict = {'Clean': Data(org_img),
                          'Noisy': Data(y, psnr_y, ssim_y),
                          'Net': Data(out_np, psnr_net, ssim_net),  
                          'avg': Data(avg, psnr_avg, ssim_avg),      
                            }
                plot_dict(tmp_dict)
        else:
            print('\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')
  
    return avg,list_psnr,list_ssim,list_stopping

In [None]:
def run_and_plot(name, plot_checkpoints={}):
    global data_dict
    noise_lev = NOISE_SIGMA/255
    tau=1 #lasciare a 1 se ci si fida della stima del rumore fatta dalla funzione considerata
    net, net_input = get_network_and_input(img_shape=data_dict[CORRUPTED].img.shape)
    denoised_img,list_psnr,list_ssim,list_stopping = train_via_admm(net, net_input, kernel_torch,data_dict[CORRUPTED].img, noise_lev,tau,
                                  plot_array=plot_checkpoints, algorithm_name=name,
                                  org_img=data_dict[ORIGINAL].img)
    # 计算最终SSIM
    final_ssim = compare_SSIM(data_dict[ORIGINAL].img, denoised_img, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
    data_dict[name] = Data(denoised_img, 
                           compare_PSNR(data_dict[ORIGINAL].img, denoised_img, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE),
                           final_ssim)  # 添加SSIM值
    # data_dict[name] = Data(denoised_img, compare_PSNR(data_dict[ORIGINAL].img, denoised_img,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
    plot_dict(data_dict)

    return denoised_img,list_psnr,list_ssim,list_stopping


plot_checkpoints = {1, 10, 50, 100, 250, 500, 2000, 3500, 5000} 
denoised_img,list_psnr,list_ssim,list_stopping=run_and_plot(DIP_NLM, plot_checkpoints) 