# Fourier Transform: Analyzing Signals in the Frequency Domain

## Overview

The Fourier Transform is a powerful mathematical tool that allows us to decompose signals into their constituent frequencies. In this section, we'll explore the Fourier Transform and its relevance to motion perception and motion energy models.

### What we'll cover:
- The concept of spatial and temporal frequency
- The intuition behind the Fourier Transform
- Implementing and visualizing Fourier Transforms
- The frequency representation of motion
- How the Fourier Transform relates to visual processing in the brain

## Setting Up

Let's import the libraries we'll need for this section.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import scipy.signal as signal
from scipy.fft import fft, fft2, fftshift, fftn, ifft, ifft2, ifftn
import sys

# Add the utils package to the path
sys.path.append('../../..')
try:
    from motionenergy.utils import stimuli_generation, visualization
except ImportError:
    print("Note: utils modules not found. This is expected if you haven't implemented them yet.")

# For interactive plots
%matplotlib inline
from IPython.display import HTML, display

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 1. Introduction to Frequency Analysis

To understand motion energy models, we need to think about visual information in terms of frequencies. But what do we mean by "frequency" in the context of visual stimuli?

### Spatial Frequency

Spatial frequency refers to how rapidly a visual pattern changes over space. It's typically measured in cycles per degree of visual angle or cycles per pixel in digital images.

- High spatial frequencies correspond to fine details and sharp edges
- Low spatial frequencies correspond to coarse patterns and gradual changes

### Temporal Frequency

Temporal frequency refers to how rapidly a visual pattern changes over time. It's typically measured in cycles per second (Hz) or cycles per frame in digital videos.

- High temporal frequencies correspond to rapid changes
- Low temporal frequencies correspond to slow changes

Let's visualize some examples to build intuition:

In [None]:
def plot_spatial_frequencies():
    """Visualize gratings with different spatial frequencies."""
    # Define spatial parameters
    x = np.linspace(0, 1, 1000)
    
    # Create gratings with different frequencies
    fig, axes = plt.subplots(4, 1, figsize=(10, 8))
    
    for i, freq in enumerate([1, 2, 5, 10]):
        # Create a sinusoidal grating
        grating = np.sin(2 * np.pi * freq * x)
        
        # Plot the grating
        axes[i].plot(x, grating)
        axes[i].set_title(f'Spatial Frequency: {freq} cycles/unit')
        axes[i].set_ylim(-1.1, 1.1)
        axes[i].set_xlabel('Position')
        axes[i].set_ylabel('Intensity')
        axes[i].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Create a 2D visualization to better show spatial frequency
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    axes = axes.flatten()
    
    # Generate 2D gratings
    x = np.linspace(0, 1, 100)
    y = np.linspace(0, 1, 100)
    X, Y = np.meshgrid(x, y)
    
    for i, freq in enumerate([1, 2, 5, 10]):
        # Create a 2D sinusoidal grating
        grating = np.sin(2 * np.pi * freq * X)
        
        # Plot the grating
        im = axes[i].imshow(grating, cmap='gray', origin='lower', extent=[0, 1, 0, 1])
        axes[i].set_title(f'Spatial Frequency: {freq} cycles/unit')
        plt.colorbar(im, ax=axes[i])
    
    plt.tight_layout()
    plt.show()

# Plot spatial frequencies
plot_spatial_frequencies()

In [ ]:
def animate_temporal_frequency(freq, duration=2, fps=30):
    """Create an animation of a grating with a specific temporal frequency."""
    # Set up the figure
    fig, ax = plt.subplots(figsize=(8, 3))
    x = np.linspace(0, 1, 1000)
    # Fixed spatial frequency (5 cycles per unit)
    spatial_freq = 5
    line, = ax.plot(x, np.sin(2 * np.pi * spatial_freq * x))
    ax.set_ylim(-1.1, 1.1)
    ax.set_title(f'Temporal Frequency: {freq} Hz, Spatial Frequency: {spatial_freq} cycles/unit')
    ax.set_xlabel('Position')
    ax.set_ylabel('Intensity')
    ax.grid(True)
    
    # Animation function
    def animate(i):
        t = i / fps
        # Phase changes with time based on temporal frequency
        phase = 2 * np.pi * freq * t
        y = np.sin(2 * np.pi * spatial_freq * x + phase)
        line.set_ydata(y)
        return (line,)
    
    # Create the animation
    frames = int(duration * fps)
    anim = animation.FuncAnimation(fig, animate, frames=frames, interval=1000/fps, blit=True)
    
    # Display the animation
    return HTML(anim.to_jshtml())

