In [1]:
import numpy as np
from utils import *
from numpy.fft import fft2, ifft2
import matplotlib.pyplot as plt
import numpy as np
import cv2
import hw3_helper_utils
from scipy.ndimage import convolve

In [6]:
def generate_noisy_image(image_path, noise_level=0.02, length=20, angle=30):
    x = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))

    v = noise_level * np.random.randn(*x.shape)

    h = hw3_helper_utils.create_motion_blur_filter(length=length, angle=angle)

    y0 = convolve(x, h, mode="wrap")

    y = y0 + v

    return x, y0, y, h, v


def my_wiener_filter(y: np.ndarray, h: np.ndarray, K: float) -> np.ndarray:
    M, N = y.shape
    L, P = h.shape

    h_padded = np.zeros((M, N))
    h_padded[:L, :P] = h
    h_padded = np.roll(h_padded, (-L // 2, -P // 2), axis=(0, 1))

    Y = fft2(y)
    H = fft2(h_padded)

    H_mag_sq = np.abs(H) ** 2

    G = np.conj(H) / (H_mag_sq + 1 / K)
    X_hat = Y * G

    x_hat = np.real(ifft2(X_hat))

    x_hat = np.clip(x_hat, 0, 1)

    return x_hat


def calculate_mse(x, x_hat):
    return np.mean((x - x_hat) ** 2)


def optimize_k(x, y, h, k_range):
    best_k = None
    best_mse = float("inf")
    mse_values = []

    for k in k_range:
        x_hat = my_wiener_filter(y, h, k)
        mse = calculate_mse(x, x_hat)
        mse_values.append(mse)

        if mse < best_mse:
            best_mse = mse
            best_k = k

    return best_k, best_mse, mse_values

In [None]:
image_path = "data/checkerboard.tif"

x, y0, y, h, v = generate_noisy_image(
    image_path=image_path, noise_level=0.02, length=20, angle=30
)

# Define a range of K values to try
k_range = np.logspace(
    -3, 3, 100
)  # 100 points logarithmically spaced from 10^-3 to 10^3

# Find the optimal K
best_k, best_mse, mse_values = optimize_k(x, y, h, k_range)

print(f"Best K: {best_k}")
print(f"Best MSE: {best_mse}")

# Plot the original MSE curve
plt.figure(figsize=(10, 6))
plt.semilogx(k_range, mse_values)
plt.xlabel("K value")
plt.ylabel("Mean Squared Error")
plt.title("MSE vs K for Wiener Filter")
plt.grid(True)
plt.show()

# Create a new figure to show the optimal value of K
plt.figure(figsize=(12, 8))
plt.semilogx(k_range, mse_values, "b-", label="MSE curve")
plt.semilogx(
    best_k, best_mse, "ro", markersize=10, label="Optimal K"
)  # Red dot for optimal point

# Add text annotation for the optimal point
plt.annotate(
    f"Optimal: (K={best_k:.4f}, MSE={best_mse:.4e})",
    xy=(best_k, best_mse),
    xytext=(best_k * 1.5, best_mse * 1.1),
    arrowprops=dict(facecolor="black", shrink=0.05),
    bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1),
)

plt.xlabel("K value")
plt.ylabel("Mean Squared Error")
plt.title("MSE vs K for Wiener Filter (with Optimal K)")
plt.grid(True)
plt.legend()

# Adjust the y-axis to focus on the region of interest
plt.ylim(
    0, min(max(mse_values), best_mse * 2)
)  # Set upper limit to either max MSE or twice the best MSE

plt.show()

# Use the best K to get the optimal deblurred image
x_hat_optimal = my_wiener_filter(y, h, best_k)

# Display the results
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(x, cmap="gray")
axs[0].set_title("Original Image")
axs[1].imshow(y, cmap="gray")
axs[1].set_title("Blurred and Noisy Image")
axs[2].imshow(x_hat_optimal, cmap="gray")
axs[2].set_title(f"Deblurred Image (K={best_k:.4f})")
plt.show()

In [None]:
image_path = "data/cameraman.tif"

x, y0, y, h, v = generate_noisy_image(
    image_path=image_path, noise_level=0.2, length=10, angle=0
)

# Define a range of K values to try
k_range = np.logspace(
    -3, 3, 100
)  # 100 points logarithmically spaced from 10^-3 to 10^3

# Find the optimal K
best_k, best_mse, mse_values = optimize_k(x, y, h, k_range)

print(f"Best K: {best_k}")
print(f"Best MSE: {best_mse}")

# Use the best K to get the optimal deblurred image
x_hat_optimal = my_wiener_filter(y, h, best_k)

# Create a single figure with multiple subplots
fig = plt.figure(figsize=(20, 12))
fig.suptitle("Wiener Filter Optimization and Results", fontsize=16)

# Plot 1: Original MSE curve
ax1 = fig.add_subplot(231)
ax1.semilogx(k_range, mse_values)
ax1.set_xlabel("K value")
ax1.set_ylabel("Mean Squared Error")
ax1.set_title("MSE vs K")
ax1.grid(True)

# Plot 2: MSE curve with optimal K
ax2 = fig.add_subplot(232)
ax2.semilogx(k_range, mse_values, "b-", label="MSE curve")
ax2.semilogx(best_k, best_mse, "ro", markersize=10, label="Optimal K")
ax2.annotate(
    f"Optimal: (K={best_k:.4f}, MSE={best_mse:.4e})",
    xy=(best_k, best_mse),
    xytext=(best_k * 1.5, best_mse * 1.1),
    arrowprops=dict(facecolor="black", shrink=0.05),
    bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1),
)
ax2.set_xlabel("K value")
ax2.set_ylabel("Mean Squared Error")
ax2.set_title("MSE vs K (with Optimal K)")
ax2.grid(True)
ax2.legend()
ax2.set_ylim(0, min(max(mse_values), best_mse * 2))

# Plot 3: Original Image
ax3 = fig.add_subplot(234)
ax3.imshow(x, cmap="gray")
ax3.set_title("Original Image")
ax3.axis("off")

# Plot 4: Blurred and Noisy Image
ax4 = fig.add_subplot(235)
ax4.imshow(y, cmap="gray")
ax4.set_title("Blurred and Noisy Image")
ax4.axis("off")

# Plot 5: Deblurred Image
ax5 = fig.add_subplot(236)
ax5.imshow(x_hat_optimal, cmap="gray")
ax5.set_title(f"Deblurred Image (K={best_k:.4f})")
ax5.axis("off")

plt.tight_layout()
plt.show()