# MMD Embedding Drift Detection - Live Demo
## Detecting Semantic Shifts in BERT Embedding Space

**Detector:** MMD (Maximum Mean Discrepancy)  
**Purpose:** Detect embedding space drift between baseline and current distributions  
**Algorithm:** Statistical two-sample test using kernel functions  
**Runtime:** ~20 seconds

---

## Setup

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from fed_drift.drift_detection import MMDDriftDetector

# Set random seed for reproducibility
np.random.seed(42)

print("‚úÖ Setup complete!")
print(f"üì¶ Imported MMDDriftDetector from fed_drift.drift_detection")

## Scenario: BERT Embeddings from AG News

We'll simulate BERT-tiny embeddings (128 dimensions) for two scenarios:
- **Baseline:** Normal AG News text embeddings
- **Drifted:** Embeddings after vocabulary shift (synonym replacement)

In production, these would be actual BERT embeddings. For the demo, we'll generate realistic synthetic embeddings that mimic the drift pattern.

In [None]:
# Generate synthetic BERT-like embeddings
embedding_dim = 128
n_samples = 100

# Baseline embeddings: centered around origin with moderate spread
baseline_embeddings = np.random.randn(n_samples, embedding_dim) * 0.5

# Drifted embeddings: shifted mean + increased variance (vocabulary drift effect)
drift_shift = np.random.randn(embedding_dim) * 0.3  # Systematic shift
drifted_embeddings = np.random.randn(n_samples, embedding_dim) * 0.7 + drift_shift

print(f"üìä Generated BERT-like embeddings:")
print(f"   Embedding dimension: {embedding_dim}")
print(f"   Baseline samples: {n_samples}")
print(f"   Drifted samples: {n_samples}")
print(f"\n   Baseline mean norm: {np.linalg.norm(baseline_embeddings.mean(axis=0)):.4f}")
print(f"   Drifted mean norm: {np.linalg.norm(drifted_embeddings.mean(axis=0)):.4f}")
print(f"   Distribution shift: {np.linalg.norm(drifted_embeddings.mean(axis=0) - baseline_embeddings.mean(axis=0)):.4f}")

## Run MMD Test

In [None]:
# Initialize MMD detector with p-value threshold of 0.05
detector = MMDDriftDetector(p_val_threshold=0.05, n_permutations=100)

print("‚è±Ô∏è  Running MMD drift test...\n")

# Run test
p_value = detector.test(baseline_embeddings, drifted_embeddings)
drift_detected = p_value < 0.05

print(f"üß™ MMD Test Results:")
print(f"   P-value: {p_value:.6f}")
print(f"   Threshold: 0.05")
print(f"   Decision: {'üö® DRIFT DETECTED' if drift_detected else '‚úÖ NO DRIFT'}")
print(f"\n   Interpretation:")
if drift_detected:
    print(f"   The two distributions are statistically different with {(1-p_value)*100:.2f}% confidence.")
    print(f"   Embedding space has shifted significantly - vocabulary drift likely occurred.")
else:
    print(f"   Cannot reject null hypothesis - distributions appear similar.")
    print(f"   No significant embedding drift detected.")

## Visualization: Embedding Space

In [None]:
# Create comprehensive visualization
fig = plt.figure(figsize=(16, 6))

# Plot 1: PCA projection (2D)
ax1 = plt.subplot(131)
pca = PCA(n_components=2)
baseline_2d = pca.fit_transform(baseline_embeddings)
drifted_2d = pca.transform(drifted_embeddings)

ax1.scatter(baseline_2d[:, 0], baseline_2d[:, 1], alpha=0.6, s=50, 
           c='blue', edgecolors='darkblue', linewidth=0.5, label='Baseline')
ax1.scatter(drifted_2d[:, 0], drifted_2d[:, 1], alpha=0.6, s=50, 
           c='red', edgecolors='darkred', linewidth=0.5, label='Drifted')

# Plot centroids
ax1.scatter(baseline_2d.mean(axis=0)[0], baseline_2d.mean(axis=0)[1], 
           marker='X', s=300, c='blue', edgecolors='black', linewidth=2, label='Baseline Centroid')
ax1.scatter(drifted_2d.mean(axis=0)[0], drifted_2d.mean(axis=0)[1], 
           marker='X', s=300, c='red', edgecolors='black', linewidth=2, label='Drifted Centroid')

ax1.set_xlabel('PCA Component 1', fontweight='bold')
ax1.set_ylabel('PCA Component 2', fontweight='bold')
ax1.set_title('Embedding Space (PCA Projection)', fontweight='bold')
ax1.legend(fontsize=9)
ax1.grid(alpha=0.3, linestyle='--')

# Plot 2: P-value visualization
ax2 = plt.subplot(132)
bars = ax2.bar(['Baseline vs\nDrifted'], [p_value], color=['red' if drift_detected else 'green'], 
               alpha=0.7, edgecolor='black', linewidth=2)
ax2.axhline(y=0.05, color='orange', linestyle='--', linewidth=2, label='Threshold (0.05)')
ax2.set_ylabel('P-value', fontweight='bold')
ax2.set_title(f'MMD Test Result\n{"DRIFT" if drift_detected else "NO DRIFT"}', fontweight='bold')
ax2.set_ylim([0, max(0.1, p_value * 1.2)])
ax2.legend(fontsize=9)
ax2.grid(alpha=0.3, linestyle='--', axis='y')

