In [4]:
#!/usr/bin/env python3
"""
amplify_all_preproc_fast.py

Fast amplification (Hilbert envelope) for all processed_signals/{acoustic,vibration}/*.preproc.npz

- Default envelope mode: 'hilbert' (fast)
- Overwrites each .preproc.npz safely (writes <base>.tmp.npz then os.replace)
- Tunables at top: ALPHA, SMOOTH_MS, FREQ_BAND, ENVELOPE_MODE

Usage:
    python amplify_all_preproc_fast.py
"""

import os
import glob
import json
import time
from datetime import datetime
import traceback

import numpy as np
from scipy import ndimage, signal
from scipy.signal import butter, filtfilt, hilbert

# Optional: pywt if you intend to use 'cwt' envelope mode
try:
    import pywt
except Exception:
    pywt = None

# ----------------- USER TUNABLES -----------------
PROC_BASE = 'processed_signals'
FOLDERS = ['acoustic', 'vibration']  # subfolders to process (under PROC_BASE)

# Amplification parameters
ALPHA = 6.0                # multiplicative gain via (1 + ALPHA * envelope)
SMOOTH_MS = 8.0            # envelope smoothing (milliseconds)
# Frequency band to detect envelope energy (tune per modality)
FREQ_BAND = (100.0, 5000.0)  # (low_hz, high_hz) or None for automatic

# Envelope method: 'hilbert' (fast), 'stft' (medium), 'cwt' (slow, requires pywt)
ENVELOPE_MODE = 'hilbert'

# CWT settings (only used if ENVELOPE_MODE == 'cwt')
N_SCALES = 128

# STFT settings (only used if ENVELOPE_MODE == 'stft')
STFT_NPERSEG = 512

# Hilbert bandpass filter order
HILBERT_BP_ORDER = 4

# dtype for saving
DTYPE = np.float32

# ----------------- HELPERS -----------------
def safe_load_npz(path):
    try:
        return np.load(path, allow_pickle=True)
    except Exception as e:
        print(f"[ERR] failed loading {path}: {e}")
        return None

def try_parse_meta(meta_raw):
    if meta_raw is None:
        return {}
    try:
        # meta stored as json.dumps -> string or bytes
        if isinstance(meta_raw, np.ndarray):
            val = meta_raw.tolist()
        else:
            val = meta_raw
        if isinstance(val, bytes):
            val = val.decode('utf-8')
        if isinstance(val, str):
            return json.loads(val)
        if isinstance(val, dict):
            return val
    except Exception:
        try:
            return json.loads(meta_raw)
        except Exception:
            pass
    return {}

def smooth_envelope(env, fs, smooth_ms=SMOOTH_MS):
    sigma_samples = max(1.0, (smooth_ms / 1000.0) * fs)
    return ndimage.gaussian_filter1d(env, sigma=sigma_samples)

def bandpass_filter(sig, fs, low, high, order=HILBERT_BP_ORDER):
    ny = 0.5 * fs
    low_n = max(1e-8, low / ny)
    high_n = min(0.9999999, high / ny)
    if low_n >= high_n:
        return sig
    b, a = butter(order, [low_n, high_n], btype='band')
    try:
        return filtfilt(b, a, sig)
    except Exception:
        # fallback to single-pass filter if filtfilt fails
        try:
            return signal.lfilter(b, a, sig)
        except Exception:
            return sig

