# Demixed PLS (dPLS) Tutorial

## What is Demixed PLS?

**Demixed PLS** combines the marginalization concept from dPCA with the supervised nature of Partial Least Squares (PLS).

## dPCA vs dPLS: Key Differences

| Aspect | dPCA | dPLS |
|--------|------|------|
| **Objective** | Maximize **variance** per marginalization | Maximize **covariance with Y** per marginalization |
| **Supervision** | Unsupervised (no target) | Supervised (uses mental scores) |
| **Question answered** | "What patterns exist in the data?" | "What patterns predict mental state?" |
| **Components** | Explain most variance | Explain most covariance with target |
| **Best for** | Exploratory analysis | Prediction, association finding |

## Mathematical Comparison

```
dPCA:  Find W that maximizes Var(W'X_φ) for each marginalization φ
dPLS:  Find W that maximizes Cov(W'X_φ, Y) for each marginalization φ
```

## When to Use Which?

- **dPCA**: Explore what patterns exist in gait data across mental states
- **dPLS**: Find gait patterns that specifically predict mental scores

In [None]:
# Setup
import sys
from pathlib import Path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.dpca import DemixedPCA, DemixedPLS

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
print("Setup complete!")

## 1. Generate Data with Known Mental Score Structure

We'll create synthetic gait data where some features are associated with mental scores.

In [None]:
# Generate synthetic data
np.random.seed(42)

n_features = 15
n_timepoints = 100
n_conditions = 5

# Labels
feature_labels = [
    'hip_flexion', 'hip_abduction', 'knee_flexion', 'ankle_dorsiflexion',
    'pelvis_tilt', 'pelvis_obliquity', 'trunk_flexion', 'trunk_rotation',
    'stride_length', 'step_width', 'cadence', 'grf_vertical', 
    'grf_anterior', 'grf_lateral', 'com_velocity'
]
condition_labels = ['neutral', 'anxious', 'relaxed', 'focused', 'fatigued']

# Mental scores (Y) for each condition
# Higher = better wellbeing
mental_scores = np.array([5.0, 3.0, 7.0, 6.0, 4.0])  # neutral, anxious, relaxed, focused, fatigued

print("=== Mental Scores per Condition ===")
for label, score in zip(condition_labels, mental_scores):
    bar = "█" * int(score * 3)
    print(f"  {label:12s}: {score:.1f} {bar}")

In [None]:
# Generate gait data X: [n_features, n_timepoints, n_conditions]
t = np.linspace(0, 2*np.pi, n_timepoints)
data = np.zeros((n_features, n_timepoints, n_conditions))

# Define which features are associated with mental scores
# Features 0, 2, 8, 14 (hip, knee, stride, com_velocity) will correlate with wellbeing
score_related_features = [0, 2, 8, 14]

for f in range(n_features):
    for c in range(n_conditions):
        # Time component (gait cycle)
        time_component = np.sin(t + f * 0.3) + 0.5 * np.cos(2*t + f * 0.2)
        
        # Condition component
        if f in score_related_features:
            # These features are STRONGLY correlated with mental score
            condition_effect = 0.5 * (mental_scores[c] - mental_scores.mean())
        else:
            # These features have weak/no correlation
            condition_effect = 0.1 * (c - 2)
        
        # Interaction
        interaction = 0.2 * np.sin(t + c * 0.5) * (1 if f < 5 else 0.1)
        
        # Noise
        noise = np.random.randn(n_timepoints) * 0.15
        
        data[f, :, c] = time_component + condition_effect + interaction + noise

print(f"Data shape: {data.shape}")
print(f"  X: [n_features={n_features}, n_timepoints={n_timepoints}, n_conditions={n_conditions}]")
print(f"  Y: [n_conditions={n_conditions}]")
print(f"\nFeatures correlated with mental score: {[feature_labels[i] for i in score_related_features]}")

## 2. Compare dPCA and dPLS

Let's fit both models and compare their results.

