## 1. Introduction ‚Äî Beyond Correlation

In previous modules, we explored connectivity measures like phase-based metrics (PLV, PLI) and amplitude correlations. These capture **linear** or **periodic** relationships between signals.

But neural relationships can be **nonlinear**!

### A Motivating Example

Consider two signals X and Y where Y increases when |X| is large, regardless of whether X is positive or negative:

- **Correlation = 0** (no linear relationship ‚Äî high positive and negative X values both relate to high Y)
- **But there IS a relationship!** Y clearly depends on X

### Enter Mutual Information

**Mutual Information (MI)** captures **any** statistical dependency between two variables:

> *"How much does knowing X tell us about Y?"*

MI is:
- **General**: Detects linear AND nonlinear relationships
- **Symmetric**: MI(X, Y) = MI(Y, X)
- **Non-negative**: MI ‚â• 0, with MI = 0 only when X and Y are independent

The trade-off: MI is more powerful but computationally harder to estimate than correlation.

> üí° **Key insight**: MI detects relationships that correlation misses.

## 2. Intuition ‚Äî Shared Information

Before the math, let's build intuition.

### Thought Experiment: Two Weather Stations

- **Station A** records temperature in Paris
- **Station B** records temperature in Lyon

If you know Paris is 25¬∞C, you can make a better guess about Lyon's temperature than without any information. The cities share weather patterns!

**Mutual information** quantifies this shared uncertainty:
- If X and Y are **independent**: knowing X tells nothing about Y ‚Üí **MI = 0**
- If X **determines** Y completely: knowing X removes ALL uncertainty about Y ‚Üí **MI = H(Y)** (maximum)

### Key Properties

- MI is **symmetric**: MI(X, Y) = MI(Y, X)
- MI measures "how much information is common to both variables"
- MI = 0 ‚Üî statistical independence

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Wedge
from matplotlib.collections import PatchCollection
import matplotlib.patches as mpatches
from numpy.typing import NDArray
from typing import Tuple, Optional, Dict, List
from scipy import stats
import sys
sys.path.append("../../..")

from src.colors import COLORS
from src.plotting import configure_plots
from src.information import (
    compute_entropy_discrete,
    compute_entropy_continuous,
    compute_entropy_from_counts,
    optimal_n_bins
)

configure_plots()

In [None]:
# Visualization 1: Different relationships ‚Äî correlation vs MI

np.random.seed(42)
n_samples = 500

# Generate different relationships
x_base = np.random.randn(n_samples)

# 1. Independent
y_independent = np.random.randn(n_samples)

# 2. Linear relationship
y_linear = 0.8 * x_base + 0.6 * np.random.randn(n_samples)

# 3. Nonlinear (quadratic) ‚Äî correlation ‚âà 0 but dependent!
y_quadratic = x_base**2 + 0.3 * np.random.randn(n_samples)

