---

## Table of Contents

1. [Introduction ‚Äî The Problem with FFT](#1-introduction--the-problem-with-fft)
2. [Intuition ‚Äî Windows in Time](#2-intuition--windows-in-time)
3. [Short-Time Fourier Transform (STFT)](#3-short-time-fourier-transform-stft)
4. [Limitations of STFT](#4-limitations-of-stft)
5. [Introduction to Wavelets](#5-introduction-to-wavelets)
6. [The Morlet Wavelet](#6-the-morlet-wavelet)
7. [Wavelet Convolution](#7-wavelet-convolution)
8. [Time-Frequency Representation](#8-time-frequency-representation)
9. [Choosing Wavelet Parameters](#9-choosing-wavelet-parameters)
10. [Extracting Phase from Wavelets](#10-extracting-phase-from-wavelets)
11. [Wavelet vs Hilbert Approach](#11-wavelet-vs-hilbert-approach)
12. [Edge Effects](#12-edge-effects)
13. [Practical Application ‚Äî Event-Related Time-Frequency](#13-practical-application--event-related-time-frequency)
14. [Hyperscanning Application ‚Äî Time-Resolved Connectivity](#14-hyperscanning-application--time-resolved-connectivity)
15. [Exercises](#15-exercises)
16. [Summary](#16-summary)
17. [Discussion Questions](#17-discussion-questions)

In [None]:
# ============================================================================
# IMPORTS AND SETUP
# ============================================================================
import sys
from pathlib import Path
from typing import Tuple, Optional, Union

import numpy as np
from numpy.typing import NDArray
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.signal import hilbert, stft, spectrogram
from scipy.fft import fft, ifft, fftfreq

# Add src to path for local imports
sys.path.insert(0, str(Path.cwd().parents[2]))

from src.colors import COLORS
from src.filtering import bandpass_filter

# Color aliases for convenience - following style guide
PRIMARY_BLUE = COLORS["signal_1"]      # Sky Blue
PRIMARY_RED = COLORS["signal_2"]       # Rose Pink
PRIMARY_GREEN = COLORS["signal_3"]     # Sage Green
SECONDARY_PURPLE = COLORS["high_sync"] # Purple
SECONDARY_ORANGE = COLORS["signal_4"]  # Golden (used as orange)
ACCENT_PURPLE = COLORS["signal_5"]     # Lavender
ACCENT_GOLD = COLORS["signal_4"]       # Golden

# Sampling frequency (standard EEG)
fs = 256  # Hz

# Set random seed for reproducibility
np.random.seed(42)

print("‚úì Imports successful!")
print(f"NumPy version: {np.__version__}")

---

## 1. Introduction ‚Äî The Problem with FFT

The **Fast Fourier Transform (FFT)** is a powerful tool that decomposes a signal into its frequency components. However, it has a fundamental limitation:

**FFT tells us WHICH frequencies are present, but not WHEN.**

This is problematic for EEG analysis because:

- Neural oscillations are **non-stationary**: they come and go
- Alpha bursts appear and disappear over seconds
- Cognitive states change during an experiment
- Social interactions involve dynamic synchronization

FFT assumes the signal is **stationary** (same statistics throughout), which is rarely true for brain signals.

**Solution**: We need methods that provide BOTH time AND frequency information ‚Äî this is **time-frequency analysis**.

In [None]:
# ============================================================================
# VISUALIZATION 1: FFT Loses Temporal Information
# ============================================================================

# Create a non-stationary signal: frequency changes over time
duration = 9.0  # seconds
t = np.arange(0, duration, 1/fs)

# Three segments with different frequencies
# 0-3s: 5 Hz, 3-6s: 15 Hz, 6-9s: 10 Hz
segment1 = np.sin(2 * np.pi * 5 * t[t < 3])
segment2 = np.sin(2 * np.pi * 15 * t[(t >= 3) & (t < 6)])
segment3 = np.sin(2 * np.pi * 10 * t[t >= 6])

# Concatenate
signal_nonstat = np.concatenate([segment1, segment2, segment3])

# Add small noise
signal_nonstat += 0.1 * np.random.randn(len(signal_nonstat))

# Compute FFT
n = len(signal_nonstat)
fft_result = fft(signal_nonstat)
frequencies = fftfreq(n, 1/fs)
magnitude = np.abs(fft_result) / n

# Only positive frequencies
pos_mask = frequencies >= 0
freq_pos = frequencies[pos_mask]
mag_pos = 2 * magnitude[pos_mask]  # Double for one-sided

# Create figure
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Top: Time domain
axes[0].plot(t, signal_nonstat, color=PRIMARY_BLUE, linewidth=0.8)
axes[0].axvline(x=3, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
axes[0].axvline(x=6, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
axes[0].annotate('5 Hz', xy=(1.5, 1.2), ha='center', fontsize=12, fontweight='bold', color=COLORS["theta"])
axes[0].annotate('15 Hz', xy=(4.5, 1.2), ha='center', fontsize=12, fontweight='bold', color=COLORS["beta"])
axes[0].annotate('10 Hz', xy=(7.5, 1.2), ha='center', fontsize=12, fontweight='bold', color=COLORS["alpha"])
axes[0].set_xlabel('Time (s)', fontsize=12)
axes[0].set_ylabel('Amplitude', fontsize=12)
axes[0].set_title('Time Domain: We Can See WHEN Each Frequency Occurs', fontsize=13, fontweight='bold')
axes[0].set_xlim(0, 9)
axes[0].grid(True, alpha=0.3)

# Bottom: Frequency domain (FFT)
axes[1].plot(freq_pos, mag_pos, color=SECONDARY_PURPLE, linewidth=1.5)
axes[1].axvline(x=5, color=COLORS["theta"], linestyle='--', linewidth=2, label='5 Hz')
axes[1].axvline(x=10, color=COLORS["alpha"], linestyle='--', linewidth=2, label='10 Hz')
axes[1].axvline(x=15, color=COLORS["beta"], linestyle='--', linewidth=2, label='15 Hz')
axes[1].set_xlabel('Frequency (Hz)', fontsize=12)
axes[1].set_ylabel('Magnitude', fontsize=12)
axes[1].set_title('FFT Spectrum: Shows WHICH Frequencies, But Not WHEN!', fontsize=13, fontweight='bold', color='red')
axes[1].set_xlim(0, 30)
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

# Add annotation
axes[1].annotate('All 3 frequencies visible,\nbut timing is lost!', 
                 xy=(20, mag_pos.max()*0.7), fontsize=11, ha='center',
                 bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='orange'))

plt.tight_layout()
plt.show()

print("‚Üí FFT reveals all three frequencies (5, 10, 15 Hz)")
print("‚Üí But we can't tell that 5 Hz was first, 15 Hz second, 10 Hz last!")
print("‚Üí For non-stationary signals, FFT alone is insufficient.")

---

## 2. Intuition ‚Äî Windows in Time

The solution is intuitive: **compute the FFT on SHORT windows of the signal**, then slide the window through time.

This gives us frequency content at each time point!

However, there's a fundamental **trade-off**:

| Window Size | Time Resolution | Frequency Resolution |
|-------------|-----------------|---------------------|
| **Short** | Good (precise timing) | Poor (frequencies blur together) |
| **Long** | Poor (timing uncertain) | Good (frequencies well-separated) |

This is the **Heisenberg uncertainty principle** for signals:

$$\Delta t \times \Delta f \geq \text{constant}$$

**You cannot have perfect resolution in BOTH time AND frequency simultaneously.**

This is not a limitation of our methods ‚Äî it's a fundamental property of signals!

In [None]:
# ============================================================================
# VISUALIZATION 2: Time-Frequency Trade-off with Different Window Sizes
# ============================================================================

from scipy.signal import spectrogram as scipy_spectrogram

# Use the non-stationary signal from before
window_sizes = [0.2, 1.0, 3.0]  # seconds
titles = ['Short Window (0.2s)\nGood time, poor frequency', 
          'Medium Window (1.0s)\nBalanced',
          'Long Window (3.0s)\nGood frequency, poor time']

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for idx, (win_sec, title) in enumerate(zip(window_sizes, titles)):
    nperseg = int(win_sec * fs)
    
    # Compute spectrogram
    f, t_spec, Sxx = scipy_spectrogram(signal_nonstat, fs=fs, nperseg=nperseg, 
                                        noverlap=nperseg//2)
    
    # Plot
    im = axes[idx].pcolormesh(t_spec, f, 10*np.log10(Sxx + 1e-10), 
                               shading='gouraud', cmap='viridis')
    axes[idx].set_ylim(0, 25)
    axes[idx].set_xlabel('Time (s)', fontsize=11)
    axes[idx].set_ylabel('Frequency (Hz)', fontsize=11)
    axes[idx].set_title(title, fontsize=11, fontweight='bold')
    
    # Mark true frequencies
    axes[idx].axhline(y=5, color='white', linestyle='--', alpha=0.5)
    axes[idx].axhline(y=10, color='white', linestyle='--', alpha=0.5)
    axes[idx].axhline(y=15, color='white', linestyle='--', alpha=0.5)
    
    # Mark transitions
    axes[idx].axvline(x=3, color='red', linestyle='--', alpha=0.5)
    axes[idx].axvline(x=6, color='red', linestyle='--', alpha=0.5)
    
    plt.colorbar(im, ax=axes[idx], label='Power (dB)')

plt.suptitle('Visualization 2: The Time-Frequency Trade-off', 
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("Observations:")
print("- Short window (left): Transitions are sharp, but frequencies blur vertically")
print("- Medium window (center): Reasonable balance")
print("- Long window (right): Frequencies are distinct, but transitions are smeared")
print("\n‚Üí This is the Heisenberg uncertainty principle in action!")

---

## 3. Short-Time Fourier Transform (STFT)

The **Short-Time Fourier Transform (STFT)** formalizes the windowed FFT approach:

1. Apply a **window function** (Hann, Hamming) to a segment of the signal
2. Compute FFT of the windowed segment
3. Slide the window forward and repeat
4. Result: 2D matrix of complex coefficients (frequency √ó time)

**Key parameters**:
- `nperseg`: Window length in samples ‚Üí determines frequency resolution
- `noverlap`: Overlap between windows ‚Üí determines time sampling density
- `window`: Window type (Hann reduces spectral leakage)

The **spectrogram** is the squared magnitude of STFT: $|\text{STFT}(t, f)|^2$

In [None]:
# ============================================================================
# FUNCTION 1: compute_stft
# ============================================================================

def compute_stft(
    signal: NDArray[np.floating],
    fs: float,
    nperseg: int = 256,
    noverlap: Optional[int] = None,
    window: str = "hann"
) -> Tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.complexfloating]]:
    """
    Compute Short-Time Fourier Transform of a signal.
    
    Parameters
    ----------
    signal : NDArray[np.floating]
        Input signal.
    fs : float
        Sampling frequency in Hz.
    nperseg : int, optional
        Window length in samples (default: 256).
    noverlap : int, optional
        Overlap between windows (default: nperseg // 2).
    window : str, optional
        Window type (default: 'hann').
    
    Returns
    -------
    frequencies : NDArray[np.floating]
        Frequency values in Hz.
    times : NDArray[np.floating]
        Time values in seconds.
    stft_matrix : NDArray[np.complexfloating]
        Complex STFT matrix (frequency √ó time).
    """
    if noverlap is None:
        noverlap = nperseg // 2
    
    frequencies, times, stft_matrix = stft(signal, fs=fs, nperseg=nperseg,
                                            noverlap=noverlap, window=window)
    
    return frequencies, times, stft_matrix

In [None]:
# ============================================================================
# FUNCTION 2: compute_spectrogram
# ============================================================================

def compute_spectrogram(
    signal: NDArray[np.floating],
    fs: float,
    nperseg: int = 256,
    noverlap: Optional[int] = None,
    window: str = "hann"
) -> Tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]:
    """
    Compute power spectrogram (squared magnitude of STFT).
    
    Parameters
    ----------
    signal : NDArray[np.floating]
        Input signal.
    fs : float
        Sampling frequency in Hz.
    nperseg : int, optional
        Window length in samples (default: 256).
    noverlap : int, optional
        Overlap between windows (default: nperseg // 2).
    window : str, optional
        Window type (default: 'hann').
    
    Returns
    -------
    frequencies : NDArray[np.floating]
        Frequency values in Hz.
    times : NDArray[np.floating]
        Time values in seconds.
    power : NDArray[np.floating]
        Power spectrogram (frequency √ó time).
    """
    frequencies, times, stft_matrix = compute_stft(signal, fs, nperseg, 
                                                    noverlap, window)
    power = np.abs(stft_matrix) ** 2
    
    return frequencies, times, power

In [None]:
# ============================================================================
# VISUALIZATION 3: STFT Spectrogram
# ============================================================================

# Compute spectrogram with balanced parameters
nperseg_stft = int(1.0 * fs)  # 1-second window
freqs_stft, times_stft, power_stft = compute_spectrogram(signal_nonstat, fs, 
                                                          nperseg=nperseg_stft)

# Create figure
fig, axes = plt.subplots(2, 1, figsize=(14, 9), gridspec_kw={'height_ratios': [1, 2]})

# Top: Time domain signal
axes[0].plot(t, signal_nonstat, color=PRIMARY_BLUE, linewidth=0.8)
axes[0].axvline(x=3, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
axes[0].axvline(x=6, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
axes[0].set_ylabel('Amplitude', fontsize=11)
axes[0].set_title('Non-Stationary Signal', fontsize=12, fontweight='bold')
axes[0].set_xlim(0, 9)
axes[0].grid(True, alpha=0.3)

# Bottom: Spectrogram
im = axes[1].pcolormesh(times_stft, freqs_stft, 10*np.log10(power_stft + 1e-10),
                         shading='gouraud', cmap='viridis')
axes[1].set_xlabel('Time (s)', fontsize=12)
axes[1].set_ylabel('Frequency (Hz)', fontsize=12)
axes[1].set_title('STFT Spectrogram: Now We See BOTH Time AND Frequency!', 
                  fontsize=12, fontweight='bold', color='green')
axes[1].set_ylim(0, 25)

# Mark transitions
axes[1].axvline(x=3, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
axes[1].axvline(x=6, color='red', linestyle='--', linewidth=1.5, alpha=0.7)

# Mark frequencies
axes[1].axhline(y=5, color='white', linestyle=':', alpha=0.5)
axes[1].axhline(y=10, color='white', linestyle=':', alpha=0.5)
axes[1].axhline(y=15, color='white', linestyle=':', alpha=0.5)

# Annotate
axes[1].annotate('5 Hz', xy=(1.5, 5), color='white', fontweight='bold', fontsize=11, ha='center')
axes[1].annotate('15 Hz', xy=(4.5, 15), color='white', fontweight='bold', fontsize=11, ha='center')
axes[1].annotate('10 Hz', xy=(7.5, 10), color='white', fontweight='bold', fontsize=11, ha='center')

cbar = plt.colorbar(im, ax=axes[1])
cbar.set_label('Power (dB)', fontsize=11)

plt.tight_layout()
plt.show()

print("‚úì The spectrogram clearly shows:")
print("  - 5 Hz from 0-3s")
print("  - 15 Hz from 3-6s")
print("  - 10 Hz from 6-9s")
print("‚Üí We now have BOTH time and frequency information!")

J'ai cr√©√© le notebook avec les imports et les sections 1-3. Continuons avec les sections 4-6 ?

In [None]:
# Visualization 4: STFT resolution trade-off comparison
# Same signal analyzed with different window sizes

# Create a chirp signal (frequency increases over time)
duration = 4.0  # Longer signal to accommodate large windows
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
# Chirp from 5 Hz to 40 Hz
chirp = np.sin(2 * np.pi * (5 * t + (40 - 5) / (2 * duration) * t**2))
chirp += np.random.randn(len(t)) * 0.1  # Small noise

# Compare different window sizes (adjusted for fs=256)
window_sizes = [32, 128, 512]  # in samples
window_labels = ['32 samples\n(High time res)', '128 samples\n(Balanced)', '512 samples\n(High freq res)']

fig, axes = plt.subplots(2, 3, figsize=(12, 6))

# Top row: Spectrograms
for idx, (nperseg, label) in enumerate(zip(window_sizes, window_labels)):
    noverlap = nperseg // 2 if nperseg > 2 else 0
    f, t_spec, Sxx = spectrogram(chirp, fs=fs, nperseg=nperseg, noverlap=noverlap)
    
    # Limit to 0-50 Hz
    freq_mask = f <= 50
    
    ax = axes[0, idx]
    im = ax.pcolormesh(t_spec, f[freq_mask], 10 * np.log10(Sxx[freq_mask] + 1e-10),
                       shading='gouraud', cmap='viridis')
    
    # Overlay true frequency trajectory
    true_freq = 5 + (40 - 5) / duration * t_spec
    ax.plot(t_spec, true_freq, color=SECONDARY_ORANGE, linewidth=2, 
            linestyle='--', label='True frequency')
    
    ax.set_ylabel('Frequency (Hz)' if idx == 0 else '')
    ax.set_xlabel('Time (s)')
    ax.set_title(label, fontsize=11)
    ax.set_ylim([0, 50])

# Bottom row: Time-frequency resolution boxes
for idx, (nperseg, label) in enumerate(zip(window_sizes, window_labels)):
    ax = axes[1, idx]
    
    # Calculate resolution
    time_res = nperseg / fs  # seconds
    freq_res = fs / nperseg  # Hz
    
    # Draw resolution boxes at different frequencies
    frequencies = [10, 20, 30]
    times = [1.0, 2.0, 3.0]
    
    for f_center, t_center in zip(frequencies, times):
        # All boxes have same size (STFT limitation)
        rect = plt.Rectangle((t_center - time_res/2, f_center - freq_res/2),
                             time_res, freq_res, 
                             fill=False, edgecolor=PRIMARY_BLUE, linewidth=2)
        ax.add_patch(rect)
    
    ax.set_xlim([0, 4])
    ax.set_ylim([0, 50])
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)' if idx == 0 else '')
    ax.set_title(f'Œît={time_res*1000:.0f}ms, Œîf={freq_res:.1f}Hz', fontsize=10)
    ax.grid(True, alpha=0.3)

fig.suptitle('STFT Resolution Trade-off: Fixed Time-Frequency Boxes', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üìä Key insight: In STFT, all frequencies use the same resolution box!")
print("   This is inefficient: high frequencies need good time resolution,")
print("   while low frequencies need good frequency resolution.")

## 5. Introduction to Wavelets üåä

A **wavelet** is a small wave-like oscillation that:
- Starts at zero
- Increases in amplitude
- Returns to zero
- Has finite duration (localized in time)

Unlike sines and cosines (which extend infinitely), wavelets are **compact**.

### Why Wavelets for EEG?

| Fourier (Sines) | Wavelets |
|-----------------|----------|
| Infinite duration | Finite duration |
| Perfect frequency localization | Good frequency localization |
| No time localization | Good time localization |
| Same resolution everywhere | Adaptive resolution |

### The Multi-Resolution Principle

Wavelets solve the resolution trade-off elegantly:
- **High frequencies** ‚Üí Short wavelets ‚Üí Good time resolution
- **Low frequencies** ‚Üí Long wavelets ‚Üí Good frequency resolution

This matches what we need in EEG:
- Fast gamma bursts need precise timing
- Slow alpha oscillations need precise frequency

In [None]:
# Visualization 5: Gallery of common wavelets
# Show different wavelet families and their properties

fig, axes = plt.subplots(2, 3, figsize=(12, 6))

# Parameters
n_points = 256
t_wavelet = np.linspace(-2, 2, n_points)

# Row 1: Different wavelet types
# Morlet wavelet (complex - we show real part)
def create_morlet(n_points: int, width: float, w: float = 5.0) -> NDArray:
    """Create a complex Morlet wavelet."""
    t = np.linspace(-width, width, n_points)
    gaussian = np.exp(-t**2 / 2)
    oscillation = np.exp(1j * w * t)
    return gaussian * oscillation

morlet_wav = create_morlet(n_points, 4, w=5.0)

axes[0, 0].plot(t_wavelet, np.real(morlet_wav), color=PRIMARY_BLUE, linewidth=2, label='Real')
axes[0, 0].plot(t_wavelet, np.imag(morlet_wav), color=PRIMARY_RED, linewidth=2, alpha=0.7, label='Imaginary')
axes[0, 0].plot(t_wavelet, np.abs(morlet_wav), color=PRIMARY_GREEN, linewidth=2, linestyle='--', label='Envelope')
axes[0, 0].set_title('Morlet Wavelet', fontsize=11, fontweight='bold')
axes[0, 0].legend(loc='upper right', fontsize=8)
axes[0, 0].set_xlabel('Time (a.u.)')

# Mexican hat (Ricker) wavelet - manual implementation
def create_ricker(n_points: int, sigma: float) -> NDArray:
    """Create a Ricker (Mexican hat) wavelet."""
    t = np.linspace(-4, 4, n_points)
    A = 2 / (np.sqrt(3 * sigma) * np.pi**0.25)
    return A * (1 - (t/sigma)**2) * np.exp(-t**2 / (2 * sigma**2))

ricker_wav = create_ricker(n_points, 1.0)
axes[0, 1].plot(t_wavelet, ricker_wav, color=SECONDARY_PURPLE, linewidth=2)
axes[0, 1].set_title('Mexican Hat (Ricker)', fontsize=11, fontweight='bold')
axes[0, 1].set_xlabel('Time (a.u.)')

# Haar wavelet (simple step function)
haar = np.zeros(n_points)
haar[n_points//4:n_points//2] = 1
haar[n_points//2:3*n_points//4] = -1
axes[0, 2].plot(t_wavelet, haar, color=SECONDARY_ORANGE, linewidth=2)
axes[0, 2].set_title('Haar Wavelet', fontsize=11, fontweight='bold')
axes[0, 2].set_xlabel('Time (a.u.)')

# Row 2: Morlet at different frequencies (scales)
frequencies = [5, 15, 30]  # Hz
colors = [PRIMARY_BLUE, PRIMARY_GREEN, PRIMARY_RED]

for idx, (freq, color) in enumerate(zip(frequencies, colors)):
    # Scale wavelet duration inversely with frequency
    # Higher frequency = shorter wavelet
    n_cycles = 5
    wavelet_duration = n_cycles / freq
    t_wav = np.linspace(-wavelet_duration, wavelet_duration, n_points)
    
    # Create Morlet-like wavelet
    gaussian_env = np.exp(-t_wav**2 * freq**2 / (2 * n_cycles**2))
    wavelet = gaussian_env * np.cos(2 * np.pi * freq * t_wav)
    
    axes[1, idx].plot(t_wav, wavelet, color=color, linewidth=2)
    axes[1, idx].fill_between(t_wav, -gaussian_env, gaussian_env, alpha=0.2, color=color)
    axes[1, idx].set_title(f'{freq} Hz Morlet\n(duration: {wavelet_duration*1000:.0f} ms)', fontsize=10)
    axes[1, idx].set_xlabel('Time (s)')
    axes[1, idx].set_xlim([-0.5, 0.5])

# Common formatting
for ax in axes.flat:
    ax.set_ylabel('Amplitude')
    ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
    ax.grid(True, alpha=0.3)

fig.suptitle('Wavelet Gallery: Different Types and Multi-Resolution Property', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üåä Key insight: Higher frequency wavelets are SHORTER!")
print("   This is the multi-resolution property that makes wavelets ideal for EEG.")

## 6. The Morlet Wavelet: Our Tool of Choice üéØ

For EEG analysis, the **complex Morlet wavelet** is the standard. It consists of:

1. **A cosine wave** at the target frequency (real part)
2. **A sine wave** at the target frequency (imaginary part)  
3. **A Gaussian envelope** that tapers the oscillation

$$\psi(t, f) = A \cdot e^{-\frac{t^2}{2\sigma_t^2}} \cdot e^{i 2\pi f t}$$

Where:
- $A$ is a normalization constant
- $\sigma_t$ is the temporal standard deviation (controls width)
- $f$ is the center frequency
- $i = \sqrt{-1}$ (imaginary unit)

### The n_cycles Parameter

The width of the Gaussian envelope is often expressed in **number of cycles**:

$$\sigma_t = \frac{n_{cycles}}{2\pi f}$$

- **More cycles** ‚Üí Better frequency resolution, worse time resolution
- **Fewer cycles** ‚Üí Better time resolution, worse frequency resolution

Typical values: 3-7 cycles (5-7 common for EEG)

In [None]:
# Function 3: Create Morlet wavelet

def create_morlet_wavelet(
    frequency: float,
    fs: float,
    n_cycles: float = 5.0,
    return_time: bool = False
) -> Union[NDArray[np.complex128], Tuple[NDArray[np.complex128], NDArray[np.float64]]]:
    """
    Create a complex Morlet wavelet for a given frequency.
    
    The Morlet wavelet is a complex exponential modulated by a Gaussian
    envelope, making it ideal for time-frequency analysis.
    
    Parameters
    ----------
    frequency : float
        Center frequency of the wavelet in Hz.
    fs : float
        Sampling frequency in Hz.
    n_cycles : float, optional
        Number of cycles in the wavelet. Controls the trade-off between
        time and frequency resolution. Default is 5.0.
    return_time : bool, optional
        If True, also return the time vector. Default is False.
        
    Returns
    -------
    wavelet : ndarray of complex128
        Complex Morlet wavelet, normalized to unit energy.
    time : ndarray of float64, optional
        Time vector in seconds (only if return_time=True).
        
    Notes
    -----
    The wavelet duration is set to 4 * sigma_t on each side, where
    sigma_t = n_cycles / (2 * pi * frequency).
    
    Examples
    --------
    >>> wavelet = create_morlet_wavelet(10, 256, n_cycles=5)
    >>> len(wavelet)  # Depends on frequency and sampling rate
    """
    # Calculate temporal standard deviation
    sigma_t = n_cycles / (2 * np.pi * frequency)
    
    # Wavelet duration: 4 sigma on each side captures >99.99% of energy
    wavelet_duration = 4 * sigma_t
    
    # Create time vector
    n_samples = int(2 * wavelet_duration * fs) + 1
    time = np.linspace(-wavelet_duration, wavelet_duration, n_samples)
    
    # Create Gaussian envelope
    gaussian = np.exp(-time**2 / (2 * sigma_t**2))
    
    # Create complex sinusoid
    sinusoid = np.exp(2j * np.pi * frequency * time)
    
    # Combine to form Morlet wavelet
    wavelet = gaussian * sinusoid
    
    # Normalize to unit energy
    wavelet = wavelet / np.sqrt(np.sum(np.abs(wavelet)**2))
    
    if return_time:
        return wavelet, time
    return wavelet


# Test the function
test_wavelet, test_time = create_morlet_wavelet(10, 256, n_cycles=5, return_time=True)
print(f"‚úì Created Morlet wavelet at 10 Hz")
print(f"  - Length: {len(test_wavelet)} samples")
print(f"  - Duration: {test_time[-1] - test_time[0]:.3f} s")
print(f"  - Energy: {np.sum(np.abs(test_wavelet)**2):.4f} (should be ~1.0)")

In [None]:
# Visualization 6: Morlet wavelet components

wavelet, time = create_morlet_wavelet(10, 256, n_cycles=5, return_time=True)

fig, axes = plt.subplots(2, 2, figsize=(12, 6))

# Real part (cosine)
axes[0, 0].plot(time, np.real(wavelet), color=PRIMARY_BLUE, linewidth=2)
axes[0, 0].fill_between(time, 0, np.real(wavelet), alpha=0.3, color=PRIMARY_BLUE)
axes[0, 0].set_title('Real Part (Cosine Component)', fontsize=11)
axes[0, 0].set_ylabel('Amplitude')
axes[0, 0].axhline(y=0, color='gray', linestyle='-', alpha=0.3)

# Imaginary part (sine)
axes[0, 1].plot(time, np.imag(wavelet), color=PRIMARY_RED, linewidth=2)
axes[0, 1].fill_between(time, 0, np.imag(wavelet), alpha=0.3, color=PRIMARY_RED)
axes[0, 1].set_title('Imaginary Part (Sine Component)', fontsize=11)
axes[0, 1].axhline(y=0, color='gray', linestyle='-', alpha=0.3)

# Magnitude (envelope)
axes[1, 0].plot(time, np.abs(wavelet), color=PRIMARY_GREEN, linewidth=2)
axes[1, 0].fill_between(time, 0, np.abs(wavelet), alpha=0.3, color=PRIMARY_GREEN)
axes[1, 0].set_title('Magnitude (Gaussian Envelope)', fontsize=11)
axes[1, 0].set_xlabel('Time (s)')
axes[1, 0].set_ylabel('Amplitude')

# Phase
phase = np.angle(wavelet)
axes[1, 1].plot(time, phase, color=SECONDARY_PURPLE, linewidth=2)
axes[1, 1].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
axes[1, 1].set_title('Phase (Linear at Center)', fontsize=11)
axes[1, 1].set_xlabel('Time (s)')
axes[1, 1].set_ylabel('Phase (radians)')
axes[1, 1].set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
axes[1, 1].set_yticklabels(['-œÄ', '-œÄ/2', '0', 'œÄ/2', 'œÄ'])

for ax in axes.flat:
    ax.grid(True, alpha=0.3)

fig.suptitle('Anatomy of a Complex Morlet Wavelet (10 Hz, 5 cycles)', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üéØ The complex Morlet gives us:")
print("   - Amplitude via |wavelet| (the Gaussian envelope)")
print("   - Phase via angle(wavelet) (linear phase = constant frequency)")

In [None]:
# Visualization 7: Effect of n_cycles on wavelet properties

fig, axes = plt.subplots(2, 3, figsize=(12, 6))

n_cycles_values = [3, 5, 7]
frequency = 10  # Hz

# Top row: Wavelets in time domain
for idx, n_cycles in enumerate(n_cycles_values):
    wavelet, time = create_morlet_wavelet(frequency, 256, n_cycles=n_cycles, return_time=True)
    
    ax = axes[0, idx]
    ax.plot(time, np.real(wavelet), color=PRIMARY_BLUE, linewidth=2, label='Real')
    ax.plot(time, np.abs(wavelet), color=PRIMARY_GREEN, linewidth=2, linestyle='--', label='Envelope')
    
    # Mark wavelet duration
    sigma_t = n_cycles / (2 * np.pi * frequency)
    ax.axvline(x=-2*sigma_t, color='gray', linestyle=':', alpha=0.7)
    ax.axvline(x=2*sigma_t, color='gray', linestyle=':', alpha=0.7)
    
    ax.set_title(f'n_cycles = {n_cycles}\n(œÉ_t = {sigma_t*1000:.1f} ms)', fontsize=10)
    ax.set_xlabel('Time (s)')
    if idx == 0:
        ax.set_ylabel('Amplitude')
        ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.set_xlim([-0.4, 0.4])

# Bottom row: Frequency response (FFT of wavelet)
for idx, n_cycles in enumerate(n_cycles_values):
    wavelet, time = create_morlet_wavelet(frequency, 256, n_cycles=n_cycles, return_time=True)
    
    # Compute FFT
    n_fft = 1024
    fft_wavelet = np.fft.fft(wavelet, n=n_fft)
    freqs = np.fft.fftfreq(n_fft, 1/256)
    
    # Keep positive frequencies
    pos_mask = freqs >= 0
    freqs_pos = freqs[pos_mask]
    power = np.abs(fft_wavelet[pos_mask])**2
    power = power / power.max()  # Normalize
    
    ax = axes[1, idx]
    ax.plot(freqs_pos, power, color=SECONDARY_PURPLE, linewidth=2)
    ax.fill_between(freqs_pos, 0, power, alpha=0.3, color=SECONDARY_PURPLE)
    ax.axvline(x=frequency, color='gray', linestyle='--', alpha=0.7)
    
    # Calculate frequency resolution (FWHM)
    half_max = 0.5
    above_half = freqs_pos[power > half_max]
    if len(above_half) > 0:
        fwhm = above_half[-1] - above_half[0]
        ax.set_title(f'Freq. resolution: FWHM ‚âà {fwhm:.1f} Hz', fontsize=10)
    
    ax.set_xlabel('Frequency (Hz)')
    if idx == 0:
        ax.set_ylabel('Power (normalized)')
    ax.set_xlim([0, 30])
    ax.grid(True, alpha=0.3)

fig.suptitle(f'Effect of n_cycles on Time-Frequency Resolution (f = {frequency} Hz)', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üìä Trade-off summary:")
print("   - Low n_cycles (3): Short wavelet, poor frequency resolution")
print("   - High n_cycles (7): Long wavelet, excellent frequency resolution")
print("   - Common choice: 5-7 cycles for EEG analysis")

## 7. Wavelet Convolution: How It Works ‚öôÔ∏è

To extract time-frequency information, we **convolve** the signal with each wavelet:

$$W(t, f) = s(t) * \psi^*(t, f)$$

Where:
- $s(t)$ is our signal
- $\psi^*(t, f)$ is the complex conjugate of the Morlet wavelet at frequency $f$
- $*$ denotes convolution

### The Convolution Process

1. **Slide** the wavelet along the signal
2. At each time point, **multiply** signal √ó wavelet
3. **Sum** the products ‚Üí This gives the wavelet coefficient

The result $W(t, f)$ is a **complex number** at each time-frequency point:
- **Magnitude** $|W(t, f)|$ ‚Üí Power at that time-frequency
- **Phase** $\angle W(t, f)$ ‚Üí Phase at that time-frequency

### Efficient Implementation

Instead of sliding and multiplying (slow), we use the **convolution theorem**:

$$\mathcal{F}\{s * \psi\} = \mathcal{F}\{s\} \cdot \mathcal{F}\{\psi\}$$

Convolution in time = Multiplication in frequency ‚Üí Much faster!

In [None]:
# Function 4: Wavelet convolution (single frequency)

def wavelet_convolution(
    signal: NDArray[np.float64],
    wavelet: NDArray[np.complex128],
    mode: str = 'same'
) -> NDArray[np.complex128]:
    """
    Convolve a signal with a complex wavelet using FFT for efficiency.
    
    Parameters
    ----------
    signal : ndarray of float64
        Input signal (1D array).
    wavelet : ndarray of complex128
        Complex wavelet (e.g., Morlet wavelet).
    mode : str, optional
        Convolution mode. 'same' returns output with same length as signal.
        Default is 'same'.
        
    Returns
    -------
    result : ndarray of complex128
        Complex-valued convolution result. The magnitude gives power,
        and the angle gives instantaneous phase.
        
    Notes
    -----
    Uses FFT-based convolution (convolution theorem) for efficiency:
    conv(s, w) = ifft(fft(s) * fft(w))
    
    Examples
    --------
    >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 256))
    >>> wavelet = create_morlet_wavelet(10, 256, n_cycles=5)
    >>> result = wavelet_convolution(signal, wavelet)
    >>> power = np.abs(result) ** 2
    """
    # Determine FFT size (next power of 2 for efficiency)
    n_signal = len(signal)
    n_wavelet = len(wavelet)
    n_conv = n_signal + n_wavelet - 1
    n_fft = int(2 ** np.ceil(np.log2(n_conv)))
    
    # FFT of signal and wavelet
    signal_fft = np.fft.fft(signal, n=n_fft)
    wavelet_fft = np.fft.fft(wavelet, n=n_fft)
    
    # Multiply in frequency domain (convolution theorem)
    result_fft = signal_fft * wavelet_fft
    
    # Inverse FFT
    result = np.fft.ifft(result_fft)
    
    # Trim to match 'same' mode
    if mode == 'same':
        # Remove half the wavelet length from each end
        start = (n_wavelet - 1) // 2
        result = result[start:start + n_signal]
    
    return result


# Test the function
test_signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 256))
test_wavelet = create_morlet_wavelet(10, 256, n_cycles=5)
test_result = wavelet_convolution(test_signal, test_wavelet)

print(f"‚úì Wavelet convolution completed")
print(f"  - Input signal length: {len(test_signal)}")
print(f"  - Wavelet length: {len(test_wavelet)}")
print(f"  - Output length: {len(test_result)} (same as input)")

In [None]:
# Visualization 8: Wavelet convolution step by step

# Create signal with a 10 Hz burst in the middle
duration = 2.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
signal = np.zeros_like(t)

# Add burst between 0.5-1.5s
burst_mask = (t >= 0.5) & (t <= 1.5)
signal[burst_mask] = np.sin(2 * np.pi * 10 * t[burst_mask])
signal += np.random.randn(len(t)) * 0.1  # Small noise

# Create wavelet at 10 Hz
wavelet, wavelet_time = create_morlet_wavelet(10, fs, n_cycles=5, return_time=True)

# Perform convolution
result = wavelet_convolution(signal, wavelet)

fig, axes = plt.subplots(4, 1, figsize=(12, 8), sharex=True)

# Original signal
axes[0].plot(t, signal, color=PRIMARY_BLUE, linewidth=1)
axes[0].axvspan(0.5, 1.5, alpha=0.2, color=SECONDARY_ORANGE, label='10 Hz burst')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Original Signal (10 Hz burst from 0.5-1.5s)', fontsize=11)
axes[0].legend(loc='upper right')

# Real part of convolution result
axes[1].plot(t, np.real(result), color=PRIMARY_BLUE, linewidth=1)
axes[1].set_ylabel('Real part')
axes[1].set_title('Real Part of Wavelet Convolution', fontsize=11)

# Imaginary part of convolution result
axes[2].plot(t, np.imag(result), color=PRIMARY_RED, linewidth=1)
axes[2].set_ylabel('Imaginary part')
axes[2].set_title('Imaginary Part of Wavelet Convolution', fontsize=11)

# Power (magnitude squared)
power = np.abs(result) ** 2
axes[3].plot(t, power, color=PRIMARY_GREEN, linewidth=2)
axes[3].fill_between(t, 0, power, alpha=0.3, color=PRIMARY_GREEN)
axes[3].axvspan(0.5, 1.5, alpha=0.2, color=SECONDARY_ORANGE)
axes[3].set_ylabel('Power')
axes[3].set_xlabel('Time (s)')
axes[3].set_title('Power = |convolution|¬≤ (Detects the 10 Hz burst!)', fontsize=11)

for ax in axes:
    ax.grid(True, alpha=0.3)

fig.suptitle('Wavelet Convolution: From Signal to Time-Frequency Power', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üéØ The power trace correctly identifies WHEN the 10 Hz activity is present!")
print("   This is the key advantage of wavelet analysis over standard FFT.")

## 8. Full Wavelet Transform: Multiple Frequencies üåà

To get a complete time-frequency representation, we:
1. Create wavelets for each frequency of interest
2. Convolve the signal with each wavelet
3. Stack the results into a 2D matrix (frequency √ó time)

This gives us a **time-frequency map** (scalogram) similar to STFT's spectrogram, 
but with multi-resolution properties.

In [None]:
# Function 5: Full wavelet transform

def compute_wavelet_transform(
    signal: NDArray[np.float64],
    frequencies: NDArray[np.float64],
    fs: float,
    n_cycles: Union[float, NDArray[np.float64]] = 5.0
) -> NDArray[np.complex128]:
    """
    Compute the continuous wavelet transform using complex Morlet wavelets.
    
    Parameters
    ----------
    signal : ndarray of float64
        Input signal (1D array).
    frequencies : ndarray of float64
        Array of frequencies to analyze (in Hz).
    fs : float
        Sampling frequency in Hz.
    n_cycles : float or ndarray, optional
        Number of cycles for the wavelets. Can be a single value or
        an array with one value per frequency. Default is 5.0.
        
    Returns
    -------
    tfr : ndarray of complex128
        Complex time-frequency representation with shape (n_frequencies, n_times).
        
    Notes
    -----
    The magnitude squared of the result gives power, and the angle gives phase.
    
    Examples
    --------
    >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 256))
    >>> frequencies = np.arange(5, 40, 1)
    >>> tfr = compute_wavelet_transform(signal, frequencies, 256)
    >>> power = np.abs(tfr) ** 2
    """
    n_times = len(signal)
    n_freqs = len(frequencies)
    
    # Handle n_cycles (scalar or array)
    if np.isscalar(n_cycles):
        n_cycles_array = np.full(n_freqs, n_cycles)
    else:
        n_cycles_array = np.asarray(n_cycles)
    
    # Initialize output
    tfr = np.zeros((n_freqs, n_times), dtype=np.complex128)
    
    # Convolve with each wavelet
    for idx, (freq, nc) in enumerate(zip(frequencies, n_cycles_array)):
        wavelet = create_morlet_wavelet(freq, fs, n_cycles=nc)
        tfr[idx, :] = wavelet_convolution(signal, wavelet)
    
    return tfr


# Function 6: Compute wavelet power

def compute_wavelet_power(
    signal: NDArray[np.float64],
    frequencies: NDArray[np.float64],
    fs: float,
    n_cycles: Union[float, NDArray[np.float64]] = 5.0,
    baseline: Optional[Tuple[float, float]] = None,
    baseline_mode: str = 'ratio'
) -> NDArray[np.float64]:
    """
    Compute time-frequency power using wavelet transform.
    
    Parameters
    ----------
    signal : ndarray of float64
        Input signal (1D array).
    frequencies : ndarray of float64
        Array of frequencies to analyze (in Hz).
    fs : float
        Sampling frequency in Hz.
    n_cycles : float or ndarray, optional
        Number of cycles for the wavelets. Default is 5.0.
    baseline : tuple of float, optional
        Baseline period as (start, end) in seconds. If provided, power is
        normalized relative to this baseline.
    baseline_mode : str, optional
        How to normalize: 'ratio' (divide by baseline), 'zscore', or 'percent'.
        Default is 'ratio'.
        
    Returns
    -------
    power : ndarray of float64
        Time-frequency power with shape (n_frequencies, n_times).
        
    Examples
    --------
    >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 256))
    >>> freqs = np.arange(5, 40, 1)
    >>> power = compute_wavelet_power(signal, freqs, 256)
    """
    # Compute wavelet transform
    tfr = compute_wavelet_transform(signal, frequencies, fs, n_cycles)
    
    # Get power (magnitude squared)
    power = np.abs(tfr) ** 2
    
    # Apply baseline normalization if requested
    if baseline is not None:
        times = np.arange(len(signal)) / fs
        baseline_mask = (times >= baseline[0]) & (times <= baseline[1])
        baseline_power = power[:, baseline_mask].mean(axis=1, keepdims=True)
        
        if baseline_mode == 'ratio':
            power = power / baseline_power
        elif baseline_mode == 'zscore':
            baseline_std = power[:, baseline_mask].std(axis=1, keepdims=True)
            power = (power - baseline_power) / baseline_std
        elif baseline_mode == 'percent':
            power = (power - baseline_power) / baseline_power * 100
    
    return power


# Test the functions
test_signal = np.sin(2 * np.pi * 10 * np.linspace(0, 2, 512))
test_freqs = np.arange(5, 30, 1)
test_tfr = compute_wavelet_transform(test_signal, test_freqs, 256)
test_power = compute_wavelet_power(test_signal, test_freqs, 256)

print(f"‚úì Wavelet transform computed")
print(f"  - Signal length: {len(test_signal)} samples")
print(f"  - Frequencies: {len(test_freqs)} ({test_freqs[0]}-{test_freqs[-1]} Hz)")
print(f"  - TFR shape: {test_tfr.shape} (frequencies √ó times)")

In [None]:
# Function 7: Plot time-frequency representation

def plot_time_frequency(
    power: NDArray[np.float64],
    times: NDArray[np.float64],
    frequencies: NDArray[np.float64],
    ax: Optional[plt.Axes] = None,
    cmap: str = 'viridis',
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    log_scale: bool = False,
    colorbar: bool = True,
    title: str = ''
) -> plt.Axes:
    """
    Plot a time-frequency power representation.
    
    Parameters
    ----------
    power : ndarray of float64
        Time-frequency power matrix (n_frequencies √ó n_times).
    times : ndarray of float64
        Time vector in seconds.
    frequencies : ndarray of float64
        Frequency vector in Hz.
    ax : matplotlib Axes, optional
        Axes to plot on. If None, creates new figure.
    cmap : str, optional
        Colormap name. Default is 'viridis'.
    vmin, vmax : float, optional
        Color scale limits.
    log_scale : bool, optional
        If True, apply log10 to power. Default is False.
    colorbar : bool, optional
        If True, add colorbar. Default is True.
    title : str, optional
        Plot title.
        
    Returns
    -------
    ax : matplotlib Axes
        The axes with the plot.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 5))
    
    # Apply log scale if requested
    plot_power = np.log10(power + 1e-10) if log_scale else power
    
    # Create plot
    im = ax.pcolormesh(times, frequencies, plot_power, 
                       shading='gouraud', cmap=cmap, 
                       vmin=vmin, vmax=vmax)
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    ax.set_title(title)
    
    if colorbar:
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Power (log10)' if log_scale else 'Power')
    
    return ax


# Visualization 9: STFT vs Wavelet comparison
# Create a signal with multiple frequency bursts

duration = 3.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
signal_multi = np.zeros_like(t)

# Add different frequency bursts at different times
# 8 Hz (alpha) burst at 0.5-1.0s
mask1 = (t >= 0.5) & (t <= 1.0)
signal_multi[mask1] += np.sin(2 * np.pi * 8 * t[mask1]) * 2

# 20 Hz (beta) burst at 1.0-1.5s  
mask2 = (t >= 1.0) & (t <= 1.5)
signal_multi[mask2] += np.sin(2 * np.pi * 20 * t[mask2]) * 1.5

# 35 Hz (gamma) burst at 1.5-2.0s
mask3 = (t >= 1.5) & (t <= 2.0)
signal_multi[mask3] += np.sin(2 * np.pi * 35 * t[mask3])

# Add noise
signal_multi += np.random.randn(len(t)) * 0.2

# Compute STFT
frequencies_stft = np.arange(1, 50, 0.5)
f_stft, t_stft, Sxx = spectrogram(signal_multi, fs=fs, nperseg=256, noverlap=192)

# Compute wavelet
frequencies_wav = np.arange(2, 50, 0.5)
power_wav = compute_wavelet_power(signal_multi, frequencies_wav, fs, n_cycles=5)
times_wav = np.arange(len(signal_multi)) / fs

fig, axes = plt.subplots(3, 1, figsize=(12, 8))

# Original signal
axes[0].plot(t, signal_multi, color=PRIMARY_BLUE, linewidth=0.8)
axes[0].axvspan(0.5, 1.0, alpha=0.2, color=PRIMARY_RED, label='8 Hz')
axes[0].axvspan(1.0, 1.5, alpha=0.2, color=PRIMARY_GREEN, label='20 Hz')
axes[0].axvspan(1.5, 2.0, alpha=0.2, color=SECONDARY_PURPLE, label='35 Hz')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Signal with Three Frequency Bursts', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].set_xlim([0, duration])

# STFT spectrogram
freq_mask = f_stft <= 50
im1 = axes[1].pcolormesh(t_stft, f_stft[freq_mask], 
                         10 * np.log10(Sxx[freq_mask] + 1e-10),
                         shading='gouraud', cmap='viridis')
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_title('STFT Spectrogram (fixed resolution)', fontsize=11)
plt.colorbar(im1, ax=axes[1], label='Power (dB)')

# Wavelet scalogram
im2 = axes[2].pcolormesh(times_wav, frequencies_wav, 
                         10 * np.log10(power_wav + 1e-10),
                         shading='gouraud', cmap='viridis')
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Frequency (Hz)')
axes[2].set_title('Wavelet Scalogram (multi-resolution)', fontsize=11)
plt.colorbar(im2, ax=axes[2], label='Power (dB)')

for ax in axes[1:]:
    ax.set_xlim([0, duration])
    # Mark true burst times
    for time_start, time_end, freq in [(0.5, 1.0, 8), (1.0, 1.5, 20), (1.5, 2.0, 35)]:
        ax.axhline(y=freq, color='white', linestyle='--', alpha=0.5, linewidth=1)

fig.suptitle('STFT vs Wavelet: Multi-Resolution Advantage', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üìä Compare the time-frequency resolution:")
print("   - STFT: Same time resolution at all frequencies")
print("   - Wavelet: Better time resolution at high frequencies, better frequency resolution at low")

## 9. Adaptive n_cycles: Frequency-Dependent Resolution üîß

While using a fixed `n_cycles` works, we can optimize by **adapting n_cycles to frequency**:

- **Low frequencies**: Need more cycles for better frequency resolution
- **High frequencies**: Can use fewer cycles for better time resolution

Common approaches:

1. **Linear scaling**: `n_cycles = freq / 2` (e.g., 5 cycles at 10 Hz, 20 cycles at 40 Hz)
2. **Logarithmic scaling**: `n_cycles = log(freq) * k`
3. **Bounded linear**: `n_cycles = max(min_cycles, min(freq / 2, max_cycles))`

The goal is to maintain consistent time-frequency uncertainty across the spectrum.

In [None]:
# Function 8: Compute adaptive n_cycles

def compute_adaptive_cycles(
    frequencies: NDArray[np.float64],
    min_cycles: float = 3.0,
    max_cycles: float = 10.0,
    scaling: str = 'linear'
) -> NDArray[np.float64]:
    """
    Compute frequency-adaptive number of cycles for wavelet analysis.
    
    Parameters
    ----------
    frequencies : ndarray of float64
        Array of frequencies in Hz.
    min_cycles : float, optional
        Minimum number of cycles. Default is 3.0.
    max_cycles : float, optional
        Maximum number of cycles. Default is 10.0.
    scaling : str, optional
        Scaling method: 'linear' or 'log'. Default is 'linear'.
        
    Returns
    -------
    n_cycles : ndarray of float64
        Array of n_cycles values, one per frequency.
        
    Notes
    -----
    Linear scaling: n_cycles = freq / 2, bounded by min/max.
    Log scaling: n_cycles scales with log2(freq).
    """
    frequencies = np.asarray(frequencies)
    
    if scaling == 'linear':
        n_cycles = frequencies / 2.0
    elif scaling == 'log':
        n_cycles = np.log2(frequencies) * 2
    else:
        raise ValueError(f"Unknown scaling: {scaling}")
    
    # Apply bounds
    n_cycles = np.clip(n_cycles, min_cycles, max_cycles)
    
    return n_cycles


# Visualization 10: Fixed vs adaptive n_cycles comparison
frequencies = np.arange(4, 50, 1)
n_cycles_fixed = np.full_like(frequencies, 5.0, dtype=float)
n_cycles_adaptive = compute_adaptive_cycles(frequencies, min_cycles=3, max_cycles=10)

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Plot n_cycles strategies
axes[0, 0].plot(frequencies, n_cycles_fixed, color=PRIMARY_BLUE, 
                linewidth=2, label='Fixed (5 cycles)')
axes[0, 0].plot(frequencies, n_cycles_adaptive, color=PRIMARY_RED, 
                linewidth=2, label='Adaptive')
axes[0, 0].set_xlabel('Frequency (Hz)')
axes[0, 0].set_ylabel('n_cycles')
axes[0, 0].set_title('n_cycles Strategies', fontsize=11)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Time resolution (sigma_t)
sigma_t_fixed = n_cycles_fixed / (2 * np.pi * frequencies)
sigma_t_adaptive = n_cycles_adaptive / (2 * np.pi * frequencies)

axes[0, 1].plot(frequencies, sigma_t_fixed * 1000, color=PRIMARY_BLUE, 
                linewidth=2, label='Fixed')
axes[0, 1].plot(frequencies, sigma_t_adaptive * 1000, color=PRIMARY_RED, 
                linewidth=2, label='Adaptive')
axes[0, 1].set_xlabel('Frequency (Hz)')
axes[0, 1].set_ylabel('Time resolution œÉ_t (ms)')
axes[0, 1].set_title('Time Resolution vs Frequency', fontsize=11)
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Compute spectrograms with chirp signal
duration = 2.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
chirp = np.sin(2 * np.pi * (5 * t + (45 - 5) / (2 * duration) * t**2))
chirp += np.random.randn(len(t)) * 0.1

# Fixed n_cycles
power_fixed = compute_wavelet_power(chirp, frequencies.astype(float), fs, n_cycles=5.0)
times = np.arange(len(chirp)) / fs

# Adaptive n_cycles
power_adaptive = compute_wavelet_power(chirp, frequencies.astype(float), fs, 
                                        n_cycles=n_cycles_adaptive)

# Plot spectrograms
im1 = axes[1, 0].pcolormesh(times, frequencies, 10 * np.log10(power_fixed + 1e-10),
                            shading='gouraud', cmap='viridis')
axes[1, 0].plot(times, 5 + (45 - 5) / duration * times, color='white', 
                linestyle='--', linewidth=2, label='True frequency')
axes[1, 0].set_xlabel('Time (s)')
axes[1, 0].set_ylabel('Frequency (Hz)')
axes[1, 0].set_title('Fixed n_cycles = 5', fontsize=11)
plt.colorbar(im1, ax=axes[1, 0], label='Power (dB)')

im2 = axes[1, 1].pcolormesh(times, frequencies, 10 * np.log10(power_adaptive + 1e-10),
                            shading='gouraud', cmap='viridis')
axes[1, 1].plot(times, 5 + (45 - 5) / duration * times, color='white', 
                linestyle='--', linewidth=2, label='True frequency')
axes[1, 1].set_xlabel('Time (s)')
axes[1, 1].set_ylabel('Frequency (Hz)')
axes[1, 1].set_title('Adaptive n_cycles (3-10)', fontsize=11)
plt.colorbar(im2, ax=axes[1, 1], label='Power (dB)')

fig.suptitle('Fixed vs Adaptive n_cycles: Effect on Time-Frequency Resolution', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üìä Adaptive n_cycles provides:")
print("   - Better time precision at high frequencies (shorter wavelets)")
print("   - Better frequency precision at low frequencies (longer wavelets)")

## 10. Extracting Phase from Wavelets üìê

One powerful feature of complex Morlet wavelets: they give us **instantaneous phase** at each frequency!

From the wavelet transform result $W(t, f)$:

$$\text{Phase}(t, f) = \arctan\left(\frac{\text{Im}(W)}{\text{Re}(W)}\right) = \angle W(t, f)$$

This is exactly what we need for **phase-based connectivity metrics** like PLV!

### Wavelet Phase vs Hilbert Phase

Both methods extract phase, but:

| Hilbert Transform | Wavelet Transform |
|-------------------|-------------------|
| Single frequency band | Multiple frequencies at once |
| Requires pre-filtering | No pre-filtering needed |
| Fixed time resolution | Adaptive time resolution |
| Faster for single band | Better for time-frequency |

In [None]:
# Function 9: Compute wavelet phase

def compute_wavelet_phase(
    signal: NDArray[np.float64],
    frequencies: NDArray[np.float64],
    fs: float,
    n_cycles: Union[float, NDArray[np.float64]] = 5.0
) -> NDArray[np.float64]:
    """
    Compute instantaneous phase at multiple frequencies using wavelets.
    
    Parameters
    ----------
    signal : ndarray of float64
        Input signal (1D array).
    frequencies : ndarray of float64
        Array of frequencies to analyze (in Hz).
    fs : float
        Sampling frequency in Hz.
    n_cycles : float or ndarray, optional
        Number of cycles for the wavelets. Default is 5.0.
        
    Returns
    -------
    phase : ndarray of float64
        Phase values in radians, shape (n_frequencies, n_times).
        Values are in [-œÄ, œÄ].
    """
    tfr = compute_wavelet_transform(signal, frequencies, fs, n_cycles)
    return np.angle(tfr)


# Visualization 11: Wavelet phase extraction

# Create a 10 Hz sine wave
duration = 1.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
signal_10hz = np.sin(2 * np.pi * 10 * t)

# Extract phase using wavelet at 10 Hz
freqs_phase = np.array([10.0])
wavelet_phase = compute_wavelet_phase(signal_10hz, freqs_phase, fs, n_cycles=5)
wavelet_phase_10hz = wavelet_phase[0, :]  # Take the 10 Hz row

# For comparison: Hilbert phase (from previous notebooks)
from scipy.signal import hilbert
analytic = hilbert(signal_10hz)
hilbert_phase = np.angle(analytic)

# True phase
true_phase = np.mod(2 * np.pi * 10 * t, 2 * np.pi)
true_phase[true_phase > np.pi] -= 2 * np.pi

fig, axes = plt.subplots(3, 1, figsize=(12, 7), sharex=True)

# Signal
axes[0].plot(t, signal_10hz, color=PRIMARY_BLUE, linewidth=2)
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Original 10 Hz Sine Wave', fontsize=11)
axes[0].grid(True, alpha=0.3)

# Phase comparison
axes[1].plot(t, true_phase, color='gray', linewidth=2, label='True phase', alpha=0.7)
axes[1].plot(t, wavelet_phase_10hz, color=PRIMARY_RED, linewidth=2, 
             linestyle='--', label='Wavelet phase')
axes[1].plot(t, hilbert_phase, color=PRIMARY_GREEN, linewidth=2, 
             linestyle=':', label='Hilbert phase')
axes[1].set_ylabel('Phase (rad)')
axes[1].set_title('Phase Comparison: Wavelet vs Hilbert', fontsize=11)
axes[1].set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
axes[1].set_yticklabels(['-œÄ', '-œÄ/2', '0', 'œÄ/2', 'œÄ'])
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

# Phase difference (wavelet - hilbert)
phase_diff = np.angle(np.exp(1j * (wavelet_phase_10hz - hilbert_phase)))
axes[2].plot(t, phase_diff, color=SECONDARY_PURPLE, linewidth=2)
axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Œî Phase (rad)')
axes[2].set_title('Phase Difference (Wavelet - Hilbert)', fontsize=11)
axes[2].set_ylim([-0.5, 0.5])
axes[2].grid(True, alpha=0.3)

fig.suptitle('Wavelet Phase Extraction vs Hilbert Transform', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üìê Both methods give nearly identical phase for narrowband signals!")
print("   The small differences are due to edge effects and method differences.")

## 11. Edge Effects: The Wavelet Challenge ‚ö†Ô∏è

Wavelet analysis has a significant **edge effect problem**:

- Wavelets extend beyond signal boundaries at the start and end
- The wavelet "sees" zeros (or other padding) instead of real data
- This creates **artifacts** in the first and last portions of the result

### How Many Samples Are Affected?

The edge effect extends approximately:
$$N_{edge} = \frac{n_{cycles} \cdot f_s}{f}$$

Where:
- $n_{cycles}$ = number of wavelet cycles
- $f_s$ = sampling frequency
- $f$ = frequency of interest

**Lower frequencies = longer wavelets = more edge effects!**

### Solutions

1. **Exclude edges**: Remove affected samples from analysis
2. **Mirror padding**: Reflect signal at boundaries
3. **Collect extra data**: Record beyond your analysis window

In [None]:
# Function 10: Compute edge samples

def compute_edge_samples(
    frequency: float,
    fs: float,
    n_cycles: float = 5.0,
    n_sigma: float = 3.0
) -> int:
    """
    Compute the number of samples affected by edge effects.
    
    Parameters
    ----------
    frequency : float
        Frequency of the wavelet in Hz.
    fs : float
        Sampling frequency in Hz.
    n_cycles : float, optional
        Number of cycles in the wavelet. Default is 5.0.
    n_sigma : float, optional
        Number of sigma (standard deviations) to consider. Default is 3.0.
        
    Returns
    -------
    n_edge : int
        Number of samples affected by edge effects on each side.
        
    Notes
    -----
    The wavelet extends n_sigma * sigma_t on each side, where
    sigma_t = n_cycles / (2 * pi * frequency).
    """
    sigma_t = n_cycles / (2 * np.pi * frequency)
    edge_duration = n_sigma * sigma_t
    n_edge = int(np.ceil(edge_duration * fs))
    return n_edge


# Visualization 12: Edge effects demonstration

# Create a clean signal
duration = 2.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
clean_signal = np.sin(2 * np.pi * 10 * t)  # 10 Hz throughout

# Compute wavelet power at different frequencies
frequencies = np.array([5, 10, 20, 40])
n_cycles = 5.0

fig, axes = plt.subplots(len(frequencies) + 1, 1, figsize=(12, 10), sharex=True)

# Original signal
axes[0].plot(t, clean_signal, color=PRIMARY_BLUE, linewidth=1)
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Original Signal (constant 10 Hz)', fontsize=11)
axes[0].grid(True, alpha=0.3)

# Power at each frequency
for idx, freq in enumerate(frequencies):
    # Compute power
    power = compute_wavelet_power(clean_signal, np.array([freq]), fs, n_cycles=n_cycles)
    power_1d = power[0, :]
    
    # Compute edge samples
    n_edge = compute_edge_samples(freq, fs, n_cycles=n_cycles)
    edge_time = n_edge / fs
    
    ax = axes[idx + 1]
    ax.plot(t, power_1d, color=PRIMARY_GREEN, linewidth=1.5)
    
    # Shade edge regions
    ax.axvspan(0, edge_time, alpha=0.3, color=PRIMARY_RED, label='Edge effects')
    ax.axvspan(duration - edge_time, duration, alpha=0.3, color=PRIMARY_RED)
    
    # Mark valid region
    ax.axvline(x=edge_time, color=PRIMARY_RED, linestyle='--', alpha=0.7)
    ax.axvline(x=duration - edge_time, color=PRIMARY_RED, linestyle='--', alpha=0.7)
    
    ax.set_ylabel('Power')
    ax.set_title(f'{freq} Hz wavelet: {n_edge} edge samples ({edge_time*1000:.0f} ms) per side', 
                 fontsize=10)
    ax.grid(True, alpha=0.3)
    
    if idx == 0:
        ax.legend(loc='upper right')

axes[-1].set_xlabel('Time (s)')

fig.suptitle('Edge Effects: Lower Frequencies Are More Affected', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

# Print edge samples for common EEG bands
print("üìä Edge samples for common EEG frequency bands (n_cycles=5):")
print("   (with 3œÉ criterion)")
bands = {'Delta (2 Hz)': 2, 'Theta (6 Hz)': 6, 'Alpha (10 Hz)': 10, 
         'Beta (20 Hz)': 20, 'Gamma (40 Hz)': 40}
for band, freq in bands.items():
    n_edge = compute_edge_samples(freq, fs, n_cycles=5)
    print(f"   {band}: {n_edge} samples ({n_edge/fs*1000:.0f} ms)")

## 12. Application: Event-Related Time-Frequency Analysis üß†

In EEG research, we often analyze brain responses to **events** (stimuli, actions, etc.).

Time-frequency analysis reveals:
- **Event-Related Synchronization (ERS)**: Power increase at specific frequencies
- **Event-Related Desynchronization (ERD)**: Power decrease

This is far more informative than simple time-domain averaging (ERPs)!

In [None]:
# Visualization 13: Simulated event-related time-frequency

# Simulate an EEG epoch around an event at t=0
np.random.seed(42)

epoch_duration = 2.0  # -1s to +1s around event
t_epoch = np.linspace(-1, 1, int(fs * epoch_duration), endpoint=False)
n_samples = len(t_epoch)

# Create simulated EEG with event-related modulation
eeg_epoch = np.zeros(n_samples)

# Background oscillations (always present)
eeg_epoch += 0.5 * np.sin(2 * np.pi * 10 * t_epoch)  # Alpha
eeg_epoch += 0.3 * np.sin(2 * np.pi * 6 * t_epoch)   # Theta

# Event-related modulations:
# 1. Alpha suppression (ERD) after event (0-0.5s)
alpha_suppression = np.zeros_like(t_epoch)
mask_erd = (t_epoch >= 0) & (t_epoch <= 0.5)
alpha_suppression[mask_erd] = -0.7 * np.sin(2 * np.pi * 10 * t_epoch[mask_erd])
eeg_epoch += alpha_suppression

# 2. Gamma burst (ERS) after event (0.1-0.3s)
gamma_burst = np.zeros_like(t_epoch)
mask_gamma = (t_epoch >= 0.1) & (t_epoch <= 0.3)
gamma_envelope = np.exp(-((t_epoch - 0.2)**2) / (2 * 0.05**2))
gamma_burst = gamma_envelope * np.sin(2 * np.pi * 40 * t_epoch)
eeg_epoch += gamma_burst

# Add pink noise
noise = np.random.randn(n_samples) * 0.3
eeg_epoch += noise

# Compute time-frequency representation
frequencies = np.arange(4, 50, 0.5)
n_cycles_adaptive = compute_adaptive_cycles(frequencies, min_cycles=3, max_cycles=8)
power = compute_wavelet_power(eeg_epoch, frequencies, fs, n_cycles=n_cycles_adaptive)

# Baseline normalize
baseline_mask = (t_epoch >= -0.5) & (t_epoch <= -0.1)
baseline_power = power[:, baseline_mask].mean(axis=1, keepdims=True)
power_normalized = (power - baseline_power) / baseline_power * 100  # Percent change

fig, axes = plt.subplots(3, 1, figsize=(12, 8))

# EEG signal
axes[0].plot(t_epoch, eeg_epoch, color=PRIMARY_BLUE, linewidth=0.8)
axes[0].axvline(x=0, color=PRIMARY_RED, linestyle='--', linewidth=2, label='Event')
axes[0].axvspan(-0.5, -0.1, alpha=0.2, color='gray', label='Baseline')
axes[0].set_ylabel('Amplitude (¬µV)')
axes[0].set_title('Simulated EEG Epoch', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Raw power
im1 = axes[1].pcolormesh(t_epoch, frequencies, 10 * np.log10(power + 1e-10),
                         shading='gouraud', cmap='viridis')
axes[1].axvline(x=0, color='white', linestyle='--', linewidth=2)
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_title('Time-Frequency Power (raw)', fontsize=11)
plt.colorbar(im1, ax=axes[1], label='Power (dB)')

# Baseline-normalized power
im2 = axes[2].pcolormesh(t_epoch, frequencies, power_normalized,
                         shading='gouraud', cmap='RdBu_r', vmin=-100, vmax=100)
axes[2].axvline(x=0, color='black', linestyle='--', linewidth=2)
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Frequency (Hz)')
axes[2].set_title('Event-Related Power Change (% from baseline)', fontsize=11)
cbar = plt.colorbar(im2, ax=axes[2], label='% change')

# Annotate
axes[2].annotate('Gamma ERS', xy=(0.2, 40), fontsize=10, color='white',
                 ha='center', va='center', fontweight='bold')
axes[2].annotate('Alpha ERD', xy=(0.25, 10), fontsize=10, color='white',
                 ha='center', va='center', fontweight='bold')

fig.suptitle('Event-Related Time-Frequency Analysis', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üß† This analysis reveals:")
print("   - Alpha ERD (blue): Suppression of 10 Hz after event")
print("   - Gamma ERS (red): Burst of 40 Hz during processing")

## 13. Preview: Wavelets for Hyperscanning Connectivity üë•

In hyperscanning, wavelets enable powerful **time-resolved connectivity analysis**:

1. **Time-resolved PLV**: Track phase synchronization over time
2. **Time-frequency coherence**: Coherence at each time-frequency point
3. **Cross-frequency coupling**: Phase-amplitude relationships

This is a preview of what we'll explore in depth in Module G!

In [None]:
# Visualization 14: Time-resolved inter-brain synchronization preview

# Simulate two participants' EEG with varying synchronization
np.random.seed(42)

duration = 4.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)

# Participant 1: Base oscillation at 10 Hz
phase1 = 2 * np.pi * 10 * t
eeg1 = np.sin(phase1)

# Participant 2: Initially desynchronized, becomes synchronized during task
# Desync (0-1s): random phase offset
# Transition (1-2s): gradually synchronizing
# Sync (2-4s): same phase as P1

phase_offset = np.zeros_like(t)
phase_offset[t < 1] = np.pi  # Opposite phase
phase_offset[(t >= 1) & (t < 2)] = np.pi * (1 - (t[(t >= 1) & (t < 2)] - 1))  # Gradual sync
# After t=2, offset is 0 (synchronized)

phase2 = phase1 + phase_offset
eeg2 = np.sin(phase2)

# Add noise
eeg1 += np.random.randn(len(t)) * 0.2
eeg2 += np.random.randn(len(t)) * 0.2

# Extract phase using wavelets
freq_target = np.array([10.0])
phase1_wav = compute_wavelet_phase(eeg1, freq_target, fs, n_cycles=5)[0, :]
phase2_wav = compute_wavelet_phase(eeg2, freq_target, fs, n_cycles=5)[0, :]

# Compute instantaneous phase difference
phase_diff = np.angle(np.exp(1j * (phase1_wav - phase2_wav)))

# Compute time-resolved PLV (sliding window)
window_size = int(0.5 * fs)  # 500 ms window
step_size = int(0.05 * fs)  # 50 ms step

plv_times = []
plv_values = []

for start in range(0, len(t) - window_size, step_size):
    end = start + window_size
    window_diff = phase_diff[start:end]
    
    # PLV = magnitude of mean phase difference vector
    plv = np.abs(np.mean(np.exp(1j * window_diff)))
    
    center_time = t[start + window_size // 2]
    plv_times.append(center_time)
    plv_values.append(plv)

plv_times = np.array(plv_times)
plv_values = np.array(plv_values)

fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)

# EEG signals
axes[0].plot(t, eeg1, color=PRIMARY_BLUE, linewidth=0.8, label='Person 1')
axes[0].plot(t, eeg2, color=PRIMARY_RED, linewidth=0.8, alpha=0.7, label='Person 2')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Simulated EEG from Two Participants', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Mark periods
for ax in axes:
    ax.axvspan(0, 1, alpha=0.1, color=PRIMARY_RED, label='Desync' if ax == axes[0] else None)
    ax.axvspan(1, 2, alpha=0.1, color=SECONDARY_ORANGE, label='Transition' if ax == axes[0] else None)
    ax.axvspan(2, 4, alpha=0.1, color=PRIMARY_GREEN, label='Sync' if ax == axes[0] else None)

# Phase of each participant
axes[1].plot(t, phase1_wav, color=PRIMARY_BLUE, linewidth=1, label='P1 phase')
axes[1].plot(t, phase2_wav, color=PRIMARY_RED, linewidth=1, alpha=0.7, label='P2 phase')
axes[1].set_ylabel('Phase (rad)')
axes[1].set_title('Wavelet-Extracted Phase at 10 Hz', fontsize=11)
axes[1].set_yticks([-np.pi, 0, np.pi])
axes[1].set_yticklabels(['-œÄ', '0', 'œÄ'])
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

# Phase difference
axes[2].plot(t, phase_diff, color=SECONDARY_PURPLE, linewidth=0.8)
axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[2].set_ylabel('Phase diff (rad)')
axes[2].set_title('Phase Difference (P1 - P2)', fontsize=11)
axes[2].set_yticks([-np.pi, 0, np.pi])
axes[2].set_yticklabels(['-œÄ', '0', 'œÄ'])
axes[2].grid(True, alpha=0.3)

# Time-resolved PLV
axes[3].plot(plv_times, plv_values, color=PRIMARY_GREEN, linewidth=2)
axes[3].fill_between(plv_times, 0, plv_values, alpha=0.3, color=PRIMARY_GREEN)
axes[3].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
axes[3].set_xlabel('Time (s)')
axes[3].set_ylabel('PLV')
axes[3].set_title('Time-Resolved Phase Locking Value', fontsize=11)
axes[3].set_ylim([0, 1])
axes[3].grid(True, alpha=0.3)

fig.suptitle('Wavelets Enable Time-Resolved Inter-Brain Synchronization Analysis', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("üë• This preview shows how wavelets enable:")
print("   - Real-time tracking of inter-brain synchronization")
print("   - Precise timing of when participants synchronize")
print("   - Foundation for hyperscanning connectivity metrics!")

## 14. Exercises üìù

Now it's your turn to practice! Complete the following exercises to solidify your understanding.

### Exercise 1: STFT Parameter Exploration üîç

Create a signal with a 15 Hz component and analyze it with three different STFT window sizes.
Compare the spectrograms and explain the trade-offs.

In [None]:
# Exercise 1: STFT parameter exploration
# TODO: Create a 2-second signal with 15 Hz oscillation that appears only in second half
# TODO: Compute STFT with nperseg = 64, 256, 512
# TODO: Plot the three spectrograms and compare

# Your code here:
# ---------------

# Create signal
duration_ex1 = 2.0
t_ex1 = np.linspace(0, duration_ex1, int(fs * duration_ex1), endpoint=False)

# 15 Hz appears only from t=1.0 to t=2.0
signal_ex1 = np.zeros_like(t_ex1)
mask_ex1 = t_ex1 >= 1.0
signal_ex1[mask_ex1] = np.sin(2 * np.pi * 15 * t_ex1[mask_ex1])
signal_ex1 += np.random.randn(len(t_ex1)) * 0.1

# TODO: Compute and plot spectrograms...

In [None]:
# Solution Exercise 1

window_sizes_ex1 = [64, 256, 512]

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

for idx, nperseg in enumerate(window_sizes_ex1):
    f_ex1, t_ex1_spec, Sxx_ex1 = spectrogram(signal_ex1, fs=fs, nperseg=nperseg, noverlap=nperseg//2)
    
    freq_mask_ex1 = f_ex1 <= 40
    
    ax = axes[idx]
    im = ax.pcolormesh(t_ex1_spec, f_ex1[freq_mask_ex1], 
                       10 * np.log10(Sxx_ex1[freq_mask_ex1] + 1e-10),
                       shading='gouraud', cmap='viridis')
    
    ax.axhline(y=15, color='white', linestyle='--', alpha=0.7)
    ax.axvline(x=1.0, color='white', linestyle=':', alpha=0.7)
    
    time_res = nperseg / fs
    freq_res = fs / nperseg
    ax.set_title(f'nperseg={nperseg}\nŒît={time_res*1000:.0f}ms, Œîf={freq_res:.1f}Hz', fontsize=10)
    ax.set_xlabel('Time (s)')
    if idx == 0:
        ax.set_ylabel('Frequency (Hz)')

fig.suptitle('Exercise 1: STFT Window Size Comparison', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("‚úÖ Exercise 1 Solution:")
print("   - Small window (64): Good time resolution, blurry frequency (Œîf=4Hz)")
print("   - Medium window (256): Balanced (Œîf=1Hz)")
print("   - Large window (512): Sharp frequency, poor time resolution")

### Exercise 2: Create Your Own Morlet Wavelet üåä

Write code to create a Morlet wavelet at 20 Hz with 7 cycles.
Visualize its real part, imaginary part, and envelope.
Calculate its duration and frequency resolution.

In [None]:
# Exercise 2: Create Morlet wavelet
# TODO: Use create_morlet_wavelet() to create a 20 Hz wavelet with 7 cycles
# TODO: Plot real, imaginary, and envelope
# TODO: Calculate sigma_t and estimate frequency resolution

# Your code here:
# ---------------

freq_ex2 = 20  # Hz
n_cycles_ex2 = 7

# TODO: Create wavelet and plot...

In [None]:
# Solution Exercise 2

wavelet_ex2, time_ex2 = create_morlet_wavelet(freq_ex2, fs, n_cycles=n_cycles_ex2, return_time=True)

# Calculate parameters
sigma_t_ex2 = n_cycles_ex2 / (2 * np.pi * freq_ex2)
sigma_f_ex2 = 1 / (2 * np.pi * sigma_t_ex2)  # Frequency resolution
fwhm_f_ex2 = 2.355 * sigma_f_ex2  # Full width at half maximum

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Real part
axes[0].plot(time_ex2, np.real(wavelet_ex2), color=PRIMARY_BLUE, linewidth=2)
axes[0].fill_between(time_ex2, 0, np.real(wavelet_ex2), alpha=0.3, color=PRIMARY_BLUE)
axes[0].set_title('Real Part (Cosine)', fontsize=11)
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True, alpha=0.3)

# Imaginary part
axes[1].plot(time_ex2, np.imag(wavelet_ex2), color=PRIMARY_RED, linewidth=2)
axes[1].fill_between(time_ex2, 0, np.imag(wavelet_ex2), alpha=0.3, color=PRIMARY_RED)
axes[1].set_title('Imaginary Part (Sine)', fontsize=11)
axes[1].set_xlabel('Time (s)')
axes[1].grid(True, alpha=0.3)

# Envelope
axes[2].plot(time_ex2, np.abs(wavelet_ex2), color=PRIMARY_GREEN, linewidth=2)
axes[2].fill_between(time_ex2, 0, np.abs(wavelet_ex2), alpha=0.3, color=PRIMARY_GREEN)
axes[2].axvline(x=-sigma_t_ex2, color='gray', linestyle='--', alpha=0.7)
axes[2].axvline(x=sigma_t_ex2, color='gray', linestyle='--', alpha=0.7, label=f'¬±œÉ_t')
axes[2].set_title('Envelope (Gaussian)', fontsize=11)
axes[2].set_xlabel('Time (s)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

fig.suptitle(f'Exercise 2: Morlet Wavelet at {freq_ex2} Hz, {n_cycles_ex2} cycles', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"‚úÖ Exercise 2 Solution:")
print(f"   - Wavelet duration: {time_ex2[-1] - time_ex2[0]:.3f} s")
print(f"   - Temporal resolution (œÉ_t): {sigma_t_ex2*1000:.1f} ms")
print(f"   - Frequency resolution (œÉ_f): {sigma_f_ex2:.2f} Hz")
print(f"   - FWHM in frequency: {fwhm_f_ex2:.2f} Hz")

### Exercise 3: Compare Wavelet vs Hilbert Phase üìê

Create a 10 Hz sine wave with a phase jump at the midpoint.
Extract phase using both wavelet transform and Hilbert transform.
Compare how each method handles the phase discontinuity.

In [None]:
# Exercise 3: Wavelet vs Hilbert phase comparison
# TODO: Create a 10 Hz signal with a œÄ/2 phase jump at t=0.5
# TODO: Extract phase using wavelet (at 10 Hz)
# TODO: Extract phase using Hilbert
# TODO: Plot and compare

# Your code here:
# ---------------

duration_ex3 = 1.0
t_ex3 = np.linspace(0, duration_ex3, int(fs * duration_ex3), endpoint=False)

# Signal with phase jump
phase_ex3 = 2 * np.pi * 10 * t_ex3
phase_ex3[t_ex3 >= 0.5] += np.pi / 2  # Add œÄ/2 phase jump at midpoint
signal_ex3 = np.sin(phase_ex3)

# TODO: Extract phases and compare...

In [None]:
# Solution Exercise 3

# Extract phases
wavelet_phase_ex3 = compute_wavelet_phase(signal_ex3, np.array([10.0]), fs, n_cycles=5)[0, :]
hilbert_phase_ex3 = np.angle(hilbert(signal_ex3))

# True phase (wrapped to [-œÄ, œÄ])
true_phase_ex3 = np.mod(phase_ex3 + np.pi, 2 * np.pi) - np.pi

fig, axes = plt.subplots(3, 1, figsize=(12, 7), sharex=True)

# Signal
axes[0].plot(t_ex3, signal_ex3, color=PRIMARY_BLUE, linewidth=1.5)
axes[0].axvline(x=0.5, color=PRIMARY_RED, linestyle='--', linewidth=2, label='Phase jump')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Signal with œÄ/2 Phase Jump at t=0.5s', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Phase comparison
axes[1].plot(t_ex3, true_phase_ex3, color='gray', linewidth=2, alpha=0.5, label='True phase')
axes[1].plot(t_ex3, wavelet_phase_ex3, color=PRIMARY_RED, linewidth=1.5, 
             linestyle='--', label='Wavelet')
axes[1].plot(t_ex3, hilbert_phase_ex3, color=PRIMARY_GREEN, linewidth=1.5, 
             linestyle=':', label='Hilbert')
axes[1].axvline(x=0.5, color='gray', linestyle='--', alpha=0.5)
axes[1].set_ylabel('Phase (rad)')
axes[1].set_title('Phase Extraction Comparison', fontsize=11)
axes[1].set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
axes[1].set_yticklabels(['-œÄ', '-œÄ/2', '0', 'œÄ/2', 'œÄ'])
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

# Zoom around phase jump
axes[2].plot(t_ex3, wavelet_phase_ex3, color=PRIMARY_RED, linewidth=2, label='Wavelet')
axes[2].plot(t_ex3, hilbert_phase_ex3, color=PRIMARY_GREEN, linewidth=2, 
             linestyle='--', label='Hilbert')
axes[2].axvline(x=0.5, color='gray', linestyle='--', alpha=0.5)
axes[2].set_xlim([0.4, 0.6])
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Phase (rad)')
axes[2].set_title('Zoom: Phase Around Jump', fontsize=11)
axes[2].legend(loc='upper right')
axes[2].grid(True, alpha=0.3)

fig.suptitle('Exercise 3: Wavelet vs Hilbert Phase at Discontinuity', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("‚úÖ Exercise 3 Solution:")
print("   - Both methods track the phase jump correctly")
print("   - Hilbert is slightly faster at detecting the jump (no temporal smoothing)")
print("   - Wavelet response is smoother due to the finite wavelet duration")

### Exercise 4: Edge Effects Analysis ‚ö†Ô∏è

Analyze how many samples should be excluded at the edges for:
- 4 Hz (delta) wavelet
- 10 Hz (alpha) wavelet  
- 30 Hz (beta) wavelet

All with n_cycles=5. Visualize the "valid" region for each.

In [None]:
# Exercise 4: Edge effects analysis
# TODO: Calculate edge samples for 4, 10, and 30 Hz wavelets
# TODO: Visualize the valid region for a 2-second signal

# Your code here:
# ---------------

frequencies_ex4 = [4, 10, 30]
n_cycles_ex4 = 5
duration_ex4 = 2.0

# TODO: Use compute_edge_samples() and visualize...

In [None]:
# Solution Exercise 4

fig, ax = plt.subplots(figsize=(12, 4))

n_samples_ex4 = int(fs * duration_ex4)
colors_ex4 = [PRIMARY_BLUE, PRIMARY_GREEN, PRIMARY_RED]

for idx, (freq, color) in enumerate(zip(frequencies_ex4, colors_ex4)):
    n_edge = compute_edge_samples(freq, fs, n_cycles=n_cycles_ex4)
    edge_time = n_edge / fs
    
    # Draw bar showing valid region
    y_pos = idx
    ax.barh(y_pos, duration_ex4, height=0.6, color='lightgray', alpha=0.5)
    ax.barh(y_pos, duration_ex4 - 2 * edge_time, left=edge_time, 
            height=0.6, color=color, alpha=0.7, label=f'{freq} Hz: {edge_time*1000:.0f} ms edges')
    
    # Mark edge regions
    ax.axvline(x=edge_time, color=color, linestyle='--', alpha=0.7)
    ax.axvline(x=duration_ex4 - edge_time, color=color, linestyle='--', alpha=0.7)
    
    # Annotate
    valid_duration = duration_ex4 - 2 * edge_time
    ax.text(duration_ex4 / 2, y_pos, f'Valid: {valid_duration:.2f}s ({valid_duration/duration_ex4*100:.1f}%)',
            ha='center', va='center', fontsize=10, fontweight='bold')

ax.set_yticks([0, 1, 2])
ax.set_yticklabels([f'{f} Hz' for f in frequencies_ex4])
ax.set_xlabel('Time (s)')
ax.set_title(f'Valid Analysis Region for 2s Signal (n_cycles={n_cycles_ex4})', fontsize=12, fontweight='bold')
ax.legend(loc='upper right')
ax.set_xlim([0, duration_ex4])
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

print("‚úÖ Exercise 4 Solution:")
for freq in frequencies_ex4:
    n_edge = compute_edge_samples(freq, fs, n_cycles=n_cycles_ex4)
    edge_time = n_edge / fs
    valid_time = duration_ex4 - 2 * edge_time
    print(f"   {freq:2d} Hz: {n_edge:3d} edge samples ({edge_time*1000:.0f}ms), "
          f"valid region: {valid_time:.2f}s ({valid_time/duration_ex4*100:.1f}%)")

### Exercise 5: Event-Related Power Analysis üß†

Create a simulated EEG with:
- Continuous 10 Hz (alpha) oscillation
- Beta (25 Hz) burst appearing 200-400 ms after a simulated event

Compute the time-frequency representation and identify the event-related modulations.

In [None]:
# Exercise 5: Event-related power analysis
# TODO: Create epoch from -0.5s to 1s (event at t=0)
# TODO: Add continuous alpha (10 Hz)
# TODO: Add beta burst (25 Hz) from 0.2-0.4s
# TODO: Compute time-frequency power and baseline-normalize

# Your code here:
# ---------------

np.random.seed(123)
t_ex5 = np.linspace(-0.5, 1.0, int(fs * 1.5), endpoint=False)

# TODO: Build signal and analyze...

In [None]:
# Solution Exercise 5

# Create signal
eeg_ex5 = np.zeros_like(t_ex5)

# Continuous alpha (10 Hz)
eeg_ex5 += 1.0 * np.sin(2 * np.pi * 10 * t_ex5)

# Beta burst (25 Hz) from 0.2 to 0.4s
beta_envelope_ex5 = np.exp(-((t_ex5 - 0.3)**2) / (2 * 0.05**2))  # Gaussian at 0.3s
eeg_ex5 += 0.8 * beta_envelope_ex5 * np.sin(2 * np.pi * 25 * t_ex5)

# Add noise
eeg_ex5 += np.random.randn(len(t_ex5)) * 0.3

# Compute time-frequency power
freqs_ex5 = np.arange(5, 40, 0.5)
power_ex5 = compute_wavelet_power(eeg_ex5, freqs_ex5, fs, 
                                   n_cycles=compute_adaptive_cycles(freqs_ex5, 3, 7))

# Baseline normalize
baseline_mask_ex5 = (t_ex5 >= -0.4) & (t_ex5 <= -0.1)
baseline_power_ex5 = power_ex5[:, baseline_mask_ex5].mean(axis=1, keepdims=True)
power_norm_ex5 = (power_ex5 - baseline_power_ex5) / baseline_power_ex5 * 100

fig, axes = plt.subplots(2, 1, figsize=(12, 6))

# Signal
axes[0].plot(t_ex5, eeg_ex5, color=PRIMARY_BLUE, linewidth=0.8)
axes[0].axvline(x=0, color=PRIMARY_RED, linestyle='--', linewidth=2, label='Event')
axes[0].axvspan(-0.4, -0.1, alpha=0.2, color='gray', label='Baseline')
axes[0].axvspan(0.2, 0.4, alpha=0.2, color=SECONDARY_ORANGE, label='Beta burst')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Simulated EEG Epoch', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Time-frequency
im = axes[1].pcolormesh(t_ex5, freqs_ex5, power_norm_ex5,
                        shading='gouraud', cmap='RdBu_r', vmin=-100, vmax=100)
axes[1].axvline(x=0, color='black', linestyle='--', linewidth=2)
axes[1].axhline(y=10, color='white', linestyle=':', alpha=0.5)
axes[1].axhline(y=25, color='white', linestyle=':', alpha=0.5)
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_title('Event-Related Power Change (% from baseline)', fontsize=11)
plt.colorbar(im, ax=axes[1], label='% change')

# Annotate
axes[1].annotate('Beta ERS', xy=(0.3, 25), fontsize=10, color='white',
                 ha='center', va='center', fontweight='bold',
                 bbox=dict(boxstyle='round', facecolor='black', alpha=0.5))

fig.suptitle('Exercise 5: Event-Related Time-Frequency Analysis', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("‚úÖ Exercise 5 Solution:")
print("   - Continuous alpha visible throughout (10 Hz band)")
print("   - Beta ERS clearly visible at 25 Hz, 200-400 ms post-event")

### Exercise 6: Time-Resolved PLV Calculation üë•

Create two simulated EEG signals that start desynchronized and become synchronized.
Use wavelet-based phase extraction to compute time-resolved PLV.
Identify the moment when synchronization emerges.

In [None]:
# Exercise 6: Time-resolved PLV
# TODO: Create two 10 Hz signals
# TODO: Signal 2 has random phase offset for first 2s, then synchronized
# TODO: Extract phases with wavelets
# TODO: Compute time-resolved PLV with 500ms sliding window

# Your code here:
# ---------------

np.random.seed(42)
duration_ex6 = 4.0
t_ex6 = np.linspace(0, duration_ex6, int(fs * duration_ex6), endpoint=False)

# TODO: Create signals and compute PLV...

In [None]:
# Solution Exercise 6

# Create two signals
freq_ex6 = 10  # Hz

# Signal 1: Clean 10 Hz
eeg1_ex6 = np.sin(2 * np.pi * freq_ex6 * t_ex6) + np.random.randn(len(t_ex6)) * 0.2

# Signal 2: Random phase offset first 2s, then synchronized
phase_offset_ex6 = np.zeros_like(t_ex6)
# Random walk phase offset for first 2s
random_phase = np.cumsum(np.random.randn(sum(t_ex6 < 2)) * 0.1)
random_phase = random_phase - random_phase.mean()  # Center around 0
phase_offset_ex6[t_ex6 < 2] = random_phase

eeg2_ex6 = np.sin(2 * np.pi * freq_ex6 * t_ex6 + phase_offset_ex6) + np.random.randn(len(t_ex6)) * 0.2

# Extract phases using wavelet
phase1_ex6 = compute_wavelet_phase(eeg1_ex6, np.array([freq_ex6]), fs, n_cycles=5)[0, :]
phase2_ex6 = compute_wavelet_phase(eeg2_ex6, np.array([freq_ex6]), fs, n_cycles=5)[0, :]

# Compute phase difference
phase_diff_ex6 = np.angle(np.exp(1j * (phase1_ex6 - phase2_ex6)))

# Compute time-resolved PLV
window_size_ex6 = int(0.5 * fs)  # 500 ms
step_size_ex6 = int(0.05 * fs)  # 50 ms

plv_times_ex6 = []
plv_values_ex6 = []

for start in range(0, len(t_ex6) - window_size_ex6, step_size_ex6):
    end = start + window_size_ex6
    window_diff = phase_diff_ex6[start:end]
    
    # PLV
    plv = np.abs(np.mean(np.exp(1j * window_diff)))
    
    center_time = t_ex6[start + window_size_ex6 // 2]
    plv_times_ex6.append(center_time)
    plv_values_ex6.append(plv)

plv_times_ex6 = np.array(plv_times_ex6)
plv_values_ex6 = np.array(plv_values_ex6)

# Find synchronization threshold crossing
sync_threshold = 0.7
sync_onset_idx = np.where(plv_values_ex6 > sync_threshold)[0]
if len(sync_onset_idx) > 0:
    sync_onset_time = plv_times_ex6[sync_onset_idx[0]]
else:
    sync_onset_time = None

fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

# Signals
axes[0].plot(t_ex6, eeg1_ex6, color=PRIMARY_BLUE, linewidth=0.8, label='Person 1')
axes[0].plot(t_ex6, eeg2_ex6 - 3, color=PRIMARY_RED, linewidth=0.8, label='Person 2 (offset)')
axes[0].axvline(x=2.0, color='gray', linestyle='--', alpha=0.7, label='Sync start')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Two Simulated EEG Signals', fontsize=11)
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Phase difference
axes[1].plot(t_ex6, phase_diff_ex6, color=SECONDARY_PURPLE, linewidth=0.8)
axes[1].axvline(x=2.0, color='gray', linestyle='--', alpha=0.7)
axes[1].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
axes[1].set_ylabel('Phase diff (rad)')
axes[1].set_title('Phase Difference (Wavelet-Extracted)', fontsize=11)
axes[1].set_yticks([-np.pi, 0, np.pi])
axes[1].set_yticklabels(['-œÄ', '0', 'œÄ'])
axes[1].grid(True, alpha=0.3)

# PLV
axes[2].plot(plv_times_ex6, plv_values_ex6, color=PRIMARY_GREEN, linewidth=2)
axes[2].fill_between(plv_times_ex6, 0, plv_values_ex6, alpha=0.3, color=PRIMARY_GREEN)
axes[2].axhline(y=sync_threshold, color=SECONDARY_ORANGE, linestyle='--', 
                label=f'Threshold ({sync_threshold})')
axes[2].axvline(x=2.0, color='gray', linestyle='--', alpha=0.7)
if sync_onset_time:
    axes[2].axvline(x=sync_onset_time, color=PRIMARY_RED, linestyle='-', linewidth=2,
                    label=f'Sync detected: {sync_onset_time:.2f}s')
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('PLV')
axes[2].set_title('Time-Resolved Phase Locking Value', fontsize=11)
axes[2].set_ylim([0, 1])
axes[2].legend(loc='upper left')
axes[2].grid(True, alpha=0.3)

fig.suptitle('Exercise 6: Time-Resolved PLV Analysis', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("‚úÖ Exercise 6 Solution:")
print(f"   - Synchronization ground truth: t = 2.0 s")
if sync_onset_time:
    print(f"   - Synchronization detected (PLV > {sync_threshold}): t = {sync_onset_time:.2f} s")
print("   - PLV clearly distinguishes desynchronized vs synchronized periods!")

## 15. Summary üìã

### What We Learned

In this notebook, we explored **wavelets and time-frequency analysis**, essential tools for understanding dynamic brain activity:

1. **Limitations of FFT**: Standard Fourier analysis assumes stationarity and cannot localize events in time

2. **Short-Time Fourier Transform (STFT)**: 
   - Slides a window along the signal
   - Trade-off: fixed time-frequency resolution
   - Heisenberg-Gabor uncertainty: Œît √ó Œîf ‚â• 1/(4œÄ)

3. **Wavelets**: 
   - Compact oscillations localized in time
   - **Multi-resolution**: short wavelets for high frequencies, long for low
   - Ideal for non-stationary EEG signals

4. **Complex Morlet Wavelet**:
   - Gaussian-enveloped complex sinusoid
   - n_cycles parameter controls time-frequency trade-off
   - Provides both power (|W|¬≤) and phase (‚à†W)

5. **Wavelet Convolution**:
   - Efficient via FFT (convolution theorem)
   - Output at each time-frequency point is complex

6. **Adaptive n_cycles**: Scale with frequency for consistent resolution

7. **Edge Effects**: Lower frequencies lose more samples at boundaries

8. **Applications**:
   - Event-related power (ERS/ERD)
   - Time-resolved connectivity (PLV)
   - Foundation for hyperscanning analysis

### Key Equations

| Concept | Equation |
|---------|----------|
| Morlet wavelet | $\psi(t, f) = A e^{-t^2/(2\sigma_t^2)} e^{i2\pi ft}$ |
| Temporal std | $\sigma_t = n_{cycles}/(2\pi f)$ |
| Edge samples | $N_{edge} = n_\sigma \cdot \sigma_t \cdot f_s$ |
| Wavelet power | $P(t, f) = |W(t, f)|^2$ |
| Wavelet phase | $\phi(t, f) = \angle W(t, f)$ |

### Key Functions Implemented

| Function | Purpose |
|----------|---------|
| `compute_stft()` | Short-Time Fourier Transform |
| `compute_spectrogram()` | STFT power spectrogram |
| `create_morlet_wavelet()` | Generate complex Morlet wavelet |
| `wavelet_convolution()` | FFT-based wavelet convolution |
| `compute_wavelet_transform()` | Full TFR at multiple frequencies |
| `compute_wavelet_power()` | Time-frequency power with baseline |
| `plot_time_frequency()` | Visualize TFR |
| `compute_adaptive_cycles()` | Frequency-dependent n_cycles |
| `compute_wavelet_phase()` | Extract phase at each frequency |
| `compute_edge_samples()` | Calculate affected edge samples |

## 16. Discussion Questions üí¨

1. **Resolution Trade-offs**: Why is the Heisenberg-Gabor uncertainty principle particularly problematic for EEG analysis? How do wavelets help mitigate this issue?

2. **Choosing n_cycles**: A researcher wants to analyze fast gamma oscillations (40-80 Hz) for precise event timing. Should they use high or low n_cycles? What about for slow theta rhythms (4-8 Hz)?

3. **STFT vs Wavelets**: When might STFT be preferred over wavelet analysis? Consider computational cost, interpretability, and the nature of the signal.

4. **Edge Effects in Practice**: You have 30-second EEG epochs and want to analyze theta (6 Hz) oscillations with 5 cycles. How much valid data do you have? What strategies could increase usable data?

5. **Phase vs Power**: A hyperscanning study finds increased inter-brain PLV but no change in power. What does this tell us about the neural interaction?

6. **Baseline Normalization**: Why is baseline normalization important for event-related time-frequency analysis? When might you NOT want to use it?

7. **Computational Considerations**: For a 64-channel EEG with 1-hour recording, estimate the memory needed for wavelet transform from 1-100 Hz (1 Hz resolution). What optimization strategies could help?

---

### Looking Ahead üîÆ

This notebook completes Module B on Phase and Amplitude! You now have the tools to:
- Extract power at any time-frequency point
- Track phase dynamics across frequencies
- Analyze event-related modulations

**Next up in Module C**: We'll tackle connectivity concepts including volume conduction, connectivity matrices, and statistical significance testing.

**In Module F-G**: We'll apply wavelets to compute time-resolved hyperscanning metrics like coherence and PLV!