In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import cv2
import scipy.io
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import os
from image_processing_utilities.functions import validation_dataset_generator
from denoising_functions import fft_denoising, mask_a_b


In [None]:
def l2_samples(samples, samples_gt):
    val = 0
    for i in range(4):
        for j in range(4):
            val += np.square(samples[i][j] - samples_gt[i][j]).mean()

    return val / 16

def ssim_samples(samples, samples_gt):
    val = 0
    for i in range(4):
        for j in range(4):
            val += ssim(samples[i][j], samples_gt[i][j], channel_axis=2)

    return val / 16

def psnr_samples(samples, samples_gt):
    val = 0
    for i in range(4):
        for j in range(4):
            val += psnr(samples[i][j], samples_gt[i][j])

    return val / 16

In [None]:
def mask_ellipse(a, b):
    mask = np.zeros((256, 256))
    x, y = np.ogrid[:256, :256]
    mask_area = ((x - 127) / a) ** 2 + ((y - 127) / b) ** 2 <= 1
    mask[mask_area] = 1

    return mask

def mask_diamond(a, b):
    mask = np.zeros((256, 256))
    x, y = np.ogrid[:256, :256]
    mask_area = abs((x - 127) / a) + abs((y - 127) / b) <= 1
    mask[mask_area] = 1

    return mask

def mask_star(a, b):
    mask = np.zeros((256, 256))
    x, y = np.ogrid[:256, :256]
    mask_area = abs((x - 127) / a) ** 0.5 + abs((y - 127) / b) ** 0.5 <= 1
    mask[mask_area] = 1

    return mask


In [None]:
dataset = 'SIDD'
method = 'FFT'
mask_type = 'Star'

In [None]:
x_noisy, x_gt = validation_dataset_generator(dataset=dataset)


In [None]:
# Sample Images
image_samples = [1, 10, 17, 23]
image_crops = [2, 4, 7, 11]
samples_noisy = np.array([[x_noisy[i, j] for j in image_crops] for i in image_samples])
samples_gt = np.array([[x_gt[i, j] for j in image_crops] for i in image_samples])

In [None]:
# Create subplot figure with 2 rows and 2 columns
fig = make_subplots(rows=4, cols=4, subplot_titles=(
    "Noisy 01", "GT 01", "Noisy 02", "GT 02", 
    "Noisy 03", "GT 03", "Noisy 04", "GT 04",
    "Noisy 05", "GT 05", "Noisy 06", "GT 06", 
    "Noisy 07", "GT 07", "Noisy 08", "GT 08"
))

# Add images to the subplots
fig.add_trace(go.Image(z=samples_noisy[0][1]), row=1, col=1)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[0][1]), row=1, col=2)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[1][0]), row=1, col=3)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[1][0]), row=1, col=4)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[1][3]), row=2, col=1)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[1][3]), row=2, col=2)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[2][0]), row=2, col=3)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[2][0]), row=2, col=4)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][0]), row=3, col=1)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[3][0]), row=3, col=2)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][1]), row=3, col=3)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[3][1]), row=3, col=4)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][2]), row=4, col=1)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[3][2]), row=4, col=2)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][3]), row=4, col=3)  # Noisy Image
fig.add_trace(go.Image(z=samples_gt[3][3]), row=4, col=4)  # Ground Truth

# Remove axes for all subplots
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)

fig.update_layout(height=1000, 
                  width=1200, 
                  title_text="Noisy and Ground Truth Images",
                  margin=dict(l=20, r=20, t=80, b=20),  # Reduce margins
)
fig.show()

In [None]:
A = list(range(1, 50, 1))
B = list(range(1, 50, 1))

In [None]:
print('ORIGINAL VALUES\n')

print('L2: ', l2_samples(samples_noisy, samples_gt))
print('SSIM: ', ssim_samples(samples_noisy, samples_gt))
print('PSNR: ', psnr_samples(samples_noisy, samples_gt))

In [None]:
best_loss = 1e4
best_a = None
best_b = None
best_index = None

metrics = np.zeros((len(A), len(B)))
for i, a in enumerate(A):
    for j, b in enumerate(B):
        if mask_type == 'Ellipse':
            mask = mask_ellipse(a, b)
        elif mask_type == 'Diamond':
            mask = mask_diamond(a, b)
        else:
            mask = mask_star(a, b)

        test = fft_samples(samples_noisy, mask)

        avg_loss = 1 - ssim_samples(test, samples_gt)
        metrics[i, j] = avg_loss
        print(a, b)
        if avg_loss < best_loss:
            best_index = [i, j]
            best_a = a
            best_b = b
            best_loss = avg_loss

print(best_a, best_b)

In [None]:
metrics_normalized = (metrics - metrics.min()) / (metrics.max() - metrics.min())