# ----------------- ENVELOPE METHODS -----------------
def envelope_via_hilbert(channels, fs, freq_band=None, smooth_ms=SMOOTH_MS):
    """
    Fast envelope builder:
      - bandpass per channel (freq_band)
      - analytic signal (Hilbert) -> magnitude
      - sum across channels, smooth, normalize to 0..1
    """
    C, N = channels.shape
    total_env = np.zeros(N, dtype=np.float64)
    if freq_band is None:
        low = 20.0
        high = min(0.45 * fs, fs / 2.0)
    else:
        low, high = float(freq_band[0]), float(freq_band[1])

    for c in range(C):
        sig = channels[c].astype(np.float64)
        if sig.size == 0:
            continue
        # quick skip: constant or zero channel
        if np.all(sig == 0) or np.nanstd(sig) == 0:
            continue
        try:
            bp = bandpass_filter(sig, fs, low, high, order=HILBERT_BP_ORDER)
            analytic = hilbert(bp)
            env = np.abs(analytic)
        except Exception:
            # fallback: absolute detrended signal smoothed
            env = np.abs(signal.detrend(sig))
            env = ndimage.gaussian_filter1d(env, sigma=max(1.0, (smooth_ms / 1000.0) * fs))
        total_env += env

    env_sm = smooth_envelope(total_env, fs, smooth_ms=smooth_ms)
    env_sm = env_sm - np.min(env_sm)
    mx = np.max(env_sm)
    if mx > 0:
        env_norm = (env_sm / (mx + 1e-12)).astype(np.float32)
    else:
        env_norm = env_sm.astype(np.float32)
    return env_norm

def envelope_via_stft(channels, fs, freq_band=None, smooth_ms=SMOOTH_MS, nperseg=STFT_NPERSEG):
    """
    Coarse STFT-based envelope per channel, mapped back to sample grid.
    """
    C, N = channels.shape
    total_env = np.zeros(N, dtype=np.float64)
    noverlap = int(0.5 * nperseg)
    sample_times = np.linspace(0, (N - 1) / fs, N)
    for c in range(C):
        sig = channels[c].astype(np.float64)
        if sig.size == 0 or np.nanstd(sig) == 0:
            continue
        try:
            f, t_seg, Sxx = signal.spectrogram(sig, fs=fs, nperseg=nperseg, noverlap=noverlap, mode='magnitude', scaling='density')
            if freq_band is None:
                mask = np.ones_like(f, dtype=bool)
            else:
                f_low, f_high = float(freq_band[0]), float(freq_band[1])
                mask = (f >= f_low) & (f <= f_high)
                if mask.sum() == 0:
                    mask = np.ones_like(f, dtype=bool)
            energy = np.sum(Sxx[mask, :], axis=0)
            energy_interp = np.interp(sample_times, t_seg, energy)
            total_env += energy_interp
        except Exception:
            # fallback quickly to hilbert
            return envelope_via_hilbert(channels, fs, freq_band=freq_band, smooth_ms=smooth_ms)
    env_sm = smooth_envelope(total_env, fs, smooth_ms=smooth_ms)
    env_sm = env_sm - np.min(env_sm)
    mx = np.max(env_sm)
    if mx > 0:
        env_norm = (env_sm / (mx + 1e-12)).astype(np.float32)
    else:
        env_norm = env_sm.astype(np.float32)
    return env_norm

def envelope_via_cwt(channels, fs, freq_band=None, n_scales=N_SCALES, smooth_ms=SMOOTH_MS):
    """
    Slow CWT-based envelope (kept for reference). Requires pywt.
    """
    if pywt is None:
        raise RuntimeError("pywt not available for CWT envelope mode.")
    C, N = channels.shape
    total_env = np.zeros(N, dtype=np.float64)
    fc = pywt.central_frequency('morl')
    fmax = min(0.45 * fs, fs / 2.0)
    fmin = 1.0 if (freq_band is None) else max(1.0, float(freq_band[0]))
    freqs = np.logspace(np.log10(fmin), np.log10(max(fmin * 1.01, fmax)), num=n_scales)
    dt = 1.0 / fs
    scales = fc / (freqs * dt)
    for c in range(C):
        sig = channels[c].astype(np.float64)
        if sig.size == 0 or np.nanstd(sig) == 0:
            continue
        coef, _ = pywt.cwt(sig, scales, 'morl', sampling_period=dt)
        power = np.abs(coef)
        med = np.median(power, axis=1) + 1e-12
        power_scaled = power / med[:, None]
        if freq_band is None:
            mask = np.ones_like(freqs, dtype=bool)
        else:
            f_low, f_high = float(freq_band[0]), float(freq_band[1])
            mask = (freqs >= f_low) & (freqs <= f_high)
            if mask.sum() == 0:
                mask = np.ones_like(freqs, dtype=bool)
        env_c = np.sum(power_scaled[mask, :], axis=0)
        total_env += env_c
    env_sm = smooth_envelope(total_env, fs, smooth_ms=smooth_ms)
    env_sm = env_sm - np.min(env_sm)
    mx = np.max(env_sm)
    if mx > 0:
        env_norm = (env_sm / (mx + 1e-12)).astype(np.float32)
    else:
        env_norm = env_sm.astype(np.float32)
    return env_norm