In [None]:
# Fit dPCA (unsupervised)
dpca = DemixedPCA(n_components=5, regularizer='auto')
dpca.fit(data, feature_labels=feature_labels)

# Fit dPLS (supervised with mental scores)
dpls = DemixedPLS(n_components=5, regularizer='auto')
dpls.fit(data, mental_scores, feature_labels=feature_labels, condition_labels=condition_labels)

print("=== Model Fitting Complete ===")
print(f"\ndPCA: Unsupervised (no mental scores used)")
print(f"dPLS: Supervised (mental scores = {mental_scores})")

### Compare Feature Importance

**dPCA**: Which features explain most variance?  
**dPLS**: Which features predict mental scores?

In [None]:
# Compare feature weights
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# dPCA: encoder weights for condition marginalization
ax1 = axes[0]
dpca_weights = dpca._encoder['condition'][0, :]  # First component
sorted_idx_pca = np.argsort(np.abs(dpca_weights))[::-1]
sorted_weights_pca = dpca_weights[sorted_idx_pca]
sorted_labels_pca = [feature_labels[i] for i in sorted_idx_pca]

colors_pca = ['#ff6b6b' if i in score_related_features else '#4dabf7' for i in sorted_idx_pca]
ax1.barh(range(len(sorted_weights_pca)), sorted_weights_pca, color=colors_pca, alpha=0.8, edgecolor='black')
ax1.set_yticks(range(len(sorted_weights_pca)))
ax1.set_yticklabels(sorted_labels_pca)
ax1.set_xlabel('Weight')
ax1.set_title('dPCA: Condition PC1 Weights\n(Maximizes VARIANCE)', fontsize=12, fontweight='bold')
ax1.axvline(x=0, color='gray', lw=1)
ax1.grid(True, axis='x', alpha=0.3)

# dPLS: PLS weights for condition marginalization
ax2 = axes[1]
dpls_importance = dpls.get_feature_importance('condition')
dpls_weights = np.array([dpls_importance[f] for f in feature_labels])
sorted_idx_pls = np.argsort(dpls_weights)[::-1]
sorted_weights_pls = dpls_weights[sorted_idx_pls]
sorted_labels_pls = [feature_labels[i] for i in sorted_idx_pls]

colors_pls = ['#ff6b6b' if i in score_related_features else '#4dabf7' for i in sorted_idx_pls]
ax2.barh(range(len(sorted_weights_pls)), sorted_weights_pls, color=colors_pls, alpha=0.8, edgecolor='black')
ax2.set_yticks(range(len(sorted_weights_pls)))
ax2.set_yticklabels(sorted_labels_pls)
ax2.set_xlabel('VIP Score (Importance)')
ax2.set_title('dPLS: Feature Importance\n(Maximizes COVARIANCE with Y)', fontsize=12, fontweight='bold')
ax2.grid(True, axis='x', alpha=0.3)

