In [None]:
"""
Phase-Based Video Motion Magnification (Wadhwa et al. 2013)
Complex Steerable Pyramid (steerable filters) + temporal bandpass on PHASE.

Deps:
  pip install opencv-python numpy scipy pyrtools tqdm

Run:
  python phase_based_wadhwa.py --in input.mp4 --out out.mp4 --alpha 20 --fl 0.4 --fh 3.0 --levels 4 --order 3

Notes:
- Works best on stable (tripod) videos.
- Uses YIQ: magnifies Y (luma) only, keeps I/Q to reduce color artifacts.
- Memory-heavy (stores pyramid subbands for all frames). Start with short clips / lower resolution.
"""

import argparse
from dataclasses import dataclass
from typing import Dict, List, Tuple

import cv2
import numpy as np
from scipy.signal import butter, filtfilt
from tqdm import tqdm

import pyrtools as pt


# -------------------- Color space: RGB <-> YIQ --------------------
def rgb_to_yiq(rgb: np.ndarray) -> np.ndarray:
    # rgb in [0,1], shape (H,W,3)
    M = np.array(
        [[0.299, 0.587, 0.114],
         [0.596, -0.274, -0.322],
         [0.211, -0.523, 0.312]],
        dtype=np.float32,
    )
    return rgb @ M.T


def yiq_to_rgb(yiq: np.ndarray) -> np.ndarray:
    # yiq shape (H,W,3)
    Minv = np.array(
        [[1.0, 0.956, 0.621],
         [1.0, -0.272, -0.647],
         [1.0, -1.106, 1.703]],
        dtype=np.float32,
    )
    return yiq @ Minv.T


# -------------------- Temporal filtering --------------------
def butter_bandpass(fl: float, fh: float, fs: float, order: int = 2):
    nyq = 0.5 * fs
    lo = max(fl / nyq, 1e-6)
    hi = min(fh / nyq, 0.999999)
    if not (0 < lo < hi < 1):
        raise ValueError(f"Invalid bandpass after normalization: lo={lo}, hi={hi}")
    b, a = butter(order, [lo, hi], btype="bandpass")
    return b, a


def bandpass_filter_time(x: np.ndarray, fl: float, fh: float, fs: float, order: int = 2) -> np.ndarray:
    """
    x: (..., T, ...) but we filter along axis=0 where axis 0 is time in our stacks.
    Here x is (T,H,W). We filter along axis=0.
    """
    b, a = butter_bandpass(fl, fh, fs, order=order)
    # zero-phase to avoid phase distortion in the temporal signal
    return filtfilt(b, a, x, axis=0, padtype="odd", padlen=3 * (max(len(a), len(b)) - 1))


# -------------------- Steerable pyramid helpers --------------------
@dataclass
class PyramidStack:
    # subband_coeffs[key] = np.ndarray (T,H,W) complex
    subband_coeffs: Dict[Tuple, np.ndarray]
    # residuals: per-frame highpass & lowpass (complex or real depending on pyrtools)
    highpass: np.ndarray  # (T,H,W) float/complex
    lowpass: np.ndarray   # (T,H,W) float/complex
    keys: List[Tuple]


def build_pyramid(y: np.ndarray, levels: int, order: int):
    """
    Build complex steerable pyramid using pyrtools.
    Returns pyrtools pyramid object.
    """
    # SteerablePyramidFreq is overcomplete and supports complex-valued bands.
    # order controls number of orientations: n_orient = order + 1
    pyr = pt.pyramids.SteerablePyramidFreq(y, height=levels, order=order, is_complex=True)
    return pyr