def build_envelope_from_channels_fast(channels, fs, freq_band=None, mode=None):
    if mode is None:
        mode = ENVELOPE_MODE
    if mode == 'hilbert':
        return envelope_via_hilbert(channels, fs, freq_band=freq_band, smooth_ms=SMOOTH_MS)
    elif mode == 'stft':
        return envelope_via_stft(channels, fs, freq_band=freq_band, smooth_ms=SMOOTH_MS)
    elif mode == 'cwt':
        return envelope_via_cwt(channels, fs, freq_band=freq_band, n_scales=N_SCALES, smooth_ms=SMOOTH_MS)
    else:
        # fallback
        return envelope_via_hilbert(channels, fs, freq_band=freq_band, smooth_ms=SMOOTH_MS)

# ----------------- MAIN FILE PROCESSING -----------------
def process_one_file(path, alpha=ALPHA, freq_band=FREQ_BAND, mode=ENVELOPE_MODE):
    data = safe_load_npz(path)
    if data is None:
        return False

    if 'times' not in data or 'channels' not in data:
        print(f"[ERR] missing keys in {path}")
        return False

    try:
        times = data['times'].astype(np.float64)
        channels = data['channels'].astype(np.float64)  # (C, N)
    except Exception as e:
        print(f"[ERR] cannot read times/channels from {path}: {e}")
        return False

    ch_names_raw = data.get('channel_names', None)
    if ch_names_raw is not None:
        try:
            ch_names = [c.decode('utf-8') if isinstance(c, bytes) else str(c) for c in ch_names_raw.tolist()]
        except Exception:
            ch_names = list(map(str, ch_names_raw))
    else:
        ch_names = [f'ch{i}' for i in range(channels.shape[0])]

    meta_raw = data.get('meta', None)
    meta = try_parse_meta(meta_raw)

    fs = float(meta.get('target_fs', 25600.0))
    N = channels.shape[1]

    # Build envelope (fast)
    try:
        env = build_envelope_from_channels_fast(channels, fs, freq_band=freq_band, mode=mode)
    except Exception as e:
        print(f"[WARN] envelope build failed for {path} with mode={mode}: {e}")
        # fallback to hilbert
        if mode != 'hilbert':
            try:
                env = build_envelope_from_channels_fast(channels, fs, freq_band=freq_band, mode='hilbert')
            except Exception as e2:
                print(f"[ERR] fallback hilbert also failed: {e2}")
                return False
        else:
            return False

    # Ensure envelope matches channel length
    if env.shape[0] != N:
        try:
            env = np.interp(np.linspace(0, 1, N), np.linspace(0, 1, env.shape[0]), env)
        except Exception:
            # fallback zeros
            env = np.zeros(N, dtype=np.float32)

    # Amplify multiplicatively
    factor = (1.0 + float(alpha) * env)  # shape (N,)
    channels_new = (channels * factor[None, :]).astype(DTYPE)

    # Update meta
    amplify_info = {
        'amplified': True,
        'amplify_params': {
            'alpha': float(alpha),
            'smooth_ms': float(SMOOTH_MS),
            'freq_band': tuple(freq_band) if freq_band is not None else None,
            'envelope_mode': mode,
            'n_scales': int(N_SCALES) if mode == 'cwt' else None
        },
        'amplify_time_utc': datetime.utcnow().isoformat() + 'Z'
    }
    meta_out = dict(meta) if isinstance(meta, dict) else {}
    meta_out['amplify'] = amplify_info

    # Save safely: write to <base>.tmp.npz then os.replace
    try:
        base, ext = os.path.splitext(path)
        if ext == '':
            ext = '.npz'
            base = path
        tmp_path = base + '.tmp' + ext  # e.g., file.preproc.tmp.npz

        # write tmp file
        np.savez_compressed(tmp_path,
                            times=times.astype(DTYPE),
                            channels=channels_new,
                            channel_names=np.array(ch_names),
                            meta=json.dumps(meta_out))

        # sanity: ensure tmp exists
        if not os.path.exists(tmp_path):
            print(f"[ERR] temp file {tmp_path} not found after save for {path}")
            return False

        # atomic replace
        try:
            os.replace(tmp_path, path)
        except Exception as e:
            # On Windows this can fail if file locked; try a retry
            print(f"[WARN] os.replace failed first attempt for {path}: {e}. Retrying after short sleep...")
            time.sleep(0.25)
            try:
                os.replace(tmp_path, path)
            except Exception as e2:
                # cleanup tmp and fail
                print(f"[ERR] os.replace failed again for {path}: {e2}")
                try:
                    if os.path.exists(tmp_path):
                        os.remove(tmp_path)
                except Exception:
                    pass
                return False

        return True

    except Exception as e:
        print(f"[ERR] failed to write {path}: {e}")
        traceback.print_exc()
        # cleanup if tmp exists
        try:
            if 'tmp_path' in locals() and os.path.exists(tmp_path):
                os.remove(tmp_path)
        except Exception:
            pass
        return False

