In [None]:
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction

from src.bot_sort import BoTSORT

In [None]:
class ObjectDetector:
    """
    Class to manage datasets and perform object detection.
    """
    def __init__(self, config_path, checkpoint_path, use_slicing=True):
        self.config_path = config_path
        self.checkpoint_path = checkpoint_path
        self.use_slicing = use_slicing
        self._initialize_model()

    def _initialize_model(self):
        """
        Initializes the model using the configuration and checkpoint paths.
        """
        self.model = AutoDetectionModel.from_pretrained(
            model_type='mmdet',
            model_path=self.checkpoint_path,
            config_path=self.config_path,
            confidence_threshold=0.4,
            category_mapping=None,
            image_size=None,
            device='cuda:0'
        )

        # Fetch the class names and colors from the model metadata
        self.class_labels = self.model.model.model.dataset_meta.get("classes")
        self.class_colors = self.model.model.model.dataset_meta.get("palette")

        # Map colors to class labels
        self.class_colors = {i: tuple(color) for i, color in enumerate(self.class_colors)}
        self.class_labels = {i: label for i, label in enumerate(self.class_labels)}

    def detect(self, rgb_frame, slice_height=720, slice_width=720, overlap_height_ratio=0.1, overlap_width_ratio=0.1):
        """
        Performs object detection on a given frame.
        """
        if self.use_slicing:
            result = get_sliced_prediction(
                rgb_frame,
                self.model,
                slice_height=slice_height,
                slice_width=slice_width,
                postprocess_match_threshold=0.7,
                postprocess_type="NMM",
                postprocess_class_agnostic=True,
                overlap_height_ratio=overlap_height_ratio,
                overlap_width_ratio=overlap_width_ratio,
                verbose=0
            )
        else:
            result = get_prediction(
                rgb_frame,
                self.model,
            )

        object_prediction_list = result.object_prediction_list
        detections = []
        for obj in object_prediction_list:
            bbox = obj.bbox.to_voc_bbox()
            score = obj.score.value
            label = obj.category.id
            detections.append((np.array([bbox[0], bbox[1], bbox[2], bbox[3]]), score, label))

        return detections

    def get_class_labels(self):
        """
        Returns the class labels.
        """
        return self.class_labels

    def get_class_colors(self):
        """
        Returns the class colors.
        """
        return self.class_colors

class Tracker:
    """
    Class to handle object tracking.
    """
    def __init__(self):
        self.track_high_thresh = 0.6
        self.track_low_thresh = 0.1
        self.new_track_thresh = 0.6
        self.track_buffer = 30
        self.proximity_thresh = 0.5
        self.appearance_thresh = 0.25
        self.with_reid = False
        self.device = 'cuda:0'
        self.cmc_method = 'orb'
        self.name = 'video'
        self.ablation = False
        self.mot20 = False
        self.match_thresh = 0.8
        self.tracker = BoTSORT(self)

    def update(self, detections, frame):
        """
        Updates the tracker with the detections for a new frame.
        """
        return self.tracker.update(detections, frame)


