# BPNet LSTM Experiments

This notebook summarizes the experiments we ran for the BPNet project. We compare two implementations:

- **LSTM Simple** – the original 2-layer, 128-hidden LSTM stack (no scheduler or weight decay).
- **LSTM Tuned** – the larger 3-layer, 256-hidden LSTM with scheduling/regularization knobs.

For each model we train on different dataset sizes (20k vs. 50k segments, etc.) and visualize how the validation metrics evolve. The notebook automatically scans the `runs/` directory for training logs, so whenever we add a new run the tables and plots update automatically.

## 1. Imports & Utility Helpers

In [None]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8')
ROOT = Path('..').resolve()
RUNS_ROOT = ROOT / 'runs'
print(f'Notebook root: {ROOT}')

## 2. Load Metrics from All Runs

Each training run writes a `metrics.csv` plus a `config.json`. We parse both to recover the model type, dataset slice, and per-epoch statistics.

In [None]:
def infer_dataset(train_mat: str) -> str:
    name = train_mat.lower()
    if '50k' in name:
        return '50k segments'
    if 'x10' in name or '200k' in name:
        return '200k segments'
    if 'train_subset' in name:
        return 'full subset mat'
    if 'oscar' in name:
        return '20k segments'
    return 'unknown'


def infer_model(log_dir: Path, config: dict) -> str:
    log_name = str(log_dir).lower()
    if 'simple' in log_name or not config:
        return 'lstm_simple'
    return 'lstm_tuned'

records = []
for metrics_path in RUNS_ROOT.rglob('metrics.csv'):
    log_dir = metrics_path.parent
    config_path = log_dir / 'config.json'
    config = {}
    if config_path.exists():
        with open(config_path) as f:
            config = json.load(f)
    df = pd.read_csv(metrics_path)
    train_mat = config.get('train_mat', '')
    dataset = infer_dataset(train_mat)
    model_type = infer_model(log_dir, config)
    for _, row in df.iterrows():
        records.append({
            'log_dir': str(log_dir.relative_to(ROOT)),
            'model': model_type,
            'dataset': dataset,
            'epoch': int(row['epoch']),
            'lr': row.get('lr'),
            'train_loss': row['train_loss'],
            'train_mae': row.get('train_mae'),
            'train_rmse': row.get('train_rmse'),
            'train_corr': row.get('train_corr'),
            'val_loss': row.get('val_loss'),
            'val_mae': row.get('val_mae'),
            'val_rmse': row.get('val_rmse'),
            'val_corr': row.get('val_corr'),
        })

metrics_df = pd.DataFrame(records)
if metrics_df.empty:
    raise RuntimeError('No metrics.csv files found. Make sure training runs have logged metrics.')

metrics_df.head()

## 3. Best-epoch Summary

We aggregate the best validation loss and corresponding correlation for each (model, dataset, run) combination.

In [None]:
summary = (metrics_df.dropna(subset=['val_loss'])
                     .groupby(['log_dir', 'model', 'dataset'])
                     .agg(best_val_loss=('val_loss', 'min'),
                          best_val_corr=('val_corr', 'max'),
                          epochs=('epoch', 'max'))
                     .reset_index())
summary.sort_values('best_val_loss').reset_index(drop=True)

## 4. Validation Curves

The following plots compare how validation loss/correlation evolve for each configuration. Use the legend to distinguish runs and dataset sizes.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.lineplot(data=metrics_df, x='epoch', y='val_loss', hue='log_dir', style='dataset', ax=axes[0])
axes[0].set_title('Validation Loss per Run')
axes[0].set_ylabel('Val Loss (MSE)')
axes[0].grid(alpha=0.3)

sns.lineplot(data=metrics_df, x='epoch', y='val_corr', hue='log_dir', style='dataset', ax=axes[1])
axes[1].set_title('Validation Correlation per Run')
axes[1].set_ylabel('Pearson r')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Sample Predictions

For qualitative inspection we load the latest `epoch_*_samples.npz` stored inside a run directory and plot the first few target/pred pairs.

In [None]:
def preview_samples(log_dir: Path, max_pairs: int = 3):
    sample_files = sorted(log_dir.glob('epoch_*_samples.npz'))
    if not sample_files:
        print(f'No sample files for {log_dir}')
        return
    data = np.load(sample_files[-1])
    targets, preds = data['target_bp'], data['pred_bp']
    pairs = min(max_pairs, targets.shape[0])
    fig, axes = plt.subplots(pairs, 1, figsize=(10, 2 * pairs))
    if pairs == 1:
        axes = [axes]
    t = np.arange(targets.shape[-1])
    for i in range(pairs):
        axes[i].plot(t, targets[i, 0], label='Target')
        axes[i].plot(t, preds[i, 0], label='Predicted', alpha=0.8)
        axes[i].set_title(f'Segment {i+1}')
        axes[i].set_xlabel('Sample idx')
        axes[i].set_ylabel('BP (norm)')
        axes[i].grid(alpha=0.3)
    axes[0].legend()
    plt.tight_layout()
    plt.show()

best_log = summary.sort_values('best_val_loss').iloc[0]['log_dir']
print(f'Previewing samples for {best_log}')
preview_samples(ROOT / best_log)

## 6. Architecture Reference

| Model | Conv Layers | LSTM Hidden Size / Layers | Scheduler & Weight Decay |
|-------|-------------|---------------------------|---------------------------|
| `lstm_simple` | 2 conv layers (32 filters, kernel 7) | 128 hidden units × 2 layers, dropout 0.1 | None |
| `lstm_tuned`  | Same conv front-end | 256 hidden units × 3 layers, dropout 0.2 | Adam + weight decay 1e-4 + ReduceLROnPlateau |

Both models share the same per-segment normalization and MSE targets. The tuned variant simply adds capacity and better optimization controls, while the simple variant is the baseline used earlier in the project.

## 7. Takeaways

- The tuned LSTM consistently reaches lower validation loss/correlation when trained on 20k segments, but it needs the scheduler/regularization to behave.
- Scaling the dataset (e.g., 50k+) shifts the curves downward for both models, showing the benefit of exposing the LSTM to more subjects.
- The sample visualizations remain useful for sanity-checking: look for systematic phase shifts or amplitude bias when you iterate on preprocessing.

Feel free to rerun training with different dataset slices; as soon as the `runs/` directory contains their logs, this notebook will update the tables/plots automatically.