In [None]:
# Animate a higher temporal frequency
animate_temporal_frequency(2.0)  # 2 Hz

## 2. The Fourier Transform: Mathematical Foundation

The Fourier Transform is based on the idea that any signal can be decomposed into a sum of sinusoids with different frequencies, amplitudes, and phases.

### Mathematical Definition

For a continuous signal $f(t)$, the Fourier Transform $F(\omega)$ is defined as:

$$F(\omega) = \int_{-\infty}^{\infty} f(t) e^{-i\omega t} dt$$

where $\omega$ is the angular frequency and $i$ is the imaginary unit.

For discrete signals, which are what we typically work with in digital data, we use the Discrete Fourier Transform (DFT). For a sequence $x[n]$ of length $N$, the DFT is:

$$X[k] = \sum_{n=0}^{N-1} x[n] e^{-i2\pi kn/N}$$

where $k = 0, 1, 2, \ldots, N-1$.

### Intuition Behind the Fourier Transform

We can think of the Fourier Transform as a "correlation" of the signal with sinusoids of different frequencies. For each frequency, we compute how much of that frequency is present in the signal. The result gives us the amplitude and phase of each frequency component.

Intuitively, the Fourier Transform is like decomposing a musical chord into its constituent notes. Just as a chord can be broken down into individual notes (frequencies) with specific volumes (amplitudes) and timing (phases), any signal can be broken down into sinusoidal components.

Let's visualize this idea by decomposing a complex signal into its frequency components:

In [None]:
def visualize_fourier_intuition():
    """Visualize the intuition behind the Fourier Transform."""
    # Create a signal composed of multiple frequencies
    t = np.linspace(0, 1, 1000, endpoint=False)
    
    # Components at different frequencies
    f1 = 5    # 5 Hz component
    f2 = 10   # 10 Hz component
    f3 = 20   # 20 Hz component
    
    # Signal with three components
    signal = 1.0 * np.sin(2 * np.pi * f1 * t) + \
             0.5 * np.sin(2 * np.pi * f2 * t) + \
             0.25 * np.sin(2 * np.pi * f3 * t)
    
    # Compute the FFT
    fft_result = fft(signal)
    freqs = np.fft.fftfreq(len(t), t[1] - t[0])
    
    # Plot the signal and its components
    fig, axes = plt.subplots(5, 1, figsize=(12, 10))
    
    # Plot the composite signal
    axes[0].plot(t, signal)
    axes[0].set_title('Composite Signal')
    axes[0].set_xlabel('Time (s)')
    axes[0].set_ylabel('Amplitude')
    axes[0].grid(True)
    
    # Plot the individual components
    axes[1].plot(t, 1.0 * np.sin(2 * np.pi * f1 * t))
    axes[1].set_title(f'Component 1: {f1} Hz')
    axes[1].set_xlabel('Time (s)')
    axes[1].set_ylabel('Amplitude')
    axes[1].grid(True)
    
    axes[2].plot(t, 0.5 * np.sin(2 * np.pi * f2 * t))
    axes[2].set_title(f'Component 2: {f2} Hz')
    axes[2].set_xlabel('Time (s)')
    axes[2].set_ylabel('Amplitude')
    axes[2].grid(True)
    
    axes[3].plot(t, 0.25 * np.sin(2 * np.pi * f3 * t))
    axes[3].set_title(f'Component 3: {f3} Hz')
    axes[3].set_xlabel('Time (s)')
    axes[3].set_ylabel('Amplitude')
    axes[3].grid(True)
    
    # Plot the FFT magnitude
    # Only plot positive frequencies up to 30 Hz
    mask = (freqs > 0) & (freqs < 30)
    axes[4].stem(freqs[mask], 2 * np.abs(fft_result[mask]) / len(t))
    axes[4].set_title('Frequency Spectrum (Fourier Transform)')
    axes[4].set_xlabel('Frequency (Hz)')
    axes[4].set_ylabel('Amplitude')
    axes[4].grid(True)
    
    plt.tight_layout()
    plt.show()

# Visualize the Fourier Transform intuition
visualize_fourier_intuition()

### Properties of the Fourier Transform

The Fourier Transform has several important properties that are useful to understand:

1. **Linearity**: The Fourier Transform of a sum of signals is the sum of their Fourier Transforms.
2. **Time Shift**: A shift in time corresponds to a phase change in the frequency domain.
3. **Scaling**: Stretching a signal in time compresses its frequency representation and vice versa.
4. **Convolution Theorem**: Convolution in the time/space domain corresponds to multiplication in the frequency domain.

