# Tensor-Train Volterra for Full MIMO Systems

This notebook demonstrates **Tensor-Train (TT) decomposition** for full Multi-Input Multi-Output (MIMO) Volterra systems.

## The Curse of Dimensionality in Volterra Models

A full $N$-th order Volterra model for a MIMO system with $I$ inputs, $O$ outputs, and memory $M$ requires:
$$
\text{Parameters} = O \cdot I^N \cdot M^N
$$

**Example:** For $I=3$ inputs, $O=2$ outputs, $N=3$ order, $M=10$ memory:
- Full model: $2 \times 3^3 \times 10^3 = 54{,}000$ parameters
- Storage explodes exponentially with $N$ and $M$!

## Tensor-Train Decomposition

TT decomposition factorizes the Volterra kernel tensor into a **product of low-rank cores**:
$$
\mathcal{H}(m_1, \ldots, m_N) = \mathbf{G}_1[m_1] \cdot \mathbf{G}_2[m_2] \cdots \mathbf{G}_N[m_N]
$$

where each core $\mathbf{G}_k[m_k]$ is a matrix of size $r_{k-1} \times r_k$ (with $r_0 = r_N = 1$).

**Benefits:**
- **Parameters**: $O(N \cdot M \cdot I \cdot r^2)$ instead of $O(I^N \cdot M^N)$
- **Example**: With $r=3$: $3 \times 10 \times 3 \times 3^2 = 810$ parameters (67× reduction!)
- **Scalable**: Works for high orders ($N > 5$) and long memory ($M > 20$)

---

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

from volterra import TTVolterraMIMO

np.random.seed(456)

plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['font.size'] = 10

---

## 1. Generate MIMO Nonlinear System Data

We'll create a 2-input, 2-output (MIMO) system with:
- **Linear crosstalk**: output 1 depends on both inputs
- **Nonlinear interactions**: quadratic and cubic terms
- **Memory effects**: IIR filtering

In [None]:
# System configuration
I, O = 2, 2  # 2 inputs, 2 outputs
fs = 48000
duration = 0.5
n_samples = int(fs * duration)

# Generate 2 input signals (bandlimited noise in different bands)
x1_white = np.random.randn(n_samples)
x2_white = np.random.randn(n_samples)

# Input 1: 100-2000 Hz
sos1 = signal.butter(6, [100, 2000], btype='bandpass', fs=fs, output='sos')
x1 = signal.sosfilt(sos1, x1_white)
x1 = x1 / np.std(x1) * 0.25

# Input 2: 1000-4000 Hz
sos2 = signal.butter(6, [1000, 4000], btype='bandpass', fs=fs, output='sos')
x2 = signal.sosfilt(sos2, x2_white)
x2 = x2 / np.std(x2) * 0.25

# Combine into MIMO input: shape (n_samples, I=2)
x = np.column_stack([x1, x2])

print(f"Input shape: {x.shape}")
print(f"Input 1 RMS: {np.sqrt(np.mean(x1**2)):.4f}")
print(f"Input 2 RMS: {np.sqrt(np.mean(x2**2)):.4f}")

In [None]:
# Define MIMO nonlinear system
def mimo_system(x):
    """
    2-input, 2-output nonlinear system.
    
    Output 1: Primarily driven by input 1, with weak coupling from input 2
    Output 2: Primarily driven by input 2, with weak coupling from input 1
    """
    x1, x2 = x[:, 0], x[:, 1]
    
    # Output 1: linear + quadratic + cubic, with crosstalk
    y1_nl = (
        0.7 * x1 +                 # Linear (main input)
        0.2 * x2 +                 # Linear crosstalk
        0.1 * x1**2 +              # Quadratic
        0.05 * x1**3 +             # Cubic
        0.03 * x1 * x2             # Cross-input interaction
    )
    
    # Output 2: different coefficients, reversed crosstalk
    y2_nl = (
        0.6 * x2 +                 # Linear (main input)
        0.25 * x1 +                # Linear crosstalk
        0.12 * x2**2 +             # Quadratic
        0.04 * x2**3 +             # Cubic
        0.02 * x1 * x2             # Cross-input interaction
    )
    
    # Apply memory (different IIR filters for each output)
    b1, a1 = [0.2, -0.38, 0.18], [1.0, -1.9, 0.94]
    b2, a2 = [0.18, -0.35, 0.17], [1.0, -1.85, 0.92]
    
    y1 = signal.lfilter(b1, a1, y1_nl)
    y2 = signal.lfilter(b2, a2, y2_nl)
    
    return np.column_stack([y1, y2])

# Generate outputs
y_clean = mimo_system(x)

# Add noise
noise = np.random.randn(n_samples, O) * 0.01
y = y_clean + noise

