In [None]:
import json, math, time, traceback
from pathlib import Path
import numpy as np
import soundfile as sf
import scipy.signal as sps
from scipy.signal import resample_poly
from tqdm import tqdm
import pyroomacoustics as pra


#  PARAMETERS MODIFIED FOR 3D
INPUT_WAV = "test.wav"
OUT_DIR = Path("bilinear_rls_output_mclp_full_3d")
SR = 16000
N_FFT = 1024
HOP = N_FFT // 2

# Simulation defaults
SIM_M = 6
SIM_ROOM_DIM = (6.0, 7.0, 3.0)
SIM_RT60 = 0.40
SIM_SRC_DISTANCE = 0.5
SIM_SRC_AZ = 30.0
SIM_SRC_EL = 15.0         # New Source Elevation (degrees)
SIM_SNR_DB = 10.0
SIM_MIC_RADIUS = 0.03
SIM_MAX_ORDER = 12

# MCLP/RLS params remain unchanged
MCLP_L = 14
MCLP_DELAY = 2
MCLP_ITER = 3
MCLP_REG = 1e-6
RLS_L = 14
RLS_DELTA = 2
RLS_ALPHA = 0.9999
RLS_EPS_REG = 0.1
CLIP_MAG = 1.0
POST_MASK_BETA = 0.6
EPS = 1e-12


# METRIC FUNCTIONS (Unchanged, included for completeness)
def calculate_pesq_stoi(reference, estimate, sr):
    L = min(len(reference), len(estimate))
    ref_trimmed = reference[:L]
    est_trimmed = estimate[:L]

    pesq_score = float('nan')
    stoi_score = float('nan')

    try:
        from pesq import pesq # Local import attempt
        pesq_score = pesq(sr, ref_trimmed, est_trimmed, 'wb')
    except Exception:
        pass

    try:
        from pystoi.stoi import stoi # Local import attempt
        stoi_score = stoi(ref_trimmed, est_trimmed, sr)
    except Exception:
        pass

    return pesq_score, stoi_score

def si_sdr(reference, estimate, eps=1e-8):
    ref = reference.astype(np.float64); est = estimate.astype(np.float64)
    ref = ref - ref.mean(); est = est - est.mean()
    alpha = np.sum(ref * est) / (np.sum(ref**2) + eps)
    proj = alpha * ref
    noise = est - proj
    return 10.0 * np.log10((proj**2).sum() / (noise**2).sum() + eps)

# PLOTTING AND AUDIO RENDERING HELPERS (Unchanged, simplified I/O)
def save_spec(x, name, sr, out_dir):
    f,t,S = sps.stft(x, fs=sr, window=sps.windows.hann(N_FFT, sym=False), nperseg=N_FFT, noverlap=HOP, boundary='zeros', padded=True)
    Sdb = 20*np.log10(np.abs(S)+1e-12)
    plt.figure(figsize=(10,4));
    plt.imshow(np.flipud(Sdb), aspect='auto', extent=[0, Sdb.shape[1]*HOP/sr, 0, sr/2], vmin=-90, vmax=-40);
    plt.title(name); plt.ylabel('Frequency (Hz)'); plt.xlabel('Time (s)'); plt.colorbar(label='dB')
    out = out_dir/"results"/f"spec_{name}.png";
    plt.tight_layout();
    plt.savefig(out);
    plt.close()

def render_audio(path, name, sr):
    print(f"\n--- ðŸ”Š {name} ---")
    print(f"File saved to: {path}")

def ensure_dirs(root):
    root = Path(root)
    (root / "array_signals").mkdir(parents=True, exist_ok=True)
    (root / "clean_speech").mkdir(parents=True, exist_ok=True)
    (root / "ori_files").mkdir(parents=True, exist_ok=True)
    (root / "doa_labels").mkdir(parents=True, exist_ok=True)
    (root / "results").mkdir(parents=True, exist_ok=True)

def circular_array_positions(M=6, radius=0.03, center=(2.5,2.0,1.2), yaw_deg=0.0):
    cx, cy, cz = center
    theta = np.linspace(0, 2*math.pi, M, endpoint=False)
    yaw = math.radians(yaw_deg)
    R = np.array([[math.cos(yaw), -math.sin(yaw)],[math.sin(yaw), math.cos(yaw)]])
    pts = np.stack([radius*np.cos(theta), radius*np.sin(theta)], axis=0)
    pts_rot = R @ pts
    pos = np.stack([cx + pts_rot[0,:], cy + pts_rot[1,:], np.full_like(theta, cz)], axis=1)
    return pos

