# Multivariate Convolutional Sparse Coding for Electromagnetic Brain Signals

**MVA 2025/2026 - Machine Learning for Time Series**

This notebook reproduces and extends the NeurIPS 2018 paper by DuprÃ© La Tour et al.

## Contents
1. Setup and Imports
2. Experiment 1: Synthetic CSC Validation
3. Experiment 2: Mu-Rhythm Recovery from MEG Data
4. Experiment 3: Multi-Band Frequency Analysis (New)
5. Summary

## 1. Setup and Imports

In [None]:
# Install required packages (uncomment if needed)
# !pip install numpy scipy matplotlib scikit-learn alphacsc mne joblib numba

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import mne

# alphacsc library - main CSC implementation
from alphacsc import learn_d_z, BatchCDL
from alphacsc.simulate import simulate_data
from alphacsc.datasets.mne_data import load_data

print("All imports successful!")

## 2. Experiment 1: Synthetic CSC Validation

Before applying CSC to real brain data, we validate it on synthetic signals with known ground truth atoms.

**Goal:** Verify that CSC can recover the true waveform shapes.

In [None]:
# Parameters
n_trials = 100      # N: number of signals
n_times = 512       # T: signal length
n_times_atom = 64   # L: atom length
n_atoms = 2         # K: number of atoms
n_iter = 50         # iterations
reg = 0.1           # regularization (lambda)
random_state = 42

print("Generating synthetic data...")
X, ds_true, z_true = simulate_data(
    n_trials, n_times, n_times_atom, n_atoms,
    random_state=random_state
)

# Add noise
rng = np.random.RandomState(random_state)
X += 0.01 * rng.randn(*X.shape)

print(f"Data shape: {X.shape}")
print(f"True atoms shape: {ds_true.shape}")

In [None]:
print("Running Convolutional Sparse Coding...")
pobj, times, d_hat, z_hat, reg_used = learn_d_z(
    X, n_atoms, n_times_atom,
    reg=reg, n_iter=n_iter,
    solver_d_kwargs=dict(factr=100),
    random_state=random_state,
    n_jobs=1, verbose=1
)
print(f"Final objective: {pobj[-1]:.6f}")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
t = np.arange(n_times_atom)

# Ground truth
ax = axes[0, 0]
for k in range(n_atoms):
    ax.plot(t, ds_true[k], label=f'Atom {k+1}', linewidth=2)
ax.set_title('Ground Truth Atoms')
ax.legend()
ax.grid(True, alpha=0.3)

# Learned
ax = axes[0, 1]
for k in range(n_atoms):
    ax.plot(t, d_hat[k], label=f'Atom {k+1}', linewidth=2)
ax.set_title('Learned Atoms (CSC)')
ax.legend()
ax.grid(True, alpha=0.3)

# Comparison
ax = axes[1, 0]
colors = plt.cm.tab10.colors
for k in range(n_atoms):
    ax.plot(t, ds_true[k], '--', color=colors[k], label=f'True {k+1}', alpha=0.7)
    corr_pos = np.corrcoef(ds_true[k], d_hat[k])[0, 1]
    corr_neg = np.corrcoef(ds_true[k], -d_hat[k])[0, 1]
    sign = -1 if corr_neg > corr_pos else 1
    ax.plot(t, sign * d_hat[k], '-', color=colors[k], label=f'Learned {k+1}')
ax.set_title('Comparison (dashed=true, solid=learned)')
ax.legend()
ax.grid(True, alpha=0.3)

# Convergence
ax = axes[1, 1]
ax.semilogy(pobj)
ax.set_xlabel('Iteration')
ax.set_ylabel('Objective')
ax.set_title('Convergence')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nExperiment 1: CSC successfully recovers ground truth atoms!")

## 3. Experiment 2: Mu-Rhythm Recovery from MEG Data

Now we apply **multivariate CSC with rank-1 constraint** to real MEG data.

**Key innovation:** The rank-1 constraint decomposes each atom as $D_k = u_k v_k^T$ where:
- $u_k$ = spatial pattern (which channels activate)
- $v_k$ = temporal pattern (waveform shape)

This enables source localization in the brain!

In [None]:
# Parameters
sfreq = 150.0  # Hz
n_atoms = 25   # K (paper used 40)
n_times_atom = int(sfreq * 1.0)  # 1 second
reg = 0.2      # lambda = 0.2 * lambda_max
n_iter = 100
n_jobs = 6

print("Loading MEG data (somatosensory dataset)...")
print("Note: ~600MB download on first run")
X, info = load_data(dataset='somato', epoch=(-2, 4), sfreq=sfreq)
print(f"Data shape: {X.shape} (trials, channels, time)")

In [None]:
print("\nFitting BatchCDL with rank-1 constraint...")
cdl = BatchCDL(
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    rank1=True,  # KEY: enables spatial-temporal decomposition
    uv_constraint='separate',
    D_init='chunk',
    lmbd_max="scaled",
    reg=reg,
    n_iter=n_iter,
    solver_z="lgcd",
    solver_d='alternate_adaptive',
    verbose=1,
    random_state=0,
    n_jobs=n_jobs
)
cdl.fit(X)
print(f"\nLearned u shape: {cdl.u_hat_.shape}")
print(f"Learned v shape: {cdl.v_hat_.shape}")

In [None]:
# Find mu-rhythm atom (strongest power in 8-12 Hz)
mu_band = (8, 12)
best_atom, best_power = None, 0

for i in range(n_atoms):
    v = cdl.v_hat_[i]
    psd = np.abs(np.fft.rfft(v)) ** 2
    freqs = np.fft.rfftfreq(len(v), 1/sfreq)
    mask = (freqs >= mu_band[0]) & (freqs <= mu_band[1])
    mu_power = np.sum(psd[mask])
    if mu_power > best_power:
        best_power = mu_power
        best_atom = i