# ----------------- MAIN -----------------
def main():
    total = 0
    succeeded = 0
    failed = 0
    start = time.time()

    print(f"[START] ENVELOPE_MODE={ENVELOPE_MODE}  ALPHA={ALPHA}  SMOOTH_MS={SMOOTH_MS}  FREQ_BAND={FREQ_BAND}")

    for folder in FOLDERS:
        folder_path = os.path.join(PROC_BASE, folder)
        if not os.path.isdir(folder_path):
            print(f"[WARN] folder missing: {folder_path}")
            continue
        files = sorted(glob.glob(os.path.join(folder_path, '*.preproc.npz')))
        print(f"[INFO] Found {len(files)} files in {folder_path}")
        for p in files:
            total += 1
            print(f"[{total}] Processing: {p}")
            ok = process_one_file(p, alpha=ALPHA, freq_band=FREQ_BAND, mode=ENVELOPE_MODE)
            if ok:
                succeeded += 1
                print(f"  -> OK (overwritten)")
            else:
                failed += 1
                print(f"  -> FAILED")

    elapsed = time.time() - start
    print(f"\nDone. processed={total}, succeeded={succeeded}, failed={failed}, elapsed={elapsed:.1f}s")

if __name__ == '__main__':
    main()


[START] ENVELOPE_MODE=hilbert  ALPHA=6.0  SMOOTH_MS=8.0  FREQ_BAND=(100.0, 5000.0)
[INFO] Found 5 files in processed_signals\acoustic
[1] Processing: processed_signals\acoustic\0Nm_BPFI_03.preproc.npz


  'amplify_time_utc': datetime.utcnow().isoformat() + 'Z'


[ERR] failed to write processed_signals\acoustic\0Nm_BPFI_03.preproc.npz: [WinError 2] The system cannot find the file specified: 'processed_signals\\acoustic\\0Nm_BPFI_03.preproc.npz.tmp' -> 'processed_signals\\acoustic\\0Nm_BPFI_03.preproc.npz'
  -> FAILED
[2] Processing: processed_signals\acoustic\0Nm_BPFI_10.preproc.npz
[ERR] failed to write processed_signals\acoustic\0Nm_BPFI_10.preproc.npz: [WinError 2] The system cannot find the file specified: 'processed_signals\\acoustic\\0Nm_BPFI_10.preproc.npz.tmp' -> 'processed_signals\\acoustic\\0Nm_BPFI_10.preproc.npz'
  -> FAILED
[3] Processing: processed_signals\acoustic\0Nm_BPFO_03.preproc.npz
[ERR] failed to write processed_signals\acoustic\0Nm_BPFO_03.preproc.npz: [WinError 2] The system cannot find the file specified: 'processed_signals\\acoustic\\0Nm_BPFO_03.preproc.npz.tmp' -> 'processed_signals\\acoustic\\0Nm_BPFO_03.preproc.npz'
  -> FAILED
[4] Processing: processed_signals\acoustic\0Nm_BPFO_10.preproc.npz
[ERR] failed to write 

KeyboardInterrupt: 