In [5]:
import cv2, zmq, numpy as np, time, threading, queue, traceback, sys, os, math, csv
from collections import deque, defaultdict

# ---------- Config ----------
ZMQ_ADDR = "tcp://localhost:5555"
SUB_TOPICS = [b"kreo1", b"kreo2"]
FPS_WINDOW = 1.0        # seconds for fps moving window
DISPLAY_FPS = 20
VISUALIZE = True     # show tiled view window

STATIC_TAG_IDS = [0,1,2,3]
TAG_POSITIONS = {
    0: np.array([0.9, 0.0, 0.0], dtype=float),
    1: np.array([0.0, 0.0, 0.0], dtype=float),
    2: np.array([0.9, 0.9, 0.0], dtype=float),
    3: np.array([0.0, 1.2, 0.0], dtype=float)
}
TAG_SIZES = {0: 0.099, 1: 0.096, 2: 0.096, 3: 0.096, 4: 0.096, 5: 0.096}
CALIB_FRAMES = 30
CALIB_DIR = "../calibration/"

# Predictor settings (tweak these)
PRED_LOG = "predict_log.csv"
BUFFER_SIZE = 7
MIN_POINTS = 4
LAG_MS = 15.0
V_MAX = 2.0  # m/s, robot max linear speed for feasibility check
Z_TARGET = 0.25  # 25 cm

# ---------------- APRILTAG CONFIG ----------------
DICT_TYPE = cv2.aruco.DICT_APRILTAG_36h11
def create_april_detector():
    aruco_dict = cv2.aruco.getPredefinedDictionary(DICT_TYPE)
    params = cv2.aruco.DetectorParameters()
    params.adaptiveThreshWinSizeMin = 3
    params.adaptiveThreshWinSizeMax = 35
    params.adaptiveThreshWinSizeStep = 2
    params.cornerRefinementMethod = cv2.aruco.CORNER_REFINE_SUBPIX
    params.cornerRefinementWinSize = 7
    params.cornerRefinementMaxIterations = 50
    params.cornerRefinementMinAccuracy = 0.01
    params.minMarkerPerimeterRate = 0.02
    params.maxMarkerPerimeterRate = 6.0
    params.polygonalApproxAccuracyRate = 0.02
    params.adaptiveThreshConstant = 7
    return cv2.aruco.ArucoDetector(aruco_dict, params)

# color thresholds
orange_hsvVals = {'hmin': 0,  'smin': 94,  'vmin': 156, 'hmax': 12,  'smax': 255, 'vmax': 255}
purple_hsvVals = {'hmin': 113,'smin': 78,  'vmin': 3,   'hmax': 129, 'smax': 255, 'vmax': 255}

# detector parameters
MIN_AREA = 150
MAX_AREA = 20000
CIRCULARITY_MIN = 0.25
ASPECT_RATIO_MAX = 2.0
MAX_DETECTIONS_PER_CAM = 12

# ---------- Helpers ----------
def fmt_ts(ts):
    return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts)) + f".{int((ts%1)*1000):03d}"

def recv_latest(sub):
    msg = None
    while True:
        try:
            msg = sub.recv_multipart(flags=zmq.NOBLOCK)
        except zmq.Again:
            break
    return msg

def update_fps(camera, cam_ts):
    dq = fps_windows[camera]
    dq.append(cam_ts)
    while dq and (cam_ts - dq[0]) > FPS_WINDOW:
        dq.popleft()
    fps = len(dq) / FPS_WINDOW
    return fps

def load_camera_calib(cam_name):
    path = os.path.join(CALIB_DIR, f'camera_calibration_{cam_name}.npz')
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    calib = np.load(path)
    camera_matrix = calib["cameraMatrix"]
    dist_coeffs = calib['distCoeffs']
    print("[INFO] Loaded calibrated camera parameters for", cam_name)
    return camera_matrix, dist_coeffs

def build_tag_world_map_from_centers(tag_centers, tag_sizes):
    out = {}
    for tid, center in tag_centers.items():
        size = tag_sizes.get(tid, tag_sizes.get(1))
        half = float(size) / 2.0
        local = np.array([
            [-half,  half, 0.0],
            [ half,  half, 0.0],
            [ half, -half, 0.0],
            [-half, -half, 0.0],
        ], dtype=np.float64)
        corners_world = (local + center.reshape(1,3)).astype(np.float64)
        out[tid] = corners_world
    return out

TAG_WORLD_MAP = build_tag_world_map_from_centers(TAG_POSITIONS, TAG_SIZES)

# ---------- Color masking helpers ----------
def hsv_mask_from_vals(bgr_img, hsvVals):
    hsv = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2HSV)
    lower = np.array([hsvVals['hmin'], hsvVals['smin'], hsvVals['vmin']], dtype=np.uint8)
    upper = np.array([hsvVals['hmax'], hsvVals['smax'], hsvVals['vmax']], dtype=np.uint8)
    mask = cv2.inRange(hsv, lower, upper)
    mask = cv2.medianBlur(mask, 5)
    return mask