Let's explore some of these properties with examples:

In [None]:
def demonstrate_fourier_properties():
    """Demonstrate key properties of the Fourier Transform."""
    # Create a signal
    t = np.linspace(0, 1, 1000, endpoint=False)
    signal = np.sin(2 * np.pi * 10 * t) + 0.5 * np.sin(2 * np.pi * 20 * t)
    
    # Create a shifted signal (time shift property)
    shift = 0.1  # Shift by 0.1 seconds
    t_shifted = t - shift
    signal_shifted = np.sin(2 * np.pi * 10 * t_shifted) + 0.5 * np.sin(2 * np.pi * 20 * t_shifted)
    
    # Create a scaled signal (scaling property)
    scale = 2.0  # Stretch in time
    signal_scaled = np.sin(2 * np.pi * 10 * t / scale) + 0.5 * np.sin(2 * np.pi * 20 * t / scale)
    
    # Compute FFTs
    fft_signal = fft(signal)
    fft_shifted = fft(signal_shifted)
    fft_scaled = fft(signal_scaled)
    
    # Get frequency values
    freqs = np.fft.fftfreq(len(t), t[1] - t[0])
    
    # Plot the signals and their FFTs
    fig, axes = plt.subplots(3, 2, figsize=(14, 10))
    
    # Original signal
    axes[0, 0].plot(t, signal)
    axes[0, 0].set_title('Original Signal')
    axes[0, 0].set_xlabel('Time (s)')
    axes[0, 0].set_ylabel('Amplitude')
    axes[0, 0].grid(True)
    
    # FFT of original signal
    mask = (freqs > 0) & (freqs < 30)
    axes[0, 1].stem(freqs[mask], 2 * np.abs(fft_signal[mask]) / len(t))
    axes[0, 1].set_title('FFT of Original Signal')
    axes[0, 1].set_xlabel('Frequency (Hz)')
    axes[0, 1].set_ylabel('Amplitude')
    axes[0, 1].grid(True)
    
    # Shifted signal
    axes[1, 0].plot(t, signal_shifted)
    axes[1, 0].set_title(f'Time-Shifted Signal (shift = {shift}s)')
    axes[1, 0].set_xlabel('Time (s)')
    axes[1, 0].set_ylabel('Amplitude')
    axes[1, 0].grid(True)
    
    # FFT of shifted signal
    axes[1, 1].stem(freqs[mask], 2 * np.abs(fft_shifted[mask]) / len(t))
    axes[1, 1].set_title('FFT of Shifted Signal (Amplitudes Unchanged, Phase Changed)')
    axes[1, 1].set_xlabel('Frequency (Hz)')
    axes[1, 1].set_ylabel('Amplitude')
    axes[1, 1].grid(True)
    
    # Scaled signal
    axes[2, 0].plot(t, signal_scaled)
    axes[2, 0].set_title(f'Time-Scaled Signal (scale = {scale})')
    axes[2, 0].set_xlabel('Time (s)')
    axes[2, 0].set_ylabel('Amplitude')
    axes[2, 0].grid(True)
    
    # FFT of scaled signal
    axes[2, 1].stem(freqs[mask], 2 * np.abs(fft_scaled[mask]) / len(t))
    axes[2, 1].set_title('FFT of Scaled Signal (Frequencies Compressed)')
    axes[2, 1].set_xlabel('Frequency (Hz)')
    axes[2, 1].set_ylabel('Amplitude')
    axes[2, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

# Demonstrate Fourier Transform properties
demonstrate_fourier_properties()

## 3. Implementing the Fourier Transform

In practice, we typically use the Fast Fourier Transform (FFT) algorithm, which efficiently computes the Discrete Fourier Transform. Let's implement a simple function to compute and visualize the Fourier Transform of a signal.

In [None]:
def compute_and_plot_fft(signal, sampling_rate, max_freq=None):
    """
    Compute and plot the FFT of a signal.
    
    Parameters:
    -----------
    signal : ndarray
        Input signal
    sampling_rate : float
        Sampling rate in Hz
    max_freq : float or None
        Maximum frequency to plot (in Hz)
    """
    # Compute the FFT
    fft_result = fft(signal)
    n = len(signal)
    
    # Compute the frequencies
    freqs = np.fft.fftfreq(n, 1/sampling_rate)
    
    # Compute the magnitude of the FFT (normalized)
    magnitude = 2 * np.abs(fft_result) / n
    
    # Compute the phase of the FFT
    phase = np.angle(fft_result)
    
    # Plot the signal and its FFT
    fig, axes = plt.subplots(3, 1, figsize=(12, 10))
    
    # Plot the time-domain signal
    t = np.arange(n) / sampling_rate
    axes[0].plot(t, signal)
    axes[0].set_title('Time Domain Signal')
    axes[0].set_xlabel('Time (s)')
    axes[0].set_ylabel('Amplitude')
    axes[0].grid(True)
    
    # Plot the magnitude spectrum
    if max_freq is not None:
        mask = (freqs >= 0) & (freqs <= max_freq)
    else:
        mask = freqs >= 0  # Only plot positive frequencies
    
    axes[1].stem(freqs[mask], magnitude[mask])
    axes[1].set_title('Magnitude Spectrum')
    axes[1].set_xlabel('Frequency (Hz)')
    axes[1].set_ylabel('Magnitude')
    axes[1].grid(True)
    
    # Plot the phase spectrum
    axes[2].stem(freqs[mask], phase[mask])
    axes[2].set_title('Phase Spectrum')
    axes[2].set_xlabel('Frequency (Hz)')
    axes[2].set_ylabel('Phase (radians)')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()

# Generate a test signal
sampling_rate = 1000  # 1000 Hz
t = np.arange(0, 1, 1/sampling_rate)  # 1 second of data
f1, f2, f3 = 5, 20, 50  # Frequencies in Hz
signal = 1.0 * np.sin(2 * np.pi * f1 * t) + \
         0.5 * np.sin(2 * np.pi * f2 * t + np.pi/4) + \
         0.25 * np.sin(2 * np.pi * f3 * t + np.pi/2) + \
         0.1 * np.random.randn(len(t))  # Add some noise

# Compute and plot the FFT
compute_and_plot_fft(signal, sampling_rate, max_freq=60)

## 4. The 2D Fourier Transform for Images

For image processing and motion analysis, we often use the 2D Fourier Transform, which decomposes an image into its spatial frequency components.

Let's implement and visualize the 2D Fourier Transform of some example images:

In [ ]:
def compute_and_plot_2d_fft(image, log_scale=True):
    """
    Compute and plot the 2D Fourier Transform of an image.
    
    Parameters:
    -----------
    image : ndarray
        Input image (2D array)
    log_scale : bool
        Whether to use a logarithmic scale for the magnitude spectrum
    """
    # Compute the 2D FFT
    fft_result = fft2(image)
    
    # Shift the zero frequency component to the center
    fft_shifted = fftshift(fft_result)
    
    # Compute the magnitude of the FFT
    magnitude = np.abs(fft_shifted)
    
    # Apply log scale if requested
    if log_scale:
        magnitude = np.log1p(magnitude)  # log(1 + x) to avoid log(0)
    
    # Compute the phase
    phase = np.angle(fft_shifted)
    
    # Create a figure to display the image and its FFT
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Display the original image
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Display the magnitude spectrum
    im1 = axes[1].imshow(magnitude, cmap='viridis')
    axes[1].set_title('Magnitude Spectrum' + (' (Log Scale)' if log_scale else ''))
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], label='Magnitude' + (' (log scale)' if log_scale else ''))
    
    # Display the phase spectrum
    im2 = axes[2].imshow(phase, cmap='hsv')
    axes[2].set_title('Phase Spectrum')
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], label='Phase (radians)')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Create and visualize a vertical grating
vertical_grating = create_grating_image(256, 8, orientation=np.pi/2)
compute_and_plot_2d_fft(vertical_grating)

