In [3]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, fftshift, ifft2, ifftshift
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

%matplotlib inline

# Global variables to store the current image and its data
current_img = None
original_img_array = None
h, w = 0, 0

# Function to handle image upload
def handle_image_upload(change):
    global current_img, original_img_array, h, w
    
    # Get the uploaded file content
    uploaded_file = change['new'][0]
    file_content = uploaded_file['content']
    
    # Open the image
    img_data = io.BytesIO(file_content)
    current_img = Image.open(img_data).convert('RGB')
    original_img_array = np.asarray(current_img).astype(float)
    h, w = original_img_array.shape[:2]
    
    # Clear any previous outputs
    clear_output(wait=True)
    
    # Recreate the interface with the new image
    create_tabbed_interface()

# Common utility functions
def process_channel(channel, mask=None):
    """Process a single image channel with FFT and optional masking"""
    fft_channel = fft2(channel)
    fft_shifted = fftshift(fft_channel)
    
    if mask is not None:
        fft_shifted *= mask
        
    magnitude = np.abs(fft_shifted)
    phase = np.angle(fft_shifted)
    return fft_shifted, magnitude, phase

def reconstruct_image(fft_r, fft_g, fft_b):
    """Reconstruct an RGB image from FFT data"""
    rec_r = np.real(ifft2(ifftshift(fft_r)))
    rec_g = np.real(ifft2(ifftshift(fft_g)))
    rec_b = np.real(ifft2(ifftshift(fft_b)))
    
    reconstructed = np.stack((rec_r, rec_g, rec_b), axis=-1)
    return np.clip(reconstructed, 0, 255).astype(np.uint8)

def display_results(original, spectrum, reconstructed, title=""):
    """Display 3-panel figure with original, spectrum, and reconstructed image"""
    plt.figure(figsize=(18, 8))
    
    plt.subplot(1, 3, 1)
    plt.imshow(original)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(np.log(spectrum + 1), cmap='viridis')
    plt.title(f"Magnitude Spectrum\n{title}")
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(reconstructed)
    plt.title("Reconstructed Image")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# =============================================================================
# DEMO 1: Compression with Top % Coefficients
# =============================================================================

def compress_channel(channel, percentage):
    """Compress a channel by keeping only top % of FFT coefficients"""
    h, w = channel.shape
    fft_channel = fft2(channel)
    fft_shifted = fftshift(fft_channel)
    
    # Find magnitude threshold for the top percentage
    flat_fft = fft_shifted.flatten()
    magnitudes = np.abs(flat_fft)
    threshold_index = int((1 - percentage / 100.0) * len(magnitudes))
    threshold_index = max(threshold_index, 1)
    threshold = np.partition(magnitudes, threshold_index)[threshold_index]
    
    # Apply threshold mask
    mask = magnitudes >= threshold
    compressed_fft = np.zeros_like(flat_fft, dtype=complex)
    compressed_fft[mask] = flat_fft[mask]
    compressed_fft = compressed_fft.reshape(h, w)
    
    # Reconstruct
    reconstructed = np.real(ifft2(ifftshift(compressed_fft)))
    
    return reconstructed, fft_shifted, compressed_fft, np.count_nonzero(mask)

def update_compression_demo(percent):
    """Update the compression visualization"""
    if current_img is None:
        print("Please upload an image first.")
        return
    
    # Compress each channel
    rec_r, _, comp_fft_r, used_r = compress_channel(original_img_array[:, :, 0], percent)
    rec_g, _, comp_fft_g, used_g = compress_channel(original_img_array[:, :, 1], percent)
    rec_b, _, comp_fft_b, used_b = compress_channel(original_img_array[:, :, 2], percent)
    
    # Create visualization data
    reconstructed_image = np.stack((rec_r, rec_g, rec_b), axis=-1)
    reconstructed_image = np.clip(reconstructed_image, 0, 255).astype(np.uint8)
    
    filt_mag_r = np.log(np.abs(comp_fft_r) + 1)
    filt_mag_g = np.log(np.abs(comp_fft_g) + 1)
    filt_mag_b = np.log(np.abs(comp_fft_b) + 1)
    filt_mag = filt_mag_r + filt_mag_g + filt_mag_b
    
    # Calculate compression statistics
    total_coeffs = 3 * h * w
    used_coeffs = used_r + used_g + used_b
    compression_ratio = total_coeffs / max(used_coeffs, 1)
    size_kb = (used_coeffs * 8) / 1024  # Assuming complex numbers (8 bytes per coefficient)
    
    # Display results
    title = f"Compressed FFT (Top {percent:.2f}%)\nCompression Ratio: {compression_ratio:.2f}x\nSize: {size_kb:.2f} KB"
    display_results(original_img_array.astype(np.uint8), filt_mag, reconstructed_image, title)

