In [None]:
#!/usr/bin/env python3
# ================================================================
#   INSHEP bulk-feature extractor — **multicore version**
#   ▪  Spawns one worker per logical CPU (max 12 on your machine)
#   ▪  Streams results straight into CSV  (appends row-by-row)
#   ▪  Totally self-contained: just place this script beside the
#       datasets/  directory and  python fast_extract.py
# ================================================================
import os, math, warnings, csv
from pathlib import Path
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import pandas as pd
from scipy import signal, stats
from skimage.feature import graycomatrix, graycoprops
from skimage.util import img_as_ubyte

# ──────────────────────────────────────────────────────────────
#  optional pseudo-Zernike moments  (needs  pip install mahotas)
# ──────────────────────────────────────────────────────────────
try:
    import mahotas as mh
    _HAS_MAHOTAS = True
except ImportError:
    warnings.warn("⚠️  mahotas not found – pseudo-Zernike moments will be 0.")
    _HAS_MAHOTAS = False

# ──────────────────────────────────────────────────────────────
#  GLOBAL CONSTANTS
# ──────────────────────────────────────────────────────────────
DATASETS_ROOT   = DATASETS_ROOT = Path("..") / "datasets"
CSV_PATH        = "INSHEP_features.csv"

ACTIVITY_MAP = {
    "1": "walking",
    "2": "sitting_down",
    "3": "standing_up",
    "4": "pick_object",
    "5": "drink_water",
    "6": "fall",
}

TIME_WINDOW     = 200
OVERLAP_FRAC    = 0.95
PAD_FACTOR      = 4
BUTTER_N        = 4
BUTTER_CUT      = 0.0075         # high-pass cut-off (fraction of Nyquist)
TORSO_V_MAX     = 0.25           # ± m/s
DENSITY_THR_DB  = -3             # dB down from peak for masks

FIELDNAMES = [
    "file_id", "activity", "path",
    "mean_entropy", "mean_power", "variance", "stddev",
    "max_vel", "amp_density", "kurtosis", "zernike_moment",
    "periodicity", "mean_torso_power", "pos_neg_ratio",
    "doppler_offset", "main_lobe_width","auto_correlation",
    "envelope_width", "limb_asymmetry", "limb_power",
    "limb_smoothness", "clean_kurtosis",
    "motion_duration", "doppler_peak_velocity", "doppler_symmetry_index",
    "cepstral_entropy", "range_bin_span", "doppler_bandwidth",
    "skew_val",
    "contrast", "dissimilarity", "homogeneity", "energy",
    "correlation", "ASM"
    ]

# ──────────────────────────────────────────────────────────────
#  LOW-LEVEL UTILITIES  (top-level → picklable)
# ──────────────────────────────────────────────────────────────

def iq_correction(raw_data):
    """
    Perform I/Q correction on complex radar data
    
    Args:
        raw_data (np.ndarray): Complex IQ data
        window_size (int): Window size for moving average
        
    Returns:
        np.ndarray: Corrected complex IQ data
    """
    # Split into I and Q components
    i_data = np.real(raw_data)
    q_data = np.imag(raw_data)
    
    # DC offset removal
    i_dc = np.mean(i_data)
    q_dc = np.mean(q_data)
    i_data = i_data - i_dc
    q_data = q_data - q_dc
    
    # Amplitude correction
    i_amp = np.sqrt(np.mean(i_data**2))
    q_amp = np.sqrt(np.mean(q_data**2))
    amp_correction = np.sqrt(i_amp * q_amp)
    i_data = i_data * (amp_correction / i_amp)
    q_data = q_data * (amp_correction / q_amp)
    
    # Phase imbalance correction
    iq_corr = np.mean(i_data * q_data)
    phase_error = np.arcsin(iq_corr / (i_amp * q_amp))
    q_data_corr = q_data * np.cos(phase_error) - i_data * np.sin(phase_error)
    
    return i_data + 1j * q_data_corr

