In [None]:
import numpy as np
import json
import matplotlib.pyplot as plt
from numpy.fft import fft2, ifft2, fftshift

# Load JSON files (replace with your actual paths)
with open('blurred.json', 'r') as f:
    blurred = np.array(json.load(f), dtype=np.float32)

with open('psf.json', 'r') as f:
    psf = np.array(json.load(f), dtype=np.float32)

# Wiener Deconvolution with intermediate steps plotting
def wiener_deconv_plot(blurred, psf, K=0.001):
    M, N = blurred.shape
    P, Q = 2*M, 2*N
    
    # Step 1: Zero-pad the blurred image
    fp = np.zeros((P, Q), dtype=np.float32)
    fp[:M, :N] = blurred
    plt.figure(figsize=(6,6))
    plt.title("Zero-padded Blurred Image")
    plt.imshow(fp, cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 2: Zero-pad the PSF
    s, t = psf.shape
    pad_psf = np.zeros((P, Q), dtype=np.float32)
    pad_psf[:s, :t] = psf
    plt.figure(figsize=(6,6))
    plt.title("Zero-padded PSF")
    plt.imshow(pad_psf, cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 3: Center PSF with roll
    pad_psf = np.roll(pad_psf, -s//2, axis=0)
    pad_psf = np.roll(pad_psf, -t//2, axis=1)
    plt.figure(figsize=(6,6))
    plt.title("Centered PSF after roll")
    plt.imshow(pad_psf, cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 4: FFT of padded blurred image and PSF
    F = fft2(fp)
    H = fft2(pad_psf)
    plt.figure(figsize=(6,6))
    plt.title("Magnitude of FFT of Blurred Image")
    plt.imshow(np.log(np.abs(F)+1), cmap='gray')
    plt.colorbar()
    plt.show()
    
    plt.figure(figsize=(6,6))
    plt.title("Magnitude of FFT of PSF")
    plt.imshow(np.log(np.abs(H)+1), cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 5: Wiener Filter
    H_conj = np.conj(H)
    denom = (H * H_conj) + K
    F_prime = (F * H_conj) / denom
    plt.figure(figsize=(6,6))
    plt.title("Magnitude of Wiener Filtered FFT")
    plt.imshow(np.log(np.abs(F_prime)+1), cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 6: Inverse FFT to get restored image
    output = np.real(ifft2(F_prime))
    plt.figure(figsize=(6,6))
    plt.title("Restored Image (before cropping)")
    plt.imshow(output, cmap='gray')
    plt.colorbar()
    plt.show()
    
    # Step 7: Crop to original size
    output = output[:M, :N]
    plt.figure(figsize=(6,6))
    plt.title("Restored Image (final)")
    plt.imshow(output, cmap='gray')
    plt.colorbar()
    plt.show()
    
    return output

# Run and plot
restored = wiener_deconv_plot(blurred, psf, K=0.001)