def postprocess_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    m = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel, iterations=1)
    return m

def find_candidate_contours(mask):
    if mask is None:
        return []
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    candidates = []
    for c in contours:
        area = cv2.contourArea(c)
        if area < MIN_AREA or area > MAX_AREA:
            continue
        perim = cv2.arcLength(c, True)
        if perim <= 0:
            continue
        circularity = 4 * np.pi * area / (perim * perim)
        x,y,w,h = cv2.boundingRect(c)
        aspect = float(w)/float(h) if h>0 else 0.0
        if circularity >= CIRCULARITY_MIN or (0.5*min(w,h) > 5 and area > (MIN_AREA*2)):
            if aspect <= ASPECT_RATIO_MAX:
                candidates.append((c, area, (int(x),int(y),int(w),int(h))))
    candidates.sort(key=lambda d: d[1], reverse=True)
    return candidates

def estimate_pose_apriltag(corners, tag_size, cam_mtx, cam_dist):
    half = tag_size / 2.0
    objp = np.array([
        [-half,  half, 0.0],
        [ half,  half, 0.0],
        [ half, -half, 0.0],
        [-half, -half, 0.0]
    ], dtype=np.float32)
    imgp = corners.reshape(4,2).astype(np.float32)
    ok, rvec, tvec = cv2.solvePnP(
        objp, imgp, cam_mtx, cam_dist,
        flags=cv2.SOLVEPNP_ITERATIVE
    )
    if not ok:
        raise RuntimeError("solvePnP failed")
    R, _ = cv2.Rodrigues(rvec)
    T = np.eye(4)
    T[:3, :3] = R
    T[:3, 3] = tvec.reshape(3)
    return T



# APRILTAG THREAD
class AprilTagThread(threading.Thread):
    def __init__(self, cam_name, frame_queue, detect_cache, lock, calibrator, shared_robot_poses):
        super().__init__(daemon=True)
        self.cam_name = cam_name
        self.frame_queue = frame_queue
        self.detect_cache = detect_cache
        self.lock = lock
        self.detector = create_april_detector()
        self.stop_flag = False
        self.calibrator = calibrator
        self.shared_robot_poses = shared_robot_poses
    
    def run(self):
        print(f"[{self.cam_name}] AprilTagThread started")
        while not self.stop_flag:
            try:
                frame,ts = self.frame_queue.get(timeout=0.1)
            except queue.Empty:
                continue
            try:
                gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                corners, ids, _ = self.detector.detectMarkers(gray)
                det={"corners":corners,"ids":ids,"ts":ts,"det_time":time.time()}
                with self.lock:
                    if self.cam_name not in self.detect_cache:
                        self.detect_cache[self.cam_name]={}
                    self.detect_cache[self.cam_name]["tags"]=det
                if ids is not None:
                    self.calibrator.add_detection(self.cam_name,ids,corners,ts)
                    self.calibrator.try_compute_extrinsic(self.cam_name)
                    if self.cam_name in self.calibrator.extrinsics:
                        try:
                            K, dist = self.calibrator.load_intrinsics(self.cam_name)
                            for i,idarr in enumerate(ids):
                                tid = int(idarr[0])
                                if tid in (4,5):
                                    corners_i = np.array(corners[i]).reshape(4,2)
                                    T_tag_cam = estimate_pose_apriltag(corners_i, TAG_SIZES[tid], K, dist)
                                    R_tag_cam = T_tag_cam[:3,:3]
                                    t_tag_cam = T_tag_cam[:3,3].reshape(3,)   # tag origin in camera frame

                                    # compute the 3D positions of the 4 tag corners in camera frame
                                    half = TAG_SIZES[tid] / 2.0
                                    objp = np.array([
                                        [-half,  half, 0.0],
                                        [ half,  half, 0.0],
                                        [ half, -half, 0.0],
                                        [-half, -half, 0.0]
                                    ], dtype=np.float64)   # 4x3
                                    cam_corners = (R_tag_cam @ objp.T).T + t_tag_cam.reshape(1,3)  # 4x3

                                    tag_origin_world = self.calibrator.cam_to_world(self.cam_name, t_tag_cam)
                                    world_corners = self.calibrator.cam_to_world(self.cam_name, cam_corners)  # 4x3

                                    mean_corner_z = float(np.mean(world_corners[:,2]))

                                    corrected_tag_origin_world = tag_origin_world.copy()
                                    if mean_corner_z < 0:
                                        corrected_tag_origin_world[2] = abs(corrected_tag_origin_world[2])
                                        print(f"[WARN] {self.cam_name} tag{tid} mean_corner_z negative ({mean_corner_z:.4f}). Flipping Z to {corrected_tag_origin_world[2]:.4f}")

                                    nominal_tag_height = 0.075
                                    if abs(corrected_tag_origin_world[2] - nominal_tag_height) > 0.05:
                                        # warn but do not overwrite automatically; instead print for debugging
                                        print(f"[INFO] {self.cam_name} tag{tid} measured z {corrected_tag_origin_world[2]:.3f}, nominal {nominal_tag_height:.3f}")

                                    with self.lock:
                                        self.shared_robot_poses[tid] = {
                                            'world_pos': corrected_tag_origin_world,
                                            'cam': self.cam_name,
                                            'ts': ts,
                                            'det_time': time.time(),
                                            'reproj_src': 'estimate_pose_apriltag',
                                            'mean_corner_z': mean_corner_z,
                                            'cam_t': t_tag_cam.copy()
                                        }
                                        print(f"tid: {tid}, pos: {corrected_tag_origin_world}")
                        except Exception as e:
                            print(f"[{self.cam_name}] robot pose compute error:", e)
                            traceback.print_exc()

            except Exception as e:
                print(f"[ERROR-{self.cam_name}] AprilTag Detection exception:", e)
                traceback.print_exc()

        print(f"[DETECT-{self.cam_name}] AprilTag Detector thread stopped")

    def stop(self):
        self.stop_flag = True