class Counter:
    """
    Class to count detected objects crossing a line.
    """
    def __init__(self, class_labels, class_colors, line_offset=0, draw_line=True):
        self.class_labels = class_labels
        self.class_colors = class_colors
        self.line_offset = line_offset
        self.crossed_line_ids = {cls_id: set() for cls_id in class_labels}
        self.track_class_mapping = {}
        self.track_trails = {}
        self.draw_line = draw_line

    def count_objects(self, tracked_objects, frame_width, frame):
        """
        Counts objects crossing the vertical line in the frame.
        """
        vertical_line_x = (frame_width // 2) + self.line_offset

        for t in tracked_objects:
            track_id = int(t.track_id)
            cls = t.cls

            if track_id not in self.track_class_mapping:
                self.track_class_mapping[track_id] = cls
                self.track_trails[track_id] = []

            bbox = t.tlwh
            center = (int(bbox[0] + bbox[2] // 2), int(bbox[1] + bbox[3] // 2))
            self.track_trails[track_id].append(center)

            if len(self.track_trails[track_id]) > 1:
                prev_center = self.track_trails[track_id][-2]
                if (prev_center[0] < vertical_line_x and center[0] >= vertical_line_x) or (prev_center[0] > vertical_line_x and center[0] <= vertical_line_x):
                    self.crossed_line_ids[cls].add(track_id)

        if self.draw_line:
            cv2.line(frame, (vertical_line_x, 0), (vertical_line_x, frame.shape[0]), (0, 255, 0), 2)

        return self.track_trails, self.track_class_mapping

    def get_counts(self):
        """
        Returns the count of objects that crossed the line.

        Returns:
            dict: Dictionary with class names and their counts.
        """
        return {self.class_labels[cls]: len(self.crossed_line_ids[cls]) for cls in self.class_labels}


class Plotting:
    """
    Class to add logo and text to frames.
    """
    def __init__(self, logo_path):
        self.logo = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED)

    def resize_logo(self, logo, new_width):
        """
        Resizes the logo to the specified width while maintaining aspect ratio.
        """
        h, w = logo.shape[:2]
        scale = new_width / w
        new_height = int(h * scale)
        logo = cv2.resize(logo, (new_width, new_height), interpolation=cv2.INTER_AREA)
        return logo

    def add_logo(self, frame, logo_width=235, alpha=0.3, padding=7, top_space=20):
        """
        Adds a logo to the frame at the top right corner with a fixed width.
        """
        # Resize the logo to the fixed width
        logo_resized = self.resize_logo(self.logo, logo_width)
        h, w = logo_resized.shape[:2]

        # Calculate the position to place the logo (top right corner)
        x = frame.shape[1] - w - padding
        y = padding + top_space

        overlay = frame.copy()
        cv2.rectangle(overlay, (x - padding, y - padding), (x + w + padding, y + h + padding), (255, 255, 255), -1)
        cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
        logo_rgb = logo_resized[..., :3]
        alpha_mask = logo_resized[..., 3] / 255.0
        roi = frame[y:y+h, x:x+w]
        for c in range(3):
            roi[:, :, c] = roi[:, :, c] * (1 - alpha_mask) + logo_rgb[:, :, c] * alpha_mask
        frame[y:y+h, x:x+w] = roi
        return frame

    def add_text(self, frame, text_lines, class_labels, class_colors, x_offset, y_offset, alpha=0.3, padding=7, top_space=20):
        """
        Adds text annotations to the frame.
        """
        y_offset += top_space
        max_text_width = max([cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0][0] for line in text_lines])
        total_text_height = len(text_lines) * 30
        overlay = frame.copy()
        cv2.rectangle(overlay, (x_offset - padding, y_offset - padding), (x_offset + max_text_width + 50 + padding, y_offset + total_text_height + padding), (255, 255, 255), -1)
        cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)

        current_y = y_offset + 30
        for line in text_lines:
            class_name = line.split(':')[0].strip()
            color = class_colors[list(class_labels.values()).index(class_name)]
            # Draw circle
            circle_radius = 8
            circle_y_offset = 10
            cv2.circle(frame, (x_offset + 15, current_y - circle_y_offset), circle_radius, color, -1)
            # Put text
            cv2.putText(frame, line, (x_offset + 30, current_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
            current_y += 30

        return frame

    def draw_tracks_and_ids(self, frame, tracked_objects, track_trails, track_class_mapping, class_labels, class_colors):
        """
        Draws tracking results and trails on the frame.
        """
        for t in tracked_objects:
            bbox = t.tlwh
            track_id = int(t.track_id)
            cls = track_class_mapping[track_id]

            x1, y1, w, h = map(int, bbox)
            x2, y2 = x1 + w, y1 + h
            label = f'{track_id}'

            color = class_colors[cls]

            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            # Create a filled rectangle for background of the track ID text
            (text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
            cv2.rectangle(frame, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
            cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)

            for point in track_trails[track_id]:
                cv2.circle(frame, point, 2, color, -1)

        return frame


class VideoProcessor:
    """
    Class to process videos with object detection and tracking.
    """
    def __init__(self, detector, tracker, counter, plotting):
        self.detector = detector
        self.tracker = tracker
        self.counter = counter
        self.plotting = plotting

    def process_video(self, video_path, output_path, save_video=True, save_csv=True, csv_output_path=None):
        """
        Processes the input video and saves the output video and/or CSV file.
        """
        cap = cv2.VideoCapture(video_path)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Initialize video writer if required
        if save_video:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
        
        csv_data = [] if save_csv else None

        with tqdm(total=total_frames, desc="Processing video", unit="frame") as pbar:
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break

                frame_id = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
                # Convert the frame from RGB to BGR
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

                detections = self.detector.detect(frame_rgb)
                output_results = np.array([[*bbox, score, label] for bbox, score, label in detections])
                
                tracked_objects = self.tracker.update(output_results, frame)

                # Count objects
                track_trails, track_class_mapping = self.counter.count_objects(tracked_objects, frame_width, frame_rgb)

                # Draw tracking results and trails on the frame
                frame_rgb = self.plotting.draw_tracks_and_ids(frame_rgb, tracked_objects, track_trails, track_class_mapping, self.counter.class_labels, self.counter.class_colors)

                # Get counts and add text to the frame
                counts = self.counter.get_counts()
                text_lines = [f"{cls}: {count}" for cls, count in counts.items()]
                frame_rgb = self.plotting.add_text(frame_rgb, text_lines, self.counter.class_labels, self.counter.class_colors, x_offset=10, y_offset=10)

                # Add logo to the frame at the top right corner with a fixed width
                frame_rgb = self.plotting.add_logo(frame_rgb)

                # Collect data for CSV output
                if save_csv:
                    self._save_csv(csv_data, frame_id, tracked_objects, csv_output_path)

                # Convert the frame back to RGB before writing to the output
                frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_BGR2RGB)

                if save_video:
                    out.write(frame_bgr)

                pbar.update(1)

        cap.release()
        if save_video:
            out.release()
        pbar.close()

    def _save_csv(self, csv_data, frame_id, tracked_objects, csv_output_path):
        """
        Collects tracking data for CSV output and saves it to a CSV file.
        """
        for obj in tracked_objects:
            bbox = obj.tlwh
            csv_data.append([frame_id, obj.track_id, round(bbox[0], 2), round(bbox[1], 2), round(bbox[2], 2), round(bbox[3], 2), int(obj.cls)])
        df = pd.DataFrame(csv_data, columns=["frame_id", "track_id", "x", "y", "w", "h", "class_id"])
        df.to_csv(csv_output_path, index=False)

In [None]:
# Dataset and processing configurations
config_path = '/home/jupyter/tomato_weight/ddod_r152_Solanumlycopersicum13.py'
checkpoint_path = '/home/jupyter/tomato_weight/best_coco_bbox_mAP_50_epoch_1176.pth'
logo_path = '/home/jupyter/logo.png'
video_path = '/home/jupyter/tomato_3.mp4'
output_path = '/home/jupyter/tomato_output.mp4'
csv_output_path = '/home/jupyter/tomato_output.csv'

In [None]:
# Initialize components
detector = ObjectDetector(config_path=config_path, checkpoint_path=checkpoint_path, use_slicing=True)
tracker = Tracker()
counter = Counter(detector.get_class_labels(), detector.get_class_colors(), line_offset=0, draw_line=True)
plotting = Plotting(logo_path)

video_processor = VideoProcessor(detector, tracker, counter, plotting)

# Process video
video_processor.process_video(video_path=video_path,
                              output_path=output_path,
                              save_video=True,
                              save_csv=True,
                              csv_output_path=csv_output_path)