def collect_pyramids(y_frames: np.ndarray, levels: int, order: int) -> PyramidStack:
    """
    y_frames: (T,H,W) float32 in [0,1]
    """
    T = y_frames.shape[0]
    first = build_pyramid(y_frames[0], levels, order)

    # In pyrtools, pyr.pyr_coeffs is a dict-like with keys for subbands and residuals.
    # Typical keys include ('residual_highpass',), ('residual_lowpass',) plus oriented bands.
    # We'll separate residuals from oriented/mid bands.
    all_keys = list(first.pyr_coeffs.keys())

    def is_residual(k):
        return "residual" in str(k)

    sub_keys = [k for k in all_keys if not is_residual(k)]
    # Fallback naming for residuals
    hi_key = [k for k in all_keys if "highpass" in str(k)]
    lo_key = [k for k in all_keys if "lowpass" in str(k)]
    if not hi_key or not lo_key:
        raise RuntimeError(f"Could not find residual keys in pyramid. Keys: {all_keys}")
    hi_key, lo_key = hi_key[0], lo_key[0]

    # Allocate stacks
    subband_coeffs: Dict[Tuple, np.ndarray] = {}
    for k in sub_keys:
        c0 = first.pyr_coeffs[k]
        subband_coeffs[k] = np.empty((T, *c0.shape), dtype=np.complex64)

    highpass = np.empty((T, *first.pyr_coeffs[hi_key].shape), dtype=np.complex64 if np.iscomplexobj(first.pyr_coeffs[hi_key]) else np.float32)
    lowpass = np.empty((T, *first.pyr_coeffs[lo_key].shape), dtype=np.complex64 if np.iscomplexobj(first.pyr_coeffs[lo_key]) else np.float32)

    # Fill stacks
    for t in tqdm(range(T), desc="Building steerable pyramids"):
        pyr = build_pyramid(y_frames[t], levels, order)
        for k in sub_keys:
            subband_coeffs[k][t] = pyr.pyr_coeffs[k].astype(np.complex64, copy=False)
        highpass[t] = pyr.pyr_coeffs[hi_key]
        lowpass[t] = pyr.pyr_coeffs[lo_key]

    return PyramidStack(subband_coeffs=subband_coeffs, highpass=highpass, lowpass=lowpass, keys=sub_keys)


def reconstruct_frame_from_coeffs(coeffs: Dict[Tuple, np.ndarray], levels: int, order: int) -> np.ndarray:
    """
    Reconstruct a single frame from pyramid coefficients dict.
    """
    # Create a dummy pyramid to reuse recon logic (pyrtools expects consistent params).
    # We can feed any image to instantiate then overwrite coeffs.
    dummy = pt.pyramids.SteerablePyramidFreq(np.zeros_like(next(iter(coeffs.values())).real), height=levels, order=order, is_complex=True)
    dummy.pyr_coeffs = coeffs
    recon = dummy.recon_pyr()
    return recon


def magnify_phase(pstack: PyramidStack, alpha: float, fl: float, fh: float, fs: float, tf_order: int) -> PyramidStack:
    """
    Phase-based magnification per Wadhwa et al:
      phase_t = angle(c_t)
      delta_t = unwrap(phase_t) - unwrap(phase_0)
      delta_bp = bandpass(delta_t)
      new_phase = phase_t + alpha * delta_bp
      c'_t = |c_t| * exp(i * new_phase)
    """
    out = PyramidStack(
        subband_coeffs={},
        highpass=pstack.highpass,
        lowpass=pstack.lowpass,
        keys=pstack.keys,
    )

    for k in tqdm(pstack.keys, desc="Filtering + magnifying phase (per subband)"):
        c = pstack.subband_coeffs[k]  # (T,H,W) complex
        amp = np.abs(c).astype(np.float32)
        phase = np.angle(c).astype(np.float32)

        # unwrap along time
        phase_u = np.unwrap(phase, axis=0)
        delta = phase_u - phase_u[0:1]

        # temporal bandpass on delta
        delta_bp = bandpass_filter_time(delta, fl=fl, fh=fh, fs=fs, order=tf_order).astype(np.float32)

        new_phase = phase + alpha * delta_bp
        out.subband_coeffs[k] = (amp * np.exp(1j * new_phase)).astype(np.complex64)

    return out


