In [None]:
# ============================================================================
# IMPORTS AND SETUP
# ============================================================================
import sys
from pathlib import Path
from typing import Tuple, Optional, List, Dict, Any

import numpy as np
from numpy.typing import NDArray
import matplotlib.pyplot as plt
import scipy.signal
from scipy.signal import butter, filtfilt

# Add src to path for local imports
sys.path.insert(0, str(Path.cwd().parents[2]))

from src.colors import COLORS

# Color aliases
PRIMARY_BLUE = COLORS["signal_1"]      # Sky Blue
PRIMARY_RED = COLORS["signal_2"]       # Rose Pink
PRIMARY_GREEN = COLORS["signal_3"]     # Sage Green
SECONDARY_ORANGE = COLORS["signal_4"]  # Golden
SECONDARY_PURPLE = COLORS["high_sync"] # Purple
SUBJECT_1 = COLORS["signal_1"]         # For hyperscanning
SUBJECT_2 = COLORS["signal_2"]         # For hyperscanning

# Sampling frequency
fs = 256  # Hz

# Random seed for reproducibility
np.random.seed(42)

print("‚úì Imports successful!")
print(f"NumPy version: {np.__version__}")

---

# C02: Connectivity Matrices

## From Pairs to Networks

**Duration**: ~55 minutes

**Prerequisites**: C01 (Volume Conduction), B02 (Working with Phase)

---

## Learning Objectives

By the end of this notebook, you will be able to:

1. üß© Understand connectivity matrices as organized pairwise measurements
2. üîß Construct connectivity matrices from multi-channel data
3. üìä Visualize connectivity with heatmaps and circular plots
4. üß† Handle hyperscanning matrices (within + between participants)
5. üìà Extract network-level summary statistics

---

## 1. Introduction ‚Äî From Pairs to Networks

So far in this workshop, we've focused on connectivity between **two signals** ‚Äî a single pair of electrodes. We computed PLV, coherence, or correlation between signal A and signal B.

But real EEG data has **many channels**:
- Clinical EEG: 19-21 electrodes
- Research EEG: 32, 64, or 128+ electrodes
- **Hyperscanning**: 2 participants √ó n electrodes = even more!

To understand brain networks, we need to analyze **all pairs systematically**. This is where **connectivity matrices** come in.

### What is a Connectivity Matrix?

A connectivity matrix is simply an organized way to store pairwise connectivity values:

- **Rows and columns** represent channels (electrodes)
- **Entry (i, j)** contains the connectivity between channel i and channel j
- The result is a **square matrix**: n_channels √ó n_channels

This organized structure is the foundation for:
- **Network neuroscience** and graph theory analysis
- Comparing connectivity across conditions, participants, or groups
- Identifying "hubs" (highly connected regions)
- Statistical analysis of connectivity patterns

---

## 2. Anatomy of a Connectivity Matrix

Let's understand the structure of a connectivity matrix in detail:

### Key Properties

| Property | Description |
|----------|-------------|
| **Shape** | n_channels √ó n_channels (square) |
| **Diagonal** | M[i, i] = self-connectivity (often NaN or 1) |
| **Symmetry** | For undirected metrics: M[i, j] = M[j, i] |
| **Values** | Depend on metric (0-1 for PLV, -1 to +1 for correlation) |

### Symmetric vs Asymmetric

- **Symmetric metrics** (most common): PLV, coherence, correlation
  - "Connectivity from A to B" = "Connectivity from B to A"
  - Matrix is symmetric: M = M.T
  
- **Asymmetric metrics** (directed): Granger causality, transfer entropy
  - "A causes B" ‚â† "B causes A"
  - Matrix is NOT symmetric

### The Diagonal

The diagonal entries M[i, i] represent "self-connectivity" ‚Äî the connectivity of a channel with itself. This is:
- **Meaningless** for most metrics (PLV of a signal with itself = 1 always)
- Usually set to **NaN** to exclude from analysis
- Sometimes set to **0** or **1** depending on convention

In [None]:
# ============================================================================
# VISUALIZATION 1: Structure of a Connectivity Matrix
# ============================================================================

fig, ax = plt.subplots(figsize=(10, 8))

# Create example matrix
channel_names = ['F3', 'F4', 'C3', 'C4', 'P3', 'P4']
n = len(channel_names)

# Generate example values (symmetric)
np.random.seed(42)
example_matrix = np.random.uniform(0.2, 0.8, (n, n))
example_matrix = (example_matrix + example_matrix.T) / 2  # Make symmetric
np.fill_diagonal(example_matrix, np.nan)  # Diagonal = NaN

# Plot heatmap
im = ax.imshow(example_matrix, cmap='viridis', vmin=0, vmax=1)

# Add colorbar
cbar = plt.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label('Connectivity (PLV)', fontsize=11)

# Labels
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(channel_names, fontsize=11)
ax.set_yticklabels(channel_names, fontsize=11)
ax.set_xlabel('Channel j', fontsize=12)
ax.set_ylabel('Channel i', fontsize=12)
ax.set_title('Structure of a Connectivity Matrix', fontsize=14, fontweight='bold')

# Annotate values
for i in range(n):
    for j in range(n):
        if i == j:
            ax.text(j, i, 'NaN', ha='center', va='center', fontsize=9, color='white')
        else:
            ax.text(j, i, f'{example_matrix[i, j]:.2f}', ha='center', va='center', 
                    fontsize=9, color='white' if example_matrix[i, j] > 0.5 else 'black')

# Highlight diagonal
for i in range(n):
    rect = plt.Rectangle((i-0.5, i-0.5), 1, 1, fill=False, 
                          edgecolor='white', linewidth=2, linestyle='--')
    ax.add_patch(rect)