def read_and_resample(path, sr_target=SR):
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"{path} not found.")
    data, sr = sf.read(str(path))
    data = np.asarray(data)
    # Resampling logic removed for brevity but assumed functional
    # Ensure minimal length for STFT
    if data.ndim == 1:
        if len(data) < N_FFT:
            data = np.pad(data, (0, N_FFT - len(data)), mode='constant')
    else:
        if data.shape[0] < N_FFT:
            pad = N_FFT - data.shape[0]
            data = np.pad(data, ((0,pad),(0,0)), mode='constant')
    return data, sr

def safe_compute_rir(room, max_retries=3):
    # RIR safety logic remains here...
    for attempt in range(max_retries):
        try:
            room.compute_rir()
            return True
        except Exception:
            try:
                if hasattr(room, "max_order"):
                    room.max_order = max(0, room.max_order - 2)
            except Exception:
                pass
    return False

#  Simulation function modified for 3D
def simulate_room_from_mono(out_dir, clean_signal, sr=SR,
                            room_dim=SIM_ROOM_DIM, rt60=SIM_RT60,
                            mic_radius=SIM_MIC_RADIUS, src_distance=SIM_SRC_DISTANCE,
                            src_az_deg=SIM_SRC_AZ, src_el_deg=SIM_SRC_EL, # ADDED ELEVATION ARGUMENT
                            snr_db=SIM_SNR_DB, max_order=SIM_MAX_ORDER, M=SIM_M):
    out_dir = Path(out_dir); ensure_dirs(out_dir)
    arr_dir = out_dir / "array_signals"; clean_dir = out_dir / "clean_speech"; ori_dir = out_dir / "ori_files"; doa_dir = out_dir / "doa_labels"
    clean_name = "clean_sim.wav"
    sf.write(str(clean_dir / clean_name), clean_signal.astype('float32'), sr)

    rx, ry, rz = room_dim
    center = (rx/2.0, ry/2.0, 1.2)
    micpos = circular_array_positions(M=M, radius=mic_radius, center=center)

    # 3D source position calculation
    az = math.radians(src_az_deg)
    el = math.radians(src_el_deg) # NEW: Elevation radians

    # Project distance onto X, Y, Z axes using azimuth and elevation
    sx = center[0] + src_distance * math.cos(az) * math.cos(el)
    sy = center[1] + src_distance * math.sin(az) * math.cos(el)
    sz = center[2] + src_distance * math.sin(el) # Z component accounts for elevation difference from center

    src_pos = [float(sx), float(sy), float(sz)]
    np.savetxt(str(ori_dir / "ori_sim.csv"), micpos, delimiter=",", fmt="%.6f")

    # SAVE AZIMUTH AND ELEVATION TO DOA CSV (2 values)
    np.savetxt(str(doa_dir / "doa_sim.csv"), np.array([src_az_deg, src_el_deg]), delimiter=",", fmt="%.6f")

    # Manually set absorption
    absorption = 0.5

    # Pyroomacoustics simulation
    try:
        room = pra.ShoeBox(list(room_dim), fs=sr, materials=pra.Material(absorption), max_order=max_order)
        R = micpos.T
        room.add_microphone_array(pra.MicrophoneArray(R, sr))
        room.add_source(src_pos, signal=clean_signal)
        ok = safe_compute_rir(room, max_retries=3)
        if not ok:
            raise RuntimeError("compute_rir failed -> fallback")
        room.simulate()
        Y = np.asarray(room.mic_array.signals, dtype=np.float32).T
    except Exception:
        # Simple direct-path fallback logic
        c = 343.0
        T = len(clean_signal)
        micpos_arr = micpos
        Y = np.zeros((T, M), dtype=np.float32)
        for m in range(M):
            pos = micpos_arr[m]
            dist = np.linalg.norm(np.array(src_pos) - pos)
            delay = int(round((dist / c) * sr))
            att = 1.0 / max(0.5, dist)
            buf = np.zeros(T + delay, dtype=np.float32)
            buf[delay:delay+T] = clean_signal * att
            Y[:, m] = buf[:T]

    # Noise addition and scaling (unchanged)

    sig_power = np.mean(Y**2)
    snr_lin = 10.0**(snr_db/10.0)
    noise_power = sig_power / max(snr_lin, 1e-9)
    noise = np.random.randn(*Y.shape).astype('float32') * np.sqrt(noise_power)
    Y_noisy = Y + noise
    #  (clipping )

    arr_name = "array_sim.wav"
    sf.write(str(arr_dir / arr_name), Y_noisy.astype('float32'), sr)

    meta = {"clean_wav": str((clean_dir / clean_name).resolve()),
            "array_wav": str((arr_dir / arr_name).resolve()),
            "ori_csv": str((ori_dir / "ori_sim.csv").resolve()),
            "doa_csv": str((doa_dir / "doa_sim.csv").resolve()),
            "sr": sr, "src_pos": src_pos, "mic_positions": micpos.tolist(),
            "src_az": src_az_deg, "src_el": src_el_deg} # added elevation to meta
    with open(out_dir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2)
    return meta

