# Paper Replication: de Cheveigné & Simon (2008)

**"Denoising based on spatial filtering"** - *J Neurosci Methods*

This notebook replicates key figures from the foundational paper on DSS for evoked response enhancement.

## Key Results to Replicate
- **Figure 1**: Power distribution across DSS components
- **Figure 2**: RMS evoked response before/after denoising
- **Figure 3**: Component topographies and spatial filters

In [None]:
import matplotlib.pyplot as plt
import mne
import numpy as np
from mne.datasets import sample

from mne_denoise.dss import DSS, TrialAverageBias

# Set plotting style
plt.style.use("seaborn-v0_8-whitegrid")
np.random.seed(42)

## 1. Load Auditory Evoked Data

The original paper used auditory MEG data. We use MNE's sample dataset (also auditory stimulation).

In [None]:
# Load MNE sample data
data_path = sample.data_path()
raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif"

raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False)

# Filter similar to paper (1-40 Hz)
raw.filter(1, 40, verbose=False)

# Get auditory events only (like paper's noise burst)
events = mne.find_events(raw, stim_channel="STI 014", verbose=False)
event_id = {"auditory/left": 1, "auditory/right": 2}

# Create epochs (-200 to +500 ms, like paper)
epochs = mne.Epochs(
    raw,
    events,
    event_id,
    tmin=-0.2,
    tmax=0.5,
    picks="mag",  # Use magnetometers (similar to paper's axial gradiometers)
    preload=True,
    baseline=(None, 0),
    reject=dict(mag=4e-12),  # Reject outliers
    verbose=False,
)

# Use subset like paper (100 trials)
n_trials = min(100, len(epochs))
epochs = epochs[:n_trials]

print(f"Epochs: {len(epochs)} trials, {len(epochs.ch_names)} channels")
print(f"Time: {epochs.tmin:.2f} to {epochs.tmax:.2f} s")

## 2. Apply DSS with Trial-Average Bias

Following the paper's algorithm (Section II):
1. Normalize channels
2. PCA whitening
3. Apply bias (trial averaging)
4. Second PCA on biased data
5. Order components by reproducibility

In [None]:
# Get epoched data: (n_epochs, n_channels, n_times)
data = epochs.get_data()
n_epochs, n_channels, n_times = data.shape
print(f"Data shape: {data.shape}")

# Reshape to DSS format: (n_channels, n_times, n_epochs)
data_dss = data.transpose(1, 2, 0)

# Apply DSS with trial-average bias
bias = TrialAverageBias()
dss = DSS(bias=bias, n_components=n_channels)
dss.fit(data_dss)

# Get all components
sources = dss.transform(data_dss)  # (n_components, n_times*n_epochs)
eigenvalues = dss.eigenvalues_

print(f"\nDSS Components: {sources.shape[0]}")
print(f"Eigenvalues (reproducibility): {eigenvalues[:5].round(3)}")

## Figure 1: Power Distribution Across Components

Replicating Figure 1 from the paper:
- (a) % power per component before (black) and after (red) averaging
- (b) Cumulative power retained vs component cutoff
- (c) Evoked power per component

In [None]:
# Reshape sources to (n_components, n_times, n_epochs)
sources_3d = sources.reshape(sources.shape[0], n_times, n_epochs)

# Compute power per component
# Before averaging: total power
power_total = np.var(sources_3d, axis=(1, 2))  # Variance across time and epochs
power_total_pct = 100 * power_total / power_total.sum()

# After averaging: evoked power
sources_avg = sources_3d.mean(axis=2)  # Average across epochs
power_evoked = np.var(sources_avg, axis=1)  # Variance of averaged signal
power_evoked_pct = 100 * power_evoked / power_evoked.sum()

# Reproducibility: ratio of evoked to total power
reproducibility = power_evoked / (power_total + 1e-12)