In [None]:
# Create and visualize a diagonal grating
diagonal_grating = create_grating_image(256, 8, orientation=np.pi/4)
compute_and_plot_2d_fft(diagonal_grating)

In [None]:
# Create and visualize a plaid pattern (sum of two gratings)
grating1 = create_grating_image(256, 8, orientation=0)
grating2 = create_grating_image(256, 8, orientation=np.pi/2)
plaid = (grating1 + grating2) / 2  # Average to keep the values in the -1 to 1 range
compute_and_plot_2d_fft(plaid)

## 5. Frequency Domain Filtering

One of the most powerful applications of the Fourier Transform is filtering in the frequency domain. Instead of convolving a signal with a filter kernel in the spatial/temporal domain, we can multiply the signal's Fourier Transform with the filter's frequency response.

This is based on the convolution theorem, which states that convolution in the spatial/temporal domain is equivalent to multiplication in the frequency domain:

$$f * g \Leftrightarrow F \cdot G$$

Let's implement and visualize frequency domain filtering on an image:

In [None]:
def frequency_domain_filtering(image, filter_type, cutoff):
    """
    Apply filtering in the frequency domain.
    
    Parameters:
    -----------
    image : ndarray
        Input image (2D array)
    filter_type : str
        Type of filter ('lowpass', 'highpass', or 'bandpass')
    cutoff : float or tuple
        Cutoff frequency (normalized to [0, 1]), or tuple of (low, high) for bandpass
    """
    # Get image dimensions
    rows, cols = image.shape
    
    # Create a meshgrid for the filter
    crow, ccol = rows // 2, cols // 2  # Center coordinates
    y, x = np.ogrid[-crow:rows-crow, -ccol:cols-ccol]
    
    # Compute the distance from the center
    r = np.sqrt((x/ccol)**2 + (y/crow)**2)  # Normalize to [0, 1]
    
    # Create the filter
    if filter_type == 'lowpass':
        # Ideal low-pass filter
        mask = r <= cutoff
        title = f'Low-pass Filter (cutoff={cutoff:.2f})'
    elif filter_type == 'highpass':
        # Ideal high-pass filter
        mask = r >= cutoff
        title = f'High-pass Filter (cutoff={cutoff:.2f})'
    elif filter_type == 'bandpass':
        # Ideal band-pass filter
        low, high = cutoff
        mask = (r >= low) & (r <= high)
        title = f'Band-pass Filter (cutoff={low:.2f}-{high:.2f})'
    else:
        raise ValueError(f"Unknown filter type: {filter_type}")
    
    # Compute the FFT of the image
    f = fft2(image)
    fshift = fftshift(f)
    
    # Apply the filter
    fshift_filtered = fshift * mask
    
    # Inverse FFT to get the filtered image
    f_ishift = np.fft.ifftshift(fshift_filtered)
    img_filtered = np.real(ifft2(f_ishift))
    
    # Visualize the results
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    # Original image
    axes[0, 0].imshow(image, cmap='gray')
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    # Filter in frequency domain
    axes[0, 1].imshow(mask, cmap='gray')
    axes[0, 1].set_title(title)
    axes[0, 1].axis('off')
    
    # Magnitude spectrum after filtering
    magnitude = np.log1p(np.abs(fshift_filtered))
    axes[1, 0].imshow(magnitude, cmap='viridis')
    axes[1, 0].set_title('Filtered Magnitude Spectrum (Log Scale)')
    axes[1, 0].axis('off')
    
    # Filtered image
    axes[1, 1].imshow(img_filtered, cmap='gray')
    axes[1, 1].set_title('Filtered Image')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return img_filtered

