# SGN Architecture: Multi-Mic Echo Cancellation & Noise Suppression

This notebook demonstrates the SGN (Speech Enhancement Network) architecture with visual explanations.

**Companion Files:**
- `5_SGN_Architecture_Formulas_Dimensions.md` - Complete mathematical formulation
- `5_SGN_Feature_Evolution_Visual_Guide.md` - Visual explanations of feature evolution

**Model Specifications:**
- Parameters: 5.5M
- Complexity: ~0.5 GMAC/s
- FFT size: 320, Hop: 160
- 2-mic input + reference signal

## Setup

Install required packages and import libraries.

In [None]:
# Uncomment to install packages
# !pip install torch torchaudio numpy matplotlib scipy librosa soundfile

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Architecture Components

Refer to the companion markdown files for detailed explanations:
- Mathematical formulas and dimensions
- Visual guide showing feature evolution
- Why each component helps

This notebook provides a simplified implementation for demonstration.

In [None]:
print("SGN Architecture Overview:")
print("="*60)
print("1. STFT Preprocessing")
print("   - Converts time-domain to frequency-domain")
print("   - Input: 2 mics + reference signal")
print("")
print("2. Rotation Layer")
print("   - Spatial feature enhancement (learned beamforming)")
print("   - 2 mics -> 8 channels")
print("   - Enhances source separation")
print("")
print("3. Concatenation with Delayed Reference")
print("   - Adds echo cancellation context")
print("   - 2 past frames (20ms lookback)")
print("")
print("4. BiLSTM")
print("   - Temporal context (past + future)")
print("   - Tracks non-stationary noise")
print("   - Speech continuity modeling")
print("")
print("5. Dual LSTM Branches")
print("   - Branch 1: Echo Cancellation")
print("   - Branch 2: Noise Suppression")
print("   - Task specialization")
print("")
print("6. FC Layers")
print("   - Non-linear mapping to mask space")
print("   - Feature refinement")
print("")
print("7. Filter Block")
print("   - Adaptive mask generation")
print("   - Frequency-selective suppression")
print("")
print("8. ISTFT Post-processing")
print("   - Converts back to time-domain")
print("   - Clean speech output")
print("="*60)

## Key Insights

### Why Rotation Layer?
- **Spatial mixing**: Creates linear combinations of mic signals
- **Learned beamforming**: Optimized from data, not geometry
- **Feature expansion**: 2 mics -> 8 channels for more capacity
- **Better separation**: Speech and noise become more separable

### Why BiLSTM?
- **Full temporal context**: Past (forward) + Future (backward)
- **Speech continuity**: Uses surrounding frames to understand current frame
- **Non-stationary tracking**: Adapts to time-varying noise
- **Echo path modeling**: Tracks changing acoustic conditions

### Why Dual Branches?
- **Task decomposition**: EC and NS have different characteristics
- **Prevents interference**: Separate objectives don't conflict
- **Specialization**: Each branch learns task-specific patterns
- **Multi-task learning**: Shared BiLSTM + specialized branches

### Why Adaptive Filtering?
- **Frequency-selective**: Each frequency bin independently controlled
- **Learned mask**: Non-linear, context-aware suppression
- **Soft masking**: Smooth transitions, better quality
- **Superior to Wiener**: No assumptions, learns from data

## Dimension Flow Summary

```
Input Audio (2 mics + ref):
  [B, T] x 3

STFT:
  [B, K, 161] x 3

Pre-processing:
  [B, K, 2, 320]

Rotation Layer:
  [B, K, 8, 640]
  - Spatial enhancement
  - 2 mics -> 8 channels

Concat with Delayed Ref:
  [B, K, 8, 1280]
  - Added 2 past frames
  - Echo context

BiLSTM:
  [B, K, 8, 320]
  - Temporal context
  - Compression: 1280 -> 320

Dual LSTM Branches:
  Branch 1: [B, K, 8, 320]
  Branch 2: [B, K, 8, 320]
  - Task specialization

FC Layers:
  FC1: [B, K, 8, 640]
  FC2: [B, K, 8, 320]
  - Feature refinement

Filter Block:
  Mask: [B, K, 161]
  - Adaptive suppression

ISTFT:
  Output: [B, T]
  - Clean speech
```