# Create Figure 1
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# (a) Power per component
ax = axes[0]
x = np.arange(1, len(power_total_pct) + 1)
ax.bar(x - 0.2, power_total_pct, 0.4, label="Total (raw)", color="black", alpha=0.7)
ax.bar(
    x + 0.2, power_evoked_pct, 0.4, label="Evoked (averaged)", color="red", alpha=0.7
)
ax.set_xlabel("Component")
ax.set_ylabel("% Power")
ax.set_title("(a) Power per component")
ax.legend()
ax.set_xlim([0, min(20, len(x)) + 1])

# (b) Cumulative power
ax = axes[1]
cum_total = np.cumsum(power_total_pct)
cum_evoked = np.cumsum(power_evoked_pct)
ax.plot(x, cum_total, "k-", linewidth=2, label="Total power")
ax.plot(x, cum_evoked, "r-", linewidth=2, label="Evoked power")
ax.axhline(96, color="gray", linestyle="--", alpha=0.5)
ax.axvline(10, color="gray", linestyle=":", alpha=0.5)
ax.set_xlabel("Components retained")
ax.set_ylabel("% Power retained")
ax.set_title("(b) Cumulative power")
ax.legend()
ax.set_xlim([0, min(20, len(x)) + 1])
ax.set_ylim([0, 105])

# (c) Reproducibility per component
ax = axes[2]
ax.bar(x, 100 * reproducibility, color="blue", alpha=0.7)
ax.set_xlabel("Component")
ax.set_ylabel("% Reproducible")
ax.set_title("(c) Reproducibility (evoked/total)")
ax.set_xlim([0, min(20, len(x)) + 1])

plt.suptitle(
    "Figure 1: Power distribution across DSS components\n(Replicating de Cheveigné & Simon 2008)",
    fontsize=12,
)
plt.tight_layout()
plt.savefig("paper2_figure1.png", dpi=150)
plt.show()

# Print key metrics from paper
print("\n--- Metrics (compare with paper) ---")
print(f"Top 10 components capture {cum_evoked[9]:.1f}% of evoked power")
print(f"Top 10 components capture {cum_total[9]:.1f}% of total power")
print(f"Component 1 reproducibility: {100 * reproducibility[0]:.1f}%")

## Figure 2: RMS Evoked Response

Replicating Figure 2:
- RMS evoked response before and after DSS denoising
- Show M100 auditory response with improved SNR

In [None]:
def compute_rms_with_bootstrap(data_3d, n_bootstrap=100):
    """Compute RMS and bootstrap confidence interval."""
    n_ch, n_times, n_ep = data_3d.shape

    # Average over epochs, then RMS over channels
    avg = data_3d.mean(axis=2)  # (n_ch, n_times)
    rms = np.sqrt(np.mean(avg**2, axis=0))  # (n_times,)

    # Bootstrap confidence intervals
    rms_boots = []
    rng = np.random.default_rng(42)
    for _ in range(n_bootstrap):
        idx = rng.choice(n_ep, n_ep, replace=True)
        avg_boot = data_3d[:, :, idx].mean(axis=2)
        rms_boot = np.sqrt(np.mean(avg_boot**2, axis=0))
        rms_boots.append(rms_boot)

    rms_boots = np.array(rms_boots)
    rms_std = rms_boots.std(axis=0)

    return rms, rms_std


# Original data RMS
rms_original, rms_std_original = compute_rms_with_bootstrap(data_dss)

# DSS denoised data (keep top 10 components)
n_keep = 10
cleaned = dss.inverse_transform(sources[:n_keep])
cleaned_3d = cleaned.reshape(n_channels, n_times, n_epochs)
rms_denoised, rms_std_denoised = compute_rms_with_bootstrap(cleaned_3d)

# Plot Figure 2
times = epochs.times * 1000  # Convert to ms

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# (a) RMS before vs after denoising
ax = axes[0]
ax.plot(times, rms_original, "r-", linewidth=1.5, label="Before DSS")
ax.fill_between(
    times,
    rms_original - 2 * rms_std_original,
    rms_original + 2 * rms_std_original,
    alpha=0.2,
    color="red",
)
ax.plot(times, rms_denoised, "b-", linewidth=1.5, label="After DSS (10 comp.)")
ax.fill_between(
    times,
    rms_denoised - 2 * rms_std_denoised,
    rms_denoised + 2 * rms_std_denoised,
    alpha=0.2,
    color="blue",
)
ax.axvline(0, color="k", linestyle="--", alpha=0.5)
ax.axvline(100, color="green", linestyle=":", alpha=0.5, label="~M100")
ax.set_xlabel("Time (ms)")
ax.set_ylabel("RMS Field")
ax.set_title("(a) RMS evoked response")
ax.legend()

