# Name - Parikshit Sahu
## Code alpha internship task 4
## Object Detection and Trackings

In [1]:
import torch
import torchvision
import cv2
import numpy as np
from scipy.spatial import distance

class CentroidTracker:
    def __init__(self, max_disappeared=50):
        self.nextObjectID = 0  # Next object ID to assign
        self.objects = {}  # Maps object IDs to centroids
        self.disappeared = {}  # Tracks how long each object has disappeared
        self.max_disappeared = max_disappeared  # Max frames an object can disappear before being removed

    def register(self, centroid):
        # Register a new object with an ID and its centroid
        self.objects[self.nextObjectID] = centroid
        self.disappeared[self.nextObjectID] = 0
        self.nextObjectID += 1

    def deregister(self, objectID):
        # Deregister an object by removing it from tracking
        del self.objects[objectID]
        del self.disappeared[objectID]

    def update(self, rects):
        # Update object tracking given detected bounding boxes
        if len(rects) == 0:
            for objectID in list(self.disappeared.keys()):
                self.disappeared[objectID] += 1
                if self.disappeared[objectID] > self.max_disappeared:
                    self.deregister(objectID)
            return self.objects
        
        input_centroids = np.zeros((len(rects), 2), dtype="int")
        
        # Calculate the centroids of bounding boxes
        for (i, (startX, startY, endX, endY)) in enumerate(rects):
            cX = int((startX + endX) / 2.0)
            cY = int((startY + endY) / 2.0)
            input_centroids[i] = (cX, cY)

        if len(self.objects) == 0:
            for i in range(len(input_centroids)):
                self.register(input_centroids[i])
        else:
            objectIDs = list(self.objects.keys())
            object_centroids = list(self.objects.values())

            # Compute distance matrix between each pair of object and input centroid
            D = distance.cdist(np.array(object_centroids), input_centroids)
            rows = D.min(axis=1).argsort()
            cols = D.argmin(axis=1)[rows]

            used_rows, used_cols = set(), set()
            
            # Update tracked objects with new centroids
            for (row, col) in zip(rows, cols):
                if row in used_rows or col in used_cols:
                    continue

                objectID = objectIDs[row]
                self.objects[objectID] = input_centroids[col]
                self.disappeared[objectID] = 0
                used_rows.add(row)
                used_cols.add(col)

            unused_rows = set(range(0, D.shape[0])).difference(used_rows)
            unused_cols = set(range(0, D.shape[1])).difference(used_cols)

            # Mark objects that are no longer detected as disappeared
            for row in unused_rows:
                objectID = objectIDs[row]
                self.disappeared[objectID] += 1
                if self.disappeared[objectID] > self.max_disappeared:
                    self.deregister(objectID)

            # Register new objects
            for col in unused_cols:
                self.register(input_centroids[col])

        return self.objects

# Load Faster R-CNN model
def load_fasterrcnn_model():
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model

# Detect objects using Faster R-CNN
def detect_objects(frame, model):
    transform = torchvision.transforms.ToTensor()  # Convert to tensor
    frame_tensor = transform(frame).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        predictions = model(frame_tensor)
    
    boxes = predictions[0]['boxes'].numpy()
    labels = predictions[0]['labels'].numpy()
    scores = predictions[0]['scores'].numpy()

    detection_threshold = 0.5
    filtered_boxes = []
    for i in range(len(scores)):
        if scores[i] > detection_threshold:
            filtered_boxes.append(boxes[i])

    return filtered_boxes

# Draw bounding boxes and track IDs
def draw_tracks(frame, tracked_objects):
    for (objectID, centroid) in tracked_objects.items():
        cv2.putText(frame, f"ID {objectID}", (centroid[0] - 10, centroid[1] - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        cv2.circle(frame, (centroid[0], centroid[1]), 4, (0, 255, 0), -1)

    return frame

def main():
    # Load Faster R-CNN and initialize the centroid tracker
    model = load_fasterrcnn_model()
    tracker = CentroidTracker()

    # Initialize video capture
    cap = cv2.VideoCapture(0)  # Replace '0' with a video file path if needed

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Object detection using Faster R-CNN
        detections = detect_objects(frame, model)

        # Update the tracker with new detections
        tracked_objects = tracker.update(detections)

        # Draw detection and tracking results
        frame = draw_tracks(frame, tracked_objects)

        # Show the result
        cv2.imshow('Object Detection and Tracking', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()