print(f"Best mu-rhythm candidate: Atom {best_atom}")

In [None]:
# Visualize mu-rhythm (Figure 4 from paper)
u_mu = cdl.u_hat_[best_atom]
v_mu = cdl.v_hat_[best_atom]

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Temporal waveform
ax = axes[0]
t = np.arange(len(v_mu)) / sfreq
ax.plot(t, v_mu, 'b-', linewidth=1.5)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude')
ax.set_title('(a) Temporal Waveform (Mu-rhythm)')
ax.grid(True, alpha=0.3)

# Spatial pattern
ax = axes[1]
mne.viz.plot_topomap(u_mu, info, axes=ax, show=False)
ax.set_title('(b) Spatial Pattern')

# PSD
ax = axes[2]
psd = np.abs(np.fft.rfft(v_mu)) ** 2
psd_db = 10 * np.log10(psd + 1e-10)
freqs = np.fft.rfftfreq(len(v_mu), 1/sfreq)
ax.plot(freqs, psd_db, 'b-')
ax.axvline(10, color='r', linestyle='--', alpha=0.5, label='~10 Hz')
ax.axvline(20, color='orange', linestyle='--', alpha=0.5, label='~20 Hz (harmonic)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power (dB)')
ax.set_title('(c) Power Spectral Density')
ax.set_xlim(0, 30)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey observation: Non-sinusoidal 'comb' shape with harmonic at 20 Hz!")

## 4. Experiment 3: Multi-Band Frequency Analysis (New)

**Original contribution:** Examine all atoms across frequency bands and measure non-sinusoidality.

**Harmonic ratio:** Higher values indicate more non-sinusoidal waveforms.

In [None]:
# Define frequency bands
FREQ_BANDS = {
    'Theta': (4, 8),
    'Alpha/Mu': (8, 12),
    'Beta': (15, 30)
}

def classify_atom(v, sfreq):
    """Classify atom by dominant frequency band."""
    freqs, psd = signal.welch(v, fs=sfreq, nperseg=min(len(v), 128))
    band_powers = {}
    for name, (f_low, f_high) in FREQ_BANDS.items():
        mask = (freqs >= f_low) & (freqs <= f_high)
        band_powers[name] = np.sum(psd[mask])
    return max(band_powers, key=band_powers.get), freqs[np.argmax(psd)]

def harmonic_ratio(v, sfreq):
    """Measure non-sinusoidality via harmonic content."""
    freqs, psd = signal.welch(v, fs=sfreq, nperseg=min(len(v), 128))
    peak_idx = np.argmax(psd)
    f0, p0 = freqs[peak_idx], psd[peak_idx]
    if f0 < 2 or p0 < 1e-10:
        return 0
    h_power = 0
    for n in [2, 3]:
        fn = n * f0
        if fn < freqs[-1]:
            mask = (freqs >= fn - 2) & (freqs <= fn + 2)
            if np.any(mask):
                h_power += np.max(psd[mask])
    return h_power / (p0 + 1e-10)

In [None]:
# Classify all atoms
atom_data = []
for i in range(n_atoms):
    v = cdl.v_hat_[i]
    band, peak_f = classify_atom(v, sfreq)
    H = harmonic_ratio(v, sfreq)
    atom_data.append({'idx': i, 'band': band, 'peak_freq': peak_f, 'H': H})
    print(f"Atom {i:2d}: {band:12s} (f={peak_f:5.1f} Hz, H={H:.3f})")

# Count by band
print("\n--- Summary ---")
for band in FREQ_BANDS:
    count = sum(1 for a in atom_data if a['band'] == band)
    print(f"{band}: {count} atoms")

In [None]:
# Visualize harmonicity
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Scatter: peak freq vs harmonic ratio
ax = axes[0]
for band in FREQ_BANDS:
    atoms = [a for a in atom_data if a['band'] == band]
    if atoms:
        ax.scatter([a['peak_freq'] for a in atoms], [a['H'] for a in atoms], 
                   label=band, s=80, alpha=0.7)
ax.set_xlabel('Peak Frequency (Hz)')
ax.set_ylabel('Harmonic Ratio')
ax.set_title('Peak Frequency vs Non-Sinusoidality')
ax.legend()
ax.grid(True, alpha=0.3)

# Histogram of H values
ax = axes[1]
H_values = [a['H'] for a in atom_data]
ax.hist(H_values, bins=15, edgecolor='black', alpha=0.7)
ax.axvline(np.mean(H_values), color='r', linestyle='--', label=f'Mean: {np.mean(H_values):.3f}')
ax.set_xlabel('Harmonic Ratio')
ax.set_ylabel('Count')
ax.set_title('Distribution of Non-Sinusoidality')
ax.legend()

plt.tight_layout()
plt.show()

## 5. Summary

### Key Findings

1. **CSC recovers ground truth atoms** on synthetic data (Exp 1)

2. **Mu-rhythm has non-sinusoidal waveform** - the characteristic "comb" shape produces a harmonic at ~20 Hz that could be mistaken for beta activity (Exp 2)

3. **Within-band waveform variety** - even atoms in the same frequency band show different degrees of non-sinusoidality (Exp 3)

4. **Rank-1 constraint enables source localization** - spatial patterns can be used for dipole fitting

### Implications

- Traditional Fourier analysis cannot distinguish mu from alpha
- CSC reveals waveform shapes hidden by frequency-only analysis
- Harmonics in PSD indicate non-sinusoidal waveforms, not independent oscillations