plt.suptitle('Feature Ranking: dPCA vs dPLS\n(Red = Ground truth score-related features)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Legend: Red bars = features we designed to correlate with mental scores")
print("\nObservation:")
print("  - dPLS should rank score-related features higher than dPCA")
print("  - dPCA finds features with most variance (not necessarily predictive)")

## 3. Prediction: dPLS Can Predict Mental Scores

Unlike dPCA, dPLS is designed for prediction.

In [None]:
# dPLS can predict mental scores
predicted_scores = dpls.predict(data, marginalization='condition')

fig, ax = plt.subplots(figsize=(8, 6))

colors = ['#7f7f7f', '#d62728', '#2ca02c', '#1f77b4', '#ff7f0e']
for i, (label, color) in enumerate(zip(condition_labels, colors)):
    ax.scatter(mental_scores[i], predicted_scores[i], c=color, s=200, 
               edgecolors='black', label=label, zorder=5)

# Add diagonal line
ax.plot([2, 8], [2, 8], 'k--', lw=2, alpha=0.5, label='Perfect prediction')

# Calculate correlation
corr = np.corrcoef(mental_scores, predicted_scores)[0, 1]

ax.set_xlabel('Actual Mental Score', fontsize=12)
ax.set_ylabel('Predicted Mental Score (dPLS)', fontsize=12)
ax.set_title(f'dPLS Prediction of Mental Scores\nCorrelation: r = {corr:.3f}', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(2, 8)
ax.set_ylim(2, 8)

plt.tight_layout()
plt.show()

print("dPLS uses gait patterns to predict mental scores.")
print("dPCA cannot do this - it doesn't know about mental scores!")

## 4. Component Comparison: Time vs Condition

Both dPCA and dPLS separate time and condition components, but with different objectives.

In [None]:
# Transform data with both methods
Z_pca = dpca.transform(data)
Z_pls = dpls.transform(data)

time_axis = np.linspace(0, 100, n_timepoints)
colors = ['#7f7f7f', '#d62728', '#2ca02c', '#1f77b4', '#ff7f0e']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# dPCA Time Component
ax1 = axes[0, 0]
for c_idx, (state, color) in enumerate(zip(condition_labels, colors)):
    ax1.plot(time_axis, Z_pca['time'][0, :, c_idx], label=state, color=color, lw=2, alpha=0.7)
ax1.set_xlabel('Gait Cycle (%)')
ax1.set_ylabel('Time PC1')
ax1.set_title('dPCA: Time Component\n(Captures gait cycle variance)', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# dPCA Condition Component
ax2 = axes[0, 1]
for c_idx, (state, color) in enumerate(zip(condition_labels, colors)):
    ax2.plot(time_axis, Z_pca['condition'][0, :, c_idx], label=state, color=color, lw=2)
ax2.set_xlabel('Gait Cycle (%)')
ax2.set_ylabel('Condition PC1')
ax2.set_title('dPCA: Condition Component\n(Captures condition variance)', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# dPLS Time Component
ax3 = axes[1, 0]
for c_idx, (state, color) in enumerate(zip(condition_labels, colors)):
    ax3.plot(time_axis, Z_pls['time'][0, :, c_idx], label=state, color=color, lw=2, alpha=0.7)
ax3.set_xlabel('Gait Cycle (%)')
ax3.set_ylabel('Time PLS1')
ax3.set_title('dPLS: Time Component\n(Captures time patterns predicting Y)', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)

# dPLS Condition Component
ax4 = axes[1, 1]
for c_idx, (state, color) in enumerate(zip(condition_labels, colors)):
    ax4.plot(time_axis, Z_pls['condition'][0, :, c_idx], label=state, color=color, lw=2)
ax4.set_xlabel('Gait Cycle (%)')
ax4.set_ylabel('Condition PLS1')
ax4.set_title('dPLS: Condition Component\n(Captures condition patterns predicting Y)', fontsize=12, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.suptitle('Demixed Components: dPCA vs dPLS', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Summary

### dPCA vs dPLS Comparison

| | dPCA | dPLS |
|---|------|------|
| **Objective** | Max variance | Max covariance with Y |
| **Input** | X only | X and Y |
| **Output** | Variance-explaining components | Predictive components |
| **Feature selection** | High-variance features | Predictive features |
| **Can predict Y?** | ❌ No | ✅ Yes |

### When to Use Each

**Use dPCA when:**
- You want to explore data structure
- You don't have target labels/scores
- You want to understand what patterns exist

**Use dPLS when:**
- You have mental scores to predict
- You want to find predictive features
- You need a predictive model

### Data Requirements

Both methods require:
```python
X.shape = [n_features, n_timepoints, n_conditions]
```

dPLS additionally requires:
```python
Y.shape = [n_conditions] or [n_conditions, n_targets]
```

### Available in This Package

```python
from src.dpca import DemixedPCA, DemixedPLS

# Unsupervised
dpca = DemixedPCA(n_components=5)
dpca.fit(X)

# Supervised
dpls = DemixedPLS(n_components=5)
dpls.fit(X, Y)
Y_pred = dpls.predict(X_new)
```