print(f"\nOutput shape: {y.shape}")
print(f"Output 1 RMS: {np.sqrt(np.mean(y[:, 0]**2)):.4f}")
print(f"Output 2 RMS: {np.sqrt(np.mean(y[:, 1]**2)):.4f}")
print(f"SNR (output 1): {10 * np.log10(np.mean(y_clean[:, 0]**2) / np.mean(noise[:, 0]**2)):.1f} dB")
print(f"SNR (output 2): {10 * np.log10(np.mean(y_clean[:, 1]**2) / np.mean(noise[:, 1]**2)):.1f} dB")

In [None]:
# Visualize MIMO signals
fig, axes = plt.subplots(2, 2, figsize=(15, 8))

n_plot = 2000
t_ms = np.arange(n_plot) / fs * 1000

# Inputs
axes[0, 0].plot(t_ms, x[:n_plot, 0], label='Input 1', alpha=0.7)
axes[0, 0].plot(t_ms, x[:n_plot, 1], label='Input 2', alpha=0.7)
axes[0, 0].set_xlabel('Time (ms)')
axes[0, 0].set_ylabel('Amplitude')
axes[0, 0].set_title('Input Signals (2 channels)')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Outputs
axes[0, 1].plot(t_ms, y[:n_plot, 0], label='Output 1', alpha=0.7)
axes[0, 1].plot(t_ms, y[:n_plot, 1], label='Output 2', alpha=0.7)
axes[0, 1].set_xlabel('Time (ms)')
axes[0, 1].set_ylabel('Amplitude')
axes[0, 1].set_title('Output Signals (2 channels)')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Spectra: Input 1
f, Pxx1 = signal.welch(x[:, 0], fs=fs, nperseg=2048)
f, Pxx2 = signal.welch(x[:, 1], fs=fs, nperseg=2048)
axes[1, 0].semilogy(f / 1000, Pxx1, label='Input 1', alpha=0.7)
axes[1, 0].semilogy(f / 1000, Pxx2, label='Input 2', alpha=0.7)
axes[1, 0].set_xlabel('Frequency (kHz)')
axes[1, 0].set_ylabel('PSD (V²/Hz)')
axes[1, 0].set_title('Input Spectra')
axes[1, 0].set_xlim([0, 6])
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Spectra: Outputs
f, Pyy1 = signal.welch(y[:, 0], fs=fs, nperseg=2048)
f, Pyy2 = signal.welch(y[:, 1], fs=fs, nperseg=2048)
axes[1, 1].semilogy(f / 1000, Pyy1, label='Output 1', alpha=0.7)
axes[1, 1].semilogy(f / 1000, Pyy2, label='Output 2', alpha=0.7)
axes[1, 1].set_xlabel('Frequency (kHz)')
axes[1, 1].set_ylabel('PSD (V²/Hz)')
axes[1, 1].set_title('Output Spectra (with harmonics)')
axes[1, 1].set_xlim([0, 6])
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 2. Fit Tensor-Train Volterra Model

We'll use `TTVolterraMIMO` with automatic rank selection.

