In [3]:
# object_tracking_occlusion_aware.py
# کد بهبود یافته
import cv2
import numpy as np
import itertools
import time
from scipy.optimize import linear_sum_assignment
from collections import deque
from filterpy.kalman import KalmanFilter
from ultralytics import YOLO

# --- ثابت‌های الگوریتم ---
CONFIDENCE_THRESHOLD = 0.6; MIN_DETECTION_AREA = 4000
IOU_COST_WEIGHT = 0.6; APPEARANCE_COST_WEIGHT = 0.4; MAX_DISTANCE_GATE = 0.9
MAX_FRAMES_TO_SKIP = 30; MIN_HITS_TO_CONFIRM = 3; MAX_LOST_TRACKER_AGE = 120
# [تغییر کلیدی] پارامترهای مدیریت هم‌پوشانی
OCCLUSION_IOU_THRESHOLD = 0.1 # آستانه IoU برای تشخیص هم‌پوشانی
GALLERY_MAX_SIZE = 15; REID_CORRELATION_THRESHOLD = 0.82
COLOR_ACTIVE = (0,255,0); HISTOGRAM_BINS = [8,8,8]

def calculate_iou(boxA,boxB):
    xA=max(boxA[0],boxB[0]);yA=max(boxA[1],boxB[1]);xB=min(boxA[0]+boxA[2],boxB[0]+boxB[2]);yB=min(boxA[1]+boxA[3],boxB[1]+boxB[3])
    interArea=max(0,xB-xA)*max(0,yB-yA);boxAArea=boxA[2]*boxA[3];boxBArea=boxB[2]*boxB[3]
    denominator=float(boxAArea+boxBArea-interArea)
    return 0.0 if denominator==0 else interArea/denominator

class ObjectTracker:
    def __init__(self, tracker_id, bbox, frame, dt):
        self.id=tracker_id; self.kf=KalmanFilter(dim_x=8,dim_z=4); self.kf.F=np.eye(8)
        self.kf.H=np.zeros((4,8));self.kf.H[0,0]=1;self.kf.H[1,1]=1;self.kf.H[2,2]=1;self.kf.H[3,3]=1
        self.kf.Q[4:,4:]*=10.;self.kf.Q[:4,:4]*=1.;self.kf.R*=5.;self.kf.P*=10.
        self.update_F(dt); self.histogram_gallery=deque(maxlen=GALLERY_MAX_SIZE)
        self.re_activate(bbox,frame); self.state='TENTATIVE'

    def update_F(self,dt): self.kf.F[0,4]=dt;self.kf.F[1,5]=dt;self.kf.F[2,6]=dt;self.kf.F[3,7]=dt

    @staticmethod
    def extract_histogram(frame,bbox):
        x,y,w,h=map(int,bbox);roi=frame[y:y+h,x:x+w];
        if roi.size==0: return None
        hsv_roi=cv2.cvtColor(roi,cv2.COLOR_BGR2HSV);hist=cv2.calcHist([hsv_roi],[0,1,2],None,HISTOGRAM_BINS,[0,180,0,256,0,256])
        cv2.normalize(hist,hist); return hist.flatten()

    def predict(self,dt):
        self.update_F(dt); self.kf.predict()
        if self.state != 'OCCLUDED': self.frames_since_update += 1
        x,y,w,h,_,_,_,_=self.kf.x.flatten(); return (x-w/2,y-h/2,w,h)

    def update(self,bbox,frame):
        self.frames_since_update=0; self.hits+=1; self.state='CONFIRMED'
        if self.hits>=MIN_HITS_TO_CONFIRM: self.state='CONFIRMED'
        x,y,w,h=bbox;measurement=np.array([x+w/2,y+h/2,w,h]).reshape((4,1));self.kf.update(measurement)
        self.last_bbox=bbox; new_hist=ObjectTracker.extract_histogram(frame,bbox)
        if new_hist is not None: self.histogram_gallery.append(new_hist)

    def re_activate(self,bbox,frame):
        self.frames_since_update=0; self.hits=1; self.state='CONFIRMED'
        x,y,w,h=bbox; self.kf.x=np.array([x+w/2,y+h/2,w,h,0,0,0,0]).reshape((8,1))
        self.last_bbox=bbox; self.histogram_gallery.clear()
        new_hist=ObjectTracker.extract_histogram(frame,bbox)
        if new_hist is not None: self.histogram_gallery.append(new_hist)
        
# --- شروع برنامه اصلی ---
cap=cv2.VideoCapture("./assets/footage/person4.mp4"); model=YOLO("yolov8n.pt")
fps=cap.get(cv2.CAP_PROP_FPS); frame_delay=int(1000/fps) if fps>0 else 0
active_trackers=[]; lost_trackers=deque(maxlen=MAX_LOST_TRACKER_AGE); id_counter=itertools.count()
prev_time=time.time()