# Create a test image with multiple spatial frequencies
def create_test_image(size=256):
    """
    Create a test image with various spatial frequencies.
    """
    image = np.zeros((size, size))
    
    # Add gratings with different frequencies
    for freq, amp in [(4, 0.5), (8, 0.3), (16, 0.2), (32, 0.1)]:
        grating = amp * create_grating_image(size, freq, np.random.uniform(0, np.pi))
        image += grating
    
    # Normalize to [0, 1]
    image = (image - np.min(image)) / (np.max(image) - np.min(image))
    
    return image

# Create a test image
test_image = create_test_image(256)

# Apply low-pass filtering
_ = frequency_domain_filtering(test_image, 'lowpass', 0.2)

In [None]:
# Apply high-pass filtering
_ = frequency_domain_filtering(test_image, 'highpass', 0.2)

In [None]:
# Apply band-pass filtering
_ = frequency_domain_filtering(test_image, 'bandpass', (0.1, 0.3))

## 6. The Fourier Transform of Motion

One of the most important applications of the Fourier Transform in motion energy models is analyzing the frequency-domain representation of moving stimuli.

Moving patterns have a specific signature in the spatio-temporal frequency domain. Let's visualize this by creating a moving grating and computing its 3D Fourier Transform (two spatial dimensions and one temporal dimension).