# (b) First DSS component (most reproducible)
ax = axes[1]
comp1 = sources_3d[0]  # (n_times, n_epochs)
comp1_avg = comp1.mean(axis=1)

# Bootstrap for component 1
rng = np.random.default_rng(42)
boots = [
    comp1[:, rng.choice(n_epochs, n_epochs, replace=True)].mean(axis=1)
    for _ in range(100)
]
comp1_std = np.std(boots, axis=0)

ax.plot(times, comp1_avg, "b-", linewidth=2)
ax.fill_between(
    times, comp1_avg - 2 * comp1_std, comp1_avg + 2 * comp1_std, alpha=0.2, color="blue"
)
ax.axvline(0, color="k", linestyle="--", alpha=0.5)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Amplitude")
ax.set_title(f"(b) First DSS component (λ={eigenvalues[0]:.2f})")

plt.suptitle(
    "Figure 2: RMS evoked response before/after denoising\n(Replicating de Cheveigné & Simon 2008)",
    fontsize=12,
)
plt.tight_layout()
plt.savefig("paper2_figure2.png", dpi=150)
plt.show()

## Figure 3: Component Topographies

Replicating Figure 3:
- Field distribution (spatial patterns) of top DSS components
- Spatial filters (how to extract each component)

In [None]:
# Get patterns (topographies) and filters
patterns = dss.patterns_  # (n_channels, n_components)
filters = dss.filters_  # (n_components, n_channels)

# Create info for topomap
picks = mne.pick_types(epochs.info, meg="mag")
info = mne.pick_info(epochs.info, picks)

# Plot top 4 components
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    # Top row: Patterns (field maps)
    ax = axes[0, i]
    pattern = patterns[:, i]
    mne.viz.plot_topomap(pattern, info, axes=ax, show=False, contours=0)
    ax.set_title(f"Pattern {i + 1}\nλ={eigenvalues[i]:.2f}")

    # Bottom row: Filters (spatial filters)
    ax = axes[1, i]
    filt = filters[i]
    mne.viz.plot_topomap(filt, info, axes=ax, show=False, contours=0)
    ax.set_title(f"Filter {i + 1}")

axes[0, 0].set_ylabel("Field Maps\n(Patterns)", fontsize=11)
axes[1, 0].set_ylabel("Spatial Filters", fontsize=11)

plt.suptitle(
    "Figure 3: Component topographies\n(Replicating de Cheveigné & Simon 2008)",
    fontsize=12,
)
plt.tight_layout()
plt.savefig("paper2_figure3.png", dpi=150)
plt.show()

print("Note: Top component should show auditory cortex pattern (bilateral temporal)")

## Summary & Comparison with Paper

Key findings from de Cheveigné & Simon (2008):
- Top ~10 components capture ~96% of evoked power
- First component is ~60% reproducible
- DSS dramatically improves reliability of evoked response

Our replication should show similar trends.

In [None]:
print("=" * 50)
print("REPLICATION SUMMARY")
print("=" * 50)
print(f"\nData: {n_epochs} trials, {n_channels} channels")
print("\nFigure 1 metrics:")
print(f"  Top 10 components capture {cum_evoked[9]:.1f}% of evoked power")
print("  (Paper reports ~96%)")
print(f"  Component 1 reproducibility: {100 * reproducibility[0]:.1f}%")
print("  (Paper reports ~60%)")

print("\nFigure 2: RMS improvement visible at M100 (~100ms)")
print("\nFigure 3: Top component shows expected auditory topography")

print("\n" + "=" * 50)
print("Saved figures: paper2_figure1.png, paper2_figure2.png, paper2_figure3.png")