# Denoising Methods - BM3D

In [None]:
from cv2 import imwrite

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pywt
import cv2
import numpy as np
from sklearn.datasets import fetch_olivetti_faces
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from bm3d import bm3d, BM3DStages
import plotly.graph_objects as go

In [None]:
def SSIM_Batch(X, X_true, k=16):
    m, _ = X.shape
    ssim_val = 0
    data_range = None if X.max() > 2 else 1.
    for i in range(m):
        ns = X[i].reshape((k, k))
        gt = X_true[i].reshape((k, k))

        ssim_val += ssim(ns, gt, data_range=data_range)

    return ssim_val / m


In [None]:
# Load the Olivetti faces dataset
faces = fetch_olivetti_faces(shuffle=True)
X = faces.data[:100]

# Add noise to the images
np.random.seed(0)
noise = np.random.normal(0, 0.1, X.shape)
X_noisy = X + noise


In [None]:
# Create subplots
fig = make_subplots(rows=2, cols=5)

# Add heatmaps for original and noisy images
for i in range(5):
    # Original image
    fig.add_trace(
        go.Heatmap(z=X[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=1, col=i+1
    )
    # Noisy image
    fig.add_trace(
        go.Heatmap(z=X_noisy[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=2, col=i+1
    )
    # Remove axes for each subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=2, col=i+1)
    
# Update layout
fig.update_layout(
    width=1000,
    height=400,
    showlegend=False,
    margin=dict(t=10, l=10, r=10, b=10)
)

# Show plot
fig.show()

## BM3D

In [None]:
X_denoised_first = []
X_denoised_all = []
for x_i in X_noisy:
    image = x_i.reshape(64, 64)
    
    # Apply BM3D denoising - Both Stages
    denoised_img = bm3d(image, sigma_psd=0.1, stage_arg=BM3DStages.HARD_THRESHOLDING)
    X_denoised_first.append(denoised_img.reshape(64 * 64, ))

    # Apply BM3D denoising - Both Stages
    denoised_img = bm3d(image, sigma_psd=0.1, stage_arg=BM3DStages.ALL_STAGES)
    X_denoised_all.append(denoised_img.reshape(64 * 64, ))

X_denoised_first = np.array(X_denoised_first)
X_denoised_all = np.array(X_denoised_all)


In [None]:
# Create subplots
fig = make_subplots(rows=4, cols=5)

# Add heatmaps for original and noisy images
for i in range(5):
    # Original image
    fig.add_trace(
        go.Heatmap(z=X[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=1, col=i+1
    )
    # Noisy image
    fig.add_trace(
        go.Heatmap(z=X_noisy[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=2, col=i+1
    )
    # Denoised image - first stage
    fig.add_trace(
        go.Heatmap(z=X_denoised_first[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=3, col=i+1
    )
    # Denoised image - both stages
    fig.add_trace(
        go.Heatmap(z=X_denoised_all[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=4, col=i+1
    )
    # Remove axes for each subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=3, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=3, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=4, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=4, col=i+1)
    
# Update layout
fig.update_layout(
    width=1000,
    height=800,
    showlegend=False,
    margin=dict(t=10, l=10, r=10, b=10)
)

# Show plot
fig.show()

In [None]:
SSIM_Batch(X_denoised_first, X, k=64)

In [None]:
SSIM_Batch(X_denoised_all, X, k=64)

In [None]:
SSIM_Batch(X_noisy, X, k=64)

In [None]:
# Save the image
imwrite("denoised_as_4_zoom.png", np.uint8(X_denoised_all[3].reshape(64, 64)[:32, 32:] * 255))