# =============================================================================
# DEMO 2: Noise Reduction with Radial Filter
# =============================================================================

def create_radial_mask(h, w, radius, x_center, y_center, invert=False, is_hamming=False):
    """Create a circular mask centered at (x_center, y_center) with optional Hamming window"""
    y, x = np.ogrid[:h, :w]
    dist = np.sqrt((x - x_center)**2 + (y - y_center)**2)
    
    if is_hamming:
        # Create Hamming window for smooth transition
        # We normalize distances to 0-1 within the radius
        normalized_dist = np.clip(dist / radius, 0, 1)
        # Always create low-pass mask first (1 at center to 0 at radius)
        mask = 0.54 + 0.46 * np.cos(np.pi * normalized_dist)
        # Ensure everything outside radius is 0
        mask = np.where(dist > radius, 0.0, mask)
        
        if invert:
            # For high-pass: subtract low-pass from 1
            mask = 1.0 - mask
    else:
        # Standard binary mask (low-pass by default)
        mask = dist <= radius  # Inside the circle
        if invert:
            # For high-pass: invert the mask
            mask = np.logical_not(mask)
    
    return mask.astype(float)

def add_noise(image, noise_level):
    """Add Gaussian noise to an image"""
    noise = np.random.normal(0, noise_level, image.shape)
    noisy_image = np.clip(image + noise, 0, 255)
    return noisy_image.astype(np.uint8)

def update_noise_reduction_demo(noise_level, radius, x_center, y_center, invert_mask, is_hamming):
    """Update the combined noise and radial filter visualization"""
    if current_img is None:
        print("Please upload an image first.")
        return
    
    # Add noise to the original image
    noisy_img_array = add_noise(original_img_array, noise_level)
    
    # Create the radial mask
    mask = create_radial_mask(h, w, radius, x_center, y_center, invert_mask, is_hamming)
    
    # Process each channel with the mask
    fft_r, mag_r, _ = process_channel(noisy_img_array[:, :, 0], mask)
    fft_g, mag_g, _ = process_channel(noisy_img_array[:, :, 1], mask)
    fft_b, mag_b, _ = process_channel(noisy_img_array[:, :, 2], mask)
    
    # Combined magnitude for visualization
    combined_magnitude = mag_r + mag_g + mag_b
    
    # Reconstruct the image
    reconstructed_image = reconstruct_image(fft_r, fft_g, fft_b)
    
    # Display results
    filter_type = "High-Pass" if invert_mask else "Low-Pass"
    window_type = " with Hamming Window" if is_hamming else ""
    title = f"Noise Level: {noise_level}\n{filter_type} Filter{window_type}\nRadius: {radius}"
    display_results(noisy_img_array, combined_magnitude, reconstructed_image, title)

# =============================================================================
# DEMO 3: Combined Cross and Quadrant Filters
# =============================================================================

def create_cross_mask(h, w, x_thickness, y_thickness, x_center, y_center):
    """Create a cross-shaped mask centered at (x_center, y_center)"""
    mask = np.zeros((h, w), dtype=float)
    
    # Horizontal bar
    y_half = int(y_thickness / 2)
    y_start = max(int(y_center - y_half), 0)
    y_end = min(int(y_center + y_half + 1), h)
    mask[y_start:y_end, :] = 1
    
    # Vertical bar
    x_half = int(x_thickness / 2)
    x_start = max(int(x_center - x_half), 0)
    x_end = min(int(x_center + x_half + 1), w)
    mask[:, x_start:x_end] = 1
    
    return mask