# Add annotations
ax.annotate('Diagonal\n(self-connectivity)', xy=(0, 0), xytext=(-2.5, 1),
            fontsize=10, ha='center',
            arrowprops=dict(arrowstyle='->', color='white', lw=1.5),
            color='black', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

ax.annotate('M[i,j] = M[j,i]\n(symmetric)', xy=(4, 1), xytext=(7, 0),
            fontsize=10, ha='center',
            arrowprops=dict(arrowstyle='->', color='black', lw=1.5),
            color='black', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

print(f"Matrix shape: {example_matrix.shape}")
print(f"Number of channels: {n}")
print(f"Number of unique pairs: {n * (n - 1) // 2}")

---

## 3. Computing a Connectivity Matrix

Now let's implement the algorithm to compute a connectivity matrix from multi-channel data.

### The Algorithm

1. **Initialize** an n √ó n matrix with NaN or zeros
2. **Loop** over all unique pairs (i, j) where i < j
3. **Compute** the connectivity metric for each pair
4. **Fill** both M[i, j] and M[j, i] (for symmetric metrics)
5. **Set** diagonal to NaN

### Efficiency Consideration

The number of unique pairs grows quickly:

| Channels | Unique Pairs | Formula |
|----------|--------------|--------|
| 6 | 15 | 6√ó5/2 |
| 19 | 171 | 19√ó18/2 |
| 64 | 2,016 | 64√ó63/2 |
| 128 | 8,128 | 128√ó127/2 |

For hyperscanning (2 √ó 64 channels), that's **8,128** pairs just for between-participant connectivity!

In [None]:
# ============================================================================
# FUNCTIONS 1-3: Helper Functions for Connectivity Matrix
# ============================================================================

def get_n_pairs(n_channels: int) -> int:
    """
    Calculate the number of unique channel pairs.
    
    Parameters
    ----------
    n_channels : int
        Number of channels.
        
    Returns
    -------
    int
        Number of unique pairs: n(n-1)/2
    """
    return n_channels * (n_channels - 1) // 2


def get_pair_indices(n_channels: int) -> List[Tuple[int, int]]:
    """
    Get list of all unique channel pair indices.
    
    Parameters
    ----------
    n_channels : int
        Number of channels.
        
    Returns
    -------
    List[Tuple[int, int]]
        List of (i, j) tuples where i < j.
    """
    pairs = []
    for i in range(n_channels):
        for j in range(i + 1, n_channels):
            pairs.append((i, j))
    return pairs


def compute_plv_pair(
    signal_1: NDArray[np.floating],
    signal_2: NDArray[np.floating]
) -> float:
    """
    Compute Phase Locking Value between two signals.
    
    Parameters
    ----------
    signal_1 : NDArray[np.floating]
        First signal.
    signal_2 : NDArray[np.floating]
        Second signal.
        
    Returns
    -------
    float
        PLV value between 0 and 1.
    """
    # Get instantaneous phases
    analytic_1 = scipy.signal.hilbert(signal_1)
    analytic_2 = scipy.signal.hilbert(signal_2)
    
    phase_1 = np.angle(analytic_1)
    phase_2 = np.angle(analytic_2)
    
    # Compute phase difference
    phase_diff = phase_1 - phase_2
    
    # PLV = |mean(exp(i * phase_diff))|
    plv = np.abs(np.mean(np.exp(1j * phase_diff)))
    
    return float(plv)


print("Helper functions defined:")
print(f"‚Ä¢ get_n_pairs(n_channels) ‚Üí number of unique pairs")
print(f"‚Ä¢ get_pair_indices(n_channels) ‚Üí list of (i, j) pairs")
print(f"‚Ä¢ compute_plv_pair(signal_1, signal_2) ‚Üí PLV value")
print()
print(f"Example: 6 channels ‚Üí {get_n_pairs(6)} pairs")
print(f"Example: 64 channels ‚Üí {get_n_pairs(64)} pairs")

In [None]:
# ============================================================================
# FUNCTION 4: Compute Full Connectivity Matrix
# ============================================================================

def bandpass_filter(
    data: NDArray[np.floating],
    lowcut: float,
    highcut: float,
    fs: int,
    order: int = 4
) -> NDArray[np.floating]:
    """
    Apply bandpass filter to data.
    
    Parameters
    ----------
    data : NDArray[np.floating]
        Input data, shape (n_channels, n_samples) or (n_samples,).
    lowcut : float
        Low cutoff frequency in Hz.
    highcut : float
        High cutoff frequency in Hz.
    fs : int
        Sampling frequency in Hz.
    order : int, optional
        Filter order. Default is 4.
        
    Returns
    -------
    NDArray[np.floating]
        Filtered data, same shape as input.
    """
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    
    if data.ndim == 1:
        return filtfilt(b, a, data)
    else:
        return np.array([filtfilt(b, a, ch) for ch in data])


def compute_connectivity_matrix(
    data: NDArray[np.floating],
    fs: int,
    band: Tuple[float, float],
    metric: str = "plv"
) -> NDArray[np.floating]:
    """
    Compute pairwise connectivity matrix from multi-channel data.
    
    Parameters
    ----------
    data : NDArray[np.floating]
        Multi-channel data, shape (n_channels, n_samples).
    fs : int
        Sampling frequency in Hz.
    band : Tuple[float, float]
        Frequency band (low, high) in Hz.
    metric : str, optional
        Connectivity metric. Currently only "plv" supported. Default is "plv".
        
    Returns
    -------
    NDArray[np.floating]
        Connectivity matrix, shape (n_channels, n_channels).
        Diagonal is NaN. Matrix is symmetric for PLV.
    """
    n_channels = data.shape[0]
    
    # Filter data to frequency band
    data_filtered = bandpass_filter(data, band[0], band[1], fs)
    
    # Initialize matrix with NaN
    matrix = np.full((n_channels, n_channels), np.nan)
    
    # Compute connectivity for all unique pairs
    pairs = get_pair_indices(n_channels)
    
    for i, j in pairs:
        if metric == "plv":
            value = compute_plv_pair(data_filtered[i], data_filtered[j])
        else:
            raise ValueError(f"Unknown metric: {metric}")
        
        # Fill both entries (symmetric)
        matrix[i, j] = value
        matrix[j, i] = value
    
    return matrix


print("Main function defined:")
print("‚Ä¢ compute_connectivity_matrix(data, fs, band, metric) ‚Üí n√ón matrix")

In [None]:
# ============================================================================
# VISUALIZATION 2: Compute and Display a Real Connectivity Matrix
# ============================================================================

# Generate synthetic 6-channel data with realistic connectivity structure
np.random.seed(42)
duration = 5.0  # seconds
n_samples = int(duration * fs)
t = np.arange(n_samples) / fs

# Create 3 INDEPENDENT sources with different random phase dynamics
# This ensures cross-cluster PLV is LOW (no shared source)
source_1 = np.sin(2 * np.pi * 10 * t + np.cumsum(0.1 * np.random.randn(n_samples)))
source_2 = np.sin(2 * np.pi * 10 * t + np.cumsum(0.1 * np.random.randn(n_samples)))
source_3 = np.sin(2 * np.pi * 10 * t + np.cumsum(0.1 * np.random.randn(n_samples)))

# Mix sources into 6 channels
# Channels 0-1: share source_1 (high PLV expected)
# Channels 2-3: share source_2 (high PLV expected)
# Channels 4-5: share source_3 (high PLV expected)
# Cross-cluster: independent sources ‚Üí low PLV
noise_level = 0.5
data = np.zeros((6, n_samples))
data[0] = source_1 + noise_level * np.random.randn(n_samples)
data[1] = source_1 + noise_level * np.random.randn(n_samples)
data[2] = source_2 + noise_level * np.random.randn(n_samples)
data[3] = source_2 + noise_level * np.random.randn(n_samples)
data[4] = source_3 + noise_level * np.random.randn(n_samples)
data[5] = source_3 + noise_level * np.random.randn(n_samples)

# Compute connectivity matrix in alpha band
alpha_band = (8, 13)
conn_matrix = compute_connectivity_matrix(data, fs, alpha_band, metric="plv")

# Plot
fig, ax = plt.subplots(figsize=(9, 7))

channel_names = ['Ch1', 'Ch2', 'Ch3', 'Ch4', 'Ch5', 'Ch6']
im = ax.imshow(conn_matrix, cmap='viridis', vmin=0, vmax=1)

cbar = plt.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label('PLV', fontsize=11)

ax.set_xticks(range(6))
ax.set_yticks(range(6))
ax.set_xticklabels(channel_names, fontsize=11)
ax.set_yticklabels(channel_names, fontsize=11)
ax.set_xlabel('Channel', fontsize=12)
ax.set_ylabel('Channel', fontsize=12)
ax.set_title('PLV Connectivity Matrix (Alpha Band: 8-13 Hz)', fontsize=14, fontweight='bold')

# Annotate values
for i in range(6):
    for j in range(6):
        if i == j:
            ax.text(j, i, 'NaN', ha='center', va='center', fontsize=9, color='white')
        else:
            val = conn_matrix[i, j]
            ax.text(j, i, f'{val:.2f}', ha='center', va='center', 
                    fontsize=9, color='white' if val > 0.5 else 'black')

# Highlight clusters with boxes
for start in [0, 2, 4]:
    rect = plt.Rectangle((start-0.5, start-0.5), 2, 2, fill=False,
                          edgecolor=PRIMARY_GREEN, linewidth=3)
    ax.add_patch(rect)

ax.text(6.5, 0.5, 'Cluster 1\n(Ch1-Ch2)', fontsize=10, color=PRIMARY_GREEN, va='center')
ax.text(6.5, 2.5, 'Cluster 2\n(Ch3-Ch4)', fontsize=10, color=PRIMARY_GREEN, va='center')
ax.text(6.5, 4.5, 'Cluster 3\n(Ch5-Ch6)', fontsize=10, color=PRIMARY_GREEN, va='center')

plt.tight_layout()
plt.show()

print("Channels sharing the same source show high PLV (bright yellow)!")
print("Cross-cluster connectivity is lower (darker colors).")

---

## 4. Visualizing Connectivity Matrices

The heatmap is the most common way to visualize connectivity matrices. Let's create a reusable plotting function with all the best practices.

In [None]:
# ============================================================================
# FUNCTION 5: Plot Connectivity Matrix
# ============================================================================

def plot_connectivity_matrix(
    matrix: NDArray[np.floating],
    channel_names: Optional[List[str]] = None,
    ax: Optional[plt.Axes] = None,
    cmap: str = "viridis",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    mask_diagonal: bool = True,
    title: Optional[str] = None,
    show_values: bool = False
) -> plt.Axes:
    """
    Plot connectivity matrix as a heatmap.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix, shape (n_channels, n_channels).
    channel_names : Optional[List[str]], optional
        Channel labels. Default is None (uses indices).
    ax : Optional[plt.Axes], optional
        Matplotlib axes. If None, creates new figure.
    cmap : str, optional
        Colormap name. Default is "viridis".
    vmin : Optional[float], optional
        Minimum value for colormap. Default is None (auto).
    vmax : Optional[float], optional
        Maximum value for colormap. Default is None (auto).
    mask_diagonal : bool, optional
        Whether to mask diagonal with gray. Default is True.
    title : Optional[str], optional
        Plot title. Default is None.
    show_values : bool, optional
        Whether to annotate cells with values. Default is False.
        
    Returns
    -------
    plt.Axes
        The matplotlib axes with the plot.
    """
    n_channels = matrix.shape[0]
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 7))
    
    if channel_names is None:
        channel_names = [str(i) for i in range(n_channels)]
    
    # Create masked array for diagonal
    plot_matrix = matrix.copy()
    if mask_diagonal:
        np.fill_diagonal(plot_matrix, np.nan)
    
    # Plot heatmap
    im = ax.imshow(plot_matrix, cmap=cmap, vmin=vmin, vmax=vmax, aspect='equal')
    
    # Colorbar
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label('Connectivity', fontsize=11)
    
    # Labels
    ax.set_xticks(range(n_channels))
    ax.set_yticks(range(n_channels))
    ax.set_xticklabels(channel_names, fontsize=10, rotation=45, ha='right')
    ax.set_yticklabels(channel_names, fontsize=10)
    
    if title:
        ax.set_title(title, fontsize=13, fontweight='bold')
    
    # Show values if requested
    if show_values and n_channels <= 10:
        for i in range(n_channels):
            for j in range(n_channels):
                if not np.isnan(plot_matrix[i, j]):
                    val = plot_matrix[i, j]
                    color = 'white' if val > (vmax or 0.5) / 2 else 'black'
                    ax.text(j, i, f'{val:.2f}', ha='center', va='center',
                            fontsize=8, color=color)
    
    return ax


print("Function defined: plot_connectivity_matrix()")

In [None]:
# ============================================================================
# VISUALIZATION 3: Colormap Comparison
# ============================================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

cmaps = ['viridis', 'plasma', 'RdBu_r', 'cividis']
titles = [
    'Viridis (sequential)',
    'Plasma (sequential)', 
    'RdBu (diverging - for signed metrics)',
    'Cividis (colorblind-friendly)'
]

