In [None]:
import numpy as np
from skimage.restoration import denoise_nl_means, estimate_sigma
from sklearn.datasets import fetch_olivetti_faces

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from cv2 import imwrite

In [None]:
# Load the Olivetti faces dataset
faces = fetch_olivetti_faces(shuffle=True)
X = faces.data[:100]

# Add noise to the images
np.random.seed(0)
noise = np.random.normal(0, 0.1, X.shape)
X_noisy = X + noise


In [None]:
# Create subplots
fig = make_subplots(rows=2, cols=5)

# Add heatmaps for original and noisy images
for i in range(5):
    # Original image
    fig.add_trace(
        go.Heatmap(z=X[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=1, col=i+1
    )
    # Noisy image
    fig.add_trace(
        go.Heatmap(z=X_noisy[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=2, col=i+1
    )
    # Remove axes for each subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=2, col=i+1)
    
# Update layout
fig.update_layout(
    width=1000,
    height=400,
    showlegend=False,
    margin=dict(t=10, l=10, r=10, b=10)
)

# Show plot
fig.show()

In [None]:
h_vals = (np.arange(10, 200 + 1, 5) / 200)[1:]
p_vals = np.arange(3, 15 + 1, 2)

In [None]:
h_vals

In [None]:
MSE = []
for h_i in h_vals:
    temp = []
    for p_i in p_vals:
        X_denoised = []
        for x_i in X_noisy:
            image = x_i.reshape(64, 64)
            sigma_est = estimate_sigma(image)  # Estimate noise standard deviation
            
            # Apply non-local means denoising
            denoised_img = denoise_nl_means(image, h=h_i, sigma=sigma_est, fast_mode=True, patch_size=p_i, patch_distance=15)
            denoised_img = np.clip(denoised_img, 0, 1)
            X_denoised.append(denoised_img.reshape(64 * 64, ))

        X_denoised = np.array(X_denoised)
        MSE_val = ((X_denoised - X) ** 2).mean()
        temp.append(MSE_val)
    MSE.append(temp)


In [None]:
MSE = np.array(MSE)

metrics_normalized = (MSE - MSE.min()) / (MSE.max() - MSE.min())

fig = go.Figure(data=go.Heatmap(z=metrics_normalized.T, x=h_vals, y=p_vals, colorscale='viridis'))
# Add a star annotation for the minimum value
fig.add_annotation(
    x=(h_vals[1] - h_vals[0]) / 2 + h_vals[0], y=p_vals[1],
    text="★ min",
    showarrow=False,
    font=dict(size=10, color="red")
)    
fig.update_layout(height=500, width=1000, 
    xaxis_title='H Values',
    yaxis_title='Patch Size'
)

fig.show()


In [None]:
X_denoised = []

for x_i in X_noisy:
    image = x_i.reshape(64, 64)
    sigma_est = estimate_sigma(image)  # Estimate noise standard deviation
    
    # Apply non-local means denoising
    denoised_img = denoise_nl_means(image, h=0.075, sigma=sigma_est, fast_mode=True, patch_size=5, patch_distance=15)
    denoised_img = np.clip(denoised_img, 0, 1)
    X_denoised.append(denoised_img.reshape(64 * 64, ))

X_denoised = np.array(X_denoised)

In [None]:
# Create subplots
fig = make_subplots(rows=3, cols=5)

# Add heatmaps for original and noisy images
for i in range(5):
    # Original image
    fig.add_trace(
        go.Heatmap(z=X[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=1, col=i+1
    )
    # Noisy image
    fig.add_trace(
        go.Heatmap(z=X_noisy[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=2, col=i+1
    )
    # Denoised image
    fig.add_trace(
        go.Heatmap(z=X_denoised[i].reshape(64, 64), colorscale='gray', showscale=False),
        row=3, col=i+1
    )
    # Remove axes for each subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=1, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=2, col=i+1)
    fig.update_xaxes(showticklabels=False, showgrid=False, row=3, col=i+1)
    fig.update_yaxes(autorange="reversed", showticklabels=False, showgrid=False, row=3, col=i+1)
    
# Update layout
fig.update_layout(
    width=1000,
    height=600,
    showlegend=False,
    margin=dict(t=10, l=10, r=10, b=10)
)

# Show plot
fig.show()

In [None]:
# Save the image
imwrite("denoised_5.png", np.uint8(X_denoised[4].reshape(64, 64) * 255))