In [None]:
# Essential imports for frequency domain processing
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import data, filters
from scipy import fft, ndimage
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("🚀 Ready to explore frequency domain filtering!")
print(f"Using NumPy FFT and SciPy for frequency analysis")

# Set random seed for reproducibility
np.random.seed(42)


In [None]:
# Frequency domain analysis functions
def compute_fft(image):
    """Compute 2D FFT of an image"""
    # Apply FFT and shift zero frequency to center
    f_transform = fft.fft2(image)
    f_shift = fft.fftshift(f_transform)
    return f_shift

def compute_magnitude_spectrum(f_shift):
    """Compute magnitude spectrum for visualization"""
    magnitude = np.abs(f_shift)
    # Use log scale for better visualization
    magnitude_log = np.log(magnitude + 1)
    return magnitude_log

def compute_phase_spectrum(f_shift):
    """Compute phase spectrum"""
    return np.angle(f_shift)

def inverse_fft(f_shift):
    """Compute inverse FFT to get back to spatial domain"""
    f_ishift = fft.ifftshift(f_shift)
    image_back = fft.ifft2(f_ishift)
    image_back = np.real(image_back)
    return image_back

# Create frequency domain filters
def create_ideal_lowpass_filter(shape, cutoff_freq):
    """Create ideal low-pass filter"""
    rows, cols = shape
    crow, ccol = rows // 2, cols // 2
    
    # Create coordinate matrices
    u = np.arange(rows).reshape(-1, 1) - crow
    v = np.arange(cols).reshape(1, -1) - ccol
    
    # Calculate distance from center
    D = np.sqrt(u**2 + v**2)
    
    # Create ideal filter
    H = np.where(D <= cutoff_freq, 1, 0)
    return H

def create_ideal_highpass_filter(shape, cutoff_freq):
    """Create ideal high-pass filter"""
    return 1 - create_ideal_lowpass_filter(shape, cutoff_freq)

def create_gaussian_lowpass_filter(shape, cutoff_freq):
    """Create Gaussian low-pass filter"""
    rows, cols = shape
    crow, ccol = rows // 2, cols // 2
    
    # Create coordinate matrices
    u = np.arange(rows).reshape(-1, 1) - crow
    v = np.arange(cols).reshape(1, -1) - ccol
    
    # Calculate distance from center
    D = np.sqrt(u**2 + v**2)
    
    # Create Gaussian filter
    H = np.exp(-(D**2) / (2 * cutoff_freq**2))
    return H

def create_gaussian_highpass_filter(shape, cutoff_freq):
    """Create Gaussian high-pass filter"""
    return 1 - create_gaussian_lowpass_filter(shape, cutoff_freq)

def create_bandpass_filter(shape, low_cutoff, high_cutoff):
    """Create band-pass filter"""
    lowpass = create_gaussian_lowpass_filter(shape, high_cutoff)
    highpass = create_gaussian_highpass_filter(shape, low_cutoff)
    return lowpass * highpass

# Load test image
test_image = data.camera().astype(float)

# Compute FFT
f_shift = compute_fft(test_image)
magnitude_spectrum = compute_magnitude_spectrum(f_shift)
phase_spectrum = compute_phase_spectrum(f_shift)

# Visualize FFT analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Fourier Transform Analysis', fontsize=16, fontweight='bold')

# Original image
axes[0, 0].imshow(test_image, cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')

# Magnitude spectrum
axes[0, 1].imshow(magnitude_spectrum, cmap='hot')
axes[0, 1].set_title('Magnitude Spectrum (Log Scale)')
axes[0, 1].axis('off')

# Phase spectrum
axes[0, 2].imshow(phase_spectrum, cmap='hsv')
axes[0, 2].set_title('Phase Spectrum')
axes[0, 2].axis('off')

# Create and visualize filters
cutoff_frequencies = [30, 50, 100]
filter_types = ['Low-pass', 'High-pass', 'Band-pass']

ideal_lowpass = create_ideal_lowpass_filter(test_image.shape, 50)
gaussian_lowpass = create_gaussian_lowpass_filter(test_image.shape, 50)
bandpass = create_bandpass_filter(test_image.shape, 20, 80)

filters_to_show = [ideal_lowpass, gaussian_lowpass, bandpass]
filter_names = ['Ideal Low-pass (D₀=50)', 'Gaussian Low-pass (D₀=50)', 'Band-pass (20-80)']

for i, (filt, name) in enumerate(zip(filters_to_show, filter_names)):
    axes[1, i].imshow(filt, cmap='gray')
    axes[1, i].set_title(name)
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("🔍 Frequency Domain Analysis Complete!")
print(f"Image shape: {test_image.shape}")
print(f"FFT shape: {f_shift.shape}")
print(f"Magnitude spectrum range: [{magnitude_spectrum.min():.2f}, {magnitude_spectrum.max():.2f}]")
print(f"Phase spectrum range: [{phase_spectrum.min():.2f}, {phase_spectrum.max():.2f}]")