for ax, cmap, title in zip(axes.flat, cmaps, titles):
    plot_connectivity_matrix(conn_matrix, channel_names, ax=ax, 
                            cmap=cmap, vmin=0, vmax=1, title=title)

plt.tight_layout()
plt.show()

print("Tips for choosing colormaps:")
print("‚Ä¢ Sequential (viridis, plasma): for metrics in [0, 1] like PLV")
print("‚Ä¢ Diverging (RdBu): for signed metrics like correlation [-1, 1]")
print("‚Ä¢ Cividis: accessible for colorblind viewers")

---

## 5. Circular Connectivity Plots

Heatmaps show all values but can hide **network structure**. Circular (chord) plots offer an alternative view:

- Channels arranged in a **circle**
- **Lines** connect pairs with significant connectivity
- Line **thickness/color** indicates strength
- Better for seeing **hubs** (highly connected nodes) and **patterns**

In [None]:
# ============================================================================
# FUNCTION 6: Circular Connectivity Plot
# ============================================================================

def plot_circular_connectivity(
    matrix: NDArray[np.floating],
    channel_names: List[str],
    threshold: Optional[float] = None,
    ax: Optional[plt.Axes] = None,
    linewidth_scale: float = 3.0,
    node_colors: Optional[List[str]] = None,
    title: Optional[str] = None
) -> plt.Axes:
    """
    Plot connectivity as a circular graph.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix, shape (n_channels, n_channels).
    channel_names : List[str]
        Channel labels.
    threshold : Optional[float], optional
        Only show connections above this value. Default is None (show all).
    ax : Optional[plt.Axes], optional
        Matplotlib axes. If None, creates new figure.
    linewidth_scale : float, optional
        Scale factor for line width. Default is 3.0.
    node_colors : Optional[List[str]], optional
        Colors for each node. Default is None (uses primary blue).
    title : Optional[str], optional
        Plot title. Default is None.
        
    Returns
    -------
    plt.Axes
        The matplotlib axes with the plot.
    """
    n_channels = len(channel_names)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': 'polar'})
    
    if node_colors is None:
        node_colors = [PRIMARY_BLUE] * n_channels
    
    # Calculate node positions (evenly spaced around circle)
    angles = np.linspace(0, 2 * np.pi, n_channels, endpoint=False)
    
    # Plot nodes
    for i, (angle, name, color) in enumerate(zip(angles, channel_names, node_colors)):
        ax.scatter(angle, 1, s=300, c=color, zorder=5, edgecolors='white', linewidths=2)
        # Label outside the circle
        label_angle = angle
        ha = 'left' if 0 <= angle < np.pi else 'right'
        ax.text(angle, 1.15, name, ha='center', va='center', fontsize=11, fontweight='bold')
    
    # Plot connections (two passes: weak in grey, strong in color)
    for i in range(n_channels):
        for j in range(i + 1, n_channels):
            value = matrix[i, j]
            if np.isnan(value):
                continue
            
            # Draw arc between nodes
            angle_i, angle_j = angles[i], angles[j]
            
            # Create arc using bezier-like curve
            n_points = 50
            t_vals = np.linspace(0, 1, n_points)
            
            # Control point at center (r=0)
            r_vals = 1 - 0.5 * np.sin(np.pi * t_vals)  # Curve inward
            angle_vals = angle_i + t_vals * (angle_j - angle_i)
            
            # Adjust for shortest path
            if abs(angle_j - angle_i) > np.pi:
                if angle_j > angle_i:
                    angle_vals = angle_i + t_vals * (angle_j - 2*np.pi - angle_i)
                else:
                    angle_vals = angle_i + t_vals * (angle_j + 2*np.pi - angle_i)
            
            # Determine if connection is strong (above threshold)
            is_strong = threshold is None or value >= threshold
            
            if is_strong:
                # Strong connections: colored with variable width/alpha
                lw = value * linewidth_scale
                alpha = 0.3 + 0.7 * value
                color = SECONDARY_PURPLE
                zorder = 2
            else:
                # Weak connections: light grey, thin, subtle
                lw = 0.8
                alpha = 0.3
                color = '#CCCCCC'
                zorder = 1
            
            ax.plot(angle_vals, r_vals, color=color, 
                   linewidth=lw, alpha=alpha, zorder=zorder)
    
    # Clean up polar plot
    ax.set_ylim(0, 1.3)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.spines['polar'].set_visible(False)
    
    if title:
        ax.set_title(title, fontsize=13, fontweight='bold', pad=20)
    
    return ax


print("Function defined: plot_circular_connectivity()")

In [None]:
# ============================================================================
# VISUALIZATION 4: Circular Plot Example
# ============================================================================

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': 'polar'})

# Color by "hemisphere" (simulated)
node_colors = [PRIMARY_BLUE, PRIMARY_BLUE, PRIMARY_RED, PRIMARY_RED, PRIMARY_GREEN, PRIMARY_GREEN]

plot_circular_connectivity(
    conn_matrix, 
    channel_names,
    threshold=0.5,  # Only show strong connections
    ax=ax,
    node_colors=node_colors,
    title='Circular Connectivity Plot (threshold > 0.5)'
)

plt.tight_layout()
plt.show()

print("Strong connections (PLV > 0.5) are visible as arcs.")
print("Node colors could represent brain regions or hemispheres.")

In [None]:
# ============================================================================
# VISUALIZATION 5: Heatmap vs Circular - Side by Side
# ============================================================================

fig = plt.figure(figsize=(16, 7))

# Left: Heatmap
ax1 = fig.add_subplot(121)
plot_connectivity_matrix(conn_matrix, channel_names, ax=ax1, 
                        vmin=0, vmax=1, title='Heatmap View', show_values=True)

# Right: Circular
ax2 = fig.add_subplot(122, projection='polar')
node_colors = [PRIMARY_BLUE, PRIMARY_BLUE, PRIMARY_RED, PRIMARY_RED, PRIMARY_GREEN, PRIMARY_GREEN]
plot_circular_connectivity(conn_matrix, channel_names, threshold=0.5,
                          ax=ax2, node_colors=node_colors, title='Circular View')

plt.tight_layout()
plt.show()

print("Choose your visualization based on your message:")
print("‚Ä¢ Heatmap: Show exact values, compare all pairs")
print("‚Ä¢ Circular: Show network structure, identify clusters")

---

## 6. Matrix Validation and Sanity Checks

Before analyzing a connectivity matrix, always validate it! Common issues include:

- **Broken symmetry** (for symmetric metrics)
- **Values out of range** (e.g., PLV > 1)
- **Unexpected NaN values** (besides diagonal)
- **All zeros or all ones** (computation problem)

In [None]:
# ============================================================================
# FUNCTIONS 7-8: Validation and Statistics
# ============================================================================

def validate_connectivity_matrix(
    matrix: NDArray[np.floating],
    metric: str = "plv",
    tolerance: float = 1e-10
) -> Dict[str, Any]:
    """
    Validate connectivity matrix properties.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix to validate.
    metric : str, optional
        Expected metric type. Default is "plv".
    tolerance : float, optional
        Tolerance for symmetry check. Default is 1e-10.
        
    Returns
    -------
    Dict[str, Any]
        Validation results with keys:
        - is_square: bool
        - is_symmetric: bool
        - in_range: bool
        - diagonal_is_nan: bool
        - has_unexpected_nan: bool
        - issues: List[str]
    """
    results = {
        "is_square": False,
        "is_symmetric": False,
        "in_range": False,
        "diagonal_is_nan": False,
        "has_unexpected_nan": False,
        "issues": []
    }
    
    # Check square
    if matrix.shape[0] != matrix.shape[1]:
        results["issues"].append(f"Matrix is not square: {matrix.shape}")
        return results
    results["is_square"] = True
    
    n = matrix.shape[0]
    
    # Check symmetry (ignoring NaN diagonal)
    matrix_no_diag = matrix.copy()
    np.fill_diagonal(matrix_no_diag, 0)
    matrix_t_no_diag = matrix_no_diag.T
    
    if np.allclose(matrix_no_diag, matrix_t_no_diag, atol=tolerance, equal_nan=True):
        results["is_symmetric"] = True
    else:
        results["issues"].append("Matrix is not symmetric")
    
    # Check diagonal is NaN
    diagonal = np.diag(matrix)
    if np.all(np.isnan(diagonal)):
        results["diagonal_is_nan"] = True
    else:
        results["issues"].append("Diagonal contains non-NaN values")
    
    # Check value range
    off_diag = matrix[~np.eye(n, dtype=bool)]
    off_diag_valid = off_diag[~np.isnan(off_diag)]
    
    if metric == "plv":
        expected_range = (0, 1)
    elif metric == "correlation":
        expected_range = (-1, 1)
    else:
        expected_range = (-np.inf, np.inf)
    
    if len(off_diag_valid) > 0:
        if np.min(off_diag_valid) >= expected_range[0] and np.max(off_diag_valid) <= expected_range[1]:
            results["in_range"] = True
        else:
            results["issues"].append(f"Values out of range {expected_range}: [{np.min(off_diag_valid):.3f}, {np.max(off_diag_valid):.3f}]")
    
    # Check for unexpected NaN
    n_expected_nan = n  # diagonal
    n_actual_nan = np.sum(np.isnan(matrix))
    if n_actual_nan > n_expected_nan:
        results["has_unexpected_nan"] = True
        results["issues"].append(f"Found {n_actual_nan - n_expected_nan} unexpected NaN values")
    
    return results


