In [None]:
from typing import Callable
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from skimage import data
from scipy import fftpack, spatial

In [None]:
# Load input images as NumPy arrays
image = data.gravel()

In [None]:
def plot_images_side_by_side(images, width=1400, height=800):
    """
    Plot a list of images side by side using Plotly.
    """
    num_images = len(images)
    
    # Create subplot with `num_images` columns
    fig = make_subplots(rows=1, cols=num_images)
    
    x_pos = 1/(num_images*2)
    x_step = 1/num_images
    # Add a trace for each image
    for i, image in enumerate(images):
        fig.add_trace(
            go.Heatmap(z=image, showscale=True, colorscale='Viridis',
                       colorbar=dict(len=x_step, thickness=20), colorbar_x=x_pos),
            row=1, col=i+1)
        x_pos += x_step
        
    # Configure subplot layout
    fig.update_layout(width=width, height=height, margin=dict(l=0, r=0, b=0, t=0),
                      coloraxis=dict(colorbar=dict(len=0.9, thickness=20)))

    # Set the range of values displayed at the ends of the heatmaps
    for i, image in enumerate(images):
        fig.update_yaxes(range=(0, image.shape[0]), row=1, col=i+1)
        fig.update_xaxes(range=(0, image.shape[1]), row=1, col=i+1)
    
    # Set colorbar orientation to horizontal
    fig.update_traces(colorbar_orientation='h', selector=dict(type='heatmap'))
    
    # Show the plot
    fig.show()

In [None]:
# Define the triweight kernel function
def triweight_kernel_2D(x: np.ndarray, y: np.ndarray, sigma: float) -> np.ndarray:
    d_squared = x ** 2 + y ** 2
    return (1 - d_squared / (sigma ** 2)) ** 3 * (d_squared <= sigma ** 2)

def rbf_kernel_2D(x: np.ndarray, y: np.ndarray, sigma: float) -> np.ndarray:
    d_squared = (x ** 2 + y ** 2) / sigma ** 2
    kernel = np.exp(-0.5 * d_squared)
    return kernel

def make_kernel(size: int) -> Callable[[float], Callable[[Callable[[np.ndarray, np.ndarray, float], np.ndarray]], np.ndarray]]:
    # Create a grid of coordinates
    x, y = np.meshgrid(np.arange(-size // 2, size // 2 + 1), np.arange(-size // 2, size // 2 + 1))
    
    def set_sigma(sigma) -> Callable[[Callable[[np.ndarray, np.ndarray, float], np.ndarray]], np.ndarray]:
    
        def eval_kernel(kernel_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray]):
            # Evaluate the kernel function on the grid
            kernel = kernel_func(x, y, sigma)
        
            # Normalize the kernel
            kernel = kernel / np.sum(kernel)
            return kernel
    
        return eval_kernel
    
    return set_sigma

def get_kernel_density_function(num_centers, sigma, kernel_function=rbf_kernel_2D):
    def kernel_density(image):
        # Initialize centers randomly within the image
        centers = np.random.randint(0, 512, size=(num_centers, 2))

        # Compute the distance from each pixel to each center
        x, y = np.meshgrid(np.arange(512), np.arange(512))
        distances = np.sqrt((x[np.newaxis, :, :] - centers[:, 0, np.newaxis, np.newaxis]) ** 2 +
                            (y[np.newaxis, :, :] - centers[:, 1, np.newaxis, np.newaxis]) ** 2)

        # Initialize the density array with zeros
        density = np.zeros_like(image, dtype=float)

        # Compute the kernel density
        for i in range(num_centers):
            kernel = kernel_function(x - centers[i][0], y - centers[i][1], sigma)
            density = density.astype(float) + kernel * (distances[i] < sigma)
        density /= num_centers

        # Clip the density values to [0, 1]
        density = np.clip(density, 0, 1)

        return density
    return kernel_density

In [None]:
def show_frequencies(image):
    """Shows the frequency content of an image."""
    # Compute the FFT of the image
    image_fft = fftpack.fft2(image)

    # Shift the zero frequency component to the center of the spectrum
    image_fft_shifted = fftpack.fftshift(image_fft)

    # Compute the magnitude spectrum
    spectrum_magnitude = np.abs(image_fft_shifted)
    
    # Apply a log transformation to the magnitude for better visualization
    spectrum_magnitude = np.log10(1 + spectrum_magnitude)
    
    return spectrum_magnitude

In [None]:
def convolve(image: np.ndarray, kernel_density: np.ndarray) -> np.ndarray:
    """ Convolves an image with a given kernel density.
    Args:
        image (np.ndarray): Input image.
        kernel_density (np.ndarray): Kernel density of the image.
        sigma (float): Standard deviation of the RBF kernel.
    Returns:
        np.ndarray: Filtered image.
    """
    # Compute the FFT of the image and the kernel density
    image_fft = fftpack.fft2(image)
    density_fft = fftpack.fft2(kernel_density)

    # Weight the frequencies using the kernel density
    filtered_fft = image_fft * density_fft

    # Compute the inverse FFT to get the filtered image
    filtered_image = fftpack.ifft2(filtered_fft).real
    
    return filtered_image

def filter_frequencies(image: np.ndarray, kernel_density: np.ndarray) -> np.ndarray:
    """ Filters out frequencies from an image based on a given kernel density.
    Args:
        image (np.ndarray): Input image.
        kernel_density (np.ndarray): Kernel density of the image.
        sigma (float): Standard deviation of the RBF kernel.
    Returns:
        np.ndarray: Filtered image.
    """
    # Compute the FFT of the image and the kernel density
    image_fft = fftpack.fft2(image)

    # Weight the frequencies using the kernel density
    filtered_fft = image_fft * kernel_density

    # Compute the inverse FFT to get the filtered image
    filtered_image = fftpack.ifft2(filtered_fft).real
    
    return filtered_image

In [None]:
# Compute the kernel density
p = [
    dict(mu=89, sigma=21*4),
    dict(mu=34, sigma=21*4),
    dict(mu=21, sigma=34*3),
    dict(mu=21, sigma=55*2),
    dict(mu=21, sigma=89),
]

for parameters in p:
    kernel_density = get_kernel_density_function(num_centers=parameters["mu"], sigma=parameters["sigma"], kernel_function=triweight_kernel_2D)(image)
    convol_image = convolve(image=image, kernel_density=kernel_density)
    image = filter_frequencies(image=image, kernel_density=kernel_density)
    frequencies = show_frequencies(image)
    plot_images_side_by_side(images=[image, frequencies, convol_image, kernel_density])