def create_quadrant_mask(h, w, quadrants, invert=False):
    """Create a mask for the specified quadrants"""
    y, x = np.ogrid[:h, :w]
    
    # Define the quadrants
    masks = {
        1: (x < w//2) & (y < h//2),      # Top-left
        2: (x >= w//2) & (y < h//2),     # Top-right
        3: (x < w//2) & (y >= h//2),     # Bottom-left
        4: (x >= w//2) & (y >= h//2)     # Bottom-right
    }
    
    # Combine the specified quadrants
    mask = np.zeros((h, w), dtype=bool)
    for q in quadrants:
        if q in masks:
            mask |= masks[q]
    
    if invert:
        mask = np.logical_not(mask)
        
    return mask.astype(float)

def update_combined_filter_demo(filter_type, x_thickness, y_thickness, x_center, y_center, 
                              q1, q2, q3, q4, invert_mask):
    """Update the combined cross/quadrant filter visualization"""
    if current_img is None:
        print("Please upload an image first.")
        return
    
    # Create the appropriate mask based on filter type
    if filter_type == "cross":
        mask = create_cross_mask(h, w, x_thickness, y_thickness, x_center, y_center)
        title = f"Cross Filter\nX Width: {x_thickness}, Y Width: {y_thickness}"
    else:  # quadrant
        quadrants = []
        if q1: quadrants.append(1)
        if q2: quadrants.append(2)
        if q3: quadrants.append(3)
        if q4: quadrants.append(4)
        mask = create_quadrant_mask(h, w, quadrants, invert_mask)
        title = f"Quadrant Filter\nQuadrants: {','.join(map(str, quadrants))}\nInverted: {invert_mask}"
    
    # Process each channel with the mask
    fft_r, mag_r, _ = process_channel(original_img_array[:, :, 0], mask)
    fft_g, mag_g, _ = process_channel(original_img_array[:, :, 1], mask)
    fft_b, mag_b, _ = process_channel(original_img_array[:, :, 2], mask)
    
    # Combined magnitude for visualization
    combined_magnitude = mag_r + mag_g + mag_b
    
    # Reconstruct the image
    reconstructed_image = reconstruct_image(fft_r, fft_g, fft_b)
    
    # Display results
    display_results(current_img, combined_magnitude, reconstructed_image, title)

# =============================================================================
# DEMO 4: Wave Texture Generator with K-Space Visualization
# =============================================================================

def generate_wave_texture(size, freq_x, freq_y, wave_type='sin', phase=0):
    """Generate a wave texture with customizable frequency"""
    x = np.linspace(0, 2*np.pi*freq_x, size[1])
    y = np.linspace(0, 2*np.pi*freq_y, size[0])
    xx, yy = np.meshgrid(x, y)
    
    if wave_type == 'sin':
        wave = np.sin(xx + yy + phase)
    elif wave_type == 'cos':
        wave = np.cos(xx + yy + phase)
    elif wave_type == 'square':
        wave = np.sign(np.sin(xx + yy + phase))
    elif wave_type == 'checker':
        wave = np.sign(np.sin(xx)) * np.sign(np.sin(yy))
    else:
        wave = np.sin(xx + yy + phase)
    
    return wave

def compute_k_space(image):
    """Compute the k-space (frequency domain) representation of an image"""
    # Convert to grayscale if RGB
    if len(image.shape) == 3:
        image = np.mean(image, axis=2)
    
    # Compute FFT and shift zero frequency to center
    fft = fftshift(fft2(image))
    magnitude = np.abs(fft)
    log_magnitude = np.log(magnitude + 1)  # Log scale for better visualization
    
    return log_magnitude

def update_wave_generator(size, freq_x, freq_y, wave_type, phase):
    """Update the wave generator visualization with k-space"""
    # Generate the wave texture
    wave = generate_wave_texture((size, size), freq_x, freq_y, wave_type, phase)
    
    # Normalize to 0-255 and convert to RGB
    wave_normalized = ((wave - wave.min()) * (255.0 / (wave.max() - wave.min()))).astype(np.uint8)
    wave_texture = np.stack([wave_normalized]*3, axis=-1)  # Make it RGB
    
    # Compute k-space representation
    k_space = compute_k_space(wave_texture)
    
    # Display the results
    plt.figure(figsize=(16, 8))
    
    # Wave pattern
    plt.subplot(1, 2, 1)
    plt.imshow(wave_texture, cmap='gray')
    plt.title(f"Wave Texture\nType: {wave_type}, Freq X: {freq_x}, Freq Y: {freq_y}")
    plt.axis('off')
    
    # K-space representation
    plt.subplot(1, 2, 2)
    plt.imshow(k_space, cmap='viridis')
    plt.title("K-Space (Frequency Domain)")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# =============================================================================
# Widget Creation Functions
# =============================================================================

def display_compression_demo():
    """Create and return the compression demo widget"""
    percent_slider = widgets.FloatSlider(
        value=5.0, min=0.01, max=10.0, step=0.01,
        description='Top % Coeff:', continuous_update=True
    )
    
    out = widgets.interactive_output(update_compression_demo, {'percent': percent_slider})
    return widgets.VBox([percent_slider, out])

def display_noise_reduction_demo():
    """Create and return the combined noise and radial filter demo widget"""
    noise_slider = widgets.FloatSlider(
        value=10.0, min=0.0, max=1000.0, step=0.1,
        description='Noise Level:', continuous_update=True
    )
    
    radius_slider = widgets.FloatSlider(
        value=50.0, min=1.0, max=min(h, w)//2 if min(h, w) > 0 else 100, step=1.0,
        description='Radius:', continuous_update=True
    )
    
    x_center_slider = widgets.FloatSlider(
        value=w//2 if w > 0 else 100, min=0, max=w-1 if w > 0 else 200, step=1.0,
        description='X Center:', continuous_update=True
    )
    
    y_center_slider = widgets.FloatSlider(
        value=h//2 if h > 0 else 100, min=0, max=h-1 if h > 0 else 200, step=1.0,
        description='Y Center:', continuous_update=True
    )
    
    invert_toggle = widgets.Checkbox(
        value=False, description='High-Pass Filter'
    )
    
    hamming_toggle = widgets.Checkbox(
        value=False, description='Apply Hamming Window'
    )
    
    out = widgets.interactive_output(update_noise_reduction_demo, {
        'noise_level': noise_slider,
        'radius': radius_slider,
        'x_center': x_center_slider,
        'y_center': y_center_slider,
        'invert_mask': invert_toggle,
        'is_hamming': hamming_toggle
    })
    
    return widgets.VBox([
        noise_slider,
        radius_slider, 
        x_center_slider, 
        y_center_slider, 
        invert_toggle,
        hamming_toggle,
        out
    ])

def display_combined_filter_demo():
    """Create and return the combined cross/quadrant filter demo widget"""
    filter_type = widgets.RadioButtons(
        options=['cross', 'quadrant'],
        value='cross',
        description='Filter Type:'
    )
    
    x_thickness_slider = widgets.FloatSlider(
        value=10.0, min=1.0, max=w//4 if w > 0 else 100, step=1.0,
        description='X Width:', continuous_update=True
    )
    
    y_thickness_slider = widgets.FloatSlider(
        value=10.0, min=1.0, max=h//4 if h > 0 else 100, step=1.0,
        description='Y Width:', continuous_update=True
    )
    
    x_center_slider = widgets.FloatSlider(
        value=w//2 if w > 0 else 100, min=0, max=w-1 if w > 0 else 200, step=1.0,
        description='X Center:', continuous_update=True
    )
    
    y_center_slider = widgets.FloatSlider(
        value=h//2 if h > 0 else 100, min=0, max=h-1 if h > 0 else 200, step=1.0,
        description='Y Center:', continuous_update=True
    )
    
    q1_checkbox = widgets.Checkbox(
        value=False, description="Quadrant 1 (Top-left)"
    )
    
    q2_checkbox = widgets.Checkbox(
        value=False, description="Quadrant 2 (Top-right)"
    )
    
    q3_checkbox = widgets.Checkbox(
        value=False, description="Quadrant 3 (Bottom-left)"
    )
    
    q4_checkbox = widgets.Checkbox(
        value=False, description="Quadrant 4 (Bottom-right)"
    )
    
    invert_toggle = widgets.Checkbox(
        value=False, description="Invert Selection"
    )
    
    # Function to update visibility based on filter type
    def update_ui(change):
        if change['new'] == 'cross':
            x_thickness_slider.layout.visibility = 'visible'
            y_thickness_slider.layout.visibility = 'visible'
            x_center_slider.layout.visibility = 'visible'
            y_center_slider.layout.visibility = 'visible'
            q1_checkbox.layout.visibility = 'hidden'
            q2_checkbox.layout.visibility = 'hidden'
            q3_checkbox.layout.visibility = 'hidden'
            q4_checkbox.layout.visibility = 'hidden'
            invert_toggle.layout.visibility = 'hidden'
        else:
            x_thickness_slider.layout.visibility = 'hidden'
            y_thickness_slider.layout.visibility = 'hidden'
            x_center_slider.layout.visibility = 'hidden'
            y_center_slider.layout.visibility = 'hidden'
            q1_checkbox.layout.visibility = 'visible'
            q2_checkbox.layout.visibility = 'visible'
            q3_checkbox.layout.visibility = 'visible'
            q4_checkbox.layout.visibility = 'visible'
            invert_toggle.layout.visibility = 'visible'
    
    # Set initial state
    update_ui({'new': 'cross'})
    filter_type.observe(update_ui, names='value')
    
    out = widgets.interactive_output(update_combined_filter_demo, {
        'filter_type': filter_type,
        'x_thickness': x_thickness_slider,
        'y_thickness': y_thickness_slider,
        'x_center': x_center_slider,
        'y_center': y_center_slider,
        'q1': q1_checkbox,
        'q2': q2_checkbox,
        'q3': q3_checkbox,
        'q4': q4_checkbox,
        'invert_mask': invert_toggle
    })
    
    return widgets.VBox([
        filter_type,
        x_thickness_slider, y_thickness_slider,
        x_center_slider, y_center_slider,
        q1_checkbox, q2_checkbox, q3_checkbox, q4_checkbox,
        invert_toggle,
        out
    ])

def display_wave_generator_demo():
    """Create and return the wave generator demo widget"""
    size_slider = widgets.IntSlider(
        value=256, min=32, max=1024, step=1,
        description='Size:', continuous_update=True
    )
    
    freq_x_slider = widgets.FloatSlider(
        value=1.0, min=0, max=100.0, step=0.1,
        description='Freq X:', continuous_update=True
    )
    
    freq_y_slider = widgets.FloatSlider(
        value=1.0, min=0, max=100.0, step=0.1,
        description='Freq Y:', continuous_update=True
    )
    
    wave_type_dropdown = widgets.Dropdown(
        options=['sin', 'cos', 'square', 'checker'],
        value='sin',
        description='Wave Type:'
    )
    
    phase_slider = widgets.FloatSlider(
        value=0.0, min=0.0, max=2*np.pi, step=0.1,
        description='Phase:', continuous_update=True
    )
    
    out = widgets.interactive_output(update_wave_generator, {
        'size': size_slider,
        'freq_x': freq_x_slider,
        'freq_y': freq_y_slider,
        'wave_type': wave_type_dropdown,
        'phase': phase_slider
    })
    
    return widgets.VBox([
        widgets.HTML("<h3>Wave Texture Generator with K-Space Visualization</h3>"),
        size_slider, freq_x_slider, freq_y_slider,
        wave_type_dropdown, phase_slider,
        out
    ])

def create_tabbed_interface():
    """Create a tabbed interface containing all demos"""
    # Create the file upload widget
    upload = widgets.FileUpload(
        accept='image/*',
        multiple=False,
        description='Select Image'
    )
    upload.observe(handle_image_upload, names='value')
    
    tab = widgets.Tab()
    
    # Create containers for each demo
    upload_container = widgets.VBox([widgets.Label("Upload an image to begin:"), upload])
    wave_container = widgets.VBox()
    compression_container = widgets.VBox()
    noise_filter_container = widgets.VBox()
    combined_filter_container = widgets.VBox()
    
    # Set the children of the tab
    tab.children = [
        upload_container,
        wave_container,
        compression_container, 
        noise_filter_container,
        combined_filter_container
    ]
    
    # Set the titles
    tab.set_title(0, 'Upload Image')
    tab.set_title(1, 'Wave Generator')
    tab.set_title(2, 'FFT Compression')
    tab.set_title(3, 'Noise & Filtering')
    tab.set_title(4, 'Cross/Quadrant Filters')
    
    # Display the tab widget
    display(tab)
    
    # Callback function for tab selection
    def on_tab_selected(change):
        selected_index = change['new']
        
        # Clear all containers except the upload tab
        for i, container in enumerate(tab.children):
            if i != 0:  # Don't clear the upload tab
                container.children = []
        
        # Add content to the selected tab
        if selected_index == 0:
            pass  # Upload tab is already set up
        elif selected_index == 1:
            wave_container.children = [display_wave_generator_demo()]
        elif selected_index == 2:
            if current_img is not None:
                compression_container.children = [display_compression_demo()]
            else:
                compression_container.children = [widgets.HTML("Please upload or generate an image first.")]
        elif selected_index == 3:
            if current_img is not None:
                noise_filter_container.children = [display_noise_reduction_demo()]
            else:
                noise_filter_container.children = [widgets.HTML("Please upload or generate an image first.")]
        elif selected_index == 4:
            if current_img is not None:
                combined_filter_container.children = [display_combined_filter_demo()]
            else:
                combined_filter_container.children = [widgets.HTML("Please upload or generate an image first.")]
    
    # Register the callback
    tab.observe(on_tab_selected, names='selected_index')
    
    # Trigger the callback for the first tab
    on_tab_selected({'new': 0})

# Run the interface
create_tabbed_interface()



Tab(children=(VBox(children=(Label(value='Upload an image to begin:'), FileUpload(value=(), accept='image/*', …