In [None]:
def create_moving_grating(size, frames, spatial_freq, velocity, direction=0):
    """
    Create a moving grating.
    
    Parameters:
    -----------
    size : int
        Size of each frame (size x size)
    frames : int
        Number of frames
    spatial_freq : float
        Spatial frequency in cycles per frame
    velocity : float
        Velocity in pixels per frame
    direction : float
        Direction of motion in radians (0 for rightward, π/2 for upward)
    """
    # Create a 3D array to hold the moving grating
    grating = np.zeros((size, size, frames))
    
    # Create a meshgrid for spatial coordinates
    x = np.linspace(-size/2, size/2, size)
    y = np.linspace(-size/2, size/2, size)
    X, Y = np.meshgrid(x, y)
    
    # Generate each frame
    for t in range(frames):
        # Phase changes with time based on velocity and direction
        phase = 2 * np.pi * velocity * t / size
        
        # Rotate coordinates for direction
        X_rot = X * np.cos(direction) + Y * np.sin(direction)
        
        # Create the grating with phase advancing with time
        grating[:, :, t] = np.sin(2 * np.pi * spatial_freq * X_rot / size - phase)
    
    return grating

def visualize_moving_grating(grating):
    """
    Visualize a moving grating.
    
    Parameters:
    -----------
    grating : ndarray
        3D array (height x width x frames)
    """
    fig, ax = plt.subplots(figsize=(6, 6))
    im = ax.imshow(grating[:, :, 0], cmap='gray', vmin=-1, vmax=1)
    ax.set_title('Moving Grating')
    ax.axis('off')
    
    def update(frame):
        im.set_array(grating[:, :, frame])
        return [im]
    
    ani = animation.FuncAnimation(fig, update, frames=grating.shape[2], interval=50, blit=True)
    return HTML(ani.to_jshtml())

def compute_3d_fft(grating):
    """
    Compute the 3D FFT of a moving grating and visualize slices.
    
    Parameters:
    -----------
    grating : ndarray
        3D array (height x width x frames)
    """
    # Compute the 3D FFT
    fft_result = fftn(grating)
    fft_shifted = np.fft.fftshift(fft_result)
    
    # Compute the magnitude
    magnitude = np.abs(fft_shifted)
    magnitude = np.log1p(magnitude)  # log scale for better visualization
    
    # Get the center coordinates
    h, w, d = magnitude.shape
    ch, cw, cd = h // 2, w // 2, d // 2
    
    # Create a figure to visualize the 3D FFT slices
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # X-Y slice (at temporal frequency = 0)
    im1 = axes[0].imshow(magnitude[:, :, cd], cmap='viridis')
    axes[0].set_title('X-Y Slice (Spatial Frequencies)')
    axes[0].set_xlabel('X Frequency')
    axes[0].set_ylabel('Y Frequency')
    plt.colorbar(im1, ax=axes[0])
    
    # X-T slice (at y = center)
    im2 = axes[1].imshow(magnitude[ch, :, :], cmap='viridis')
    axes[1].set_title('X-T Slice (Space-Time)')
    axes[1].set_xlabel('X Frequency')
    axes[1].set_ylabel('Temporal Frequency')
    plt.colorbar(im2, ax=axes[1])
    
    # Y-T slice (at x = center)
    im3 = axes[2].imshow(magnitude[:, cw, :], cmap='viridis')
    axes[2].set_title('Y-T Slice (Space-Time)')
    axes[2].set_xlabel('Y Frequency')
    axes[2].set_ylabel('Temporal Frequency')
    plt.colorbar(im3, ax=axes[2])
    
    plt.tight_layout()
    plt.show()

# Create a rightward moving grating
size = 128
frames = 32
spatial_freq = 8
velocity = 2
rightward_grating = create_moving_grating(size, frames, spatial_freq, velocity, direction=0)

# Visualize the moving grating
visualize_moving_grating(rightward_grating)

In [None]:
# Compute and visualize the 3D FFT of the rightward moving grating
compute_3d_fft(rightward_grating)

In [None]:
# Create an upward moving grating
upward_grating = create_moving_grating(size, frames, spatial_freq, velocity, direction=np.pi/2)

# Visualize the moving grating
visualize_moving_grating(upward_grating)

In [None]:
# Compute and visualize the 3D FFT of the upward moving grating
compute_3d_fft(upward_grating)

### Motion in the Frequency Domain

In the frequency domain, motion appears as a tilted plane in the spatiotemporal frequency domain. The orientation of this plane corresponds to the direction and speed of motion. This is a key insight for understanding motion energy models.

A few important observations:

1. A stationary pattern has energy only at the spatial frequency axis (temporal frequency = 0).
2. A rightward moving pattern has energy in the right half of the X-T plane, while a leftward moving pattern has energy in the left half.
3. The slope of the energy in the X-T plane is proportional to the velocity of motion. Faster motion corresponds to a steeper slope.