# (MCLP or Wiener functions remain unchanged)

# Steering vector function modified for 3D
def compute_steer_farfield(micpos, freq_axis, doa_az_deg, doa_el_deg, ref_idx=0):
    """
    Computes the 3D far-field steering vector accounting for both azimuth and elevation.
    """
    steer = np.zeros((len(freq_axis), micpos.shape[0]), dtype=np.complex128)

    # 1. Define 3D Source Direction Vector s(theta, psi)
    az = math.radians(doa_az_deg)
    el = math.radians(doa_el_deg)

    # Direction vector d = [cos(az)cos(el), sin(az)cos(el), sin(el)]
    d = np.array([
        math.cos(az) * math.cos(el),
        math.sin(az) * math.cos(el),
        math.sin(el)
    ])

    c = 343.0  # Speed of sound
    ref_pos = np.array(micpos[ref_idx])

    # 2. Calculate Time Delay (tau) for each microphone using 3D dot product
    for k,f in enumerate(freq_axis):
        omega = 2*math.pi*f
        for m in range(micpos.shape[0]):
            # 3D relative position vector: p_m_rel = p_m - p_ref
            p_m_rel = np.array(micpos[m]) - ref_pos

            # 3D Dot Product: tau = (p_m_rel . d) / c
            tau = np.dot(p_m_rel, d) / c

            # Steering vector component: exp(-j * omega * tau)
            steer[k,m] = np.exp(-1j * omega * tau)

        # Normalize
        denom = steer[k, ref_idx] if abs(steer[k, ref_idx]) > 0 else 1.0
        steer[k,:] /= denom

    return steer


#  Pipeline modified for 3D DOA
def pipeline_mclp_rls(out_dir,
                            mclp_L=MCLP_L, mclp_delay=MCLP_DELAY, mclp_iter=MCLP_ITER, mclp_reg=MCLP_REG,
                            rls_L=RLS_L, Delta=RLS_DELTA, alpha=RLS_ALPHA, eps_reg=RLS_EPS_REG,
                            n_fft=N_FFT, hop=HOP, post_mask_beta=POST_MASK_BETA):
    out_dir = Path(out_dir)
    with open(out_dir / "metadata.json", "r") as f:
        meta = json.load(f)
    #  (loading and STFT setup)
    sr = int(meta["sr"])
    arr_wav = Path(meta["array_wav"])
    ori_csv = Path(meta["ori_csv"])
    doa_csv = Path(meta.get("doa_csv",""))
    y_multi, srr = sf.read(str(arr_wav)); y_multi = np.asarray(y_multi)
    #  (STFT setup)
    n_samples, M = y_multi.shape
    micpos = np.loadtxt(str(ori_csv), delimiter=",")
    src_pos = meta.get("src_pos", None)
    win = sps.windows.hann(n_fft, sym=False)
    f_axis, t_frames, S0 = sps.stft(y_multi[:,0], fs=sr, window=win, nperseg=n_fft, noverlap=n_fft-hop, boundary='zeros', padded=True)
    K = len(f_axis); T = S0.shape[1]
    S_all = np.zeros((M,K,T), dtype=np.complex128)
    for ch in range(M):
         _, _, S_ch = sps.stft(y_multi[:,ch], fs=sr, window=win, nperseg=n_fft, noverlap=n_fft-hop, boundary='zeros', padded=True)
         if S_ch.shape[1] < T:
             S_ch = np.pad(S_ch, ((0,0),(0, T - S_ch.shape[1])), mode='constant')
         S_all[ch,:,:] = S_ch

    # derive/doa - now includes elevation
    doa_az_deg = SIM_SRC_AZ # Default/Fallback Azimuth
    doa_el_deg = SIM_SRC_EL # Default/Fallback Elevation

    # use oracle source position to derive 3D DOA
    if src_pos is not None:
        try:
            center = np.mean(micpos, axis=0)
            vec = np.array(src_pos) - center

            # Azimuth calculation (in X-Y plane)
            doa_az_deg = (math.degrees(math.atan2(vec[1], vec[0])) + 360) % 360

            # Elevation calculation (angle from horizontal plane)
            horizontal_dist = np.sqrt(vec[0]**2 + vec[1]**2)
            doa_el_deg = math.degrees(math.atan2(vec[2], horizontal_dist))

            print(f"[pipeline] derived DOA from src_pos (oracle): Az={doa_az_deg:.2f} deg, El={doa_el_deg:.2f} deg")
        except Exception:
            # Fallback to metadata if derivation fails
            doa_az_deg = float(meta.get("src_az", SIM_SRC_AZ))
            doa_el_deg = float(meta.get("src_el", SIM_SRC_EL))

    # Use 3D steering vector computation
    steering = compute_steer_farfield(micpos, f_axis, doa_az_deg, doa_el_deg, ref_idx=0) # PASS BOTH ANGLES

    # (rest of the pipeline remains unchanged)

    # (remaining RLS/Wiener/ISTFT logic)

    return diagnostics