while cap.isOpened():
    start_time_processing = time.time(); dt=start_time_processing-prev_time; prev_time=start_time_processing
    ret, frame = cap.read();
    if not ret: break

    results = model(frame, verbose=False, classes=[0])[0]
    detections = [(int(b[0]),int(b[1]),int(b[2]-b[0]),int(b[3]-b[1])) for r in results.boxes if r.conf[0]>CONFIDENCE_THRESHOLD and (r.xyxy[0][2]-r.xyxy[0][0])*(r.xyxy[0][3]-r.xyxy[0][1])>MIN_DETECTION_AREA for b in r.xyxy]
    predicted_boxes = [t.predict(dt) for t in active_trackers]
    
    matched_trk_indices, matched_det_indices = set(), set()
    if detections and active_trackers:
        cost_matrix = np.ones((len(active_trackers), len(detections))) * 1e6
        for t_idx, tracker in enumerate(active_trackers):
            for d_idx, det in enumerate(detections):
                iou_cost=1-calculate_iou(predicted_boxes[t_idx],det)
                if iou_cost>MAX_DISTANCE_GATE: continue
                det_hist=ObjectTracker.extract_histogram(frame,det)
                if not tracker.histogram_gallery or det_hist is None: hist_cost=1.0
                else: correlations=[cv2.compareHist(gal_hist,det_hist,cv2.HISTCMP_CORREL) for gal_hist in tracker.histogram_gallery]; hist_cost=1-max(correlations)
                cost_matrix[t_idx,d_idx]=(IOU_COST_WEIGHT*iou_cost)+(APPEARANCE_COST_WEIGHT*hist_cost)
        
        trk_indices,det_indices=linear_sum_assignment(cost_matrix)
        for t_idx,d_idx in zip(trk_indices,det_indices):
            if cost_matrix[t_idx,d_idx]<1e5:
                active_trackers[t_idx].update(detections[d_idx],frame); matched_trk_indices.add(t_idx); matched_det_indices.add(d_idx)
    
    unmatched_trk_indices = set(range(len(active_trackers))) - matched_trk_indices
    # [منطق جدید] مدیریت هم‌پوشانی
    for t_idx in unmatched_trk_indices:
        is_occluded = False
        for m_idx in matched_trk_indices:
            if calculate_iou(predicted_boxes[t_idx], active_trackers[m_idx].last_bbox) > OCCLUSION_IOU_THRESHOLD:
                active_trackers[t_idx].state = 'OCCLUDED'
                is_occluded = True
                break
        if not is_occluded and active_trackers[t_idx].state == 'OCCLUDED':
            active_trackers[t_idx].state = 'CONFIRMED' # از حالت هم‌پوشانی خارج شد

    unmatched_det_indices = set(range(len(detections))) - matched_det_indices
    if unmatched_det_indices and lost_trackers:
        unmatched_dets=[detections[i] for i in unmatched_det_indices]
        reid_cost_matrix=np.ones((len(lost_trackers),len(unmatched_dets)))
        for lt_idx,lost_tracker in enumerate(lost_trackers):
            for ud_idx,det in enumerate(unmatched_dets):
                det_hist=ObjectTracker.extract_histogram(frame,det)
                if not lost_tracker.histogram_gallery or det_hist is None: max_corr=0.0
                else: correlations=[cv2.compareHist(gal_hist,det_hist,cv2.HISTCMP_CORREL) for gal_hist in lost_tracker.histogram_gallery]; max_corr=max(correlations)
                reid_cost_matrix[lt_idx,ud_idx]=1-max_corr
        
        lt_indices,ud_indices=linear_sum_assignment(reid_cost_matrix)
        revived_lt_indices=set()
        for lt_idx,ud_idx in zip(lt_indices,ud_indices):
            if reid_cost_matrix[lt_idx,ud_idx]<(1-REID_CORRELATION_THRESHOLD):
                lost_tracker=lost_trackers[lt_idx];original_d_idx=list(unmatched_det_indices)[ud_idx]
                lost_tracker.re_activate(detections[original_d_idx],frame);active_trackers.append(lost_tracker)
                revived_lt_indices.add(lt_idx);matched_det_indices.add(original_d_idx)
        if revived_lt_indices: lost_trackers=deque([lt for i,lt in enumerate(lost_trackers) if i not in revived_lt_indices],maxlen=MAX_LOST_TRACKER_AGE)

    for i in set(range(len(detections)))-matched_det_indices: active_trackers.append(ObjectTracker(next(id_counter),detections[i],frame,dt))
    next_active_trackers=[]
    for tracker in active_trackers:
        # تراکرهای تحت هم‌پوشانی را حذف نکن
        if tracker.frames_since_update<=MAX_FRAMES_TO_SKIP or tracker.state=='OCCLUDED':
            next_active_trackers.append(tracker)
            if tracker.state=='CONFIRMED' and tracker.frames_since_update==0:
                x1,y1,w,h=map(int,tracker.last_bbox);cv2.rectangle(frame,(x1,y1),(x1+w,y1+h),COLOR_ACTIVE,2)
                cv2.putText(frame,f"ID {tracker.id}",(x1,y1-10),cv2.FONT_HERSHEY_SIMPLEX,0.6,COLOR_ACTIVE,2)
        else:
            lost_trackers.append(tracker)
    active_trackers=next_active_trackers

    fps_proc=1/dt if dt>0 else 0
    cv2.putText(frame,f"FPS: {int(fps_proc)}",(20,40),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,255),2)
    cv2.imshow("Occlusion-Aware Tracking", frame)
    
    processing_time_ms=(time.time()-start_time_processing)*1000
    wait_time=max(1,frame_delay-int(processing_time_ms))
    if cv2.waitKey(wait_time) & 0xFF==ord('q'): break

cap.release()
cv2.destroyAllWindows()

KeyboardInterrupt: 