**Physical interpretation of tilted planes:**
A pattern moving at constant velocity traces a slanted line in the space-time domain. When transformed to the frequency domain, this becomes a plane perpendicular to the space-time trajectory. This is why constant motion appears as a tilted plane in the frequency domain.

This frequency-domain representation of motion has profound implications because it means that directional motion can be detected by filters that are selective for specific orientations in the spatiotemporal frequency domain. The visual system appears to use this principle to detect and analyze motion.

## 7. Connection to Motion Energy Models

How does the Fourier Transform relate to motion energy models?

Motion energy models use quadrature pairs of spatiotemporal filters to detect motion. These filters are designed to be selective for specific spatiotemporal frequencies, which correspond to specific directions and speeds of motion.

In the frequency domain, these filters have localized responses in the spatiotemporal frequency domain, allowing them to selectively respond to motion in specific directions and at specific speeds.

Let's visualize a simple motion energy filter in both the spatiotemporal domain and the frequency domain:

In [None]:
def create_motion_energy_filter(size, frames, spatial_freq, velocity, sigma_space, sigma_time):
    """
    Create a spatiotemporal filter tuned to a specific direction and speed of motion.
    
    Parameters:
    -----------
    size : int
        Spatial size of the filter
    frames : int
        Temporal size of the filter
    spatial_freq : float
        Spatial frequency in cycles per size
    velocity : float
        Velocity in pixels per frame
    sigma_space : float
        Standard deviation of the spatial Gaussian envelope
    sigma_time : float
        Standard deviation of the temporal Gaussian envelope
    """
    # Create a 3D array to hold the filter
    filter_3d = np.zeros((size, size, frames))
    
    # Create a meshgrid for spatial coordinates
    x = np.linspace(-size/2, size/2, size)
    y = np.linspace(-size/2, size/2, size)
    X, Y = np.meshgrid(x, y)
    
    # Create a meshgrid for temporal coordinates
    t = np.linspace(-frames/2, frames/2, frames)
    
    # Generate the 3D filter
    for i, time in enumerate(t):
        # Phase depends on position and time
        phase = 2 * np.pi * spatial_freq * X / size - 2 * np.pi * velocity * time / size
        
        # Spatial and temporal Gaussian envelopes
        space_env = np.exp(-(X**2 + Y**2) / (2 * sigma_space**2))
        time_env = np.exp(-time**2 / (2 * sigma_time**2))
        
        # Combine to form the 3D filter
        filter_3d[:, :, i] = np.cos(phase) * space_env * time_env
    
    return filter_3d

def visualize_motion_energy_filter(filter_3d):
    """
    Visualize a motion energy filter in both the spatiotemporal domain and the frequency domain.
    
    Parameters:
    -----------
    filter_3d : ndarray
        3D filter (height x width x frames)
    """
    # Get the midpoint coordinates
    h, w, d = filter_3d.shape
    mid_h, mid_w, mid_d = h // 2, w // 2, d // 2
    
    # Create a figure to visualize the filter
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Spatiotemporal domain slices
    # X-Y slice (at middle time)
    im1 = axes[0, 0].imshow(filter_3d[:, :, mid_d], cmap='RdBu', vmin=-1, vmax=1)
    axes[0, 0].set_title('X-Y Slice (Spatial)')
    plt.colorbar(im1, ax=axes[0, 0])
    
    # X-T slice (at middle y)
    im2 = axes[0, 1].imshow(filter_3d[mid_h, :, :], cmap='RdBu', vmin=-1, vmax=1)
    axes[0, 1].set_title('X-T Slice (Space-Time)')
    plt.colorbar(im2, ax=axes[0, 1])
    
    # Y-T slice (at middle x)
    im3 = axes[0, 2].imshow(filter_3d[:, mid_w, :], cmap='RdBu', vmin=-1, vmax=1)
    axes[0, 2].set_title('Y-T Slice (Space-Time)')
    plt.colorbar(im3, ax=axes[0, 2])
    
    # Compute the 3D FFT
    fft_result = fftn(filter_3d)
    fft_shifted = np.fft.fftshift(fft_result)
    
    # Compute the magnitude
    magnitude = np.abs(fft_shifted)
    magnitude = np.log1p(magnitude)  # log scale for better visualization
    
    # Frequency domain slices
    # X-Y slice (at middle temporal frequency)
    im4 = axes[1, 0].imshow(magnitude[:, :, mid_d], cmap='viridis')
    axes[1, 0].set_title('X-Y Slice (Spatial Frequencies)')
    plt.colorbar(im4, ax=axes[1, 0])
    
    # X-T slice (at middle y frequency)
    im5 = axes[1, 1].imshow(magnitude[mid_h, :, :], cmap='viridis')
    axes[1, 1].set_title('X-T Slice (Space-Time Frequencies)')
    plt.colorbar(im5, ax=axes[1, 1])
    
    # Y-T slice (at middle x frequency)
    im6 = axes[1, 2].imshow(magnitude[:, mid_w, :], cmap='viridis')
    axes[1, 2].set_title('Y-T Slice (Space-Time Frequencies)')
    plt.colorbar(im6, ax=axes[1, 2])
    
    plt.tight_layout()
    plt.show()