In [None]:
# Split data
n_train = int(0.7 * n_samples)
x_train, x_test = x[:n_train], x[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

print(f"Training samples: {n_train}")
print(f"Testing samples: {len(x_test)}")
print(f"Training data shapes: x={x_train.shape}, y={y_train.shape}")

In [None]:
# Create TT-Volterra model
tt_model = TTVolterraMIMO(
    memory_length=8,
    order=3,
    ranks=[1, 3, 3, 1],  # TT ranks (length = order + 1)
    max_iter=30,
    lambda_reg=1e-5
)

print("TT-Volterra Model Configuration:")
print(f"  Memory length: {tt_model.memory_length}")
print(f"  Nonlinearity order: {tt_model.order}")
print(f"  TT ranks: {tt_model.ranks}")
print(f"  Max iterations: {tt_model.max_iter}")
print(f"  Regularization: {tt_model.lambda_reg}")

# Estimate number of parameters
M, N, I_in = tt_model.memory_length, tt_model.order, x_train.shape[1]
r = tt_model.ranks[1]  # Assume uniform rank
n_params_tt = N * M * I_in * r**2
n_params_full = I_in**N * M**N

print(f"\nParameter comparison:")
print(f"  TT-Volterra: ~{n_params_tt} parameters")
print(f"  Full Volterra: {n_params_full} parameters")
print(f"  Compression ratio: {n_params_full / n_params_tt:.1f}×")

In [None]:
# Fit the model
import time

print("Fitting TT-Volterra model...\n")
start_time = time.time()

tt_model.fit(x_train, y_train)

fit_time = time.time() - start_time

print(f"\nModel fitted in {fit_time:.2f} seconds")
print(f"Total parameters (per output): {tt_model.total_parameters()}")

---

## 3. Evaluate Model Performance

Let's evaluate each output separately.

In [None]:
# Predict on test set
y_test_pred = tt_model.predict(x_test)

# Trim ground truth to match prediction length
M = tt_model.memory_length
y_test_trimmed = y_test[M - 1:]

print(f"Prediction shape: {y_test_pred.shape}")
print(f"Ground truth (trimmed) shape: {y_test_trimmed.shape}")

# Compute NMSE for each output
def compute_nmse(y_true, y_pred):
    mse = np.mean((y_true - y_pred) ** 2)
    signal_power = np.mean(y_true ** 2)
    nmse_db = 10 * np.log10(mse / signal_power)
    return nmse_db

print("\nTest NMSE (per output):")
for o in range(O):
    nmse = compute_nmse(y_test_trimmed[:, o], y_test_pred[:, o])
    print(f"  Output {o + 1}: {nmse:.2f} dB")

# Overall NMSE
nmse_overall = compute_nmse(y_test_trimmed.ravel(), y_test_pred.ravel())
print(f"\nOverall NMSE: {nmse_overall:.2f} dB")

In [None]:
# Visualize predictions for both outputs
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

n_plot = 1500
t_ms = np.arange(n_plot) / fs * 1000

for o in range(O):
    # Time domain comparison
    axes[o, 0].plot(t_ms, y_test_trimmed[:n_plot, o], label='True', alpha=0.7, linewidth=1.5)
    axes[o, 0].plot(t_ms, y_test_pred[:n_plot, o], label='TT predicted', alpha=0.7, linewidth=1.5, linestyle='--')
    axes[o, 0].set_xlabel('Time (ms)')
    axes[o, 0].set_ylabel('Amplitude')
    nmse = compute_nmse(y_test_trimmed[:, o], y_test_pred[:, o])
    axes[o, 0].set_title(f'Output {o + 1} Prediction (NMSE: {nmse:.2f} dB)')
    axes[o, 0].legend()
    axes[o, 0].grid(True, alpha=0.3)
    
    # Prediction error
    error = y_test_trimmed[:n_plot, o] - y_test_pred[:n_plot, o]
    axes[o, 1].plot(t_ms, error, alpha=0.7, color='red')
    axes[o, 1].axhline(0, color='black', linestyle='--', linewidth=0.8)
    axes[o, 1].set_xlabel('Time (ms)')
    axes[o, 1].set_ylabel('Error')
    axes[o, 1].set_title(f'Output {o + 1} Error (RMS: {np.sqrt(np.mean(error**2)):.4f})')
    axes[o, 1].grid(True, alpha=0.3)
    
    # Scatter: predicted vs. true
    axes[o, 2].scatter(y_test_trimmed[::10, o], y_test_pred[::10, o], alpha=0.3, s=2)
    y_range = [y_test_trimmed[:, o].min(), y_test_trimmed[:, o].max()]
    axes[o, 2].plot(y_range, y_range, 'r--', linewidth=2, label='Perfect')
    axes[o, 2].set_xlabel('True output')
    axes[o, 2].set_ylabel('Predicted output')
    axes[o, 2].set_title(f'Output {o + 1}: Predicted vs. True')
    axes[o, 2].legend()
    axes[o, 2].grid(True, alpha=0.3)
    axes[o, 2].axis('equal')

plt.tight_layout()
plt.show()

---

## 4. Analyze TT Cores and Model Structure

Let's examine the learned TT cores to understand the model's internal structure.

In [None]:
# Access learned TT cores for first output
cores_output_0 = tt_model.get_cores(output_idx=0)

print(f"Number of TT cores (output 0): {len(cores_output_0)}")
print(f"TT core shapes (order × cores):")
for k, core in enumerate(cores_output_0):
    print(f"  Core {k + 1}: {core.shape}")

# Count total parameters
total_params = sum(core.size for core in cores_output_0)
print(f"\nTotal parameters (output 0): {total_params}")

In [None]:
# Visualize TT core magnitudes
fig, axes = plt.subplots(1, len(cores_output_0), figsize=(15, 4))

for k, core in enumerate(cores_output_0):
    # Reshape core for visualization: (r_{k-1}, memory × inputs, r_k)
    # For simplicity, show Frobenius norm per lag
    core_norms = np.linalg.norm(core.reshape(core.shape[0], -1, core.shape[-1]), axis=(0, 2))
    
    axes[k].bar(range(len(core_norms)), core_norms, alpha=0.7, edgecolor='black')
    axes[k].set_xlabel('Memory index')
    axes[k].set_ylabel('Frobenius norm')
    axes[k].set_title(f'Core {k + 1} (Order {k + 1})')
    axes[k].grid(True, alpha=0.3, axis='y')

plt.suptitle('TT Core Magnitudes (Output 0)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("  - Taller bars indicate more important memory lags for that nonlinearity order")
print("  - Core 1 (linear): typically dominates for weakly nonlinear systems")
print("  - Cores 2-3 (quadratic/cubic): capture higher-order interactions")

---

## 5. Rank Selection: Impact on Performance

How do TT ranks affect model performance? Let's compare different rank configurations.

In [None]:
# Test different TT ranks
rank_configs = [
    [1, 1, 1, 1],  # Diagonal (rank-1, equivalent to MP)
    [1, 2, 2, 1],  # Rank-2
    [1, 3, 3, 1],  # Rank-3
    [1, 4, 4, 1],  # Rank-4
    [1, 5, 5, 1],  # Rank-5
]

results = []

print("Testing different TT ranks...\n")

for ranks in rank_configs:
    model = TTVolterraMIMO(
        memory_length=8,
        order=3,
        ranks=ranks,
        max_iter=20,
        lambda_reg=1e-5
    )
    
    # Fit and predict
    start = time.time()
    model.fit(x_train, y_train)
    fit_time = time.time() - start
    
    y_pred = model.predict(x_test)
    
    # Evaluate
    nmse = compute_nmse(y_test_trimmed.ravel(), y_pred.ravel())
    n_params = model.total_parameters()
    
    results.append({
        'ranks': ranks,
        'max_rank': max(ranks),
        'n_params': n_params,
        'nmse_db': nmse,
        'fit_time': fit_time
    })
    
    print(f"Ranks {ranks}: {n_params:4d} params, NMSE = {nmse:6.2f} dB, time = {fit_time:.2f}s")

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

max_ranks = [r['max_rank'] for r in results]
n_params = [r['n_params'] for r in results]
nmses = [r['nmse_db'] for r in results]
fit_times = [r['fit_time'] for r in results]

# NMSE vs. parameters
axes[0].plot(n_params, nmses, 'o-', markersize=10, linewidth=2)
for i, r in enumerate(results):
    axes[0].text(n_params[i], nmses[i] + 0.5, f"r={r['max_rank']}", 
                ha='center', fontsize=10, fontweight='bold')
axes[0].set_xlabel('Number of Parameters')
axes[0].set_ylabel('Test NMSE (dB)')
axes[0].set_title('Rank vs. Performance Trade-off')
axes[0].axhline(-20, color='green', linestyle='--', linewidth=1.5, alpha=0.5, label='Excellent')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Fit time vs. rank
axes[1].bar(range(len(max_ranks)), fit_times, alpha=0.7, edgecolor='black')
axes[1].set_xticks(range(len(max_ranks)))
axes[1].set_xticklabels([f"r={r}" for r in max_ranks])
axes[1].set_xlabel('Max TT Rank')
axes[1].set_ylabel('Fit Time (seconds)')
axes[1].set_title('Computational Cost vs. Rank')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nKey observations:")
print("  - Rank 1 = diagonal (MP-equivalent): fast but may underfit")
print("  - Higher ranks capture more complex interactions")
print("  - Diminishing returns: rank 3-4 often sufficient for many systems")
print("  - Fit time increases ~quadratically with rank")

---

## Summary

In this notebook, we:

1. **Generated MIMO nonlinear system data** (2 inputs, 2 outputs)
2. **Fitted Tensor-Train Volterra model** with automatic rank selection
3. **Evaluated performance** for each output channel
4. **Analyzed TT core structure** to understand learned representations
5. **Compared different TT ranks** to guide hyperparameter selection

### When to use TT-Volterra:
- ✅ **High-dimensional MIMO** (many inputs/outputs)
- ✅ **High order or long memory** (N > 3 or M > 10)
- ✅ **Parameter efficiency critical** (embedded systems, real-time)
- ✅ **Full cross-input interactions** needed (not just diagonal)
- ❌ **Low-dimensional SISO** → use MP or GMP (simpler, faster)
- ❌ **Very large datasets** → TT-ALS can be slow (use stochastic methods)

### Practical recommendations:
1. **Start with rank 2-3** for most applications
2. **Increase rank if underfitting** (poor NMSE)
3. **Use regularization** (`lambda_reg > 0`) to prevent overfitting
4. **Monitor convergence** via `fit_info_` diagnostics
5. **Consider rank adaptation** (TT-MALS) for automatic rank selection

### Computational complexity:
- **TT-Volterra**: $O(N \cdot M \cdot I \cdot r^2 \cdot T)$ per iteration
- **Full Volterra**: $O(I^N \cdot M^N \cdot T)$ — intractable for N > 3!
- **Memory Polynomial**: $O(M \cdot N \cdot T)$ — fastest, but limited expressiveness

### Next steps:
- **Notebook 03**: Automatic model selection (MP vs GMP vs TT-Full)
- **Notebook 04**: Real-world application (instrument + room pipeline)