## Mathematical Highlights

### Rotation Layer
$$\mathbf{Y}_{\text{rot}} = \mathbf{W}_{\text{rot}} \cdot \mathbf{X}_{\text{concat}} + \mathbf{b}_{\text{rot}}$$

Where:
- $\mathbf{W}_{\text{rot}} \in \mathbb{R}^{640 \times 640}$: Learned weight matrix
- $\mathbf{X}_{\text{concat}}$: Concatenated 2-mic input
- Creates 8 spatial channels

### BiLSTM
$$\mathbf{h}_t^{\text{bi}} = [\mathbf{h}_t^{(\rightarrow)}; \mathbf{h}_t^{(\leftarrow)}]$$

Where:
- $\mathbf{h}_t^{(\rightarrow)}$: Forward LSTM (past context)
- $\mathbf{h}_t^{(\leftarrow)}$: Backward LSTM (future context)
- Concatenated for full temporal awareness

### Adaptive Mask
$$\mathbf{M}(k, f) = \sigma(\text{NN}(\mathbf{X}_{\text{noisy}}, \mathbf{R}_{\text{ref}}))$$

Where:
- $\mathbf{M}(k, f) \in [0, 1]$: Gain mask
- $\sigma$: Sigmoid activation
- Learned from data, not assumptions

### Clean Speech Estimation
$$\hat{\mathbf{S}}_{\text{clean}}(k, f) = \mathbf{M}(k, f) \cdot \mathbf{X}_{\text{noisy}}(k, f)$$

## Comparison: Traditional vs SGN

### Traditional Wiener Filter
$$M_{\text{Wiener}}(k, f) = \frac{|S(k, f)|^2}{|S(k, f)|^2 + |N(k, f)|^2}$$

**Limitations:**
- Requires noise statistics
- Assumes Gaussian noise
- Linear processing
- Stationarity assumption

### SGN Learned Mask
$$M_{\text{SGN}}(k, f) = \sigma(\text{NN}(\mathbf{X}_{\text{noisy}}, \mathbf{R}_{\text{ref}}))$$

**Advantages:**
- No assumptions needed
- Non-linear relationships
- Multi-microphone spatial info
- Temporal context from LSTM
- Handles non-stationary noise
- Learns from data

## Feature Evolution Visualization

### Stage 1: Input (Raw Audio)
- **Complexity**: Low
- **Information**: Mixed signal (speech + echo + noise)
- **Separability**: Poor

### Stage 2: STFT (Spectrograms)
- **Complexity**: Medium
- **Information**: Frequency-domain representation
- **Separability**: Better (frequency separation)

### Stage 3: Rotation (Spatial Enhancement)
- **Complexity**: High
- **Information**: 8 spatial channels
- **Separability**: Much better (spatial + spectral)

### Stage 4: BiLSTM (Temporal Context)
- **Complexity**: Very High
- **Information**: Spatio-temporal features
- **Separability**: Excellent (spatial + spectral + temporal)

### Stage 5: Dual Branches (Task Specialization)
- **Complexity**: Specialized
- **Information**: Task-specific features
- **Separability**: Optimal (EC vs NS)

### Stage 6: Mask (Decision)
- **Complexity**: Simple
- **Information**: Binary-like decisions
- **Separability**: Complete (0 = suppress, 1 = keep)

### Stage 7: Output (Clean Speech)
- **Complexity**: Low
- **Information**: Separated signal
- **Quality**: Enhanced

## Conclusion

The SGN architecture effectively combines:
1. **Spatial processing** (Rotation layer)
2. **Temporal processing** (BiLSTM)
3. **Spectral processing** (STFT, Mask)
4. **Task decomposition** (Dual branches)

This multi-domain approach enables:
- Superior noise suppression
- Effective echo cancellation
- Adaptation to non-stationary conditions
- Real-time processing (0.5 GMAC/s)
- Compact model (5.5M parameters)

**For detailed mathematical formulations and visual explanations, please refer to:**
- `5_SGN_Architecture_Formulas_Dimensions.md`
- `5_SGN_Feature_Evolution_Visual_Guide.md`