def read_dat(path: Path):
    """Load one *.dat file and return fc [Hz], Tsweep [s], MTI-filtered range-time matrix."""
    with open(path, "r") as f:
        lines = [ln.strip() for ln in f]

    fc, Tsweep_ms, NTS, Bw = map(float, lines[:4])
    Tsweep = Tsweep_ms * 1e-3
    NTS    = int(NTS)
    raw    = np.array([complex(s.replace("i", "j")) for s in lines[4:]])

    # Apply I/Q correction to raw data
    raw_corrected = iq_correction(raw)

    n_chirps = raw_corrected.size // NTS
    time_mat = raw_corrected[: n_chirps * NTS].reshape((NTS, n_chirps), order="F")

    rng_fft  = np.fft.fftshift(np.fft.fft(time_mat, axis=0), axes=0)
    rng_half = rng_fft[NTS // 2 :, :]

    b, a     = signal.butter(BUTTER_N, BUTTER_CUT, "high")
    mti      = signal.lfilter(b, a, rng_half, axis=1)

    return fc, Tsweep, mti[1:, :]     # skip leakage bin


def kalman_filter_1d(observed, dt, process_noise, measurement_noise):
    A = np.array([[1, dt], [0, 1]])
    B = np.array([[0.5 * dt**2], [dt]])
    C = np.array([[1, 0]])
    Q = process_noise * np.array([[dt**4/4, dt**3/2], [dt**3/2, dt**2]])
    R = measurement_noise ** 2

    x = np.array([[observed[0]], [0]])  # initial state: position + velocity
    P = np.eye(2)
    filtered = []

    for z in observed:
        # Predict
        x = A @ x
        P = A @ P @ A.T + Q

        # Update
        K = P @ C.T @ np.linalg.inv(C @ P @ C.T + R)
        x = x + K @ (z - C @ x)
        P = (np.eye(2) - K @ C) @ P

        filtered.append(x[0, 0])
    
    return np.array(filtered)
def remove_torso_from_spectrogram(Sxx, velocity_axis, torso_velocity_trace, bandwidth=0.4):
    """
    Removes the torso signature from the spectrogram and centers all motion around zero velocity
    for better limb motion analysis.
    
    Parameters:
    -----------
    Sxx : 2D array
        Input spectrogram
    velocity_axis : array
        Velocity axis values
    torso_velocity_trace : array
        Estimated torso velocity over time
    bandwidth : float
        Width of the suppression band around zero velocity in m/s
    
    Returns:
    --------
    Sxx_centered : 2D array
        Spectrogram with torso removed and motion centered around zero velocity
    """
    Sxx_centered = np.zeros_like(Sxx)
    zero_idx = np.argmin(np.abs(velocity_axis - 0))

    for i, torso_vel in enumerate(torso_velocity_trace):
        # Shift to center torso at 0
        shift_idx = np.argmin(np.abs(velocity_axis - torso_vel))
        shift_amount = zero_idx - shift_idx
        shifted_col = np.roll(Sxx[:, i], shift=shift_amount)

        # Apply suppression mask around 0 velocity
        suppress_mask = np.abs(velocity_axis) <= bandwidth
        shifted_col[suppress_mask] = 0

        # Store the centered version without rolling back
        Sxx_centered[:, i] = shifted_col

    return Sxx_centered

def detect_envelope(Sxx_dB, velocity_axis, threshold_dB):
    n_bins, n_frames = Sxx_dB.shape
    upper_envelope = np.full(n_frames, np.nan)
    lower_envelope = np.full(n_frames, np.nan)

    for i in range(n_frames):
        above_thresh = np.where(Sxx_dB[:, i] > threshold_dB)[0]
        if len(above_thresh) > 0:
            lower_envelope[i] = velocity_axis[above_thresh[0]]
            upper_envelope[i] = velocity_axis[above_thresh[-1]]

    return lower_envelope, upper_envelope


def stft_mag(mti, prf):
    """Accumulate |STFT| for range bins 10-30 → spectrogram magnitude + doppler axis."""
    nperseg  = TIME_WINDOW
    noverlap = int(round(nperseg * OVERLAP_FRAC))
    nfft     = PAD_FACTOR * nperseg

    S_accum = None
    for r in range(9, 30):  # bins 10-30  (0-based 9-29)
        _, _, S = signal.spectrogram(
            mti[r, :],
            fs            = prf,
            window        = "hann",
            nperseg       = nperseg,
            noverlap      = noverlap,
            nfft          = nfft,
            mode          = "complex",
            return_onesided=False,
        )
        S = np.fft.fftshift(S, axes=0)
        S_accum = np.abs(S) if S_accum is None else S_accum + np.abs(S)

    doppler = np.fft.fftshift(np.fft.fftfreq(nfft, d=1 / prf))
    return S_accum, doppler


def binary_mask(db_img, thresh_db):
    return db_img >= (db_img.max() + thresh_db)


def pseudo_zernike(img, radius=20, degree=4):
    if not _HAS_MAHOTAS:
        return 0.0
    size   = max(img.shape)
    padder = [(0, size - img.shape[0]), (0, size - img.shape[1])]
    img_n  = (np.pad(img, padder) - img.min()) / (img.ptp() + 1e-12)
    return float(np.mean(np.abs(mh.features.zernike_moments(img_n, radius, degree=degree))))

# Then update the function to use the new names:
def extract_glcm_features(spectrogram_image):
    spectrogram_image = np.abs(spectrogram_image)
    spectrogram_image = (spectrogram_image - spectrogram_image.min()) / (spectrogram_image.max() - spectrogram_image.min())
    spectrogram_image = img_as_ubyte(spectrogram_image)

    glcm = graycomatrix(
        spectrogram_image,
        distances=[1, 3],
        angles=[0, np.pi/4, np.pi/2, 3*np.pi/4],
        levels=256,
        symmetric=True,
        normed=True
    )

    
    features = {
        'contrast': graycoprops(glcm, 'contrast').mean(),
        'dissimilarity': graycoprops(glcm, 'dissimilarity').mean(),
        'homogeneity': graycoprops(glcm, 'homogeneity').mean(),
        'energy': graycoprops(glcm, 'energy').mean(),
        'correlation': graycoprops(glcm, 'correlation').mean(),
        'ASM': graycoprops(glcm, 'ASM').mean()
    }

    return features


def extract_features(mti, fc, Tsweep):
    prf                = 1.0 / Tsweep
    S, doppler         = stft_mag(mti, prf)
    S2                 = S**2
    flat               = S2.ravel()

    p                  = flat / (flat.sum() + 1e-12)
    mean_entropy       = float(-(p * np.log(p + 1e-12)).sum())
    mean_power         = float(flat.mean())
    variance           = float(flat.var())
    stddev             = float(math.sqrt(variance))

    v_axis             = doppler * 3e8 / (2 * fc)
    vmax               = float(v_axis[np.unravel_index(S.argmax(), S.shape)[0]])
    amp_density        = binary_mask(20 * np.log10(S + 1e-12), DENSITY_THR_DB).mean()
    kurtosis_val       = float(stats.kurtosis(flat, fisher=False))
    z_moment           = pseudo_zernike(S)

    pw_sweep           = S2.sum(axis=0)
    acf                = signal.correlate(pw_sweep, pw_sweep, mode="full")[len(pw_sweep)-1 :]
    periodicity        = float(acf[1:].max() / (acf[0] + 1e-12))

    torso_mask         = np.abs(v_axis) <= TORSO_V_MAX
    mean_torso_power   = float(S2[torso_mask, :].mean())

    pos_power          = S2[v_axis > 0, :].sum()
    neg_power          = S2[v_axis < 0, :].sum()
    pos_neg_ratio      = float(pos_power / (neg_power + 1e-12))

    weights            = S2.sum(axis=1)
    doppler_offset     = float((v_axis * weights).sum() / (weights.sum() + 1e-12))

    row_db             = 20 * np.log10(S2.mean(axis=1) + 1e-12)
    mask               = binary_mask(row_db, DENSITY_THR_DB)
    if mask.any():
        idx            = np.where(mask)[0]
        main_lobe_width = float(v_axis[idx.max()] - v_axis[idx.min()])
    else:
        main_lobe_width = 0.0

    # auto-correlation
    pw_sweep           = S2.sum(axis=0)
    acf                = signal.correlate(pw_sweep, pw_sweep, mode="full")[len(pw_sweep)-1 :]
    periodicity        = float(acf[1:].max() / (acf[0] + 1e-12))
    
    auto_correlation   = float(acf[1] / (acf[0] + 1e-12))  # First lag autocorrelation

    # Square the spectrogram for further processing
    S2 = S**2

    # Estimate torso velocity trace
    raw_idx = np.argmax(S, axis=0)
    raw_torso_v = v_axis[raw_idx]

    kalman_torso_v = kalman_filter_1d(raw_torso_v, dt=(1 / prf), process_noise=10000.0, measurement_noise=0.1)

    # Remove torso
    S2_torso_removed = remove_torso_from_spectrogram(S2, v_axis, kalman_torso_v)

    # Convert to dB
    S2_dB_torso_removed = 20 * np.log10(S2_torso_removed + 1e-12)

    # Envelope detection
    threshold = S2_dB_torso_removed.min() + 20
    lower_env, upper_env = detect_envelope(S2_dB_torso_removed, v_axis, threshold)

    # Envelope-based features
    env_width      = np.nanmean(upper_env - lower_env)
    env_asymmetry  = float((upper_env - lower_env).mean())
    smoothness     = float(np.std(np.gradient(upper_env - lower_env)))
    limb_power     = float(S2_torso_removed[(v_axis[:, None] >= lower_env) & (v_axis[:, None] <= upper_env)].sum())
    kurtosis_clean = float(stats.kurtosis(S2_torso_removed.ravel(), fisher=False))

    
    # Motion Duration: Duration of signal above a threshold in the time domain.
    nperseg = TIME_WINDOW
    noverlap = int(round(nperseg * OVERLAP_FRAC))
    time_step = (nperseg - noverlap) / prf
    motion_thresh = pw_sweep.max() * (10**(DENSITY_THR_DB / 10))
    motion_duration = float((pw_sweep >= motion_thresh).sum() * time_step)

    # Average Doppler Spectrum features
    avg_doppler_spectrum = S2.mean(axis=1)

    # Doppler Peak Velocity: Velocity at the peak of the time-averaged Doppler spectrum.
    doppler_peak_velocity = float(v_axis[avg_doppler_spectrum.argmax()])

    # Doppler Symmetry Index: Normalized difference between positive and negative Doppler power.
    pos_mask = v_axis > 0
    neg_mask = v_axis < 0
    pos_power_avg = avg_doppler_spectrum[pos_mask].sum()
    neg_power_avg = avg_doppler_spectrum[neg_mask].sum()
    doppler_symmetry_index = float((pos_power_avg - neg_power_avg) / (pos_power_avg + neg_power_avg + 1e-12))

    # Cepstral Entropy: Entropy of the power cepstrum of the average Doppler spectrum.
    log_spec = np.log(avg_doppler_spectrum + 1e-12)
    cepstrum = np.abs(np.fft.irfft(log_spec))**2
    p_cep = cepstrum / (cepstrum.sum() + 1e-12)
    cepstral_entropy = float(-(p_cep * np.log(p_cep + 1e-12)).sum())

    # Range Bin Span: Spread of the signal across range bins.
    range_power = np.sum(np.abs(mti)**2, axis=1)
    range_thresh = range_power.max() * (10**(DENSITY_THR_DB / 10))
    active_bins_mask = range_power >= range_thresh
    if active_bins_mask.any():
        active_indices = np.where(active_bins_mask)[0]
        range_bin_span = float(active_indices.max() - active_indices.min())
    else:
        range_bin_span = 0.0

    # Doppler Bandwidth: Power-weighted standard deviation of the Doppler velocity.
    doppler_variance = (weights * (v_axis - doppler_offset)**2).sum() / (weights.sum() + 1e-12)
    doppler_bandwidth = float(np.sqrt(doppler_variance))

    # Skewness of the Doppler spectrum
    skew_val = float(stats.skew(S2.mean(axis=1)))       

    return dict(
        mean_entropy       = mean_entropy,
        mean_power         = mean_power,
        variance           = variance,
        stddev             = stddev,
        max_vel            = vmax,
        amp_density        = amp_density,
        kurtosis           = kurtosis_val,
        zernike_moment     = z_moment,
        periodicity        = periodicity,
        mean_torso_power   = mean_torso_power,
        pos_neg_ratio      = pos_neg_ratio,
        doppler_offset     = doppler_offset,
        main_lobe_width    = main_lobe_width,
        auto_correlation   = auto_correlation,
        envelope_width     = env_width,
        limb_asymmetry     = env_asymmetry,
        limb_power         = limb_power,
        limb_smoothness    = smoothness,
        clean_kurtosis     = kurtosis_clean,
        motion_duration        = motion_duration,
        doppler_peak_velocity  = doppler_peak_velocity,
        doppler_symmetry_index = doppler_symmetry_index,
        cepstral_entropy       = cepstral_entropy,
        range_bin_span         = range_bin_span,
        doppler_bandwidth      = doppler_bandwidth,
        skew_val            = skew_val,
        contrast           = extract_glcm_features(S2)[ 'contrast' ],
        dissimilarity      = extract_glcm_features(S2)[ 'dissimilarity' ],
        homogeneity        = extract_glcm_features(S2)[ 'homogeneity' ],
        energy             = extract_glcm_features(S2)[ 'energy' ],
        correlation        = extract_glcm_features(S2)[ 'correlation' ],
        ASM                = extract_glcm_features(S2)[ 'ASM' ],

    )


def process_one(path: Path):
    """Worker wrapper: returns dict ready for CSV OR raises."""
    fc, Tsweep, mti = read_dat(path)
    feats           = extract_features(mti, fc, Tsweep)
    fid             = path.stem
    feats.update(
        file_id  = fid,
        activity = ACTIVITY_MAP.get(fid[0], "unknown"),
        path     = str(path),
    )
    return feats

def process_batch(files, writer, batch_start, total_files):
    """Process a batch of files with detailed error handling"""
    max_workers = min(4, cpu_count())  # Reduced from 12 to 4
    
    with ProcessPoolExecutor(max_workers=max_workers) as pool:
        future_to_path = {pool.submit(process_one, p): p for p in files}
        
        for i, fut in enumerate(as_completed(future_to_path), 1):
            p = future_to_path[fut]
            try:
                row = fut.result()
                writer.writerow(row)
                global_i = batch_start + i
                print(f"✓ [{global_i:>4}/{total_files}] {row['file_id']}")
            except Exception as e:
                import traceback
                print(f"✗ [{batch_start+i:>4}/{total_files}] {p.name}")
                print(f"Error: {e.__class__.__name__}: {e}")
                print("Traceback:")
                traceback.print_exc()

# ──────────────────────────────────────────────────────────────
#  MAIN — run workers & stream to CSV
# ──────────────────────────────────────────────────────────────
def main():
    all_files = sorted(DATASETS_ROOT.rglob("*.dat"))
    n_files = len(all_files)
    if not n_files:
        print(f"No .dat files found under {DATASETS_ROOT.resolve()}")
        return

    # prepare CSV (append if exists, else create with header)
    csv_exists = os.path.exists(CSV_PATH)
    csv_file = open(CSV_PATH, "a", newline="")
    writer = csv.DictWriter(csv_file, FIELDNAMES)
    if not csv_exists:
        writer.writeheader()

    try:
        # Process one file at a time
        for i, file_path in enumerate(all_files, 1):
            try:
                print(f"\nProcessing file {i}/{n_files}: {file_path.name}")
                # Add detailed logging
                print("Reading DAT file...")
                fc, Tsweep, mti = read_dat(file_path)
                
                print("Extracting features...")
                feats = extract_features(mti, fc, Tsweep)
                
                print("Adding metadata...")
                fid = file_path.stem
                feats.update(
                    file_id=fid,
                    activity=ACTIVITY_MAP.get(fid[0], "unknown"),
                    path=str(file_path)
                )
                
                print("Writing to CSV...")
                writer.writerow(feats)
                csv_file.flush()
                print(f"✓ [{i:>4}/{n_files}] {fid}")
                
            except Exception as e:
                print(f"✗ [{i:>4}/{n_files}] {file_path.name}")
                print(f"Error: {e.__class__.__name__}: {e}")
                import traceback
                traceback.print_exc()
                continue

    except KeyboardInterrupt:
        print("\n⚠️  Processing interrupted by user")
    finally:
        csv_file.close()
        print(f"\n✅  Features saved to {CSV_PATH}") 

# ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
    main()


Processing file 1/1754: 1P36A01R01.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   1/1754] 1P36A01R01

Processing file 2/1754: 1P36A01R02.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   2/1754] 1P36A01R02

Processing file 3/1754: 1P36A01R03.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   3/1754] 1P36A01R03

Processing file 4/1754: 1P37A01R01.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   4/1754] 1P37A01R01

Processing file 5/1754: 1P37A01R02.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   5/1754] 1P37A01R02

Processing file 6/1754: 1P37A01R03.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   6/1754] 1P37A01R03

Processing file 7/1754: 1P38A01R01.dat
Reading DAT file...
Extracting features...
Adding metadata...
Writing to CSV...
✓ [   7/17