# Add p-value text
ax2.text(0, p_value + 0.005, f'p={p_value:.6f}', ha='center', va='bottom', 
        fontweight='bold', fontsize=11, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Plot 3: Distance between distributions
ax3 = plt.subplot(133)
baseline_mean = baseline_embeddings.mean(axis=0)
drifted_mean = drifted_embeddings.mean(axis=0)
euclidean_dist = np.linalg.norm(drifted_mean - baseline_mean)
cosine_sim = np.dot(baseline_mean, drifted_mean) / (np.linalg.norm(baseline_mean) * np.linalg.norm(drifted_mean))

metrics = ['Euclidean\nDistance', 'Cosine\nSimilarity']
values = [euclidean_dist, cosine_sim]
colors = ['red', 'blue']

bars = ax3.bar(metrics, values, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax3.set_ylabel('Value', fontweight='bold')
ax3.set_title('Distribution Distance Metrics', fontweight='bold')
ax3.grid(alpha=0.3, linestyle='--', axis='y')

# Add value labels
for i, (bar, value) in enumerate(zip(bars, values)):
    ax3.text(bar.get_x() + bar.get_width()/2, value + 0.02, f'{value:.4f}', 
            ha='center', va='bottom', fontweight='bold', fontsize=10)

plt.tight_layout()
plt.show()

print(f"\nüìä Visualization complete!")
print(f"\nüìè Distance Metrics:")
print(f"   Euclidean distance: {euclidean_dist:.4f}")
print(f"   Cosine similarity: {cosine_sim:.4f}")
print(f"   PCA variance explained: {pca.explained_variance_ratio_.sum():.2%}")

## Key Observations

‚úÖ **MMD successfully detected embedding drift**
- Statistical test with rigorous p-value calculation
- Detected semantic shift not visible in raw accuracy
- Robust to high-dimensional embeddings

‚úÖ **Algorithm Characteristics**
- Non-parametric (no distribution assumptions)
- Kernel-based distance metric
- Permutation testing for significance
- Scales to high dimensions (128D embeddings)

‚úÖ **Integration in Federated Learning**
- Runs on server with aggregated embeddings
- Clients send embeddings with evaluation metrics
- Complements client-side ADWIN and Evidently
- Catches drift in semantic space

## How MMD Works

**Intuition:** MMD measures the distance between two probability distributions by comparing their mean embeddings in a reproducing kernel Hilbert space (RKHS).

**Mathematical Definition:**
```
MMD¬≤(P, Q) = E[k(x, x')] - 2E[k(x, y)] + E[k(y, y')]
where:
  x, x' ~ P (baseline distribution)
  y, y' ~ Q (current distribution)
  k = kernel function (typically RBF)
```

**Algorithm Steps:**
1. Compute MMD statistic between baseline and current embeddings
2. Generate null distribution via permutation testing:
   - Randomly shuffle combined samples
   - Split into two groups
   - Compute MMD for shuffled data
   - Repeat N times (e.g., 100 permutations)
3. P-value = fraction of permutations with MMD ‚â• observed MMD
4. Reject null hypothesis if p-value < threshold

**Advantages:**
- Works with any kernel function
- No parametric assumptions
- Handles high-dimensional data
- Proven statistical guarantees

**References:**
- Gretton et al. (2012): "A Kernel Two-Sample Test"
- Implementation: Alibi-Detect library

## Comparison: No Drift Scenario

Let's verify MMD doesn't trigger false positives when distributions are actually the same:

In [None]:
# Generate two samples from the SAME distribution
baseline_test = np.random.randn(n_samples, embedding_dim) * 0.5
nodrift_test = np.random.randn(n_samples, embedding_dim) * 0.5

# Run MMD test
p_value_nodrift = detector.test(baseline_test, nodrift_test)
drift_detected_nodrift = p_value_nodrift < 0.05

print(f"üß™ No-Drift Control Test:")
print(f"   P-value: {p_value_nodrift:.6f}")
print(f"   Threshold: 0.05")
print(f"   Decision: {'üö® FALSE POSITIVE!' if drift_detected_nodrift else '‚úÖ CORRECTLY NO DRIFT'}")
print(f"\n   Both samples from same distribution - MMD should NOT detect drift.")
print(f"   This demonstrates low false positive rate.")

## Real-World Application

In the actual federated system:

1. **Client Side:**
   - During evaluation, collect BERT embeddings for sample texts
   - Send embeddings to server along with metrics

2. **Server Side:**
   - Aggregate embeddings from all clients
   - Maintain baseline embeddings from early rounds
   - Run MMD test each round: `MMD(baseline, current)`
   - If p < 0.05: trigger mitigation

3. **Advantages:**
   - Catches semantic drift not visible in metrics
   - Aggregated view across all clients
   - Complements local ADWIN and Evidently
   - Statistical rigor with p-value

---
## Summary

MMD provides **global embedding space monitoring** that complements local detectors:

| Detector | Level | What it Catches | When it Runs |
|----------|-------|-----------------|-------------|
| ADWIN | Client | Performance degradation | Every round |
| Evidently | Client | Data distribution shift | Every round |
| **MMD** | **Server** | **Embedding space drift** | **Every round** |

Together, these three detectors provide **comprehensive drift coverage** across all aspects of the federated learning system.