# Create a motion energy filter tuned to rightward motion
size = 64
frames = 32
spatial_freq = 0.2
velocity = 1
sigma_space = 10
sigma_time = 8

rightward_filter = create_motion_energy_filter(size, frames, spatial_freq, velocity, sigma_space, sigma_time)

# Visualize the filter
visualize_motion_energy_filter(rightward_filter)

In [None]:
# Create a motion energy filter tuned to leftward motion
leftward_filter = create_motion_energy_filter(size, frames, spatial_freq, -velocity, sigma_space, sigma_time)

# Visualize the filter
visualize_motion_energy_filter(leftward_filter)

## 8. Connections to Visual Processing in the Brain

The Fourier Transform and frequency-domain analysis have important connections to how the visual system processes information.

### Receptive Fields as Filters

Neurons in the visual cortex have receptive fields that can be modeled as spatiotemporal filters. Simple cells in V1 have receptive fields that resemble Gabor filters, which are localized in both space and frequency. Complex cells can be modeled as combinations of simple cells with different phases, which makes them more invariant to the exact position of features.

### Spatiotemporal Frequency Tuning

Neurons in the visual system are tuned to specific spatiotemporal frequencies. Some neurons prefer high spatial frequencies (fine details), while others prefer low spatial frequencies (coarse patterns). Similarly, some neurons prefer high temporal frequencies (rapid changes), while others prefer low temporal frequencies (slow changes).

### Direction Selectivity

Direction-selective neurons in areas like MT/V5 prefer motion in specific directions. This can be modeled as a preference for specific orientations in the spatiotemporal frequency domain. By combining responses from neurons with different spatiotemporal frequency tuning, the visual system can estimate the direction and speed of motion.

These connections highlight the importance of understanding the Fourier Transform and frequency-domain analysis for modeling motion perception in the brain.

## 9. Summary

In this section, we've explored the Fourier Transform and its applications to motion analysis. Here's a summary of what we've learned:

1. **Frequency Analysis**: We've learned about spatial and temporal frequencies and how they relate to visual stimuli.

2. **The Fourier Transform**: We've explored the mathematical foundation of the Fourier Transform and how it decomposes signals into their frequency components.

3. **Implementation**: We've implemented the Fourier Transform for 1D signals, 2D images, and 3D spatiotemporal data.

4. **Frequency Domain Filtering**: We've seen how filtering can be performed efficiently in the frequency domain using the convolution theorem.

5. **The Frequency Representation of Motion**: We've visualized how motion appears in the spatiotemporal frequency domain and how this relates to motion energy models.

6. **Connections to Visual Processing**: We've discussed how the Fourier Transform and frequency-domain analysis relate to how the visual system processes information.

Understanding the Fourier Transform and frequency-domain analysis is crucial for building and understanding motion energy models, which we'll explore in more detail in later sections.

In the exercises, you'll implement and apply these concepts to gain a deeper understanding of how they work and how they can be used for motion analysis.

## Further Reading

- Smith, S. W. (1997). The Scientist and Engineer's Guide to Digital Signal Processing. California Technical Publishing. [Available online](http://www.dspguide.com/)

- Oppenheim, A. V., & Schafer, R. W. (2009). Discrete-Time Signal Processing (3rd ed.). Prentice Hall.

- Watson, A. B., & Ahumada, A. J. (1985). Model of human visual-motion sensing. Journal of the Optical Society of America A, 2(2), 322-342.

- Adelson, E. H., & Bergen, J. R. (1985). Spatiotemporal energy models for the perception of motion. Journal of the Optical Society of America A, 2(2), 284-299.

- De Valois, R. L., & De Valois, K. K. (1990). Spatial Vision. Oxford University Press.