# BALL THREAD
class BallThread(threading.Thread):
    def __init__(self, cam_name, frame_queue, detect_cache, lock, shared_camera_candidates):
        super().__init__(daemon=True)
        self.cam_name = cam_name
        self.frame_queue = frame_queue
        self.detect_cache = detect_cache
        self.lock = lock
        self.stop_flag = False
        self.shared_camera_candidates = shared_camera_candidates
    
    def run(self):
        print(f"[{self.cam_name}] BallThread started")
        while not self.stop_flag:
            try:
                frame,ts=self.frame_queue.get(timeout=0.1)
            except queue.Empty:
                continue
            try:
                mo=hsv_mask_from_vals(frame,orange_hsvVals)
                mp=hsv_mask_from_vals(frame,purple_hsvVals)
                mc = cv2.bitwise_or(mo,mp)
                mc = postprocess_mask(mc)
                cand = find_candidate_contours(mc)
                dets = []
                for c,area,(x,y,w,h) in cand[:MAX_DETECTIONS_PER_CAM]:
                    M = cv2.moments(c)
                    if M["m00"] != 0:
                        cx = int(M["m10"]/M["m00"]); cy = int(M["m01"]/M["m00"])
                    else:
                        cx = x + w//2; cy = y + h//2
                    s_or = int(np.count_nonzero(mo[y:y+h,x:x+w])) if mo is not None else 0
                    s_pu = int(np.count_nonzero(mp[y:y+h,x:x+w])) if mp is not None else 0
                    if s_or>s_pu and s_or>0:
                        col="orange"
                    elif s_pu>s_or and s_pu>0:
                        col="purple"
                    else:
                        hsv_roi=cv2.cvtColor(frame[y:y+h,x:x+w],cv2.COLOR_BGR2HSV)
                        mean_h=int(np.mean(hsv_roi[:,:,0]))
                        if orange_hsvVals["hmin"]<=mean_h<=orange_hsvVals["hmax"]:
                            col="orange"
                        elif purple_hsvVals["hmin"]<=mean_h<=purple_hsvVals["hmax"]:
                            col="purple"
                        else:
                            col="unknown"
                    dets.append({
                        "bbox":(x,y,w,h),
                        "centroid":(cx,cy),
                        "area":float(area),
                        "color":col,
                        "ts":ts,
                        "det_time":time.time()
                    })
                with self.lock:
                    if self.cam_name not in self.detect_cache:
                        self.detect_cache[self.cam_name]={}
                    self.detect_cache[self.cam_name]["balls"] = dets
                    self.shared_camera_candidates[self.cam_name] = [{"centroid":d["centroid"], "area":d["area"], "color":d["color"], "ts":d["ts"]} for d in dets]
            except Exception as e:
                print(f"[ERROR-{self.cam_name}] ball detection exception:", e)
                traceback.print_exc()
        print(f"[{self.cam_name}] BallDetectorThread stopped")

    def stop(self):
        self.stop_flag = True