def get_matrix_statistics(
    matrix: NDArray[np.floating],
    exclude_diagonal: bool = True
) -> Dict[str, float]:
    """
    Compute summary statistics of connectivity matrix.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix.
    exclude_diagonal : bool, optional
        Whether to exclude diagonal from statistics. Default is True.
        
    Returns
    -------
    Dict[str, float]
        Statistics: mean, std, min, max, median.
    """
    if exclude_diagonal:
        n = matrix.shape[0]
        values = matrix[~np.eye(n, dtype=bool)]
    else:
        values = matrix.flatten()
    
    # Remove NaN
    values = values[~np.isnan(values)]
    
    return {
        "mean": float(np.mean(values)),
        "std": float(np.std(values)),
        "min": float(np.min(values)),
        "max": float(np.max(values)),
        "median": float(np.median(values)),
        "n_values": len(values)
    }


print("Validation functions defined:")
print("‚Ä¢ validate_connectivity_matrix(matrix, metric) ‚Üí validation report")
print("‚Ä¢ get_matrix_statistics(matrix) ‚Üí summary stats")

In [None]:
# ============================================================================
# VISUALIZATION 6: Validation Demo
# ============================================================================

# Validate our computed matrix
print("=" * 60)
print("Validating our connectivity matrix...")
print("=" * 60)

validation = validate_connectivity_matrix(conn_matrix, metric="plv")
stats = get_matrix_statistics(conn_matrix)

print("\nValidation Results:")
print(f"  ‚úì Is square: {validation['is_square']}")
print(f"  ‚úì Is symmetric: {validation['is_symmetric']}")
print(f"  ‚úì Values in range [0,1]: {validation['in_range']}")
print(f"  ‚úì Diagonal is NaN: {validation['diagonal_is_nan']}")

if validation['issues']:
    print(f"\n  ‚ö† Issues found: {validation['issues']}")
else:
    print(f"\n  ‚úì No issues found!")

print("\nMatrix Statistics:")
print(f"  Mean connectivity: {stats['mean']:.3f}")
print(f"  Std deviation: {stats['std']:.3f}")
print(f"  Range: [{stats['min']:.3f}, {stats['max']:.3f}]")
print(f"  Median: {stats['median']:.3f}")
print(f"  Number of values: {stats['n_values']}")

## Section 7: Extracting Upper Triangle

For **symmetric matrices** (like PLV), half the values are redundant:
- $M[i,j] = M[j,i]$ for all pairs
- The **upper triangle** (where $i < j$) contains all unique information
- The **diagonal** is typically NaN (self-connectivity is meaningless)

**Why extract the upper triangle?**

1. **Avoid double-counting** in statistics (mean, correlation with behavior, etc.)
2. **Save memory** when storing many matrices
3. **Simplify comparisons** between conditions
4. **Required format** for many statistical tests

NumPy provides convenient functions:
- `np.triu_indices(n, k=1)` ‚Äî indices of upper triangle (k=1 excludes diagonal)
- `np.triu(matrix, k=1)` ‚Äî upper triangle with zeros elsewhere

In [None]:
# ============================================================================
# FUNCTION 8: Get Upper Triangle Values
# ============================================================================

def get_upper_triangle_values(
    matrix: NDArray[np.floating],
    k: int = 1
) -> NDArray[np.floating]:
    """
    Extract upper triangle values from a matrix.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Square matrix, shape (n, n).
    k : int, optional
        Diagonal offset. k=1 excludes the main diagonal (default).
        k=0 includes the diagonal.
        
    Returns
    -------
    NDArray[np.floating]
        1D array of upper triangle values.
        
    Notes
    -----
    For a symmetric matrix, this extracts all unique values.
    Number of values = n(n-1)/2 when k=1.
    """
    n = matrix.shape[0]
    indices = np.triu_indices(n, k=k)
    return matrix[indices]


def upper_triangle_to_matrix(
    values: NDArray[np.floating],
    n_channels: int,
    fill_diagonal: float = np.nan
) -> NDArray[np.floating]:
    """
    Reconstruct symmetric matrix from upper triangle values.
    
    Parameters
    ----------
    values : NDArray[np.floating]
        1D array of upper triangle values.
    n_channels : int
        Number of channels (matrix will be n_channels √ó n_channels).
    fill_diagonal : float, optional
        Value to fill diagonal. Default is NaN.
        
    Returns
    -------
    NDArray[np.floating]
        Symmetric matrix, shape (n_channels, n_channels).
        
    Raises
    ------
    ValueError
        If number of values doesn't match expected n(n-1)/2.
    """
    expected_n_values = n_channels * (n_channels - 1) // 2
    if len(values) != expected_n_values:
        raise ValueError(
            f"Expected {expected_n_values} values for {n_channels} channels, "
            f"got {len(values)}"
        )
    
    # Create empty matrix
    matrix = np.zeros((n_channels, n_channels))
    
    # Fill upper triangle
    indices = np.triu_indices(n_channels, k=1)
    matrix[indices] = values
    
    # Make symmetric
    matrix = matrix + matrix.T
    
    # Fill diagonal
    np.fill_diagonal(matrix, fill_diagonal)
    
    return matrix


print("Upper triangle functions defined:")
print("‚Ä¢ get_upper_triangle_values(matrix, k) ‚Üí 1D array of unique values")
print("‚Ä¢ upper_triangle_to_matrix(values, n_channels) ‚Üí reconstructed matrix")

In [None]:
# ============================================================================
# VISUALIZATION 7: Upper Triangle Extraction Demo
# ============================================================================

# Use our computed connectivity matrix
print("=" * 60)
print("Upper Triangle Extraction")
print("=" * 60)

# Extract upper triangle
upper_values = get_upper_triangle_values(conn_matrix)
print(f"\nOriginal matrix shape: {conn_matrix.shape}")
print(f"Number of unique pairs: {len(upper_values)}")
print(f"Expected: {get_n_pairs(len(channel_names))} = n(n-1)/2 = 6√ó5/2")

print(f"\nUpper triangle values:")
for idx, (i, j) in enumerate(get_pair_indices(len(channel_names))):
    print(f"  {channel_names[i]:>3} ‚Üî {channel_names[j]:<3}: {upper_values[idx]:.3f}")

# Reconstruct matrix
reconstructed = upper_triangle_to_matrix(upper_values, len(channel_names))

# Verify roundtrip
# Compare only upper triangle (diagonal is NaN)
original_upper = get_upper_triangle_values(conn_matrix)
reconstructed_upper = get_upper_triangle_values(reconstructed)
is_identical = np.allclose(original_upper, reconstructed_upper, equal_nan=True)

print(f"\n‚úì Roundtrip successful: {is_identical}")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original matrix
ax1 = axes[0]
im1 = ax1.imshow(conn_matrix, cmap='viridis', vmin=0, vmax=1)
ax1.set_xticks(range(len(channel_names)))
ax1.set_yticks(range(len(channel_names)))
ax1.set_xticklabels(channel_names)
ax1.set_yticklabels(channel_names)
ax1.set_title("Original Matrix", fontweight='bold')
plt.colorbar(im1, ax=ax1, shrink=0.8)

# Upper triangle highlighted
ax2 = axes[1]
# Create mask for lower triangle
mask = np.tril(np.ones_like(conn_matrix), k=0)
masked_matrix = np.ma.masked_where(mask, conn_matrix)
im2 = ax2.imshow(masked_matrix, cmap='viridis', vmin=0, vmax=1)
ax2.set_xticks(range(len(channel_names)))
ax2.set_yticks(range(len(channel_names)))
ax2.set_xticklabels(channel_names)
ax2.set_yticklabels(channel_names)
ax2.set_title("Upper Triangle Only", fontweight='bold')
# Add grey for masked area
ax2.imshow(np.where(mask, 0.8, np.nan), cmap='gray', vmin=0, vmax=1)
plt.colorbar(im2, ax=ax2, shrink=0.8)

# Reconstructed matrix
ax3 = axes[2]
im3 = ax3.imshow(reconstructed, cmap='viridis', vmin=0, vmax=1)
ax3.set_xticks(range(len(channel_names)))
ax3.set_yticks(range(len(channel_names)))
ax3.set_xticklabels(channel_names)
ax3.set_yticklabels(channel_names)
ax3.set_title("Reconstructed Matrix", fontweight='bold')
plt.colorbar(im3, ax=ax3, shrink=0.8)

plt.tight_layout()
plt.show()

print("\n‚Üí Upper triangle contains all unique information for symmetric matrices")

## Section 8: Channel Grouping and Region Averaging

Individual electrode pairs can be **noisy**. Often, we're more interested in connectivity between **brain regions** than individual electrodes.

**Example applications:**
- "Frontal-to-parietal" connectivity in working memory
- "Left-to-right hemisphere" coupling during bimanual coordination
- "Motor-to-motor" synchronization in hyperscanning

**Approach:**
1. Define channel groups (frontal, central, parietal, etc.)
2. Average connectivity within each region pair
3. Result: smaller, more robust region √ó region matrix

This is especially valuable when:
- Individual electrodes have high noise
- You want to reduce multiple comparisons
- Your hypothesis is at the region level

In [None]:
# ============================================================================
# FUNCTION 9: Channel Grouping and Region Connectivity
# ============================================================================