# -------------------- Video I/O --------------------
def read_video_rgb(path: str) -> Tuple[np.ndarray, float]:
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Could not open: {path}")
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames = []
    while True:
        ok, bgr = cap.read()
        if not ok:
            break
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        frames.append(rgb)
    cap.release()
    if not frames:
        raise RuntimeError("No frames read.")
    return np.stack(frames, axis=0), float(fps)


def write_video_rgb(path: str, rgb_frames: np.ndarray, fps: float):
    T, H, W, _ = rgb_frames.shape
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(path, fourcc, fps, (W, H))
    if not out.isOpened():
        raise RuntimeError(f"Could not open writer: {path}")
    for t in range(T):
        rgb = np.clip(rgb_frames[t], 0.0, 1.0)
        bgr = cv2.cvtColor((rgb * 255.0 + 0.5).astype(np.uint8), cv2.COLOR_RGB2BGR)
        out.write(bgr)
    out.release()


# -------------------- Main pipeline --------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="inp", required=True)
    ap.add_argument("--out", dest="out", required=True)
    ap.add_argument("--alpha", type=float, default=20.0, help="Magnification factor for phase band")
    ap.add_argument("--fl", type=float, default=0.4, help="Low cutoff (Hz)")
    ap.add_argument("--fh", type=float, default=3.0, help="High cutoff (Hz)")
    ap.add_argument("--levels", type=int, default=4, help="Pyramid levels (scales)")
    ap.add_argument("--order", type=int, default=3, help="Steerable pyramid order (orientations = order+1)")
    ap.add_argument("--tf_order", type=int, default=2, help="Butterworth temporal filter order")
    args = ap.parse_args()

    rgb, fps = read_video_rgb(args.inp)
    T, H, W, _ = rgb.shape

    # RGB -> YIQ
    yiq = np.empty_like(rgb, dtype=np.float32)
    for t in range(T):
        yiq[t] = rgb_to_yiq(rgb[t])
    y = yiq[..., 0].astype(np.float32)  # (T,H,W)

    # Build pyramid stacks on Y
    pstack = collect_pyramids(y, levels=args.levels, order=args.order)

    # Magnify phase in subbands
    pstack_mag = magnify_phase(
        pstack, alpha=args.alpha, fl=args.fl, fh=args.fh, fs=fps, tf_order=args.tf_order
    )

    # Reconstruct magnified Y per frame
    y_mag = np.empty_like(y, dtype=np.float32)
    for t in tqdm(range(T), desc="Reconstructing frames"):
        coeffs_t = {}
        # residuals: keep unchanged
        # We need the exact residual keys; easiest is to steal them from an instantiated pyramid:
        # Create dummy and overwrite (its coeff dict includes residual keys).
        dummy = build_pyramid(y[0], args.levels, args.order)
        # start from dummy coeffs and overwrite with our stored values
        coeffs_t.update(dummy.pyr_coeffs)
        # overwrite subbands
        for k in pstack.keys:
            coeffs_t[k] = pstack_mag.subband_coeffs[k][t]
        # overwrite residuals from original
        # (keys discovered in dummy)
        for k in list(coeffs_t.keys()):
            if "residual_highpass" in str(k):
                coeffs_t[k] = pstack.highpass[t]
            if "residual_lowpass" in str(k):
                coeffs_t[k] = pstack.lowpass[t]

        recon = reconstruct_frame_from_coeffs(coeffs_t, levels=args.levels, order=args.order)
        y_mag[t] = recon.astype(np.float32)

    # Replace Y channel, convert back
    yiq_out = yiq.copy()
    yiq_out[..., 0] = y_mag
    rgb_out = np.empty_like(rgb, dtype=np.float32)
    for t in range(T):
        rgb_out[t] = yiq_to_rgb(yiq_out[t])

    write_video_rgb(args.out, rgb_out, fps=fps)


if __name__ == "__main__":
    main()