fig = go.Figure(data=go.Heatmap(z=metrics_normalized.T, x=A, y=B, colorscale='viridis'))
# Add a star annotation for the minimum value
fig.add_annotation(
    x=best_a, y=best_b,
    text="★ min",
    showarrow=False,
    font=dict(size=10, color="red")
)    
fig.update_layout(height=800, width=800, 
    xaxis_title='A',
    yaxis_title='B'
)


In [None]:
best_a, best_b = 5, 5

In [None]:
mask = mask_a_b(samples_noisy[0, 0], best_a, best_b, shape=mask_type)
denoised = fft_samples(samples_noisy, mask)

print('Denoised VALUES\n')

print('L2: ', l2_samples(denoised, samples_gt))
print('SSIM: ', ssim_samples(denoised, samples_gt))
print('PSNR: ', psnr_samples(denoised, samples_gt))

In [None]:
mask = mask_a_b(samples_noisy[0, 0], best_a, best_b, shape='Ellipse')


In [None]:
mask

In [None]:
test = samples_noisy[0, 0, :, :, 0]
test

In [None]:
import plotly.express as px
fig = px.imshow(test, color_continuous_scale='gray')
fig.show()


In [None]:
def denoise_fft(image: np.ndarray[np.uint8],
                mask: np.ndarray[np.uint8]) -> np.ndarray[np.uint8]:
    transform = np.fft.fft2(image)  # Transforms the image to the frequency domain
    shifted_transform = np.fft.fftshift(transform)  # Shifts the image
    mask_transform = shifted_transform * mask  # Applies the mask
    print(mask.mean())
    print(mask_transform.mean())
    inverse_shifted_mask_transform = np.fft.ifftshift(mask_transform)  # Inverse Shift
    inverse_transform = np.fft.ifft2(inverse_shifted_mask_transform)  # Inverse transform
    real_inverse_transform = np.abs(inverse_transform)  # Return only real values
    mask_image = np.clip(real_inverse_transform, 0, 255).astype(np.uint8)  # Final image

    return mask_image

In [None]:
mask.mean()

In [None]:
test_out = denoise_fft(test, mask)
test_out

In [None]:
import plotly.express as px
fig = px.imshow(test_out, color_continuous_scale='gray')
fig.show()


In [None]:
# Create subplot figure with 2 rows and 2 columns
fig = make_subplots(rows=4, cols=6, subplot_titles=(
    "Noisy 01", "Denoised 01", "GT 01", "Noisy 02", "Denoised 02", "GT 02", 
    "Noisy 03", "Denoised 03", "GT 03", "Noisy 04", "Denoised 04", "GT 04",
    "Noisy 05", "Denoised 05", "GT 05", "Noisy 06", "Denoised 06", "GT 06", 
    "Noisy 07", "Denoised 07", "GT 07", "Noisy 08", "Denoised 08", "GT 08"
))

# Add images to the subplots
fig.add_trace(go.Image(z=samples_noisy[0][1]), row=1, col=1)  # Noisy Image
fig.add_trace(go.Image(z=denoised[0][1]), row=1, col=2)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[0][1]), row=1, col=3)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[1][0]), row=1, col=4)  # Noisy Image
fig.add_trace(go.Image(z=denoised[1][0]), row=1, col=5)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[1][0]), row=1, col=6)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[1][3]), row=2, col=1)  # Noisy Image
fig.add_trace(go.Image(z=denoised[1][3]), row=2, col=2)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[1][3]), row=2, col=3)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[2][0]), row=2, col=4)  # Noisy Image
fig.add_trace(go.Image(z=denoised[2][0]), row=2, col=5)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[2][0]), row=2, col=6)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][0]), row=3, col=1)  # Noisy Image
fig.add_trace(go.Image(z=denoised[3][0]), row=3, col=2)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[3][0]), row=3, col=3)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][1]), row=3, col=4)  # Noisy Image
fig.add_trace(go.Image(z=denoised[3][1]), row=3, col=5)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[3][1]), row=3, col=6)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][2]), row=4, col=1)  # Noisy Image
fig.add_trace(go.Image(z=denoised[3][2]), row=4, col=2)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[3][2]), row=4, col=3)  # Ground Truth

fig.add_trace(go.Image(z=samples_noisy[3][3]), row=4, col=4)  # Noisy Image
fig.add_trace(go.Image(z=denoised[3][3]), row=4, col=5)  # Denoised Image
fig.add_trace(go.Image(z=samples_gt[3][3]), row=4, col=6)  # Ground Truth

# Remove axes for all subplots
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)

fig.update_layout(height=1000, 
                  width=1200, 
                  # title_text="Noisy and Ground Truth Images",
                  margin=dict(l=40, r=40, t=40, b=20),  # Reduce margins
)
fig.show()

In [None]:
sample_noisy = samples_noisy[0, 3]
sample_gt = samples_gt[0, 3]
sample_denoised = denoised[0][3]