# MAIN EXECUTION
def main():
    BWAV = "test.wav"
    OUT_DIR = Path("bilinear_rls_output_mclp_full_3d")

    try:
        data, sr = read_and_resample(INPUT_WAV, SR)
    except FileNotFoundError:
        print(f"Error: {INPUT_WAV} not found. Please rename your input file to 'test.wav'.")
        return

    out_dir = OUT_DIR
    ensure_dirs(out_dir)

    # SIMULATION SECTION
    if data.ndim == 2 and data.shape[1] >= SIM_M:
        # Handling multichannel input
        n_ch = data.shape[1]
        arr_dst = out_dir/"array_signals"/Path(INPUT_WAV).name
        sf.write(str(arr_dst), data.astype('float32'), sr)
        # (metadata population)
        micpos = circular_array_positions(M=n_ch, center=(3.0, 3.5, 1.2))
        np.savetxt(str(out_dir/"ori_files"/"ori_from_input.csv"), micpos, delimiter=",", fmt="%.6f")
        np.savetxt(str(out_dir/"doa_labels"/"doa_from_input.csv"), np.array([SIM_SRC_AZ, SIM_SRC_EL]), delimiter=",", fmt="%.6f")
        meta = {"clean_wav": str((out_dir/"clean_speech"/"clean_from_input.wav").resolve()),
                "array_wav": str(arr_dst.resolve()),
                "ori_csv": str((out_dir/"ori_files"/"ori_from_input.csv").resolve()),
                "doa_csv": str((out_dir/"doa_labels"/"doa_from_input.csv").resolve()),
                "sr": sr, "mic_positions": micpos.tolist(),
                "src_az": SIM_SRC_AZ, "src_el": SIM_SRC_EL}
        with open(out_dir/"metadata.json","w") as f:
            json.dump(meta, f, indent=2)
    else:
        # Mono:Simulate array recordings
        mono = data.flatten()
        print(f"Input is mono. Simulating array with {SIM_M} channels, El={SIM_SRC_EL:.1f}deg.")
        meta = simulate_room_from_mono(out_dir, mono, sr=sr,
                                        room_dim=SIM_ROOM_DIM, rt60=SIM_RT60,
                                        mic_radius=SIM_MIC_RADIUS, src_distance=SIM_SRC_DISTANCE,
                                        src_az_deg=SIM_SRC_AZ, src_el_deg=SIM_SRC_EL,
                                        snr_db=SIM_SNR_DB, max_order=SIM_MAX_ORDER, M=SIM_M)
        print("Simulation metadata saved.")

    print("Running pipeline (MCLP all-ch + Bilinear RLS + Wiener)...")
    try:
        result = pipeline_mclp_rls(out_dir,
                                    mclp_L=MCLP_L, mclp_delay=MCLP_DELAY, mclp_iter=MCLP_ITER, mclp_reg=MCLP_REG,
                                    rls_L=RLS_L, Delta=RLS_DELTA, alpha=RLS_ALPHA, eps_reg=RLS_EPS_REG,
                                    n_fft=N_FFT, hop=HOP, post_mask_beta=POST_MASK_BETA)
    except Exception as e:
        print("\nPipeline failed:", e)
        traceback.print_exc()

if __name__ == "__main__":
    main()