Import necessary libraries

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact

Function to calculate the output size of a convolution operation

In [2]:
def calculate_output_size(input_size, filter_size, stride, padding):
    """Calculates the output size using the convolution formula."""
    output_size = ((input_size - filter_size + 2 * padding) // stride) + 1
    return output_size

Function to visualize the convolution process

In [5]:
def visualize_convolution(input_size, filter_size, stride, padding):
    """Displays a visual representation of convolution operation."""

    # Calculate output size
    output_size = calculate_output_size(input_size, filter_size, stride, padding)

    # Create figure
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))

    # Input image representation
    ax[0].set_title("Input Image")
    ax[0].imshow(np.ones((input_size, input_size)), cmap='Blues')
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].text(input_size//2 - 2, input_size//2, f"{input_size}x{input_size}", fontsize=14, color='black')

    # Filter representation
    ax[1].set_title("Filter")
    ax[1].imshow(np.ones((filter_size, filter_size)), cmap='Oranges')
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].text(filter_size//2 - 0.5, filter_size//2, f"{filter_size}x{filter_size}", fontsize=12, color='black')

    # Output image representation
    ax[2].set_title("Output Feature Map")
    ax[2].imshow(np.ones((output_size, output_size)), cmap='Greens')
    ax[2].set_xticks([])
    ax[2].set_yticks([])
    ax[2].text(output_size//2 - 1, output_size//2, f"{output_size}x{output_size}", fontsize=14, color='black')

    # Display formula
    plt.suptitle(
        f"Convolution Operation: Input {input_size}x{input_size} → Output {output_size}x{output_size}\n"
        f"Filter: {filter_size}x{filter_size}, Stride: {stride}, Padding: {padding}\n\n"
        r"$Output Size = \frac{(Input Size - Filter Size + 2 \times Padding)}{Stride} + 1$",
        fontsize=16, y=1.05
    )

    plt.tight_layout()
    plt.show()

#Interactive widget

interact(
    visualize_convolution,
    input_size=widgets.IntSlider(min=4, max=256, step=4, value=32, description="Input Size"),
    filter_size=widgets.IntSlider(min=2, max=15, step=1, value=3, description="Filter Size"),
    stride=widgets.IntSlider(min=1, max=10, step=1, value=1, description="Stride"),
    padding=widgets.IntSlider(min=0, max=10, step=1, value=0, description="Padding")
)

interactive(children=(IntSlider(value=32, description='Input Size', max=256, min=4, step=4), IntSlider(value=3…