In [1]:
!pip install bm3d

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting bm3d
  Downloading bm3d-3.0.9-py3-none-any.whl (8.4 MB)
[K     |████████████████████████████████| 8.4 MB 4.7 MB/s 
Installing collected packages: bm3d
Successfully installed bm3d-3.0.9


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)

        return x

class inconv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)

        return x
    
class down(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()

        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)

        return x

class up_no_skip(nn.Module):

    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up_no_skip, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.up(x)
        x = self.conv(x)

        return x


class up(nn.Module):

    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffY // 2, diffY - diffY // 2, 
                        diffX // 2, diffX - diffX // 2), 'replicate')
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

        return x

class outconv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)

        return x

class UNet(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        self.x1 = self.inc(x)
        self.x2 = self.down1(self.x1)
        self.x3 = self.down2(self.x2)
        self.x4 = self.down3(self.x3)
        self.x5 = self.down4(self.x4)
        self.x6 = self.up1(self.x5, self.x4)
        self.x7 = self.up2(self.x6, self.x3)
        self.x8 = self.up3(self.x7, self.x2)
        self.x9 = self.up4(self.x8, self.x1)
        self.y = self.outc(self.x9)
        if self.need_sigmoid:
            self.y = torch.sigmoid(self.y)

        return self.y
    
class UNet5(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet5, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 1024)
        self.down5 = down(1024, 1024)
        self.up1 = up(2048, 512)
        self.up2 = up(1024, 256)
        self.up3 = up(512, 128)
        self.up4 = up(256, 64)
        self.up5 = up(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        self.x1 = self.inc(x)
        self.x2 = self.down1(self.x1)
        self.x3 = self.down2(self.x2)
        self.x4 = self.down3(self.x3)
        self.x5 = self.down4(self.x4)
        self.x6 = self.down5(self.x5)
        self.x7 = self.up1(self.x6, self.x5)
        self.x8 = self.up2(self.x7, self.x4)
        self.x9 = self.up3(self.x8, self.x3)
        self.x10 = self.up4(self.x9, self.x2)
        self.x11 = self.up5(self.x10, self.x1)
        self.y = self.outc(self.x11)
        if self.need_sigmoid:
            self.y = torch.sigmoid(self.y)

        return self.y    

    
class UNet3(nn.Module):
    
    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet3, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 256)
        self.up1 = up(512, 128)
        self.up2 = up(256, 64)
        self.up3 = up(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        self.x1 = self.inc(x)
        self.x2 = self.down1(self.x1)
        self.x3 = self.down2(self.x2)
        self.x4 = self.down3(self.x3)
        self.x5 = self.up1(self.x4, self.x3)
        self.x6 = self.up2(self.x5, self.x2)
        self.x7 = self.up3(self.x6, self.x1)
        self.y = self.outc(self.x7)
        if self.need_sigmoid:
            self.y = torch.sigmoid(self.y)

        return self.y
    
class UNet2(nn.Module):
    
    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet2, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 128)
        self.up1 = up(256, 64)
        self.up2 = up(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        self.x1 = self.inc(x)
        self.x2 = self.down1(self.x1)
        self.x3 = self.down2(self.x2)
        self.x4 = self.up1(self.x3, self.x2)
        self.x5 = self.up2(self.x4, self.x1)
        self.y = self.outc(self.x5)
        if self.need_sigmoid:
            self.y = torch.sigmoid(self.y)

        return self.y    

class UNet5_no_skip(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet5_no_skip, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 1024)
        self.down5 = down(1024, 1024)
        self.up1 = up_no_skip(1024, 1024)
        self.up2 = up_no_skip(1024, 512)
        self.up3 = up_no_skip(512, 256)
        self.up4 = up_no_skip(256, 128)
        self.up5 = up_no_skip(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.down5(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        x = self.up5(x)
        y = self.outc(x)
        if self.need_sigmoid:
            y = torch.sigmoid(y)

        return y

class UNet_no_skip(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet_no_skip, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up_no_skip(512, 512)
        self.up2 = up_no_skip(512, 256)
        self.up3 = up_no_skip(256, 128)
        self.up4 = up_no_skip(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        y = self.outc(x)
        if self.need_sigmoid:
            y = torch.sigmoid(y)

        return y
    
class UNet3_no_skip(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet3_no_skip, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 256)
        self.up1 = up_no_skip(256, 256)
        self.up2 = up_no_skip(256, 128)
        self.up3 = up_no_skip(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        y = self.outc(x)
        if self.need_sigmoid:
            y = torch.sigmoid(y)

        return y
    
class UNet2_no_skip(nn.Module):

    def __init__(self, n_channels, n_classes, need_sigmoid=False):
        super(UNet2_no_skip, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 128)
        self.up1 = up_no_skip(128, 128)
        self.up2 = up_no_skip(128, 64)
        self.outc = outconv(64, n_classes)
        self.need_sigmoid = need_sigmoid

    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.up1(x)
        x = self.up2(x)
        y = self.outc(x)
        if self.need_sigmoid:
            y = torch.sigmoid(y)

        return y


In [3]:
import torch

from PIL import Image
import numpy as np

import bm3d

from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import matplotlib.pyplot as plt
from matplotlib import image as mpimg


import glob

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

from tqdm import tqdm

In [4]:
def pil_to_np(img_pil, normalize=False):
    img_np = np.array(img_pil)
    if normalize:
        img_np = img_np / 255

    if len(img_np.shape) == 2:
        return img_np.astype(np.float)
    else:
        return img_np.transpose(2, 0, 1).astype(np.float)

def np_to_pil(img_np, normalize=False):
    if normalize:
        img_np = img_np*255

    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)

    if len(img_np.shape) == 2:
        img_pil = Image.fromarray(img_np)
    else:
        img_pil = Image.fromarray(img_np.transpose(1, 2, 0))

    return img_pil

def np_to_torch(img_np):
    if len(img_np.shape) == 2:
        return torch.Tensor(img_np)[None, None, ...]
    else:
        return torch.Tensor(img_np)[None, ...]

def torch_to_np(img_torch):
    return img_torch.cpu().squeeze().detach().numpy()


In [5]:
def save_hist(x, root):
    x = x.flatten()
    plt.figure()
    n, bins, patches = plt.hist(x, bins=128, density=1)
    plt.savefig(root)
    plt.close()

def save_heatmap(image_np, root):
    cmap = plt.get_cmap('jet')

    rgba_img = cmap(image_np)
    rgb_img = np.delete(rgba_img, 3, 2)
    rgb_img_pil = Image.fromarray((255*rgb_img).astype(np.uint8))
    rgb_img_pil.save(root)

def sample_z(mean):
    eps = mean.clone().normal_()

    return mean + eps

def eval_sigma(num_iter, noise_level):
    if num_iter == 1:
        sigma = noise_level
    else:
        sigma = 5

    return sigma

def save_torch(img_torch, root):
    img_np = torch_to_np(img_torch)
    img_pil = np_to_pil(img_np)
    img_pil.save(root)

In [6]:
def denoising(noise_im, clean_im, LR=1e-2, sigma=3, rho=1, eta=0.5, total_step=30,
              prob1_iter=500, noise_level=None, result_root=None, f=None):

    sig=3
    r_bm3d = bm3d.bm3d_rgb(noise_im.transpose(1, 2, 0), sig)
    r_bm3d = np.clip(r_bm3d, 0, 255)
    psnr_bm3d = peak_signal_noise_ratio(clean_im.transpose(1, 2, 0), r_bm3d, data_range=255)
    ssim_bm3d = structural_similarity(r_bm3d, clean_im.transpose(1, 2, 0), multichannel=True, data_range=255)

    print('noise level {} '.format(noise_level), file=f, flush=True)
    print('PSNR_BM3D: {}, SSIM_BM3D: {}'.format(psnr_bm3d, ssim_bm3d), file=f, flush=True)

    r_bm3d = Image.fromarray(r_bm3d.astype(np.uint8))
    r_bm3d.save(result_root + 'bm3d_result.png')


    input_depth = 3
    latent_dim = 3

    en_net = UNet(input_depth, latent_dim).cuda()
    de_net = UNet(latent_dim, input_depth).cuda()

    parameters = [p for p in en_net.parameters()] + [p for p in de_net.parameters()]
    optimizer = torch.optim.Adam(parameters, lr=LR)

    l2_loss = torch.nn.MSELoss().cuda()

    i0 = np_to_torch(noise_im).cuda()
    noise_im_torch = np_to_torch(noise_im).cuda()
    i0_til_torch = np_to_torch(noise_im).cuda()
    Y = torch.zeros_like(noise_im_torch).cuda()

    best_psnr = 0
    best_pil = None

    for i in range(total_step):

################################# sub-problem 1 ###############################

        for i_1 in range(prob1_iter):

            optimizer.zero_grad()

            mean = en_net(noise_im_torch)
            z = sample_z(mean)
            out = de_net(z)

            total_loss =  0.5 * l2_loss(out, noise_im_torch)
            total_loss += 0.5 * (1/sigma**2)*l2_loss(mean, i0)
            total_loss += (rho/2) * l2_loss(i0 + Y, i0_til_torch)

            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                i0 = ((1/sigma**2)*mean.detach() + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho)

        with torch.no_grad():

################################# sub-problem 2 ###############################

            i0_np = torch_to_np(i0)
            Y_np = torch_to_np(Y)

            sig = eval_sigma(i+1, noise_level)

            i0_til_np = bm3d.bm3d_rgb(i0_np.transpose(1, 2, 0) + Y_np.transpose(1, 2, 0), sig).transpose(2, 0, 1)
            i0_til_torch = np_to_torch(i0_til_np).cuda()

################################# sub-problem 3 ###############################

            Y = Y + eta * (i0 - i0_til_torch)

###############################################################################

            Y_name = 'Y_{:04d}'.format(i) + '.png'
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png'

            # Y_np = torch_to_np(Y)
            # Y_norm_np = np.sqrt((Y_np*Y_np).sum(0))
            # save_heatmap(Y_norm_np, result_root + Y_name)

            # save_torch(mean, result_root + mean_name)
            # save_torch(out, result_root + out_name)
            # save_torch(i0, result_root + i0_name)


            i0_til_np = torch_to_np(i0_til_torch).clip(0, 255)
            psnr = peak_signal_noise_ratio(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), data_range=255)
            ssim = structural_similarity(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), multichannel=True, data_range=255)

            i0_til_pil = np_to_pil(i0_til_np)
            if i % 3 == 0:
              i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.jpg'))

            print('Iteration: {:02d}, VAE Loss: {:f}, PSNR: {:f}, SSIM: {:f}'.format(i, total_loss.item(), psnr, ssim), file=f, flush=True)

    best_psnr = psnr
    best_ssim = ssim
    best_pil = i0_til_pil
    best_pil.save(os.path.join(result_root, '{}.jpg'.format(i)))
                

    return best_psnr, best_ssim , psnr_bm3d , ssim_bm3d

# Run new denoising approach for NIND dataset

In [7]:
from google.colab import drive

drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [8]:
torch.cuda.is_available()

True

In [9]:
torch.cuda.get_device_name(0)

'Tesla T4'

In [10]:
path = '/content/gdrive/MyDrive/NIND/'
noises = sorted(glob.glob(path + '*noise.jpg'))
cleans = sorted(glob.glob(path + '*clean.jpg'))

LR = 1e-2
sigma = 3
rho = 1
eta = 0.5
total_step = 15
prob1_iter = 20

psnrs = []
ssims = []
bm3d_psnr = []
bm3d_ssim = []

for noise, clean in tqdm(zip(noises, cleans)):
    
    result = path + '/output/nn_bm3d_nind/{}/'.format(noise.split('/')[-1][:-9])
    os.system('mkdir -p ' + result)

    noise_im = Image.open(noise)
    clean_im = Image.open(clean)

    noise_im_np = pil_to_np(noise_im)
    clean_im_np = pil_to_np(clean_im)

    noise_level = 5
    # noise_level = np.mean(estimate_sigma(noise_im_np.transpose(1, 2, 0), multichannel=True))

    with open(result + 'result.txt', 'w') as f:
        psnr, ssim, psnr_bm3d, ssim_bm3d = denoising(noise_im_np, clean_im_np, LR=LR, sigma=sigma,
                                  rho=rho, eta=eta, total_step=total_step,
                                  prob1_iter=prob1_iter, noise_level=noise_level,
                                  result_root=result, f=f)

        psnrs.append(psnr)
        ssims.append(ssim)
        bm3d_psnr.append(psnr_bm3d)
        bm3d_ssim.append(ssim_bm3d)
with open(path + '/output/nn_bm3d_nind/' + 'psnr_ssim.txt', 'w') as f:
    print('AVG PSNR: {}'.format(sum(psnrs)/len(psnrs)), file=f, flush=True)
    print('AVG SSIM: {}'.format(sum(ssims)/len(ssims)), file=f, flush=True)
    print('AVG bm3d psnr =', sum(bm3d_psnr)/len(bm3d_psnr), file=f, flush=True)
    print('AVG bm3d ssim =', sum(bm3d_ssim)/len(bm3d_ssim), file=f, flush=True)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if __name__ == '__main__':
79it [2:38:36, 120.46s/it]