# Calibrator Class
class StaticCalibrator:
    def __init__(self, tag_world_map, tag_size_map):
        self.tag_world_map = tag_world_map
        self.tag_size_map = tag_size_map
        self.obs = defaultdict(list)
        self.extrinsics = {}
        self.frame_count = defaultdict(int)
        self.P_cache = {}
        self.K_cache = {}
        self.dist_cache = {}

    def load_intrinsics(self, cam_name):
        if cam_name in self.K_cache:
            return self.K_cache[cam_name], self.dist_cache[cam_name]
        camera_matrix, dist_coeffs = load_camera_calib(cam_name)
        self.K_cache[cam_name] = camera_matrix
        self.dist_cache[cam_name] = dist_coeffs
        return camera_matrix, dist_coeffs

    def add_detection(self, cam_name, ids, corners, ts):
        if ids is None:
            return
        self.frame_count[cam_name] += 1
        for i, idarr in enumerate(ids):
            tid = int(idarr[0])
            if tid in STATIC_TAG_IDS:
                c = np.array(corners[i]).reshape(4,2).astype(np.float64)
                self.obs[cam_name].append((tid, c, ts))

    def try_compute_extrinsic(self, cam_name):
        if cam_name in self.extrinsics:
            return True
        if self.frame_count.get(cam_name, 0) < CALIB_FRAMES:
            return False
        if cam_name == "kreo1":
            target_tag = 2
        elif cam_name == "kreo2":
            target_tag = 1
        else:
            return False
        obs_list = list(reversed(self.obs.get(cam_name, [])))
        use_corners = None
        for (tid, corners, ts) in obs_list:
            if int(tid) == int(target_tag):
                use_corners = corners.reshape(4,2).astype(np.float64)
                break
        if use_corners is None:
            return False
        try:
            K, dist = self.load_intrinsics(cam_name)
        except Exception as e:
            return False
        obj_corners = np.array(self.tag_world_map[target_tag], dtype=np.float64)
        ok, rvec, tvec = cv2.solvePnP(
            obj_corners,
            use_corners,
            K, dist,
            flags=cv2.SOLVEPNP_ITERATIVE
        )
        if not ok:
            return False
        R, _ = cv2.Rodrigues(rvec)
        tvec = tvec.reshape(3,1)
        self.extrinsics[cam_name] = {"rvec": rvec, "tvec": tvec, "R": R}
        P = K @ np.hstack((R, tvec))
        self.P_cache[cam_name] = P
        print(f"[Calib] extrinsics computed for {cam_name}: tvec={tvec.ravel()}")
        return True

    def cam_to_world(self, cam_name, X_cam):
        e = self.extrinsics.get(cam_name)
        if e is None:
            raise RuntimeError("Calibrator: extrinsic not ready for " + cam_name)
        R = e['R']; t = e['tvec']
        X = np.asarray(X_cam, dtype=np.float64)
        if X.ndim == 1 and X.shape[0] == 3:
            Xc = X.reshape(3,1)
            Xw = R.T @ (Xc - t)
            return Xw[:,0]
        if X.ndim == 2 and X.shape[1] == 3:
            Xc = X.T
            Xw = R.T @ (Xc - t)
            return Xw.T
        if X.ndim == 2 and X.shape[0] == 3:
            Xc = X
            Xw = R.T @ (Xc - t)
            return Xw.T
        raise ValueError("Invalid X_cam shape: " + str(X.shape))

    def projection_matrix(self, cam_name):
        return self.P_cache.get(cam_name, None)
    