# Compute correlations
corr_indep = np.corrcoef(x_base, y_independent)[0, 1]
corr_linear = np.corrcoef(x_base, y_linear)[0, 1]
corr_quad = np.corrcoef(x_base, y_quadratic)[0, 1]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: Independent
axes[0].scatter(x_base, y_independent, alpha=0.5, s=20, color=COLORS["signal_1"])
axes[0].set_xlabel("X", fontsize=12)
axes[0].set_ylabel("Y", fontsize=12)
axes[0].set_title(f"Independent\nCorr = {corr_indep:.3f}", fontsize=12, fontweight="bold")
axes[0].text(0.05, 0.95, "MI ‚âà 0", transform=axes[0].transAxes, fontsize=11,
             fontweight="bold", va="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

# Plot 2: Linear
axes[1].scatter(x_base, y_linear, alpha=0.5, s=20, color=COLORS["signal_2"])
axes[1].set_xlabel("X", fontsize=12)
axes[1].set_ylabel("Y", fontsize=12)
axes[1].set_title(f"Linear Relationship\nCorr = {corr_linear:.3f}", fontsize=12, fontweight="bold")
axes[1].text(0.05, 0.95, "MI > 0", transform=axes[1].transAxes, fontsize=11,
             fontweight="bold", va="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

# Plot 3: Quadratic (nonlinear)
axes[2].scatter(x_base, y_quadratic, alpha=0.5, s=20, color=COLORS["signal_3"])
axes[2].set_xlabel("X", fontsize=12)
axes[2].set_ylabel("Y", fontsize=12)
axes[2].set_title(f"Quadratic (Nonlinear)\nCorr = {corr_quad:.3f}", fontsize=12, fontweight="bold")
axes[2].text(0.05, 0.95, "MI > 0 !", transform=axes[2].transAxes, fontsize=11,
             fontweight="bold", va="top", color="red",
             bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

plt.suptitle("MI Captures Relationships That Correlation Misses", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print("\nüìä Key observation:")
print(f"   ‚Ä¢ Quadratic relationship: Correlation = {corr_quad:.3f} (nearly zero!)")
print("   ‚Ä¢ But Y clearly depends on X ‚Äî MI will detect this!")

## 3. The Entropy Venn Diagram

The relationship between entropy and mutual information is beautifully captured by a **Venn diagram**.

### The Diagram

Imagine two overlapping circles:
- **Circle X**: Total entropy H(X)
- **Circle Y**: Total entropy H(Y)
- **Overlap**: Mutual Information I(X; Y)
- **X only** (left crescent): H(X|Y) ‚Äî uncertainty about X given Y
- **Y only** (right crescent): H(Y|X) ‚Äî uncertainty about Y given X
- **Union** (both circles): H(X, Y) ‚Äî joint entropy

### Key Relationships

$$I(X; Y) = H(X) + H(Y) - H(X, Y)$$

$$I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)$$

MI = "what's shared" = "uncertainty reduced by knowing the other variable"

In [None]:
# Visualization 2: Entropy Venn Diagram

fig, ax = plt.subplots(figsize=(12, 8))

# Circle parameters
r = 1.5
offset = 0.9

# Draw circles
circle_x = plt.Circle((-offset, 0), r, fill=False, color=COLORS["signal_1"], linewidth=3)
circle_y = plt.Circle((offset, 0), r, fill=False, color=COLORS["signal_2"], linewidth=3)

# Fill regions with alpha
circle_x_fill = plt.Circle((-offset, 0), r, alpha=0.3, color=COLORS["signal_1"])
circle_y_fill = plt.Circle((offset, 0), r, alpha=0.3, color=COLORS["signal_2"])

ax.add_patch(circle_x_fill)
ax.add_patch(circle_y_fill)
ax.add_patch(circle_x)
ax.add_patch(circle_y)

# Labels
ax.text(-offset - 0.9, 0, "H(X|Y)", fontsize=14, fontweight="bold", ha="center", va="center")
ax.text(offset + 0.9, 0, "H(Y|X)", fontsize=14, fontweight="bold", ha="center", va="center")
ax.text(0, 0, "I(X;Y)", fontsize=16, fontweight="bold", ha="center", va="center",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.9))

# Circle labels
ax.text(-offset, r + 0.3, "H(X)", fontsize=14, fontweight="bold", ha="center", color=COLORS["signal_1"])
ax.text(offset, r + 0.3, "H(Y)", fontsize=14, fontweight="bold", ha="center", color=COLORS["signal_2"])

# Joint entropy brace/label
ax.annotate("", xy=(-offset - r, -r - 0.5), xytext=(offset + r, -r - 0.5),
            arrowprops=dict(arrowstyle="<->", color="black", lw=2))
ax.text(0, -r - 0.8, "H(X, Y) = Joint Entropy", fontsize=12, fontweight="bold", ha="center")

ax.set_xlim(-3.5, 3.5)
ax.set_ylim(-3, 3)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("The Information Venn Diagram", fontsize=16, fontweight="bold", pad=20)

# Add formulas
formulas = [
    r"$I(X;Y) = H(X) + H(Y) - H(X,Y)$",
    r"$I(X;Y) = H(X) - H(X|Y)$",
    r"$I(X;Y) = H(Y) - H(Y|X)$"
]
for i, formula in enumerate(formulas):
    ax.text(3.2, 1.5 - i * 0.6, formula, fontsize=11, ha="left", va="center")

plt.tight_layout()
plt.show()

print("\nüí° The overlap (I(X;Y)) represents SHARED information.")
print("   More overlap = more mutual information = stronger dependency.")

In [None]:
# Visualization 3: Three cases of dependency

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

cases = [
    ("Independent\nI(X;Y) = 0", 2.5, 0),      # No overlap
    ("Partially Dependent\nI(X;Y) moderate", 1.0, 0.5),  # Some overlap
    ("Fully Dependent\nI(X;Y) = H(Y)", 0, 0.8)   # One inside other
]

for ax, (title, offset, scale_y) in zip(axes, cases):
    r_x = 1.2
    r_y = 1.2 * (1 - scale_y * 0.5) if scale_y > 0 else 1.2
    
    circle_x = plt.Circle((-offset/2, 0), r_x, alpha=0.4, color=COLORS["signal_1"])
    circle_y = plt.Circle((offset/2, 0), r_y, alpha=0.4, color=COLORS["signal_2"])
    circle_x_line = plt.Circle((-offset/2, 0), r_x, fill=False, color=COLORS["signal_1"], linewidth=2)
    circle_y_line = plt.Circle((offset/2, 0), r_y, fill=False, color=COLORS["signal_2"], linewidth=2)
    
    ax.add_patch(circle_x)
    ax.add_patch(circle_y)
    ax.add_patch(circle_x_line)
    ax.add_patch(circle_y_line)
    
    ax.set_xlim(-3, 3)
    ax.set_ylim(-2, 2)
    ax.set_aspect("equal")
    ax.axis("off")
    ax.set_title(title, fontsize=12, fontweight="bold")
    
    # Labels
    ax.text(-offset/2, -1.7, "X", fontsize=12, ha="center", fontweight="bold", color=COLORS["signal_1"])
    ax.text(offset/2 if offset > 0 else 0, -1.7 if offset > 0 else -1.0, "Y", 
            fontsize=12, ha="center", fontweight="bold", color=COLORS["signal_2"])

plt.suptitle("How MI Reflects Dependency", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

## 4. Joint Entropy

**Joint entropy** H(X, Y) measures the uncertainty about the **pair** (X, Y) together.

### Definition

For discrete variables:

$$H(X, Y) = -\sum_{x}\sum_{y} p(x, y) \log p(x, y)$$

Where $p(x, y)$ is the **joint probability distribution**.

### Properties

- **Subadditivity**: $H(X, Y) \leq H(X) + H(Y)$
- Equality when X and Y are **independent**
- **Lower bound**: $H(X, Y) \geq \max(H(X), H(Y))$

### For Continuous Signals

We need **2D binning**:
1. Create a 2D histogram of (x, y) pairs
2. Normalize to get joint probability
3. Compute entropy of this 2D distribution

In [None]:
def compute_joint_histogram(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20
) -> Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
    """
    Compute 2D histogram for joint distribution.
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal.
    y : NDArray[np.float64]
        Second signal.
    n_bins : int, optional
        Number of bins per dimension. Default is 20.
    
    Returns
    -------
    Tuple[NDArray, NDArray, NDArray]
        (histogram_2d, x_edges, y_edges)
    """
    hist_2d, x_edges, y_edges = np.histogram2d(x, y, bins=n_bins)
    return hist_2d, x_edges, y_edges


def compute_joint_entropy(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20,
    base: float = 2.0
) -> float:
    """
    Compute joint entropy H(X, Y) via 2D binning.
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal.
    y : NDArray[np.float64]
        Second signal.
    n_bins : int, optional
        Number of bins per dimension. Default is 20.
    base : float, optional
        Logarithm base. Default is 2 (bits).
    
    Returns
    -------
    float
        Joint entropy H(X, Y).
    """
    # Compute 2D histogram
    hist_2d, _, _ = compute_joint_histogram(x, y, n_bins)
    
    # Normalize to get joint probability
    joint_prob = hist_2d / np.sum(hist_2d)
    
    # Flatten and remove zeros
    p = joint_prob.flatten()
    p = p[p > 0]
    
    # Compute entropy
    if base == np.e:
        entropy = -np.sum(p * np.log(p))
    else:
        entropy = -np.sum(p * np.log(p) / np.log(base))
    
    return float(entropy)

In [None]:
# Visualization 4: Joint distribution heatmap with marginals

np.random.seed(42)

# Generate correlated Gaussian signals
n_samples = 2000
correlation = 0.7
x = np.random.randn(n_samples)
y = correlation * x + np.sqrt(1 - correlation**2) * np.random.randn(n_samples)

n_bins = 30

# Create figure with marginals
fig = plt.figure(figsize=(10, 10))
gs = fig.add_gridspec(3, 3, width_ratios=[0.2, 1, 0.05], height_ratios=[0.2, 1, 0.05],
                      hspace=0.05, wspace=0.05)

# Main 2D histogram
ax_main = fig.add_subplot(gs[1, 1])
hist_2d, x_edges, y_edges, im = ax_main.hist2d(x, y, bins=n_bins, cmap="viridis")
ax_main.set_xlabel("X", fontsize=12)
ax_main.set_ylabel("Y", fontsize=12)

# Colorbar
ax_cbar = fig.add_subplot(gs[1, 2])
plt.colorbar(im, cax=ax_cbar, label="Count")

# Top marginal (X)
ax_top = fig.add_subplot(gs[0, 1], sharex=ax_main)
ax_top.hist(x, bins=n_bins, color=COLORS["signal_1"], edgecolor="white", alpha=0.8)
ax_top.set_ylabel("Count")
ax_top.tick_params(labelbottom=False)
ax_top.set_title("Marginal X", fontsize=11)

# Left marginal (Y)
ax_left = fig.add_subplot(gs[1, 0], sharey=ax_main)
ax_left.hist(y, bins=n_bins, orientation="horizontal", color=COLORS["signal_2"], 
             edgecolor="white", alpha=0.8)
ax_left.set_xlabel("Count")
ax_left.tick_params(labelleft=False)
ax_left.invert_xaxis()
ax_left.set_title("Marginal Y", fontsize=11, rotation=90, x=-0.3, y=0.5)

# Compute entropies
H_x, _ = compute_entropy_continuous(x, n_bins=n_bins)
H_y, _ = compute_entropy_continuous(y, n_bins=n_bins)
H_xy = compute_joint_entropy(x, y, n_bins=n_bins)

plt.suptitle(f"Joint Distribution (r = {correlation})\n" +
             f"H(X) = {H_x:.2f}, H(Y) = {H_y:.2f}, H(X,Y) = {H_xy:.2f} bits",
             fontsize=14, fontweight="bold", y=1.02)

plt.show()

print(f"\nüìä Entropy Analysis:")
print(f"   H(X) = {H_x:.3f} bits")
print(f"   H(Y) = {H_y:.3f} bits")
print(f"   H(X) + H(Y) = {H_x + H_y:.3f} bits")
print(f"   H(X, Y) = {H_xy:.3f} bits")
print(f"   ‚Üí H(X,Y) < H(X) + H(Y) because X and Y are dependent!")

## 5. Conditional Entropy

**Conditional entropy** H(X|Y) measures the remaining uncertainty about X **after** we observe Y.

### Definition

$$H(X|Y) = H(X, Y) - H(Y)$$

Or equivalently:

$$H(X|Y) = -\sum_{x,y} p(x, y) \log p(x|y)$$

### Interpretation

- **H(X|Y) = 0**: Y completely determines X (no remaining uncertainty)
- **H(X|Y) = H(X)**: Y tells us nothing about X (X and Y independent)
- In between: Y partially reduces uncertainty about X

### Connection to MI

$$I(X; Y) = H(X) - H(X|Y)$$

MI = initial uncertainty minus remaining uncertainty = **uncertainty reduced by knowing Y**

In [None]:
def compute_conditional_entropy(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20,
    base: float = 2.0
) -> float:
    """
    Compute conditional entropy H(X|Y).
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal (the one we're uncertain about).
    y : NDArray[np.float64]
        Second signal (the one we condition on).
    n_bins : int, optional
        Number of bins per dimension. Default is 20.
    base : float, optional
        Logarithm base. Default is 2 (bits).
    
    Returns
    -------
    float
        Conditional entropy H(X|Y) = H(X,Y) - H(Y).
    """
    H_xy = compute_joint_entropy(x, y, n_bins, base)
    H_y, _ = compute_entropy_continuous(y, n_bins=n_bins)
    
    # Convert H_y to same base if needed
    if base != 2.0:
        H_y = H_y * np.log(2) / np.log(base)
    
    return H_xy - H_y

In [None]:
# Visualization 5: Entropy decomposition

# Use same signals from before
H_x_given_y = compute_conditional_entropy(x, y, n_bins=n_bins)
H_y_given_x = compute_conditional_entropy(y, x, n_bins=n_bins)
MI = H_x - H_x_given_y  # I(X;Y) = H(X) - H(X|Y)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: H(X) decomposition
ax = axes[0]
bar_width = 0.5
ax.bar([0], [H_x_given_y], bar_width, label=f"H(X|Y) = {H_x_given_y:.2f}", color=COLORS["signal_1"], alpha=0.7)
ax.bar([0], [MI], bar_width, bottom=[H_x_given_y], label=f"I(X;Y) = {MI:.2f}", color=COLORS["signal_3"], alpha=0.7)
ax.axhline(H_x, color="black", linestyle="--", linewidth=2)
ax.text(0.6, H_x, f"H(X) = {H_x:.2f}", fontsize=11, va="center")
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(0, H_x * 1.2)
ax.set_xticks([0])
ax.set_xticklabels(["Entropy of X"])
ax.set_ylabel("Entropy (bits)", fontsize=12)
ax.set_title("H(X) = H(X|Y) + I(X;Y)", fontsize=12, fontweight="bold")
ax.legend(loc="upper right")

# Right: H(Y) decomposition
ax = axes[1]
ax.bar([0], [H_y_given_x], bar_width, label=f"H(Y|X) = {H_y_given_x:.2f}", color=COLORS["signal_2"], alpha=0.7)
ax.bar([0], [MI], bar_width, bottom=[H_y_given_x], label=f"I(X;Y) = {MI:.2f}", color=COLORS["signal_3"], alpha=0.7)
ax.axhline(H_y, color="black", linestyle="--", linewidth=2)
ax.text(0.6, H_y, f"H(Y) = {H_y:.2f}", fontsize=11, va="center")
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(0, H_y * 1.2)
ax.set_xticks([0])
ax.set_xticklabels(["Entropy of Y"])
ax.set_ylabel("Entropy (bits)", fontsize=12)
ax.set_title("H(Y) = H(Y|X) + I(X;Y)", fontsize=12, fontweight="bold")
ax.legend(loc="upper right")

plt.suptitle("Entropy Decomposition: MI is the Shared Part", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print(f"\nüìä The same I(X;Y) = {MI:.3f} bits appears in BOTH decompositions!")
print("   This is the 'shared information' ‚Äî the overlap in the Venn diagram.")

## 6. Mutual Information ‚Äî The Formula

Now we can formally define mutual information.

### Definition 1: Via Joint and Marginal Distributions

$$I(X; Y) = \sum_{x,y} p(x,y) \log \frac{p(x,y)}{p(x)p(y)}$$

This measures how much the joint distribution differs from the product of marginals (what we'd expect if independent).

### Definition 2: Via Entropies

$$I(X; Y) = H(X) + H(Y) - H(X, Y)$$

### Definition 3: Via Conditional Entropy

$$I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)$$

All three are equivalent!

### Properties

- **Non-negative**: $I(X; Y) \geq 0$ always
- **Zero iff independent**: $I(X; Y) = 0 \Leftrightarrow$ X and Y are statistically independent
- **Symmetric**: $I(X; Y) = I(Y; X)$
- **Self-information**: $I(X; X) = H(X)$

In [None]:
def compute_mutual_information(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20,
    base: float = 2.0
) -> float:
    """
    Compute mutual information I(X; Y).
    
    Uses the formula: I(X;Y) = H(X) + H(Y) - H(X,Y)
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal.
    y : NDArray[np.float64]
        Second signal.
    n_bins : int, optional
        Number of bins per dimension. Default is 20.
    base : float, optional
        Logarithm base. Default is 2 (bits).
    
    Returns
    -------
    float
        Mutual information I(X; Y).
    """
    # Compute individual entropies
    H_x, _ = compute_entropy_continuous(x, n_bins=n_bins)
    H_y, _ = compute_entropy_continuous(y, n_bins=n_bins)
    
    # Convert to specified base if needed
    if base != 2.0:
        H_x = H_x * np.log(2) / np.log(base)
        H_y = H_y * np.log(2) / np.log(base)
    
    # Compute joint entropy
    H_xy = compute_joint_entropy(x, y, n_bins, base)
    
    # MI = H(X) + H(Y) - H(X,Y)
    mi = H_x + H_y - H_xy
    
    # Ensure non-negative (can be slightly negative due to estimation)
    return max(0.0, float(mi))

In [None]:
# Verify our MI calculation with different formulas

MI_formula1 = H_x + H_y - H_xy  # Definition 2
MI_formula2 = H_x - H_x_given_y  # Definition 3a
MI_formula3 = H_y - H_y_given_x  # Definition 3b
MI_function = compute_mutual_information(x, y, n_bins=n_bins)

print("üìä Verification: All formulas give the same MI")
print("=" * 50)
print(f"  H(X) + H(Y) - H(X,Y)  = {MI_formula1:.4f} bits")
print(f"  H(X) - H(X|Y)         = {MI_formula2:.4f} bits")
print(f"  H(Y) - H(Y|X)         = {MI_formula3:.4f} bits")
print(f"  compute_mutual_information() = {MI_function:.4f} bits")
print("=" * 50)
print("\n‚úì All formulas are equivalent!")

In [None]:
# Visualization 6: MI vs correlation strength

np.random.seed(42)
n_samples = 1000
correlations = np.linspace(0, 0.99, 20)
mi_values = []

for corr in correlations:
    x_temp = np.random.randn(n_samples)
    y_temp = corr * x_temp + np.sqrt(1 - corr**2) * np.random.randn(n_samples)
    mi = compute_mutual_information(x_temp, y_temp, n_bins=20)
    mi_values.append(mi)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(correlations, mi_values, color=COLORS["signal_1"], linewidth=2.5, marker="o", markersize=6)
ax.set_xlabel("Correlation (r)", fontsize=12)
ax.set_ylabel("Mutual Information (bits)", fontsize=12)
ax.set_title("MI Increases with Statistical Dependency", fontsize=14, fontweight="bold")
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 1)
ax.set_ylim(0, max(mi_values) * 1.1)

# Add annotation
ax.annotate("Independent\n(r=0, MI‚âà0)", xy=(0.05, mi_values[1]), xytext=(0.2, 0.3),
            fontsize=10, arrowprops=dict(arrowstyle="->", color="black"))
ax.annotate("Strong dependency\n(high r, high MI)", xy=(0.9, mi_values[-2]), xytext=(0.6, mi_values[-2] * 0.7),
            fontsize=10, arrowprops=dict(arrowstyle="->", color="black"))

plt.tight_layout()
plt.show()

print("\nüí° For Gaussian variables, MI and correlation are related:")
print("   I(X;Y) = -0.5 √ó log‚ÇÇ(1 - r¬≤) for jointly Gaussian X, Y")

---

Excellent! We've covered the foundations. Let's continue to the key advantage of MI.

---

In [None]:
def generate_nonlinear_relationship(
    n_samples: int,
    relationship: str = "quadratic",
    noise_level: float = 0.2,
    seed: Optional[int] = None
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
    """
    Generate X, Y with specified nonlinear relationship.
    
    Parameters
    ----------
    n_samples : int
        Number of samples to generate.
    relationship : str, optional
        Type of relationship: "linear", "quadratic", "sinusoidal", 
        "absolute", "circular". Default is "quadratic".
    noise_level : float, optional
        Standard deviation of additive noise. Default is 0.2.
    seed : int, optional
        Random seed for reproducibility.
    
    Returns
    -------
    Tuple[NDArray, NDArray]
        (x, y) signal pair.
    """
    if seed is not None:
        np.random.seed(seed)
    
    x = np.random.uniform(-2, 2, n_samples)
    noise = noise_level * np.random.randn(n_samples)
    
    if relationship == "linear":
        y = 0.8 * x + noise
    elif relationship == "quadratic":
        y = x**2 + noise
    elif relationship == "sinusoidal":
        y = np.sin(2 * np.pi * x / 2) + noise
    elif relationship == "absolute":
        y = np.abs(x) + noise
    elif relationship == "circular":
        # XOR-like pattern
        y = np.sign(x) * np.random.choice([-1, 1], n_samples) + noise
    else:
        raise ValueError(f"Unknown relationship: {relationship}")
    
    return x, y

In [None]:
# Visualization 7: MI vs Correlation ‚Äî The Key Comparison

relationships = ["linear", "quadratic", "sinusoidal"]
n_samples = 1000

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

results = []

for idx, rel in enumerate(relationships):
    x, y = generate_nonlinear_relationship(n_samples, rel, noise_level=0.3, seed=42)
    
    # Compute metrics
    corr = np.corrcoef(x, y)[0, 1]
    mi = compute_mutual_information(x, y, n_bins=20)
    results.append((rel, corr, mi))
    
    # Top row: scatter plots
    colors_rel = [COLORS["signal_1"], COLORS["signal_2"], COLORS["signal_3"]]
    axes[0, idx].scatter(x, y, alpha=0.4, s=15, color=colors_rel[idx])
    axes[0, idx].set_xlabel("X", fontsize=11)
    axes[0, idx].set_ylabel("Y", fontsize=11)
    axes[0, idx].set_title(f"{rel.capitalize()}\nCorr = {corr:.3f}, MI = {mi:.3f}", 
                           fontsize=12, fontweight="bold")
    axes[0, idx].grid(True, alpha=0.3)

# Bottom row: bar chart comparison
x_pos = np.arange(len(relationships))
width = 0.35

corrs = [r[1] for r in results]
mis = [r[2] for r in results]

axes[1, 0].bar(x_pos - width/2, np.abs(corrs), width, label="|Correlation|", color=COLORS["signal_4"], alpha=0.8)
axes[1, 0].bar(x_pos + width/2, mis, width, label="MI (bits)", color=COLORS["signal_5"], alpha=0.8)
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels([r[0].capitalize() for r in results])
axes[1, 0].set_ylabel("Value", fontsize=11)
axes[1, 0].set_title("Comparison: |Correlation| vs MI", fontsize=12, fontweight="bold")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3, axis="y")

# Highlight the key insight
axes[1, 1].text(0.5, 0.7, "KEY INSIGHT", fontsize=18, fontweight="bold", ha="center", va="center",
                transform=axes[1, 1].transAxes)
axes[1, 1].text(0.5, 0.5, "Quadratic & Sinusoidal:\nCorrelation ‚âà 0\nbut MI > 0!", 
                fontsize=14, ha="center", va="center", transform=axes[1, 1].transAxes,
                bbox=dict(boxstyle="round", facecolor=COLORS["signal_3"], alpha=0.3))
axes[1, 1].text(0.5, 0.2, "MI detects these\nnonlinear relationships!", 
                fontsize=12, ha="center", va="center", transform=axes[1, 1].transAxes,
                style="italic")
axes[1, 1].axis("off")

# Summary table
axes[1, 2].axis("off")
table_data = [["Relationship", "|Corr|", "MI"]]
for rel, corr, mi in results:
    table_data.append([rel.capitalize(), f"{abs(corr):.3f}", f"{mi:.3f}"])

table = axes[1, 2].table(cellText=table_data, loc="center", cellLoc="center",
                          colWidths=[0.4, 0.3, 0.3])
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.2, 1.8)

# Style header row
for i in range(3):
    table[(0, i)].set_facecolor(COLORS["signal_1"])
    table[(0, i)].set_text_props(color="white", fontweight="bold")

plt.suptitle("MI Captures Nonlinear Relationships That Correlation Misses", 
             fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print("\nüìä Key observation:")
print("   ‚Ä¢ Linear: Both correlation and MI detect the relationship")
print("   ‚Ä¢ Quadratic: Correlation ‚âà 0, but MI clearly shows dependency!")
print("   ‚Ä¢ Sinusoidal: Same story ‚Äî MI wins for nonlinear relationships")

## 8. Normalized Mutual Information

Raw MI depends on the entropy of the variables ‚Äî hard to compare across different signals.

### Normalization Options

| Name | Formula | Range |
|------|---------|-------|
| Geometric | $\frac{I(X;Y)}{\sqrt{H(X) \cdot H(Y)}}$ | [0, 1] |
| Max | $\frac{I(X;Y)}{\max(H(X), H(Y))}$ | [0, 1] |
| Min | $\frac{I(X;Y)}{\min(H(X), H(Y))}$ | [0, 1] |
| Arithmetic | $\frac{2 \cdot I(X;Y)}{H(X) + H(Y)}$ | [0, 1] |

**Normalized MI = 1** means perfect dependency (one determines the other).

**Normalized MI = 0** means independence.

In [None]:
def compute_normalized_mi(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20,
    normalization: str = "geometric"
) -> float:
    """
    Compute normalized mutual information (range 0-1).
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal.
    y : NDArray[np.float64]
        Second signal.
    n_bins : int, optional
        Number of bins per dimension. Default is 20.
    normalization : str, optional
        Normalization method: "geometric", "max", "min", "arithmetic".
        Default is "geometric".
    
    Returns
    -------
    float
        Normalized MI in range [0, 1].
    """
    mi = compute_mutual_information(x, y, n_bins)
    H_x, _ = compute_entropy_continuous(x, n_bins=n_bins)
    H_y, _ = compute_entropy_continuous(y, n_bins=n_bins)
    
    if normalization == "geometric":
        denom = np.sqrt(H_x * H_y)
    elif normalization == "max":
        denom = max(H_x, H_y)
    elif normalization == "min":
        denom = min(H_x, H_y)
    elif normalization == "arithmetic":
        denom = (H_x + H_y) / 2
    else:
        raise ValueError(f"Unknown normalization: {normalization}")
    
    if denom == 0:
        return 0.0
    
    return min(1.0, mi / denom)  # Clip to [0, 1]

In [None]:
# Visualization 8: Raw MI vs Normalized MI

np.random.seed(42)
n_samples = 1000

# Create signals with different entropy levels but same dependency strength
# High entropy signals
x_high = np.random.randn(n_samples)
y_high = 0.7 * x_high + 0.71 * np.random.randn(n_samples)

# Low entropy signals (more peaked distribution)
x_low = 0.3 * np.random.randn(n_samples)
y_low = 0.7 * x_low + 0.71 * 0.3 * np.random.randn(n_samples)

# Compute metrics
mi_high = compute_mutual_information(x_high, y_high, n_bins=20)
mi_low = compute_mutual_information(x_low, y_low, n_bins=20)
nmi_high = compute_normalized_mi(x_high, y_high, n_bins=20)
nmi_low = compute_normalized_mi(x_low, y_low, n_bins=20)
corr_high = np.corrcoef(x_high, y_high)[0, 1]
corr_low = np.corrcoef(x_low, y_low)[0, 1]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Scatter plots
axes[0].scatter(x_high, y_high, alpha=0.3, s=10, color=COLORS["signal_1"], label="High entropy")
axes[0].scatter(x_low, y_low, alpha=0.5, s=10, color=COLORS["signal_2"], label="Low entropy")
axes[0].set_xlabel("X", fontsize=11)
axes[0].set_ylabel("Y", fontsize=11)
axes[0].set_title("Same Correlation, Different Entropy", fontsize=12, fontweight="bold")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Raw MI comparison
x_pos = [0, 1]
axes[1].bar(x_pos, [mi_high, mi_low], color=[COLORS["signal_1"], COLORS["signal_2"]], alpha=0.8)
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels(["High Entropy", "Low Entropy"])
axes[1].set_ylabel("MI (bits)", fontsize=11)
axes[1].set_title(f"Raw MI: Different Values!\nHigh={mi_high:.3f}, Low={mi_low:.3f}", 
                  fontsize=12, fontweight="bold")
axes[1].grid(True, alpha=0.3, axis="y")

# Normalized MI comparison
axes[2].bar(x_pos, [nmi_high, nmi_low], color=[COLORS["signal_1"], COLORS["signal_2"]], alpha=0.8)
axes[2].set_xticks(x_pos)
axes[2].set_xticklabels(["High Entropy", "Low Entropy"])
axes[2].set_ylabel("Normalized MI", fontsize=11)
axes[2].set_title(f"Normalized MI: Similar!\nHigh={nmi_high:.3f}, Low={nmi_low:.3f}", 
                  fontsize=12, fontweight="bold")
axes[2].set_ylim(0, 1)
axes[2].grid(True, alpha=0.3, axis="y")

plt.suptitle("Why Normalize MI?", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print(f"\nüìä Both have similar correlation: {corr_high:.3f} vs {corr_low:.3f}")
print(f"   Raw MI differs: {mi_high:.3f} vs {mi_low:.3f}")
print(f"   Normalized MI is comparable: {nmi_high:.3f} vs {nmi_low:.3f}")

## 9. Estimation Challenges

MI estimation from finite samples faces several challenges.

### Challenge 1: Binning Choice
- **Too few bins**: Underestimate MI (lose resolution)
- **Too many bins**: Overestimate MI (sparse sampling bias)
- Optimal depends on sample size and relationship

### Challenge 2: Positive Bias
- MI estimates are **biased upward**
- Even **independent** signals show positive MI due to finite sampling!
- More bins = more bias

### Challenge 3: Computational Cost
- 2D histograms scale with bins¬≤
- For n channels: n¬≤ pairs to compute

### Solutions
- Adaptive binning rules
- **Surrogate-based bias correction**
- KNN-based estimators (more advanced)

In [None]:
# Visualization 9: Bias demonstration

np.random.seed(42)
n_samples = 500

# Generate INDEPENDENT signals
x_indep = np.random.randn(n_samples)
y_indep = np.random.randn(n_samples)

# Compute MI with different bin counts
bin_counts = [5, 10, 15, 20, 30, 50, 75, 100]
mi_values_bias = []

for n_bins in bin_counts:
    mi = compute_mutual_information(x_indep, y_indep, n_bins=n_bins)
    mi_values_bias.append(mi)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: MI vs bins for independent signals
axes[0].plot(bin_counts, mi_values_bias, color=COLORS["signal_1"], linewidth=2.5, marker="o", markersize=8)
axes[0].axhline(0, color=COLORS["grid"], linestyle="--", linewidth=2, label="True MI = 0")
axes[0].set_xlabel("Number of Bins", fontsize=12)
axes[0].set_ylabel("Estimated MI (bits)", fontsize=12)
axes[0].set_title("Bias: Independent Signals Show Positive MI!", fontsize=12, fontweight="bold")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Add annotation
axes[0].annotate("More bins = more bias!", xy=(75, mi_values_bias[-2]), 
                 xytext=(50, mi_values_bias[-2] + 0.1),
                 fontsize=11, arrowprops=dict(arrowstyle="->", color="black"))

# Right: scatter plot showing they ARE independent
axes[1].scatter(x_indep, y_indep, alpha=0.4, s=20, color=COLORS["signal_2"])
axes[1].set_xlabel("X", fontsize=12)
axes[1].set_ylabel("Y", fontsize=12)
corr_indep = np.corrcoef(x_indep, y_indep)[0, 1]
axes[1].set_title(f"These ARE Independent!\nCorr = {corr_indep:.3f}", fontsize=12, fontweight="bold")
axes[1].grid(True, alpha=0.3)

plt.suptitle("The Positive Bias Problem in MI Estimation", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print("\\n‚ö†Ô∏è WARNING: MI estimated from independent signals is NOT zero!")
print(f"   With 20 bins: MI = {mi_values_bias[3]:.4f} bits")
print(f"   With 100 bins: MI = {mi_values_bias[-1]:.4f} bits")
print("   This is BIAS from finite sampling ‚Äî we need to correct for it!")

## 10. Surrogate Testing for MI

Just like in C03 (Statistical Significance), we use **surrogates** to:
1. Test if MI is significantly different from what we'd expect by chance
2. Estimate and correct for bias

### Procedure

1. Compute observed MI
2. Generate N surrogates by **shuffling** one signal (destroys dependency while preserving marginal distribution)
3. Compute MI for each surrogate
4. **P-value** = proportion of surrogates ‚â• observed MI
5. **Bias correction**: MI_corrected = MI_observed - mean(MI_surrogates)

In [None]:
def mi_significance_test(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    n_bins: int = 20,
    n_surrogates: int = 200,
    seed: Optional[int] = None
) -> Dict[str, float]:
    """
    Significance test for mutual information using surrogates.
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First signal.
    y : NDArray[np.float64]
        Second signal.
    n_bins : int, optional
        Number of bins. Default is 20.
    n_surrogates : int, optional
        Number of surrogate samples. Default is 200.
    seed : int, optional
        Random seed for reproducibility.
    
    Returns
    -------
    Dict[str, float]
        Dictionary with:
        - "mi_observed": Raw MI value
        - "mi_corrected": Bias-corrected MI
        - "pvalue": P-value from surrogate test
        - "null_mean": Mean of null distribution
        - "null_std": Std of null distribution
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Observed MI
    mi_observed = compute_mutual_information(x, y, n_bins)
    
    # Generate surrogates
    mi_surrogates = []
    for _ in range(n_surrogates):
        # Shuffle one signal to destroy dependency
        y_shuffled = np.random.permutation(y)
        mi_surr = compute_mutual_information(x, y_shuffled, n_bins)
        mi_surrogates.append(mi_surr)
    
    mi_surrogates = np.array(mi_surrogates)
    
    # Statistics
    null_mean = np.mean(mi_surrogates)
    null_std = np.std(mi_surrogates)
    
    # P-value: proportion of surrogates >= observed
    pvalue = np.mean(mi_surrogates >= mi_observed)
    
    # Bias-corrected MI
    mi_corrected = mi_observed - null_mean
    
    return {
        "mi_observed": float(mi_observed),
        "mi_corrected": float(mi_corrected),
        "pvalue": float(pvalue),
        "null_mean": float(null_mean),
        "null_std": float(null_std)
    }

In [None]:
# Visualization 10: Surrogate testing demonstration

np.random.seed(42)
n_samples = 500

# Case 1: Correlated signals (should be significant)
x_corr = np.random.randn(n_samples)
y_corr = 0.6 * x_corr + 0.8 * np.random.randn(n_samples)

# Case 2: Independent signals (should NOT be significant)
x_indep = np.random.randn(n_samples)
y_indep = np.random.randn(n_samples)

# Run significance tests
result_corr = mi_significance_test(x_corr, y_corr, n_bins=20, n_surrogates=500, seed=42)
result_indep = mi_significance_test(x_indep, y_indep, n_bins=20, n_surrogates=500, seed=42)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Generate surrogate distributions for visualization
np.random.seed(42)
surr_corr = [compute_mutual_information(x_corr, np.random.permutation(y_corr), 20) for _ in range(500)]
surr_indep = [compute_mutual_information(x_indep, np.random.permutation(y_indep), 20) for _ in range(500)]

# Left: Correlated signals
axes[0].hist(surr_corr, bins=30, color=COLORS["signal_1"], alpha=0.7, density=True, label="Null distribution")
axes[0].axvline(result_corr["mi_observed"], color="red", linewidth=3, linestyle="-", 
                label=f"Observed MI = {result_corr['mi_observed']:.3f}")
axes[0].axvline(result_corr["null_mean"], color=COLORS["grid"], linewidth=2, linestyle="--",
                label=f"Null mean = {result_corr['null_mean']:.3f}")
axes[0].set_xlabel("MI (bits)", fontsize=12)
axes[0].set_ylabel("Density", fontsize=12)
axes[0].set_title(f"Correlated Signals\np = {result_corr['pvalue']:.4f} (SIGNIFICANT)", 
                  fontsize=12, fontweight="bold", color="green")
axes[0].legend(fontsize=9)

# Right: Independent signals
axes[1].hist(surr_indep, bins=30, color=COLORS["signal_2"], alpha=0.7, density=True, label="Null distribution")
axes[1].axvline(result_indep["mi_observed"], color="red", linewidth=3, linestyle="-",
                label=f"Observed MI = {result_indep['mi_observed']:.3f}")
axes[1].axvline(result_indep["null_mean"], color=COLORS["grid"], linewidth=2, linestyle="--",
                label=f"Null mean = {result_indep['null_mean']:.3f}")
axes[1].set_xlabel("MI (bits)", fontsize=12)
axes[1].set_ylabel("Density", fontsize=12)
axes[1].set_title(f"Independent Signals\np = {result_indep['pvalue']:.3f} (not significant)", 
                  fontsize=12, fontweight="bold", color="gray")
axes[1].legend(fontsize=9)

plt.suptitle("Surrogate Testing for MI Significance", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

print("\\nüìä Results Summary:")
print("-" * 60)
print(f"Correlated signals:")
print(f"  MI_observed = {result_corr['mi_observed']:.4f}, MI_corrected = {result_corr['mi_corrected']:.4f}")
print(f"  p-value = {result_corr['pvalue']:.4f} ‚Üí {'SIGNIFICANT' if result_corr['pvalue'] < 0.05 else 'not significant'}")
print(f"\\nIndependent signals:")
print(f"  MI_observed = {result_indep['mi_observed']:.4f}, MI_corrected = {result_indep['mi_corrected']:.4f}")
print(f"  p-value = {result_indep['pvalue']:.4f} ‚Üí {'SIGNIFICANT' if result_indep['pvalue'] < 0.05 else 'not significant'}")

## 11. MI for Time Series ‚Äî Dynamic Analysis

Neural signals are **time series**, not static samples. We can compute MI in different ways:

### Option 1: Global MI
Treat each time point as a sample. Simple, assumes stationarity.

### Option 2: Sliding Window MI
Compute MI in short windows ‚Üí get MI over time. Captures **dynamic changes** in coupling.

### Option 3: Time-Lagged MI
MI between X(t) and Y(t + œÑ). Can reveal **delayed relationships**.

This is a preview of **Transfer Entropy** (D03)!

In [None]:
def compute_mi_sliding_window(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    window_samples: int,
    step_samples: int,
    n_bins: int = 15
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
    """
    Compute MI in sliding windows over time.
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First time series.
    y : NDArray[np.float64]
        Second time series.
    window_samples : int
        Window size in samples.
    step_samples : int
        Step size in samples.
    n_bins : int, optional
        Number of bins for MI estimation. Default is 15.
    
    Returns
    -------
    Tuple[NDArray, NDArray]
        (window_centers, mi_values)
    """
    n_samples = len(x)
    centers = []
    mi_values = []
    
    for start in range(0, n_samples - window_samples + 1, step_samples):
        end = start + window_samples
        mi = compute_mutual_information(x[start:end], y[start:end], n_bins)
        centers.append((start + end) / 2)
        mi_values.append(mi)
    
    return np.array(centers), np.array(mi_values)


def compute_mi_lagged(
    x: NDArray[np.float64],
    y: NDArray[np.float64],
    max_lag_samples: int,
    n_bins: int = 20
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
    """
    Compute MI as function of time lag.
    
    MI(X(t), Y(t+lag)) for various lags.
    
    Parameters
    ----------
    x : NDArray[np.float64]
        First time series.
    y : NDArray[np.float64]
        Second time series.
    max_lag_samples : int
        Maximum lag in samples (both positive and negative).
    n_bins : int, optional
        Number of bins. Default is 20.
    
    Returns
    -------
    Tuple[NDArray, NDArray]
        (lags, mi_values)
    """
    lags = np.arange(-max_lag_samples, max_lag_samples + 1)
    mi_values = []
    
    for lag in lags:
        if lag < 0:
            # Y leads X
            x_aligned = x[-lag:]
            y_aligned = y[:lag]
        elif lag > 0:
            # X leads Y
            x_aligned = x[:-lag]
            y_aligned = y[lag:]
        else:
            x_aligned = x
            y_aligned = y
        
        mi = compute_mutual_information(x_aligned, y_aligned, n_bins)
        mi_values.append(mi)
    
    return lags, np.array(mi_values)

In [None]:
# Visualization 11: Time-varying MI

np.random.seed(42)
fs = 256
duration = 15  # seconds
t = np.arange(0, duration, 1/fs)
n_samples = len(t)

# Create signals with time-varying coupling
# 0-5s: independent
# 5-10s: coupled
# 10-15s: independent again

x = np.random.randn(n_samples)
y = np.zeros(n_samples)

# Independent phase 1 (0-5s)
idx1 = t < 5
y[idx1] = np.random.randn(np.sum(idx1))

# Coupled phase (5-10s)
idx2 = (t >= 5) & (t < 10)
y[idx2] = 0.7 * x[idx2] + 0.71 * np.random.randn(np.sum(idx2))

# Independent phase 2 (10-15s)
idx3 = t >= 10
y[idx3] = np.random.randn(np.sum(idx3))

# Compute sliding window MI
window_sec = 2  # 2 second window
step_sec = 0.25  # 250ms step
window_samples = int(window_sec * fs)
step_samples = int(step_sec * fs)

centers, mi_time = compute_mi_sliding_window(x, y, window_samples, step_samples, n_bins=15)
time_centers = centers / fs

fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)

# Plot signals
axes[0].plot(t, x, color=COLORS["signal_1"], linewidth=0.5, alpha=0.8, label="X")
axes[0].set_ylabel("X", fontsize=12)
axes[0].set_title("Signal X", fontsize=12, fontweight="bold")
axes[0].legend(loc="upper right")

axes[1].plot(t, y, color=COLORS["signal_2"], linewidth=0.5, alpha=0.8, label="Y")
axes[1].set_ylabel("Y", fontsize=12)
axes[1].set_title("Signal Y (coupled to X during 5-10s)", fontsize=12, fontweight="bold")
axes[1].legend(loc="upper right")

# Highlight coupling period
for ax in axes[:2]:
    ax.axvspan(5, 10, alpha=0.2, color=COLORS["signal_3"], label="Coupled period")

# Plot MI over time
axes[2].plot(time_centers, mi_time, color=COLORS["signal_3"], linewidth=2.5)
axes[2].axvspan(5, 10, alpha=0.2, color=COLORS["signal_3"])
axes[2].set_xlabel("Time (s)", fontsize=12)
axes[2].set_ylabel("MI (bits)", fontsize=12)
axes[2].set_title("Sliding Window MI ‚Äî Detects Dynamic Coupling!", fontsize=12, fontweight="bold")
axes[2].grid(True, alpha=0.3)

# Add annotations
axes[2].annotate("Independent", xy=(2.5, np.mean(mi_time[:10])), fontsize=11, ha="center")
axes[2].annotate("COUPLED", xy=(7.5, np.max(mi_time)), fontsize=11, ha="center", fontweight="bold", color="red")
axes[2].annotate("Independent", xy=(12.5, np.mean(mi_time[-10:])), fontsize=11, ha="center")

plt.suptitle("Time-Resolved MI Reveals Dynamic Changes in Coupling", fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout()
plt.show()

print("\\nüìä MI clearly increases during the coupled period (5-10s)!")
print("   This shows MI can track DYNAMIC changes in statistical dependency.")

In [None]:
def compute_time_lagged_mi(
    x: np.ndarray,
    y: np.ndarray,
    max_lag: int,
    n_bins: int = 20
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute MI between two signals at different time lags.
    
    Parameters
    ----------
    x : np.ndarray
        First signal.
    y : np.ndarray
        Second signal.
    max_lag : int
        Maximum lag in samples (both positive and negative).
    n_bins : int
        Number of bins for histogram estimation.
    
    Returns
    -------
    lags : np.ndarray
        Array of lag values (negative = X leads, positive = Y leads).
    mi_values : np.ndarray
        MI values at each lag.
    """
    lags = np.arange(-max_lag, max_lag + 1)
    mi_values = np.zeros(len(lags))
    
    for i, lag in enumerate(lags):
        if lag < 0:
            x_shifted = x[:lag]  # X leads
            y_shifted = y[-lag:]
        elif lag > 0:
            x_shifted = x[lag:]  # Y leads
            y_shifted = y[:-lag]
        else:
            x_shifted = x
            y_shifted = y
        
        mi_values[i] = compute_mutual_information(x_shifted, y_shifted, n_bins)
    
    return lags, mi_values


# Visualization 12: Time-lagged MI can reveal directionality

np.random.seed(42)
fs = 256
duration = 10
t = np.arange(0, duration, 1/fs)

# Create signals with X leading Y by ~20ms (5 samples at 256 Hz)
x = np.random.randn(len(t))
# Low-pass filter to create temporal structure
from scipy.ndimage import gaussian_filter1d
x = gaussian_filter1d(x, sigma=5)

# Y follows X with a delay
delay_samples = 5
y = np.zeros_like(x)
y[delay_samples:] = 0.8 * x[:-delay_samples] + 0.4 * np.random.randn(len(x) - delay_samples)

# Compute time-lagged MI
max_lag = 50  # ~200ms
lags, mi_lagged = compute_time_lagged_mi(x, y, max_lag, n_bins=20)
lags_ms = lags * 1000 / fs  # Convert to ms

# Find peak lag
peak_idx = np.argmax(mi_lagged)
peak_lag_ms = lags_ms[peak_idx]

fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Time series
axes[0].plot(t[:500], x[:500], color=COLORS["signal_1"], linewidth=1.5, label="X (driver)")
axes[0].plot(t[:500], y[:500], color=COLORS["signal_2"], linewidth=1.5, label="Y (follower)")
axes[0].set_xlabel("Time (s)", fontsize=12)
axes[0].set_ylabel("Amplitude", fontsize=12)
axes[0].set_title("X Drives Y with ~20ms Delay", fontsize=12, fontweight="bold")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Time-lagged MI
axes[1].plot(lags_ms, mi_lagged, color=COLORS["signal_3"], linewidth=2.5)
axes[1].axvline(x=0, color=COLORS["grid"], linestyle="--", alpha=0.7, label="Zero lag")
axes[1].axvline(x=peak_lag_ms, color="red", linestyle="-", linewidth=2, 
                label=f"Peak: {peak_lag_ms:.1f} ms")
axes[1].fill_between(lags_ms, mi_lagged, alpha=0.3, color=COLORS["signal_3"])
axes[1].set_xlabel("Lag (ms) ‚Äî Negative = X leads", fontsize=12)
axes[1].set_ylabel("MI (bits)", fontsize=12)
axes[1].set_title("Time-Lagged MI Reveals Directionality", fontsize=12, fontweight="bold")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

expected_delay = delay_samples * 1000 / fs
print(f"\\nüéØ Expected delay: {expected_delay:.1f} ms")
print(f"   Detected peak lag: {peak_lag_ms:.1f} ms")
print("\\nüí° The peak at NEGATIVE lag indicates X leads Y.")
print("   This is a simple form of 'directionality' analysis!")

---

## 12. MI Connectivity Matrix üîó

Just like we computed **entropy for multiple signals** in D01, we can compute MI between **all pairs** of signals to build a **connectivity matrix**.

This is essential for hyperscanning where we want to measure statistical dependencies between multiple EEG channels!

In [None]:
def compute_mi_matrix(
    signals: np.ndarray,
    n_bins: int = 20,
    normalize: bool = True
) -> np.ndarray:
    """
    Compute MI connectivity matrix for multiple signals.
    
    Parameters
    ----------
    signals : np.ndarray
        2D array of shape (n_channels, n_samples).
    n_bins : int
        Number of bins for histogram estimation.
    normalize : bool
        If True, normalize MI to [0, 1] range.
    
    Returns
    -------
    mi_matrix : np.ndarray
        Symmetric MI matrix of shape (n_channels, n_channels).
    """
    n_channels = signals.shape[0]
    mi_matrix = np.zeros((n_channels, n_channels))
    
    for i in range(n_channels):
        for j in range(i + 1, n_channels):
            mi = compute_mutual_information(signals[i], signals[j], n_bins)
            
            if normalize:
                # Normalized MI - compute_entropy_continuous returns (entropy, n_bins)
                h_i, _ = compute_entropy_continuous(signals[i], n_bins)
                h_j, _ = compute_entropy_continuous(signals[j], n_bins)
                if h_i > 0 and h_j > 0:
                    mi = 2 * mi / (h_i + h_j)
            
            mi_matrix[i, j] = mi
            mi_matrix[j, i] = mi
    
    return mi_matrix


# Visualization 13: MI connectivity matrix

np.random.seed(42)
n_channels = 8
n_samples = 2048

# Create signals with cluster structure
# Cluster 1: channels 0, 1, 2 (coupled)
# Cluster 2: channels 4, 5, 6 (coupled)
# Channels 3 and 7: independent

signals = np.random.randn(n_channels, n_samples)

# Add coupling within clusters
base_1 = np.random.randn(n_samples)
base_2 = np.random.randn(n_samples)

signals[0] += 2 * base_1
signals[1] += 2 * base_1 + 0.5 * np.random.randn(n_samples)
signals[2] += 2 * base_1 + 0.5 * np.random.randn(n_samples)

signals[4] += 2 * base_2
signals[5] += 2 * base_2 + 0.5 * np.random.randn(n_samples)
signals[6] += 2 * base_2 + 0.5 * np.random.randn(n_samples)

# Compute MI matrix
mi_matrix = compute_mi_matrix(signals, n_bins=20, normalize=True)

# Create labels
channel_labels = [f"Ch{i}" for i in range(n_channels)]

fig, ax = plt.subplots(figsize=(10, 8))

im = ax.imshow(mi_matrix, cmap="RdYlBu_r", vmin=0, vmax=1)
cbar = plt.colorbar(im, ax=ax, label="Normalized MI", shrink=0.8)

# Add text annotations
for i in range(n_channels):
    for j in range(n_channels):
        if i != j:
            text = ax.text(j, i, f"{mi_matrix[i, j]:.2f}",
                          ha="center", va="center", fontsize=9,
                          color="white" if mi_matrix[i, j] > 0.5 else "black")

ax.set_xticks(range(n_channels))
ax.set_yticks(range(n_channels))
ax.set_xticklabels(channel_labels)
ax.set_yticklabels(channel_labels)
ax.set_xlabel("Channel", fontsize=12)
ax.set_ylabel("Channel", fontsize=12)
ax.set_title("MI Connectivity Matrix ‚Äî Two Clusters Detected!", fontsize=14, fontweight="bold")

# Add cluster annotations
ax.add_patch(plt.Rectangle((-0.5, -0.5), 3, 3, fill=False, 
                           edgecolor=COLORS["signal_1"], linewidth=3, label="Cluster 1"))
ax.add_patch(plt.Rectangle((3.5, 3.5), 3, 3, fill=False, 
                           edgecolor=COLORS["signal_2"], linewidth=3, label="Cluster 2"))

plt.tight_layout()
plt.show()

print("\nüîó The MI matrix reveals the TRUE connectivity structure:")
print("   - Cluster 1: Ch0, Ch1, Ch2 (high within-cluster MI)")
print("   - Cluster 2: Ch4, Ch5, Ch6 (high within-cluster MI)")
print("   - Ch3 and Ch7: independent nodes")

---

## 13. Application: Hyperscanning Inter-Brain MI üß†‚ÜîÔ∏èüß†

In **hyperscanning**, we record EEG from **two or more people** simultaneously. MI can measure **information sharing** between their brain signals!

**Key insight**: Unlike correlation, MI can capture:
- Non-linear neural coupling
- Complex social interactions
- Implicit communication patterns

In [None]:
# Visualization 14: Inter-brain MI in hyperscanning scenario

np.random.seed(42)
fs = 256
duration = 60  # 60 seconds
t = np.arange(0, duration, 1/fs)
n_samples = len(t)

# Simulate EEG from two subjects (3 channels each)
# Scenario: Cooperative task with phases
# 0-20s: Independent (baseline)
# 20-40s: Cooperative task (inter-brain coupling)
# 40-60s: Independent (rest)

n_channels_per_subject = 3
channel_names = ["Fz", "Cz", "Pz"]

# Generate base signals (alpha oscillations ~10 Hz)
def generate_eeg_alpha(n_samples: int, fs: int) -> np.ndarray:
    """Generate simulated alpha band EEG."""
    t = np.arange(n_samples) / fs
    alpha = np.sin(2 * np.pi * 10 * t + np.random.uniform(0, 2*np.pi))
    noise = np.random.randn(n_samples) * 0.5
    return alpha + noise

# Subject 1 signals
subject1 = np.zeros((n_channels_per_subject, n_samples))
for ch in range(n_channels_per_subject):
    subject1[ch] = generate_eeg_alpha(n_samples, fs)

# Subject 2 signals - coupled during task phase
subject2 = np.zeros((n_channels_per_subject, n_samples))
for ch in range(n_channels_per_subject):
    # Independent phases
    subject2[ch, :20*fs] = generate_eeg_alpha(20*fs, fs)
    subject2[ch, 40*fs:] = generate_eeg_alpha(20*fs, fs)
    
    # Coupled phase - share some common information
    coupled_base = subject1[ch, 20*fs:40*fs]
    subject2[ch, 20*fs:40*fs] = (0.6 * coupled_base + 
                                  0.8 * generate_eeg_alpha(20*fs, fs))

# Compute inter-brain MI in sliding windows
window_samples = 5 * fs  # 5 second window
step_samples = 1 * fs    # 1 second step

def compute_interbrain_mi_timecourse(s1: np.ndarray, s2: np.ndarray, 
                                      window: int, step: int) -> Tuple[np.ndarray, np.ndarray]:
    """Compute mean inter-brain MI over time."""
    n_channels = s1.shape[0]
    n_windows = (s1.shape[1] - window) // step + 1
    
    times = np.zeros(n_windows)
    mi_timecourse = np.zeros(n_windows)
    
    for w in range(n_windows):
        start = w * step
        end = start + window
        times[w] = (start + end) / 2 / fs
        
        # Compute MI for all inter-brain pairs and average
        mi_sum = 0
        n_pairs = 0
        for i in range(n_channels):
            for j in range(n_channels):
                mi = compute_mutual_information(s1[i, start:end], 
                                               s2[j, start:end], n_bins=15)
                mi_sum += mi
                n_pairs += 1
        
        mi_timecourse[w] = mi_sum / n_pairs
    
    return times, mi_timecourse

times, mi_timecourse = compute_interbrain_mi_timecourse(subject1, subject2, 
                                                         window_samples, step_samples)

# Plot results
fig, axes = plt.subplots(3, 1, figsize=(14, 10))

# Subject 1 EEG (one channel)
ax = axes[0]
ax.plot(t, subject1[1], color=COLORS["signal_1"], linewidth=0.5, alpha=0.8)
ax.set_ylabel("Subject 1\\nCz", fontsize=12)
ax.set_title("Subject 1 EEG", fontsize=12, fontweight="bold")
ax.axvspan(20, 40, alpha=0.2, color=COLORS["signal_3"])
ax.set_xlim([0, 60])

# Subject 2 EEG (one channel)
ax = axes[1]
ax.plot(t, subject2[1], color=COLORS["signal_2"], linewidth=0.5, alpha=0.8)
ax.set_ylabel("Subject 2\\nCz", fontsize=12)
ax.set_title("Subject 2 EEG", fontsize=12, fontweight="bold")
ax.axvspan(20, 40, alpha=0.2, color=COLORS["signal_3"], label="Cooperative task")
ax.legend(loc="upper right")
ax.set_xlim([0, 60])

# Inter-brain MI
ax = axes[2]
ax.plot(times, mi_timecourse, color=COLORS["signal_3"], linewidth=2.5)
ax.fill_between(times, mi_timecourse, alpha=0.3, color=COLORS["signal_3"])
ax.axvspan(20, 40, alpha=0.2, color=COLORS["signal_3"])
ax.axhline(y=np.mean(mi_timecourse[:15]), color=COLORS["grid"], linestyle="--", 
           label="Baseline level")
ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("Mean Inter-Brain MI", fontsize=12)
ax.set_title("Inter-Brain MI Increases During Cooperation!", fontsize=12, fontweight="bold")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim([0, 60])

plt.suptitle("Hyperscanning: MI Detects Inter-Brain Coupling During Social Interaction",
             fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout()
plt.show()

# Summary statistics
mi_baseline = np.mean(np.concatenate([mi_timecourse[:15], mi_timecourse[-15:]]))
mi_task = np.mean(mi_timecourse[18:38])
increase = ((mi_task - mi_baseline) / mi_baseline) * 100

print(f"\\nüìä Inter-Brain MI Analysis:")
print(f"   Baseline MI: {mi_baseline:.4f} bits")
print(f"   Task MI:     {mi_task:.4f} bits")
print(f"   Increase:    {increase:.1f}%")
print("\\nüß† This demonstrates how MI can track inter-brain coupling during social tasks!")

---

## 14. Exercises üìù

Test your understanding of Mutual Information!

In [None]:
# Exercise 1: Relationship between MI and relationship type
# =========================================================
# Create three pairs of signals:
# 1. Linear relationship: Y = 2*X + noise
# 2. Quadratic relationship: Y = X^2 + noise
# 3. Circular relationship: Y = sin(X) + cos(X) + noise
#
# Compute MI and correlation for each. Which metric captures non-linear dependencies better?

np.random.seed(42)
n = 1000

x = np.random.randn(n)

# Linear
y_linear = 2 * x + 0.5 * np.random.randn(n)

# Quadratic
y_quadratic = x**2 + 0.5 * np.random.randn(n)

# Circular (use x in radians)
x_rad = np.random.uniform(-np.pi, np.pi, n)
y_circular = np.sin(x_rad) + np.cos(x_rad) + 0.3 * np.random.randn(n)

# Compute MI and correlation for each pair
results_ex1 = []

for name, x_sig, y_sig in [("Linear", x, y_linear), 
                            ("Quadratic", x, y_quadratic), 
                            ("Circular", x_rad, y_circular)]:
    mi = compute_mutual_information(x_sig, y_sig, n_bins=20)
    corr = np.abs(np.corrcoef(x_sig, y_sig)[0, 1])
    results_ex1.append({"Relationship": name, "MI": mi, "|Correlation|": corr})

# Display results
print("üìä Exercise 1: MI vs Correlation for Different Relationships")
print("=" * 60)
print(f"{'Relationship':<15} {'MI (bits)':<15} {'|Correlation|':<15}")
print("-" * 60)
for r in results_ex1:
    print(f"{r['Relationship']:<15} {r['MI']:<15.4f} {r['|Correlation|']:<15.4f}")
print("-" * 60)
print("\nüí° Key insight:")
print("   - Correlation captures LINEAR relationships well")
print("   - MI captures ALL relationships (linear AND nonlinear)")
print("   - Quadratic: correlation ‚âà 0, but MI is HIGH!")

In [None]:
# Exercise 2: Effect of binning on MI estimation
# ================================================
# Using two coupled signals, compute MI with different numbers of bins:
# [5, 10, 20, 50, 100, 200]
#
# Plot MI as a function of number of bins. What do you observe?

np.random.seed(42)
n = 1000
x = np.random.randn(n)
y = 0.5 * x + 0.87 * np.random.randn(n)  # True correlation = 0.5

n_bins_list = [5, 10, 20, 50, 100, 200]
mi_values_ex2 = []

for n_bins in n_bins_list:
    mi = compute_mutual_information(x, y, n_bins)
    mi_values_ex2.append(mi)

# Plot MI vs n_bins
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(n_bins_list, mi_values_ex2, "o-", color=COLORS["signal_1"], 
        linewidth=2, markersize=10)
ax.axhline(y=mi_values_ex2[2], color=COLORS["grid"], linestyle="--", 
           label=f"Reference (20 bins): {mi_values_ex2[2]:.4f}")

ax.set_xlabel("Number of Bins", fontsize=12)
ax.set_ylabel("Estimated MI (bits)", fontsize=12)
ax.set_title("Exercise 2: Effect of Binning on MI Estimation", fontsize=14, fontweight="bold")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xscale("log")

# Add annotations
for i, (nb, mi) in enumerate(zip(n_bins_list, mi_values_ex2)):
    ax.annotate(f"{mi:.3f}", (nb, mi), textcoords="offset points", 
                xytext=(0, 10), ha="center", fontsize=9)

plt.tight_layout()
plt.show()

print("\nüí° Key observations:")
print("   - Too few bins (5): UNDERESTIMATES MI (discretization too coarse)")
print("   - Too many bins (200): BIAS increases (sparse histogram)")
print("   - Sweet spot: ~20-50 bins for n=1000 samples")
print("\n   Rule of thumb: n_bins ‚âà ‚àö(n_samples) or Sturges' formula")

In [None]:
# Exercise 3: Conditional MI
# ============================
# Conditional MI measures: I(X; Y | Z) = H(X|Z) + H(Y|Z) - H(X,Y|Z)
# This tells us how much information X and Y share BEYOND what Z provides.
#
# Create three signals:
# - Z: Common driver signal
# - X = Z + noise_1
# - Y = Z + noise_2
#
# Compare I(X; Y) with the expected behavior when conditioning on Z.

np.random.seed(42)
n = 2000

z = np.random.randn(n)  # Common driver
x = z + 0.3 * np.random.randn(n)
y = z + 0.3 * np.random.randn(n)

# Unconditional MI between X and Y
mi_xy = compute_mutual_information(x, y, n_bins=20)

# Correlation between X, Y, and Z
corr_xy = np.corrcoef(x, y)[0, 1]
corr_xz = np.corrcoef(x, z)[0, 1]
corr_yz = np.corrcoef(y, z)[0, 1]

# For demonstration: compute MI after "regressing out" Z
# This is a simplified approximation of conditional MI
x_residual = x - np.polyval(np.polyfit(z, x, 1), z)
y_residual = y - np.polyval(np.polyfit(z, y, 1), z)
mi_xy_given_z_approx = compute_mutual_information(x_residual, y_residual, n_bins=20)

print("üìä Exercise 3: Conditional MI ‚Äî Detecting Spurious Correlations")
print("=" * 65)
print(f"\nCorrelations:")
print(f"   r(X, Y) = {corr_xy:.4f}  ‚Üê High! But is it genuine?")
print(f"   r(X, Z) = {corr_xz:.4f}  ‚Üê X follows Z")
print(f"   r(Y, Z) = {corr_yz:.4f}  ‚Üê Y follows Z")

print(f"\nMutual Information:")
print(f"   I(X; Y)     = {mi_xy:.4f} bits  ‚Üê Unconditional MI")
print(f"   I(X; Y | Z) ‚âà {mi_xy_given_z_approx:.4f} bits  ‚Üê After removing Z influence")

print("\nüí° Key insight:")
print("   X and Y appear highly dependent (high I(X;Y))")
print("   But this is because BOTH depend on Z!")
print("   After conditioning on Z, the dependency (almost) disappears.")
print("\nüîç This is the 'confounding variable' problem!")
print("   Conditional MI helps detect when apparent dependencies are spurious.")

---

## 15. Summary üìã

### Key Concepts

| Concept | Formula | Meaning |
|---------|---------|---------|
| **Joint Entropy** | $H(X,Y) = -\sum p(x,y) \log_2 p(x,y)$ | Total uncertainty of both variables |
| **Conditional Entropy** | $H(Y\|X) = H(X,Y) - H(X)$ | Uncertainty in Y given X |
| **Mutual Information** | $I(X;Y) = H(X) + H(Y) - H(X,Y)$ | Shared information |
| **Normalized MI** | $NMI = \frac{2 \cdot I(X;Y)}{H(X) + H(Y)}$ | Bounded [0, 1] |

### MI vs Correlation

| Property | Correlation | Mutual Information |
|----------|-------------|-------------------|
| Range | [-1, 1] | [0, ‚àû) |
| Linear relationships | ‚úì | ‚úì |
| Non-linear relationships | ‚úó | ‚úì |
| Interpretation | Direction + strength | Information shared |
| Estimation | Simple | Requires binning/KNN |

### Key Takeaways

1. **MI captures ALL dependencies**: Unlike correlation, MI detects any statistical relationship
2. **Symmetric but not directional**: $I(X;Y) = I(Y;X)$ ‚Äî use Transfer Entropy for directionality
3. **Estimation matters**: Too few bins ‚Üí underestimate, too many ‚Üí bias/variance issues
4. **Use surrogates**: Always validate significance with shuffled surrogates
5. **Time-varying MI**: Sliding windows reveal dynamic coupling changes
6. **Perfect for hyperscanning**: Captures complex inter-brain dependencies

---

## 16. Discussion & Next Steps üöÄ

### Discussion Questions

1. **Why might MI be preferred over correlation for EEG analysis?**
   - Neural communication often involves non-linear dynamics
   - Phase-amplitude coupling is inherently non-linear
   - Information theory provides interpretable units (bits)

2. **What are the limitations of histogram-based MI estimation?**
   - Curse of dimensionality for multivariate data
   - Bin size selection is somewhat arbitrary
   - May require many samples for reliable estimates

3. **How does MI relate to other connectivity metrics?**
   - Coherence: captures linear frequency-specific dependencies
   - Phase-Locking Value: captures phase synchronization
   - MI: captures all statistical dependencies

### Next Steps

In the next notebook (**D03 - Transfer Entropy**), we'll learn:
- How to measure **directed** information flow
- The concept of **causal** coupling
- Applications to detecting leader-follower dynamics in hyperscanning

### Further Reading

- Cover, T. M., & Thomas, J. A. (2006). *Elements of Information Theory*
- Kraskov, A., et al. (2004). Estimating mutual information. *Physical Review E*
- Jeong, J., et al. (2001). Mutual information analysis of EEG. *Clinical Neurophysiology*

---

**Estimated time**: 70 minutes

**Prerequisites completed**: D01 (Entropy and Information)

**Next notebook**: D03 - Transfer Entropy