In [22]:
from models.unet import UNet
import torch

from PIL import Image
import numpy as np

import bm3d

from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import matplotlib.pyplot as plt
from matplotlib import image as mpimg


import glob

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

from tqdm import tqdm

In [4]:
def pil_to_np(img_pil, normalize=False):
    img_np = np.array(img_pil)
    if normalize:
        img_np = img_np / 255

    if len(img_np.shape) == 2:
        return img_np.astype(np.float)
    else:
        return img_np.transpose(2, 0, 1).astype(np.float)

def np_to_pil(img_np, normalize=False):
    if normalize:
        img_np = img_np*255

    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)

    if len(img_np.shape) == 2:
        img_pil = Image.fromarray(img_np)
    else:
        img_pil = Image.fromarray(img_np.transpose(1, 2, 0))

    return img_pil

def np_to_torch(img_np):
    if len(img_np.shape) == 2:
        return torch.Tensor(img_np)[None, None, ...]
    else:
        return torch.Tensor(img_np)[None, ...]

def torch_to_np(img_torch):
    return img_torch.cpu().squeeze().detach().numpy()


In [5]:
def save_hist(x, root):
    x = x.flatten()
    plt.figure()
    n, bins, patches = plt.hist(x, bins=128, density=1)
    plt.savefig(root)
    plt.close()

def save_heatmap(image_np, root):
    cmap = plt.get_cmap('jet')

    rgba_img = cmap(image_np)
    rgb_img = np.delete(rgba_img, 3, 2)
    rgb_img_pil = Image.fromarray((255*rgb_img).astype(np.uint8))
    rgb_img_pil.save(root)

def sample_z(mean):
    eps = mean.clone().normal_()

    return mean + eps

def eval_sigma(num_iter, noise_level):
    if num_iter == 1:
        sigma = noise_level
    else:
        sigma = 5

    return sigma

def save_torch(img_torch, root):
    img_np = torch_to_np(img_torch)
    img_pil = np_to_pil(img_np)
    img_pil.save(root)

In [6]:

def denoising(noise_im, LR=1e-2, sigma=3, rho=1, eta=0.5, total_step=30,
              prob1_iter=500, noise_level=None, result_root=None, f=None):

    input_depth = 3
    latent_dim = 3

    en_net = UNet(input_depth, latent_dim)#.cuda()
    de_net = UNet(latent_dim, input_depth)#.cuda()

    parameters = [p for p in en_net.parameters()] + [p for p in de_net.parameters()]
    optimizer = torch.optim.Adam(parameters, lr=LR)

    l2_loss = torch.nn.MSELoss()#.cuda()

    i0 = np_to_torch(noise_im)#.cuda()
    noise_im_torch = np_to_torch(noise_im)#.cuda()
    i0_til_torch = np_to_torch(noise_im)#.cuda()
    Y = torch.zeros_like(noise_im_torch)#.cuda()

    for i in tqdm(range(total_step)):

################################# sub-problem 1 ###############################

        for i_1 in tqdm(range(prob1_iter)):

            optimizer.zero_grad()

            mean = en_net(noise_im_torch)
            z = sample_z(mean)
            out = de_net(z)

            total_loss =  0.5 * l2_loss(out, noise_im_torch)
            total_loss += 0.5 * (1/sigma**2)*l2_loss(mean, i0)
            total_loss += (rho/2) * l2_loss(i0 + Y, i0_til_torch)

            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                i0 = ((1/sigma**2)*mean.detach() + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho)

        with torch.no_grad():

################################# sub-problem 2 ###############################

            i0_np = torch_to_np(i0)
            Y_np = torch_to_np(Y)

            sig = eval_sigma(i+1, noise_level)

            i0_til_np = bm3d.bm3d_rgb(i0_np.transpose(1, 2, 0) + Y_np.transpose(1, 2, 0), sig).transpose(2, 0, 1)
            i0_til_torch = np_to_torch(i0_til_np)#.cuda()

################################# sub-problem 3 ###############################

            Y = Y + eta * (i0 - i0_til_torch)

###############################################################################

            Y_name = 'Y_{:04d}'.format(i) + '.png'
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png'

            Y_np = torch_to_np(Y)
            Y_norm_np = np.sqrt((Y_np*Y_np).sum(0))
            save_heatmap(Y_norm_np, result_root + Y_name)

            save_torch(mean, result_root + mean_name)
            save_torch(out, result_root + out_name)
            save_torch(i0, result_root + i0_name)


            i0_til_np = torch_to_np(i0_til_torch).clip(0, 255)
            i0_til_pil = np_to_pil(i0_til_np)
            i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png'))

    return i0_til_np

###############################################################################


# Run new denoising approach

In [7]:
path = './data/own_test_image/'
noises = sorted(glob.glob(path + '*.jpg'))

LR = 1e-2
sigma = 3
rho = 1
eta = 0.5
total_step = 30
prob1_iter = 10

for noise in noises:
    result = './output/own_test_image/{}/'.format(noise.split('/')[-1][:-9])
    os.system('mkdir -p ' + result)

    noise_im = Image.open(noise)

    noise_im_np = pil_to_np(noise_im)

    noise_level = 25
    # noise_level = np.mean(estimate_sigma(noise_im_np.transpose(1, 2, 0), multichannel=True))

    with open(result + 'result.txt', 'w') as f:
        _ = denoising(noise_im_np, LR=LR, sigma=sigma,
                                  rho=rho, eta=eta, total_step=total_step,
                                  prob1_iter=prob1_iter, noise_level=noise_level,
                                  result_root=result, f=f)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return img_np.transpose(2, 0, 1).astype(np.float)
  0%|          | 0/30 [00:00<?, ?it/s]
  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:05<00:48,  5.37s/it][A
 20%|██        | 2/10 [00:10<00:41,  5.19s/it][A
 30%|███       | 3/10 [00:15<00:35,  5.08s/it][A
 40%|████      | 4/10 [00:20<00:30,  5.08s/it][A
 50%|█████     | 5/10 [00:25<00:25,  5.03s/it][A
 60%|██████    | 6/10 [00:30<00:20,  5.02s/it][A
 70%|███████   | 7/10 [00:35<00:14,  4.98s/it][A
 80%|████████  | 8/10 [00:40<00:10,  5.01s/it][A
 90%|█████████ | 9/10 [00:45<00:05,  5.16s/it][A
100%|██████████| 10/10 [00:51<00:00,  5.14s/it][A
  3%|▎         | 1/30 [01:03<30:38, 63.39s/it]
  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:04<00:44,  4.92s/it][A
 20%|██        | 2/10 [00:09<00:39,  4.90s/it][A
 30%|███       | 3/10 [00:14<00:35,  5.02s/it][A
 40%|

# Run bm3d denoising approach

In [14]:
path = './data/own_test_image/'
noises = sorted(glob.glob(path + '*.jpg'))
noise = noises[0]
result = './output/own_test_image/{}/'.format(noise.split('/')[-1][:-9])
os.system('mkdir -p ' + result)

noise_im = Image.open(noise)

noise_im_np = pil_to_np(noise_im)

noise_level = 25

r_bm3d = bm3d.bm3d_rgb(noise_im_np.transpose(1, 2, 0), noise_level)
r_bm3d = np.clip(r_bm3d, 0, 255)

r_bm3d = Image.fromarray(r_bm3d.astype(np.uint8))
r_bm3d.save(result + 'bm3d_result.png')

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return img_np.transpose(2, 0, 1).astype(np.float)