# ---------- Sliding-window predictor ----------
class SlidingWindowPredictor:
    def __init__(self, z_target=Z_TARGET, buffer_size=BUFFER_SIZE, min_points=MIN_POINTS, lag_ms=LAG_MS, v_max=V_MAX, log_path=PRED_LOG):
        self.z_target = float(z_target)
        self.buffer_size = int(buffer_size)
        self.min_points = int(min_points)
        self.lag_s = float(lag_ms) / 1000.0
        self.v_max = float(v_max)
        self.buf = deque(maxlen=self.buffer_size)
        self.robot_pos = (0.0, 0.0)
        self.robot_ts = None
        self.log_path = log_path
        self._init_log()

    def _init_log(self):
        hdr = [
            "now_ts",
            "cur_ts",
            "cur_x","cur_y","cur_z",
            "pred_ts",
            "pred_x","pred_y","pred_z",
            "robot_x","robot_y","robot_ts",
            "vreq_x","vreq_y","vreq_mag",
            "can_reach",
            "sigma_x","sigma_y","sigma_z",
            "notes"
        ]
        with open(self.log_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(hdr)

    def update_robot(self, robot_x, robot_y, robot_ts=None):
        self.robot_pos = (float(robot_x), float(robot_y))
        self.robot_ts = (time.time() if robot_ts is None else float(robot_ts))

    def add_point(self, ts, x, y, z):
        self.buf.append((float(ts), float(x), float(y), float(z)))

    def _fit_poly(self, t, y, order):
        # t: 1D array of relative times (t - t_last)
        # y: 1D array
        A = np.vander(t, N=order+1, increasing=False)
        try:
            theta, residuals, rank, s = np.linalg.lstsq(A, y, rcond=None)
        except Exception:
            return None, None, None
        n = len(y)
        p = order + 1
        if n <= p:
            sigma2 = 0.0
        else:
            ssr = residuals[0] if len(residuals) > 0 else np.sum((y - A.dot(theta))**2)
            sigma2 = ssr / max(1, (n - p))
        XtX = A.T.dot(A)
        try:
            cov_theta = sigma2 * np.linalg.inv(XtX)
        except np.linalg.LinAlgError:
            cov_theta = None
        return theta, cov_theta, sigma2

    def _predict_from_poly(self, theta, order, t0):
        powers = np.array([t0**i for i in range(order, -1, -1)])
        return float(theta.dot(powers))

    def _propagate_covariance(self, cov_theta, order, t0):
        if cov_theta is None:
            return None
        H = np.array([t0**i for i in range(order, -1, -1)])
        var = float(H.dot(cov_theta).dot(H.T))
        return var if var >= 0.0 else 0.0

    def predict_hit(self):
        if len(self.buf) < self.min_points:
            return None
        arr = np.array(self.buf)
        t = arr[:,0]
        t_rel = t - t[-1]
        z = arr[:,3]
        # fit quadratic z(t_rel) = a t^2 + b t + c
        order_z = 2
        theta_z, cov_z, sigma2_z = self._fit_poly(t_rel, z, order_z)
        if theta_z is None:
            return None
        a,b,c = theta_z[0], theta_z[1], theta_z[2]
        A = a; B = b; C = c - self.z_target
        roots = []
        if abs(A) < 1e-12:
            if abs(B) < 1e-12:
                return None
            root = -C / B
            roots = [root]
        else:
            disc = B*B - 4*A*C
            if disc < 0:
                return None
            sqrt_d = math.sqrt(disc)
            r1 = (-B + sqrt_d) / (2*A)
            r2 = (-B - sqrt_d) / (2*A)
            roots = [r1, r2]
        future_roots = [r for r in roots if r > 1e-6]
        if not future_roots:
            return None
        rel_hit = min(future_roots)
        abs_hit_ts = t[-1] + rel_hit
        x = arr[:,1]
        y = arr[:,2]
        # fit x,y adaptively with linear or quadratic by checking residual variance
        theta_x1, cov_x1, s2x1 = self._fit_poly(t_rel, x, 1)
        theta_x2, cov_x2, s2x2 = self._fit_poly(t_rel, x, 2)
        use_order_x = 1 if (theta_x1 is not None and (s2x1 <= (s2x2 + 1e-9))) else 2
        theta_x = theta_x1 if use_order_x==1 else theta_x2
        cov_x = cov_x1 if use_order_x==1 else cov_x2
        theta_y1, cov_y1, s2y1 = self._fit_poly(t_rel, y, 1)
        theta_y2, cov_y2, s2y2 = self._fit_poly(t_rel, y, 2)
        use_order_y = 1 if (theta_y1 is not None and (s2y1 <= (s2y2 + 1e-9))) else 2
        theta_y = theta_y1 if use_order_y==1 else theta_y2
        cov_y = cov_y1 if use_order_y==1 else cov_y2
        pred_x = self._predict_from_poly(theta_x, use_order_x, rel_hit)
        pred_y = self._predict_from_poly(theta_y, use_order_y, rel_hit)
        pred_z = self._predict_from_poly(theta_z, 2, rel_hit)
        var_x = self._propagate_covariance(cov_x, use_order_x, rel_hit)
        var_y = self._propagate_covariance(cov_y, use_order_y, rel_hit)
        var_z = self._propagate_covariance(cov_z, 2, rel_hit)
        return {
            "pred_ts": float(abs_hit_ts),
            "pred_xy_z": (float(pred_x), float(pred_y), float(pred_z)),
            "var_xyz": (var_x if var_x is not None else float("nan"),
                        var_y if var_y is not None else float("nan"),
                        var_z if var_z is not None else float("nan")),
            "last_point": (float(arr[-1,1]), float(arr[-1,2]), float(arr[-1,3])),
            "last_ts": float(arr[-1,0])
        }

    def compute_and_log(self):
        now = time.time()
        pred = self.predict_hit()
        if pred is None:
            return None
        pred_ts = pred["pred_ts"]
        px,py,pz = pred["pred_xy_z"]
        varx,vary,varz = pred["var_xyz"]
        last_x,last_y,last_z = pred["last_point"]
        last_ts = pred["last_ts"]
        time_remaining = pred_ts - now - self.lag_s
        notes = ""
        rx, ry = self.robot_pos
        if time_remaining <= 0:
            vreq_x = float("nan"); vreq_y = float("nan"); vreq_mag = float("nan")
            can_reach = False
            notes = "hit_time_passed_or_too_close"
        else:
            vreq_x = (px - rx) / time_remaining
            vreq_y = (py - ry) / time_remaining
            vreq_mag = math.hypot(vreq_x, vreq_y)
            can_reach = (vreq_mag <= self.v_max)
        row = [
            now,
            last_ts,
            last_x, last_y, last_z,
            pred_ts,
            px, py, pz,
            rx, ry, (self.robot_ts if self.robot_ts is not None else float("nan")),
            vreq_x, vreq_y, vreq_mag,
            bool(can_reach),
            math.sqrt(varx) if not (varx is None or (isinstance(varx, float) and math.isnan(varx))) else float("nan"),
            math.sqrt(vary) if not (vary is None or (isinstance(vary, float) and math.isnan(vary))) else float("nan"),
            math.sqrt(varz) if not (varz is None or (isinstance(varz, float) and math.isnan(varz))) else float("nan"),
            notes
        ]
        with open(self.log_path, "a", newline="") as f:
            w = csv.writer(f)
            w.writerow(row)
        return {
            "now": now,
            "last_point": (last_x,last_y,last_z),
            "last_ts": last_ts,
            "pred_ts": pred_ts,
            "pred_point": (px,py,pz),
            "var_xyz": (varx,vary,varz),
            "robot_pos": (rx,ry),
            "vreq": (vreq_x,vreq_y,vreq_mag),
            "can_reach": can_reach,
            "notes": notes
        }

# ---------- rest of your triangulation pipeline (modified only to call predictor) ----------

# Triangulation helpers (same as your original code)
def undistort_points_to_normalized(K, dist, pts):
    if len(pts) == 0:
        return np.zeros((0,2), dtype=np.float64)
    pts = np.array(pts, dtype=np.float64).reshape(-1,1,2)
    und = cv2.undistortPoints(pts, K, dist, P=K)
    und = und.reshape(-1,2)
    return und

def triangulate_pair(P1, P2, pt1, pt2):
    pts1 = np.array(pt1, dtype=np.float64).reshape(2,1)
    pts2 = np.array(pt2, dtype=np.float64).reshape(2,1)
    Xh = cv2.triangulatePoints(P1, P2, pts1, pts2)
    X = (Xh[:3,0] / Xh[3,0]).astype(np.float64)
    return X

def reprojection_error(P, X_world, observed_xy):
    Xh = np.zeros((4,1), dtype=np.float64)
    Xh[:3,0] = X_world
    Xh[3,0] = 1.0
    proj = P @ Xh
    proj_xy = proj[:2,0] / proj[2,0]
    return float(np.linalg.norm(proj_xy - np.array(observed_xy, dtype=np.float64)))

def match_and_triangulate(camera_candidates, calibrator, max_pairs=128, reproj_thresh_px=6.0, depth_min=0.03, depth_max=10.0):
    cams = list(camera_candidates.keys())
    if len(cams) < 2:
        return []
    pair = None
    if 'kreo1' in camera_candidates and 'kreo2' in camera_candidates:
        pair = ('kreo1', 'kreo2')
    else:
        pair = (cams[0], cams[1])
    c1, c2 = pair
    cand1 = sorted(camera_candidates[c1], key=lambda x: -x.get('area',1))[:8]
    cand2 = sorted(camera_candidates[c2], key=lambda x: -x.get('area',1))[:8]
    P1 = calibrator.projection_matrix(c1)
    P2 = calibrator.projection_matrix(c2)
    if P1 is None or P2 is None:
        return []
    results = []
    for a in cand1:
        for b in cand2:
            if len(results) >= max_pairs:
                break
            x1 = a['centroid']
            x2 = b['centroid']
            try:
                Xw = triangulate_pair(P1, P2, x1, x2)
            except Exception:
                continue
            e1 = calibrator.extrinsics[c1]
            R1, t1 = e1['R'], e1['tvec']
            Xcam1 = R1 @ Xw.reshape(3,1) + t1
            depth1 = float(Xcam1[2,0])
            e2 = calibrator.extrinsics[c2]
            R2, t2 = e2['R'], e2['tvec']
            Xcam2 = R2 @ Xw.reshape(3,1) + t2
            depth2 = float(Xcam2[2,0])
            if not (depth_min < depth1 < depth_max and depth_min < depth2 < depth_max):
                continue
            err1 = reprojection_error(P1, Xw, x1)
            err2 = reprojection_error(P2, Xw, x2)
            tot = err1 + err2
            if tot > (2*reproj_thresh_px):
                continue
            results.append({'pt': Xw, 'reproj_err': tot, 'pair': (c1,c2), 'cam_pixels': (x1,x2), 'depths':(depth1,depth2)})
    results = sorted(results, key=lambda r: r['reproj_err'])
    return results

# ---------- ZMQ subscriber setup ----------
ctx = zmq.Context()
sub = ctx.socket(zmq.SUB)
sub.connect(ZMQ_ADDR)
sub.setsockopt(zmq.RCVHWM, 1)
sub.setsockopt(zmq.CONFLATE, 1)
sub.setsockopt(zmq.LINGER, 0)

# flush existing
flushed = 0
while True:
    try:
        sub.recv_multipart(flags=zmq.NOBLOCK)
        flushed += 1
    except zmq.Again:
        break
if flushed > 0:
    print(f"[Subscriber] Flushed {flushed} stale messages.")
for t in SUB_TOPICS:
    sub.setsockopt(zmq.SUBSCRIBE, t)

# per-camera structures
frames = {}
fps_windows = defaultdict(lambda: deque())
frame_queues = {t.decode(): queue.Queue(maxsize=1) for t in SUB_TOPICS}
detect_cache = {}
detect_lock = threading.Lock()
tag_threads={}
ball_threads={}

# shared outputs
calibrator = StaticCalibrator(TAG_WORLD_MAP,TAG_SIZES)
shared_camera_candidates = defaultdict(list)
shared_robot_poses = {}
last_triangulated_ball = None

# Instantiate threads for each camera
for t in SUB_TOPICS:
    cam_name = t.decode()
    tag_threads[cam_name] = AprilTagThread(cam_name,frame_queues[cam_name],detect_cache,detect_lock,calibrator,shared_robot_poses)
    ball_threads[cam_name] = BallThread(cam_name,frame_queues[cam_name],detect_cache,detect_lock,shared_camera_candidates)
    tag_threads[cam_name].start()
    ball_threads[cam_name].start()

print("[Subscriber] connected, waiting for frames... (Press ESC to exit)")

last_show = time.time()
predictor = SlidingWindowPredictor(z_target=Z_TARGET, buffer_size=BUFFER_SIZE, min_points=MIN_POINTS, lag_ms=LAG_MS, v_max=V_MAX, log_path=PRED_LOG)

try:
    while True:
        parts=recv_latest(sub)
        if parts is None: continue
        topic=parts[0]; cam=topic.decode()
        if len(parts)>=3:
            ts_part=parts[1]; jpg_part=parts[2]
        else:
            ts_part=None; jpg_part=parts[1]
        recv_t=time.time()
        try: cam_ts=float(ts_part.decode()) if ts_part else recv_t
        except: cam_ts=recv_t
        img=cv2.imdecode(np.frombuffer(jpg_part,np.uint8),cv2.IMREAD_COLOR)
        if img is None: continue
        fps=update_fps(cam,cam_ts)
        frames[cam]={"img":img,"cam_ts":cam_ts,"fps":fps}
        fq=frame_queues[cam]
        try: fq.get_nowait()
        except: pass
        try: fq.put_nowait((img.copy(),cam_ts))
        except: pass
        with detect_lock:
            ready = all(c in calibrator.extrinsics for c in [t.decode() for t in SUB_TOPICS])
        if ready:
            with detect_lock:
                tri_results = match_and_triangulate(shared_camera_candidates, calibrator)
            if len(tri_results) > 0:
                best = tri_results[0]
                Xw = best['pt'].ravel()
                now_tri_ts = time.time()
                last_triangulated_ball = {'pos': Xw.copy(), 'err': best['reproj_err'], 'ts': now_tri_ts, 'pair': best['pair'], 'depths':best['depths']}
                # add to sliding predictor
                predictor.add_point(now_tri_ts, float(Xw[0]), float(Xw[1]), float(Xw[2]))
                # update robot pos from shared_robot_poses if available (choose first tag entry)
                with detect_lock:
                    if shared_robot_poses:
                        # pick the first available tag's world_pos
                        first_tid = next(iter(shared_robot_poses.keys()))
                        rpos = shared_robot_poses[first_tid].get('world_pos', None)
                        if rpos is not None:
                            predictor.update_robot(float(rpos[0]), float(rpos[1]), robot_ts=shared_robot_poses[first_tid].get('ts', None))
                    else:
                        # keep previous robot pos
                        pass
                out = predictor.compute_and_log()
                if out is not None:
                    # print summary
                    px,py,pz = out['pred_point']
                    vreq = out['vreq']
                    can = out['can_reach']
                    print(f"[PRED] now={fmt_ts(out['now'])} pred_ts={fmt_ts(out['pred_ts'])} pred_xy=({px:.3f},{py:.3f}) vreq={vreq[2]:.3f} can_reach={can}")
        # visualization
        if all(k in frames for k in [t.decode() for t in SUB_TOPICS]):
            cams=[t.decode() for t in SUB_TOPICS]
            L=frames[cams[0]]; R=frames[cams[1]]
            drift_ms=abs(L["cam_ts"]-R["cam_ts"])*1000.0
            def overlay(F,cam_name):
                im=F["img"].copy()
                y=20
                cv2.putText(im,f"{cam_name}",(10,y),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,32,20),2)
                cv2.putText(im,f"FPS:{F['fps']:.1f}",(10,y+26),cv2.FONT_HERSHEY_SIMPLEX,0.6,(14,117,5),2)
                cv2.putText(im,f"cam_ts:{fmt_ts(F['cam_ts'])}",(10,y+52),cv2.FONT_HERSHEY_SIMPLEX,0.5,(5,12,117),1)
                with detect_lock:
                    block=detect_cache.get(cam_name,{})
                    if "tags" in block and block["tags"]["ids"] is not None:
                        cs=block["tags"]["corners"]
                        ids=block["tags"]["ids"]
                        cv2.aruco.drawDetectedMarkers(im,cs,ids)
                    if "balls" in block:
                        for i,d in enumerate(block["balls"]):
                            x,y,w,h=d["bbox"]
                            cx,cy=d["centroid"]
                            color=d["color"]
                            if color=="orange": bc=(0,200,255)
                            elif color=="purple": bc=(200,0,200)
                            else: bc=(0,200,200)
                            cv2.rectangle(im,(x,y),(x+w,y+h),bc,2)
                            cv2.circle(im,(cx,cy),4,(0,0,255),-1)
                            cv2.putText(im,f"{color}:{i}",(x,y-6),
                                        cv2.FONT_HERSHEY_SIMPLEX,0.5,bc,2)
                    text_y = 80
                    if shared_robot_poses:
                        for tid,info in shared_robot_poses.items():
                            pos = info.get('world_pos')
                            if pos is not None:
                                tx = f"Tag{tid} world: {pos[0]:.3f},{pos[1]:.3f},{pos[2]:.3f}"
                                cv2.putText(im, tx, (10, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,0), 1)
                                text_y += 18
                    if cam_name == cams[0] and last_triangulated_ball is not None:
                        p = last_triangulated_ball['pos']
                        tx = f"Ball world: {p[0]:.3f},{p[1]:.3f},{p[2]:.3f} err:{last_triangulated_ball['err']:.2f}"
                        cv2.putText(im, tx, (10, text_y), cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,255,255),1)
                return im
            if VISUALIZE and (time.time()-last_show)>(1.0/DISPLAY_FPS):
                last_show=time.time()
                left_im=overlay(L,cams[0])
                right_im=overlay(R,cams[1])
                h=max(left_im.shape[0],right_im.shape[0])
                right_res=cv2.resize(right_im,(left_im.shape[1],h))
                tile=np.hstack([left_im,right_res])
                cv2.putText(tile,f"Drift:{drift_ms:.1f}ms",(10,20),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,0,255),2)
                cv2.putText(tile, f"Host now: {fmt_ts(time.time())}", (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
                cv2.imshow("Combined Detection",tile)
            elif not VISUALIZE:
                with detect_lock:
                    tag_summary = []
                    parts_status = []
                    for c in cams:
                        block = detect_cache.get(c,{})
                        if "tags" in block and block["tags"]["ids"] is not None:
                            ids=block["tags"]["ids"]
                            tag_summary.append(f"{c}: {ids.flatten().tolist()}")
                        elif "tags" in block:
                            tag_summary.append(f"{c}: No tags")
                        else:
                            tag_summary.append(f"{c}: No detection data")
                        if "balls" in block:
                            counts = { "orange":0, "purple":0, "unknown":0 }
                            for i,d in enumerate(block["balls"]):
                                counts[d.get("color","unknown")] = counts.get(d.get("color","unknown"),0) + 1
                            parts_status.append(f"{c}: Orange: {counts['orange']} Purple: {counts['purple']}")
                        else:
                            parts_status.append(f"{c}:NoBall")    
                status = (
                    f"Drift {drift_ms:.1f} ms | "
                    f"{cams[0]} ts: {fmt_ts(L['cam_ts'])} | "
                    f"{cams[1]} ts: {fmt_ts(R['cam_ts'])} | "
                    f"Host now: {fmt_ts(time.time())} | "
                    f"{cams[0]} FPS: {L['fps']:.1f} | "
                    f"{cams[1]} FPS: {R['fps']:.1f} | "
                    f"Tags:" + ",".join(tag_summary)
                )
                sys.stdout.write("\r" + status + " " * 20 + " | ".join(parts_status) + " " * 20)
                sys.stdout.flush()
        if cv2.waitKey(1)&0xFF==27:
            break

except KeyboardInterrupt:
    pass

finally:
    for t in tag_threads.values(): t.stop()
    for t in ball_threads.values(): t.stop()
    time.sleep(0.1)
    cv2.destroyAllWindows()
    sub.close()
    ctx.term()
    print("Exit clean.")





[kreo1] AprilTagThread started
[kreo1] BallThread started
[kreo2] AprilTagThread started
[kreo2] BallThread started
[Subscriber] connected, waiting for frames... (Press ESC to exit)
[INFO] Loaded calibrated camera parameters for kreo2
[Calib] extrinsics computed for kreo2: tvec=[-0.36441388  0.67253774  2.31151051]
[INFO] Loaded calibrated camera parameters for kreo1
[Calib] extrinsics computed for kreo1: tvec=[ 0.19112282 -0.1047647   2.91082942]
[INFO] kreo1 tag4 measured z 0.018, nominal 0.075
tid: 4, pos: [0.2563772  0.37775576 0.01796831]
[PRED] now=2025-11-18 19:37:35.813 pred_ts=2025-11-18 19:37:36.224 pred_xy=(1.129,-0.492) vreq=3.116 can_reach=False
[WARN] kreo1 tag4 mean_corner_z negative (-0.0065). Flipping Z to 0.0065
[INFO] kreo1 tag4 measured z 0.006, nominal 0.075
tid: 4, pos: [0.24723829 0.3632867  0.00646951]
[PRED] now=2025-11-18 19:37:35.846 pred_ts=2025-11-18 19:37:36.831 pred_xy=(1.035,-0.049) vreq=0.917 can_reach=True
[PRED] now=2025-11-18 19:37:35.857 pred_ts=202