def define_channel_groups(
    channel_names: List[str],
    group_definitions: Dict[str, List[str]]
) -> Dict[str, List[int]]:
    """
    Map channel group names to their indices.
    
    Parameters
    ----------
    channel_names : List[str]
        List of all channel names.
    group_definitions : Dict[str, List[str]]
        Mapping of group names to channel names.
        e.g., {"frontal": ["F3", "Fz", "F4"], "parietal": ["P3", "Pz", "P4"]}
        
    Returns
    -------
    Dict[str, List[int]]
        Mapping of group names to channel indices.
        
    Raises
    ------
    ValueError
        If a channel name in group_definitions is not found.
    """
    result = {}
    for group_name, channels in group_definitions.items():
        indices = []
        for ch in channels:
            if ch not in channel_names:
                raise ValueError(f"Channel '{ch}' not found in channel_names")
            indices.append(channel_names.index(ch))
        result[group_name] = indices
    return result


def compute_region_connectivity(
    matrix: NDArray[np.floating],
    channel_groups: Dict[str, List[int]]
) -> Tuple[NDArray[np.floating], List[str]]:
    """
    Compute average connectivity between brain regions.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Full connectivity matrix, shape (n_channels, n_channels).
    channel_groups : Dict[str, List[int]]
        Mapping of group names to channel indices.
        
    Returns
    -------
    Tuple[NDArray[np.floating], List[str]]
        - Region connectivity matrix, shape (n_regions, n_regions)
        - List of region names
        
    Notes
    -----
    - Diagonal = mean connectivity WITHIN a region
    - Off-diagonal = mean connectivity BETWEEN regions
    """
    region_names = list(channel_groups.keys())
    n_regions = len(region_names)
    
    region_matrix = np.zeros((n_regions, n_regions))
    
    for i, region_i in enumerate(region_names):
        for j, region_j in enumerate(region_names):
            indices_i = channel_groups[region_i]
            indices_j = channel_groups[region_j]
            
            # Get all pairwise values between these regions
            values = []
            for idx_i in indices_i:
                for idx_j in indices_j:
                    if i == j and idx_i == idx_j:
                        # Skip self-connections within same region
                        continue
                    val = matrix[idx_i, idx_j]
                    if not np.isnan(val):
                        values.append(val)
            
            if values:
                region_matrix[i, j] = np.mean(values)
            else:
                region_matrix[i, j] = np.nan
    
    return region_matrix, region_names


print("Region connectivity functions defined:")
print("‚Ä¢ define_channel_groups(channel_names, group_definitions) ‚Üí indices")
print("‚Ä¢ compute_region_connectivity(matrix, groups) ‚Üí (region_matrix, names)")

In [None]:
# ============================================================================
# VISUALIZATION 8: Region Averaging Demo
# ============================================================================

# For this demo, let's define 3 regions from our 6 channels
# Our channel_names are ['Ch1', 'Ch2', 'Ch3', 'Ch4', 'Ch5', 'Ch6']
# These correspond to our 3 clusters:
# - Cluster 1 (Ch1, Ch2): share source_1
# - Cluster 2 (Ch3, Ch4): share source_2
# - Cluster 3 (Ch5, Ch6): share source_3

group_definitions = {
    "Cluster1": ["Ch1", "Ch2"],
    "Cluster2": ["Ch3", "Ch4"],
    "Cluster3": ["Ch5", "Ch6"]
}

channel_groups = define_channel_groups(channel_names, group_definitions)

print("=" * 60)
print("Region Averaging")
print("=" * 60)
print("\nChannel groups:")
for region, indices in channel_groups.items():
    channels = [channel_names[i] for i in indices]
    print(f"  {region}: {channels} (indices: {indices})")

# Compute region connectivity
region_matrix, region_names = compute_region_connectivity(conn_matrix, channel_groups)

print(f"\nRegion connectivity matrix:")
print(f"  Shape: {region_matrix.shape} (reduced from {conn_matrix.shape})")

# Visualize side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original matrix with region boxes
ax1 = axes[0]
im1 = ax1.imshow(conn_matrix, cmap='viridis', vmin=0, vmax=1)
ax1.set_xticks(range(len(channel_names)))
ax1.set_yticks(range(len(channel_names)))
ax1.set_xticklabels(channel_names)
ax1.set_yticklabels(channel_names)
ax1.set_title("Full Channel Matrix (6√ó6)", fontweight='bold', fontsize=12)
plt.colorbar(im1, ax=ax1, shrink=0.8, label='PLV')

# Draw region boxes
region_colors = [PRIMARY_BLUE, PRIMARY_GREEN, PRIMARY_RED]
for idx, (region, indices) in enumerate(channel_groups.items()):
    start = min(indices)
    end = max(indices)
    rect = plt.Rectangle(
        (start - 0.5, start - 0.5), 
        end - start + 1, end - start + 1,
        fill=False, edgecolor=region_colors[idx], linewidth=3, linestyle='--'
    )
    ax1.add_patch(rect)

# Region matrix
ax2 = axes[1]
im2 = ax2.imshow(region_matrix, cmap='viridis', vmin=0, vmax=1)
ax2.set_xticks(range(len(region_names)))
ax2.set_yticks(range(len(region_names)))
ax2.set_xticklabels(region_names, fontsize=11)
ax2.set_yticklabels(region_names, fontsize=11)
ax2.set_title("Region Matrix (3√ó3)", fontweight='bold', fontsize=12)
plt.colorbar(im2, ax=ax2, shrink=0.8, label='Mean PLV')

# Add values to cells
for i in range(len(region_names)):
    for j in range(len(region_names)):
        val = region_matrix[i, j]
        if not np.isnan(val):
            text_color = 'white' if val > 0.5 else 'black'
            ax2.text(j, i, f'{val:.2f}', ha='center', va='center', 
                    fontsize=12, fontweight='bold', color=text_color)

plt.tight_layout()
plt.show()

print("\n‚Üí Region averaging reduces 15 unique pairs to 6 region pairs")

---

## Section 9: Hyperscanning Matrices ‚Äî The Big Picture

In **hyperscanning**, we record from **two participants simultaneously**. This creates a special matrix structure.

With $n$ channels per participant, the **full hyperscanning matrix** is $2n √ó 2n$:

```
         ‚îÇ  P1 channels  ‚îÇ  P2 channels  ‚îÇ
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
P1 ch.   ‚îÇ   Within-P1   ‚îÇ    Between    ‚îÇ
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
P2 ch.   ‚îÇ   Between.T   ‚îÇ   Within-P2   ‚îÇ
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

**Four quadrants:**
1. **Within-P1** (top-left): Connectivity within Participant 1 ‚Üí ‚ö†Ô∏è Volume conduction!
2. **Within-P2** (bottom-right): Connectivity within Participant 2 ‚Üí ‚ö†Ô∏è Volume conduction!
3. **Between** (top-right): P1 channels ‚Üí P2 channels ‚Üí ‚úÖ **No volume conduction!**
4. **Between.T** (bottom-left): Transpose of between block

The **between-participant block** is where inter-brain synchrony lives ‚Äî and it's free from volume conduction!

In [None]:
# ============================================================================
# VISUALIZATION 9: Hyperscanning Matrix Structure (Schematic)
# ============================================================================

fig, ax = plt.subplots(figsize=(10, 10))

# Create schematic matrix
n_per_participant = 4  # Simplified for visualization
n_total = 2 * n_per_participant

# Create example values for each quadrant
schematic = np.zeros((n_total, n_total))

# Within-P1 (top-left) - higher values (volume conduction)
schematic[:n_per_participant, :n_per_participant] = 0.7
# Within-P2 (bottom-right) - higher values (volume conduction)
schematic[n_per_participant:, n_per_participant:] = 0.7
# Between (off-diagonal blocks) - lower values (true coupling)
schematic[:n_per_participant, n_per_participant:] = 0.4
schematic[n_per_participant:, :n_per_participant] = 0.4

# Set diagonal to NaN
np.fill_diagonal(schematic, np.nan)

# Plot
im = ax.imshow(schematic, cmap='viridis', vmin=0, vmax=1)

# Add quadrant labels
ax.text(n_per_participant/2 - 0.5, n_per_participant/2 - 0.5, 
        'Within-P1\nVol. Cond.', 
        ha='center', va='center', fontsize=14, fontweight='bold', color='black')
ax.text(n_total - n_per_participant/2 - 0.5, n_total - n_per_participant/2 - 0.5, 
        'Within-P2\nVol. Cond.', 
        ha='center', va='center', fontsize=14, fontweight='bold', color='black')
ax.text(n_total - n_per_participant/2 - 0.5, n_per_participant/2 - 0.5, 
        'Between\nNo Vol. Cond.', 
        ha='center', va='center', fontsize=14, fontweight='bold', color='black')
ax.text(n_per_participant/2 - 0.5, n_total - n_per_participant/2 - 0.5, 
        'Between.T\nNo Vol. Cond.', 
        ha='center', va='center', fontsize=14, fontweight='bold', color='black')

# Add dividing lines
ax.axhline(n_per_participant - 0.5, color='white', linewidth=3)
ax.axvline(n_per_participant - 0.5, color='white', linewidth=3)

# Labels
p1_labels = ['P1-Ch1', 'P1-Ch2', 'P1-Ch3', 'P1-Ch4']
p2_labels = ['P2-Ch1', 'P2-Ch2', 'P2-Ch3', 'P2-Ch4']
all_labels = p1_labels + p2_labels

ax.set_xticks(range(n_total))
ax.set_yticks(range(n_total))
ax.set_xticklabels(all_labels, rotation=45, ha='right')
ax.set_yticklabels(all_labels)

# Color-code axis labels by participant
for i, label in enumerate(ax.get_xticklabels()):
    label.set_color(SUBJECT_1 if i < n_per_participant else SUBJECT_2)
for i, label in enumerate(ax.get_yticklabels()):
    label.set_color(SUBJECT_1 if i < n_per_participant else SUBJECT_2)

ax.set_title("Hyperscanning Matrix Structure (2n √ó 2n)", fontsize=14, fontweight='bold', pad=15)
plt.colorbar(im, ax=ax, shrink=0.8, label='Connectivity')

plt.tight_layout()
plt.show()

print("The BETWEEN block is the key to hyperscanning analysis!")

## Section 10: Computing Hyperscanning Connectivity

Now let's implement the computation for hyperscanning data.

**Input:**
- `data_p1`: Participant 1's data, shape `(n_channels, n_samples)`
- `data_p2`: Participant 2's data, shape `(n_channels, n_samples)`

**Output options:**
1. **Full matrix** (2n √ó 2n): All within and between connections
2. **Between matrix only** (n √ó n): Just P1‚ÜîP2 connections
3. **Separate matrices**: `within_p1`, `within_p2`, `between`

**Important note:** The between-participant matrix is **NOT symmetric** in general:
- $M[i,j]$ = connectivity from P1_channel_i to P2_channel_j
- This is not the same as P2_channel_j to P1_channel_i (different row/column meaning)

In [None]:
# ============================================================================
# FUNCTION 10: Hyperscanning Connectivity
# ============================================================================

def compute_hyperscanning_connectivity(
    data_p1: NDArray[np.floating],
    data_p2: NDArray[np.floating],
    fs: float,
    band: Tuple[float, float],
    metric: str = "plv"
) -> Dict[str, NDArray[np.floating]]:
    """
    Compute connectivity matrices for hyperscanning data.
    
    Parameters
    ----------
    data_p1 : NDArray[np.floating]
        Participant 1 data, shape (n_channels, n_samples).
    data_p2 : NDArray[np.floating]
        Participant 2 data, shape (n_channels, n_samples).
    fs : float
        Sampling frequency in Hz.
    band : Tuple[float, float]
        Frequency band (low, high) in Hz.
    metric : str, optional
        Connectivity metric. Default is "plv".
        
    Returns
    -------
    Dict[str, NDArray[np.floating]]
        Dictionary with keys:
        - "within_p1": (n_ch, n_ch) connectivity within P1
        - "within_p2": (n_ch, n_ch) connectivity within P2
        - "between": (n_ch, n_ch) connectivity P1‚ÜíP2
        - "full": (2*n_ch, 2*n_ch) complete hyperscanning matrix
    """
    n_ch_p1 = data_p1.shape[0]
    n_ch_p2 = data_p2.shape[0]
    
    if n_ch_p1 != n_ch_p2:
        raise ValueError(
            f"Both participants must have same number of channels. "
            f"Got {n_ch_p1} and {n_ch_p2}."
        )
    
    n_ch = n_ch_p1
    
    # Compute within-participant connectivity
    within_p1 = compute_connectivity_matrix(data_p1, fs, band, metric)
    within_p2 = compute_connectivity_matrix(data_p2, fs, band, metric)
    
    # Compute between-participant connectivity
    # Filter all data first
    data_p1_filt = np.array([bandpass_filter(ch, band[0], band[1], fs) for ch in data_p1])
    data_p2_filt = np.array([bandpass_filter(ch, band[0], band[1], fs) for ch in data_p2])
    
    between = np.zeros((n_ch, n_ch))
    for i in range(n_ch):
        for j in range(n_ch):
            between[i, j] = compute_plv_pair(data_p1_filt[i], data_p2_filt[j])
    
    # Build full matrix
    n_total = 2 * n_ch
    full = np.zeros((n_total, n_total))
    
    # Fill quadrants
    full[:n_ch, :n_ch] = within_p1              # Top-left
    full[n_ch:, n_ch:] = within_p2              # Bottom-right
    full[:n_ch, n_ch:] = between                 # Top-right
    full[n_ch:, :n_ch] = between.T               # Bottom-left
    
    return {
        "within_p1": within_p1,
        "within_p2": within_p2,
        "between": between,
        "full": full
    }


def extract_between_participant_matrix(
    full_matrix: NDArray[np.floating],
    n_channels_per_participant: int
) -> NDArray[np.floating]:
    """
    Extract the between-participant block from a full hyperscanning matrix.
    
    Parameters
    ----------
    full_matrix : NDArray[np.floating]
        Full hyperscanning matrix, shape (2n, 2n).
    n_channels_per_participant : int
        Number of channels per participant.
        
    Returns
    -------
    NDArray[np.floating]
        Between-participant matrix, shape (n, n).
        Rows = P1 channels, Columns = P2 channels.
    """
    n = n_channels_per_participant
    return full_matrix[:n, n:].copy()


print("Hyperscanning functions defined:")
print("‚Ä¢ compute_hyperscanning_connectivity(data_p1, data_p2, fs, band)")
print("‚Ä¢ extract_between_participant_matrix(full_matrix, n_per_participant)")

In [None]:
# ============================================================================
# VISUALIZATION 10: Hyperscanning Connectivity Example
# ============================================================================

# Generate synthetic hyperscanning data
np.random.seed(42)
n_channels_hyper = 4
n_samples_hyper = int(5 * fs)  # 5 seconds
t_hyper = np.arange(n_samples_hyper) / fs

# Create shared and independent sources
# Shared source: both participants will have some coupling on specific channels
shared_phase = 2 * np.pi * 10 * t_hyper + np.cumsum(0.05 * np.random.randn(n_samples_hyper))
shared_source = np.sin(shared_phase)

# Independent sources for each participant
def create_independent_source(freq: float, n_samples: int, phase_noise: float = 0.1) -> NDArray:
    """Create a source with random phase dynamics."""
    t = np.arange(n_samples) / fs
    phase = 2 * np.pi * freq * t + np.cumsum(phase_noise * np.random.randn(n_samples))
    return np.sin(phase)

# Participant 1 data
data_p1_hyper = np.zeros((n_channels_hyper, n_samples_hyper))
data_p1_hyper[0] = shared_source + 0.5 * create_independent_source(10, n_samples_hyper)  # Some shared
data_p1_hyper[1] = create_independent_source(10, n_samples_hyper)  # Independent
data_p1_hyper[2] = create_independent_source(10, n_samples_hyper)  # Independent
data_p1_hyper[3] = shared_source + 0.5 * create_independent_source(10, n_samples_hyper)  # Some shared

# Participant 2 data
data_p2_hyper = np.zeros((n_channels_hyper, n_samples_hyper))
data_p2_hyper[0] = shared_source + 0.5 * create_independent_source(10, n_samples_hyper)  # Coupled with P1-Ch0
data_p2_hyper[1] = create_independent_source(10, n_samples_hyper)  # Independent
data_p2_hyper[2] = create_independent_source(10, n_samples_hyper)  # Independent
data_p2_hyper[3] = shared_source + 0.5 * create_independent_source(10, n_samples_hyper)  # Coupled with P1-Ch3

# Add noise
noise_level_hyper = 0.3
data_p1_hyper += noise_level_hyper * np.random.randn(*data_p1_hyper.shape)
data_p2_hyper += noise_level_hyper * np.random.randn(*data_p2_hyper.shape)

# Compute hyperscanning connectivity
hyper_results = compute_hyperscanning_connectivity(
    data_p1_hyper, data_p2_hyper, fs, alpha_band
)

print("=" * 60)
print("Hyperscanning Connectivity Computed")
print("=" * 60)
print(f"\nData shapes: P1={data_p1_hyper.shape}, P2={data_p2_hyper.shape}")
print(f"\nMatrix shapes:")
print(f"  Within-P1: {hyper_results['within_p1'].shape}")
print(f"  Within-P2: {hyper_results['within_p2'].shape}")
print(f"  Between:   {hyper_results['between'].shape}")
print(f"  Full:      {hyper_results['full'].shape}")

# Visualize the full matrix
fig, ax = plt.subplots(figsize=(10, 10))

im = ax.imshow(hyper_results['full'], cmap='viridis', vmin=0, vmax=1)

# Add quadrant dividers
n_ch = n_channels_hyper
ax.axhline(n_ch - 0.5, color='white', linewidth=2)
ax.axvline(n_ch - 0.5, color='white', linewidth=2)

# Labels
p1_labels = [f'P1-Ch{i}' for i in range(n_ch)]
p2_labels = [f'P2-Ch{i}' for i in range(n_ch)]
all_labels = p1_labels + p2_labels

ax.set_xticks(range(2 * n_ch))
ax.set_yticks(range(2 * n_ch))
ax.set_xticklabels(all_labels, rotation=45, ha='right')
ax.set_yticklabels(all_labels)

# Color axis labels
for i, label in enumerate(ax.get_xticklabels()):
    label.set_color(SUBJECT_1 if i < n_ch else SUBJECT_2)
    label.set_fontweight('bold')
for i, label in enumerate(ax.get_yticklabels()):
    label.set_color(SUBJECT_1 if i < n_ch else SUBJECT_2)
    label.set_fontweight('bold')

ax.set_title("Full Hyperscanning Matrix", fontsize=14, fontweight='bold')
plt.colorbar(im, ax=ax, shrink=0.8, label='PLV')

plt.tight_layout()
plt.show()

print("\n‚Üí Notice higher values in corners of the between block (Ch0‚ÜîCh0, Ch3‚ÜîCh3)")

## Section 11: Visualizing Hyperscanning Connectivity

For hyperscanning data, we need specialized visualizations that emphasize the **between-participant** connectivity.

**Approaches:**
1. **Full matrix with quadrant highlighting** ‚Äî Shows everything, marks the between block
2. **Between-matrix only** ‚Äî Focus on inter-brain synchrony
3. **Circular plot with two groups** ‚Äî P1 channels on one side, P2 on the other

The circular plot is particularly intuitive: inter-brain connections cross between the two participant groups.

In [None]:
# ============================================================================
# FUNCTION 11: Hyperscanning Visualization Functions
# ============================================================================

def plot_hyperscanning_matrix(
    full_matrix: NDArray[np.floating],
    channel_names_p1: List[str],
    channel_names_p2: List[str],
    ax: Optional[plt.Axes] = None,
    highlight_between: bool = True,
    cmap: str = 'viridis',
    title: Optional[str] = None
) -> plt.Axes:
    """
    Plot full hyperscanning matrix with quadrant annotations.
    
    Parameters
    ----------
    full_matrix : NDArray[np.floating]
        Full hyperscanning matrix, shape (2n, 2n).
    channel_names_p1 : List[str]
        Channel names for Participant 1.
    channel_names_p2 : List[str]
        Channel names for Participant 2.
    ax : Optional[plt.Axes], optional
        Matplotlib axes. If None, creates new figure.
    highlight_between : bool, optional
        Whether to highlight the between-participant block. Default True.
    cmap : str, optional
        Colormap. Default is 'viridis'.
    title : Optional[str], optional
        Plot title.
        
    Returns
    -------
    plt.Axes
        The matplotlib axes with the plot.
    """
    n_ch = len(channel_names_p1)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    
    im = ax.imshow(full_matrix, cmap=cmap, vmin=0, vmax=1)
    
    # Add dividing lines
    ax.axhline(n_ch - 0.5, color='white', linewidth=2)
    ax.axvline(n_ch - 0.5, color='white', linewidth=2)
    
    # Highlight between block
    if highlight_between:
        rect = plt.Rectangle(
            (n_ch - 0.5, -0.5), n_ch, n_ch,
            fill=False, edgecolor=PRIMARY_GREEN, linewidth=3, linestyle='--'
        )
        ax.add_patch(rect)
    
    # Labels
    all_labels = [f'P1-{ch}' for ch in channel_names_p1] + [f'P2-{ch}' for ch in channel_names_p2]
    ax.set_xticks(range(2 * n_ch))
    ax.set_yticks(range(2 * n_ch))
    ax.set_xticklabels(all_labels, rotation=45, ha='right')
    ax.set_yticklabels(all_labels)
    
    # Color labels by participant
    for i, label in enumerate(ax.get_xticklabels()):
        label.set_color(SUBJECT_1 if i < n_ch else SUBJECT_2)
    for i, label in enumerate(ax.get_yticklabels()):
        label.set_color(SUBJECT_1 if i < n_ch else SUBJECT_2)
    
    plt.colorbar(im, ax=ax, shrink=0.8, label='PLV')
    
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold')
    
    return ax


def plot_hyperscanning_circular(
    between_matrix: NDArray[np.floating],
    channel_names_p1: List[str],
    channel_names_p2: List[str],
    threshold: Optional[float] = None,
    ax: Optional[plt.Axes] = None,
    linewidth_scale: float = 3.0,
    title: Optional[str] = None
) -> plt.Axes:
    """
    Circular plot for hyperscanning with P1 on left, P2 on right.
    
    Parameters
    ----------
    between_matrix : NDArray[np.floating]
        Between-participant matrix, shape (n, n).
        Rows = P1 channels, Columns = P2 channels.
    channel_names_p1 : List[str]
        Channel names for Participant 1.
    channel_names_p2 : List[str]
        Channel names for Participant 2.
    threshold : Optional[float], optional
        Only highlight connections above this value. Default None.
    ax : Optional[plt.Axes], optional
        Polar axes. If None, creates new figure.
    linewidth_scale : float, optional
        Scale factor for line width. Default 3.0.
    title : Optional[str], optional
        Plot title.
        
    Returns
    -------
    plt.Axes
        The matplotlib polar axes with the plot.
    """
    n_ch = len(channel_names_p1)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={'projection': 'polar'})
    
    # Position P1 on left side (œÄ/2 to 3œÄ/2), P2 on right side (-œÄ/2 to œÄ/2)
    angles_p1 = np.linspace(np.pi * 0.7, np.pi * 1.3, n_ch)
    angles_p2 = np.linspace(-np.pi * 0.3, np.pi * 0.3, n_ch)
    
    # Plot P1 nodes (left side)
    for i, (angle, name) in enumerate(zip(angles_p1, channel_names_p1)):
        ax.scatter(angle, 1, s=400, c=SUBJECT_1, zorder=5, edgecolors='white', linewidths=2)
        ax.text(angle, 1.2, f'P1-{name}', ha='center', va='center', fontsize=10, 
                fontweight='bold', color=SUBJECT_1)
    
    # Plot P2 nodes (right side)
    for i, (angle, name) in enumerate(zip(angles_p2, channel_names_p2)):
        ax.scatter(angle, 1, s=400, c=SUBJECT_2, zorder=5, edgecolors='white', linewidths=2)
        ax.text(angle, 1.2, f'P2-{name}', ha='center', va='center', fontsize=10, 
                fontweight='bold', color=SUBJECT_2)
    
    # Plot connections between P1 and P2
    for i in range(n_ch):
        for j in range(n_ch):
            value = between_matrix[i, j]
            if np.isnan(value):
                continue
            
            angle_i = angles_p1[i]
            angle_j = angles_p2[j]
            
            # Create arc
            n_points = 50
            t_vals = np.linspace(0, 1, n_points)
            r_vals = 1 - 0.4 * np.sin(np.pi * t_vals)
            angle_vals = angle_i + t_vals * (angle_j - angle_i)
            
            # Determine if strong connection
            is_strong = threshold is None or value >= threshold
            
            if is_strong:
                lw = value * linewidth_scale
                alpha = 0.4 + 0.6 * value
                color = SECONDARY_PURPLE
                zorder = 2
            else:
                lw = 0.8
                alpha = 0.2
                color = '#CCCCCC'
                zorder = 1
            
            ax.plot(angle_vals, r_vals, color=color, linewidth=lw, alpha=alpha, zorder=zorder)
    
    # Clean up
    ax.set_ylim(0, 1.4)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.spines['polar'].set_visible(False)
    
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    
    return ax


print("Hyperscanning visualization functions defined:")
print("‚Ä¢ plot_hyperscanning_matrix(full_matrix, ch_p1, ch_p2, ...)")
print("‚Ä¢ plot_hyperscanning_circular(between_matrix, ch_p1, ch_p2, ...)")

In [None]:
# ============================================================================
# VISUALIZATION 11: Hyperscanning Visualizations
# ============================================================================

# Channel names for our hyperscanning example
ch_names_p1 = [f'Ch{i}' for i in range(n_channels_hyper)]
ch_names_p2 = [f'Ch{i}' for i in range(n_channels_hyper)]

# Extract between matrix
between_matrix = hyper_results['between']

print("=" * 60)
print("Between-Participant Connectivity")
print("=" * 60)
print(f"\nBetween matrix shape: {between_matrix.shape}")
print(f"  Rows = P1 channels, Columns = P2 channels")
print(f"\nHighest connections:")
# Find top 3 connections
flat_indices = np.argsort(between_matrix.flatten())[::-1][:3]
for idx in flat_indices:
    i, j = np.unravel_index(idx, between_matrix.shape)
    print(f"  P1-Ch{i} ‚Üî P2-Ch{j}: PLV = {between_matrix[i, j]:.3f}")

# Create visualizations
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Between-participant matrix only
ax1 = axes[0]
im = ax1.imshow(between_matrix, cmap='viridis', vmin=0, vmax=1)
ax1.set_xticks(range(n_channels_hyper))
ax1.set_yticks(range(n_channels_hyper))
ax1.set_xticklabels([f'P2-Ch{i}' for i in range(n_channels_hyper)], fontsize=11)
ax1.set_yticklabels([f'P1-Ch{i}' for i in range(n_channels_hyper)], fontsize=11)
ax1.set_xlabel('Participant 2 Channels', fontsize=12, color=SUBJECT_2)
ax1.set_ylabel('Participant 1 Channels', fontsize=12, color=SUBJECT_1)
ax1.set_title('Between-Participant Matrix', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=ax1, shrink=0.8, label='PLV')

# Add values
for i in range(n_channels_hyper):
    for j in range(n_channels_hyper):
        val = between_matrix[i, j]
        color = 'white' if val > 0.5 else 'black'
        ax1.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=10, color=color)

# Circular hyperscanning plot
ax2 = fig.add_subplot(1, 2, 2, projection='polar')
axes[1].remove()  # Remove the original axes

plot_hyperscanning_circular(
    between_matrix,
    ch_names_p1,
    ch_names_p2,
    threshold=0.5,
    ax=ax2,
    title='Inter-Brain Connectivity'
)

plt.tight_layout()
plt.show()

print("\n‚Üí The coupled channels (Ch0‚ÜîCh0, Ch3‚ÜîCh3) show the strongest connections!")

---

## Section 12: Global Connectivity Metrics

Sometimes we need to summarize an entire matrix with a **single value**:

**Common global metrics:**
- **Mean connectivity**: Average of all off-diagonal values
- **Connection density**: Proportion of connections above a threshold
- **Hyperscanning ratio**: Between-participant / within-participant connectivity

These are useful for:
- Comparing conditions (task vs. rest)
- Correlating with behavior (reaction time, performance)
- Statistical group comparisons

In [None]:
# ============================================================================
# FUNCTION 12: Global Connectivity Metrics
# ============================================================================

def compute_global_connectivity(
    matrix: NDArray[np.floating],
    exclude_diagonal: bool = True
) -> float:
    """
    Compute mean connectivity (global connectivity).
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix.
    exclude_diagonal : bool, optional
        Whether to exclude diagonal values. Default True.
        
    Returns
    -------
    float
        Mean connectivity value.
    """
    if exclude_diagonal:
        # Get upper triangle values (excludes diagonal)
        values = get_upper_triangle_values(matrix, k=1)
    else:
        values = matrix.flatten()
    
    # Remove NaN values
    values = values[~np.isnan(values)]
    return float(np.mean(values))


def compute_connection_density(
    matrix: NDArray[np.floating],
    threshold: float,
    exclude_diagonal: bool = True
) -> float:
    """
    Compute proportion of connections exceeding threshold.
    
    Parameters
    ----------
    matrix : NDArray[np.floating]
        Connectivity matrix.
    threshold : float
        Connectivity threshold.
    exclude_diagonal : bool, optional
        Whether to exclude diagonal. Default True.
        
    Returns
    -------
    float
        Proportion of connections above threshold (0 to 1).
    """
    if exclude_diagonal:
        values = get_upper_triangle_values(matrix, k=1)
    else:
        values = matrix.flatten()
    
    values = values[~np.isnan(values)]
    return float(np.mean(values > threshold))


def compute_hyperscanning_ratio(
    within_mean: float,
    between_mean: float
) -> float:
    """
    Compute ratio of between to within connectivity.
    
    Parameters
    ----------
    within_mean : float
        Mean within-participant connectivity.
    between_mean : float
        Mean between-participant connectivity.
        
    Returns
    -------
    float
        Ratio (between / within). 
        > 1 indicates stronger inter-brain than intra-brain connectivity.
    """
    if within_mean == 0:
        return np.inf if between_mean > 0 else 0.0
    return between_mean / within_mean


print("Global metric functions defined:")
print("‚Ä¢ compute_global_connectivity(matrix) ‚Üí mean PLV")
print("‚Ä¢ compute_connection_density(matrix, threshold) ‚Üí proportion above threshold")
print("‚Ä¢ compute_hyperscanning_ratio(within, between) ‚Üí inter/intra ratio")

In [None]:
# ============================================================================
# VISUALIZATION 12: Global Metrics Comparison
# ============================================================================

# Compute global metrics for our hyperscanning data
within_p1_mean = compute_global_connectivity(hyper_results['within_p1'])
within_p2_mean = compute_global_connectivity(hyper_results['within_p2'])
between_mean = compute_global_connectivity(hyper_results['between'], exclude_diagonal=False)

# Average within-participant
within_avg = (within_p1_mean + within_p2_mean) / 2

# Hyperscanning ratio
hyper_ratio = compute_hyperscanning_ratio(within_avg, between_mean)

# Connection density at different thresholds
densities = {}
for threshold in [0.3, 0.5, 0.7]:
    densities[threshold] = {
        'within_p1': compute_connection_density(hyper_results['within_p1'], threshold),
        'within_p2': compute_connection_density(hyper_results['within_p2'], threshold),
        'between': compute_connection_density(hyper_results['between'], threshold)
    }

print("=" * 60)
print("Global Connectivity Metrics")
print("=" * 60)
print(f"\nMean Connectivity:")
print(f"  Within P1: {within_p1_mean:.3f}")
print(f"  Within P2: {within_p2_mean:.3f}")
print(f"  Between:   {between_mean:.3f}")
print(f"\nHyperscanning Ratio (between/within): {hyper_ratio:.3f}")
if hyper_ratio > 1:
    print("  ‚Üí Inter-brain > Intra-brain connectivity")
else:
    print("  ‚Üí Intra-brain > Inter-brain connectivity")

# Visualization: Bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Mean connectivity comparison
ax1 = axes[0]
categories = ['Within P1', 'Within P2', 'Between']
values = [within_p1_mean, within_p2_mean, between_mean]
colors = [SUBJECT_1, SUBJECT_2, SECONDARY_PURPLE]
bars = ax1.bar(categories, values, color=colors, edgecolor='white', linewidth=2)
ax1.set_ylabel('Mean PLV', fontsize=12)
ax1.set_title('Mean Connectivity Comparison', fontsize=13, fontweight='bold')
ax1.set_ylim(0, 1)
ax1.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Reference')

# Add value labels
for bar, val in zip(bars, values):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
             f'{val:.2f}', ha='center', fontsize=11, fontweight='bold')

# Connection density at threshold=0.5
ax2 = axes[1]
threshold = 0.5
x = np.arange(3)
width = 0.6
density_values = [densities[threshold]['within_p1'], 
                  densities[threshold]['within_p2'], 
                  densities[threshold]['between']]
bars = ax2.bar(x, density_values, width, color=colors, edgecolor='white', linewidth=2)
ax2.set_xticks(x)
ax2.set_xticklabels(categories)
ax2.set_ylabel('Connection Density', fontsize=12)
ax2.set_title(f'Proportion of Connections > {threshold}', fontsize=13, fontweight='bold')
ax2.set_ylim(0, 1)

# Add value labels
for bar, val in zip(bars, density_values):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
             f'{val:.0%}', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\n‚Üí These metrics can be compared across conditions or correlated with behavior")

---

## Summary

In this notebook, we learned how to organize, compute, and visualize **connectivity matrices** for multi-channel EEG and hyperscanning data.

### Key Concepts

| Concept | Description |
|---------|-------------|
| **Connectivity Matrix** | n√ón array storing pairwise connectivity values |
| **Symmetry** | For undirected metrics (PLV, coherence), M[i,j] = M[j,i] |
| **Diagonal** | Typically NaN (self-connectivity is meaningless) |
| **Upper Triangle** | Contains all unique values for symmetric matrices |
| **Region Averaging** | Reduces noise by averaging within brain regions |
| **Hyperscanning Matrix** | 2n√ó2n with within and between blocks |
| **Between Block** | P1‚ÜîP2 connectivity ‚Äî no volume conduction! |

### Functions Created

**Matrix Operations:**
- `get_n_pairs(n)` ‚Äî Number of unique channel pairs
- `get_pair_indices(n)` ‚Äî List of (i, j) pairs
- `compute_connectivity_matrix()` ‚Äî Compute full matrix
- `get_upper_triangle_values()` ‚Äî Extract unique values
- `upper_triangle_to_matrix()` ‚Äî Reconstruct from values

**Region Analysis:**
- `define_channel_groups()` ‚Äî Map channels to regions
- `compute_region_connectivity()` ‚Äî Average by region

**Hyperscanning:**
- `compute_hyperscanning_connectivity()` ‚Äî Full hyperscanning analysis
- `extract_between_participant_matrix()` ‚Äî Get inter-brain block

**Visualization:**
- `plot_connectivity_matrix()` ‚Äî Heatmap visualization
- `plot_circular_connectivity()` ‚Äî Network diagram
- `plot_hyperscanning_matrix()` ‚Äî Annotated hyperscanning heatmap
- `plot_hyperscanning_circular()` ‚Äî Two-brain circular plot

**Global Metrics:**
- `compute_global_connectivity()` ‚Äî Mean connectivity
- `compute_connection_density()` ‚Äî Proportion above threshold
- `compute_hyperscanning_ratio()` ‚Äî Between/within ratio

**Validation:**
- `validate_connectivity_matrix()` ‚Äî Check matrix properties
- `get_matrix_statistics()` ‚Äî Summary statistics

---

## Discussion Questions

1. **Scale considerations**: You have 64-channel EEG from two participants (128 channels total). How many unique between-participant channel pairs are there? If each PLV computation takes 0.01 seconds, how long would the full analysis take?

2. **High connectivity everywhere**: Your connectivity matrix shows high values everywhere (mean PLV = 0.85). What might cause this? Is it necessarily a problem?

3. **Visualization choice**: When would you prefer a circular connectivity plot over a heatmap? When would the heatmap be better?

4. **Why separate matrices?**: In hyperscanning, why might we want to analyze the between-participant matrix separately from the within-participant matrices?

5. **Comparing conditions**: You're comparing connectivity between a "cooperation" and "competition" condition. Would you compare full matrices, between-participant matrices, or global metrics? What are the trade-offs?

---

## Next Steps

In the upcoming notebooks, we will:
- **C03**: Learn about statistical significance testing for connectivity values
- **D01-D03**: Explore information-theoretic approaches (entropy, mutual information, transfer entropy)
- **F01-G03**: Dive deep into specific connectivity metrics (coherence, PLV, PLI, wPLI)