In [None]:
import os
import cv2
import csv
import numpy as np
import pandas as pd
from tqdm import tqdm
from moviepy.editor import VideoFileClip
import imageio
import imageio_ffmpeg as ffmpeg
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction

from botsort.bot_sort import BoTSORT

In [None]:
class Video:
    def __init__(self, video_path):
        self.video_path = video_path
        self.clip = VideoFileClip(video_path)
        self.fps = self.clip.fps
        self.total_frames = int(self.clip.fps * self.clip.duration)
        self.size = (int(self.clip.size[0]), int(self.clip.size[1]))

    def get_video_info(self):
        return {
            'fps': self.fps,
            'total_frames': self.total_frames,
            'duration': round(self.clip.duration, 2),
            'size': self.size
        }

    def _process_frames(self, new_fps, new_size, output_path):
        def process_frame(frame):
            if new_size:
                frame = cv2.resize(frame, new_size, interpolation=cv2.INTER_AREA)
            return frame

        if new_fps:
            self.clip = self.clip.set_fps(new_fps)
        if new_size:
            self.clip = self.clip.fl_image(process_frame)

        self.clip.write_videofile(output_path, fps=new_fps if new_fps else self.fps)
        return self.total_frames

    def reduce_fps(self, new_fps, output_path):
        if new_fps >= self.fps:
            raise ValueError("New FPS must be lower than the current FPS.")

        new_total_frames = self._process_frames(new_fps, None, output_path)
        new_duration = round(self.clip.duration, 2)

        return {
            'fps': new_fps,
            'total_frames': new_total_frames,
            'duration': new_duration,
            'size': self.size
        }

    def reduce_size(self, scale_factor, output_path):
        if scale_factor >= 1:
            raise ValueError("Scale factor must be less than 1.")

        new_size = (
            int(self.size[0] * scale_factor),
            int(self.size[1] * scale_factor)
        )
        new_total_frames = self._process_frames(self.fps, new_size, output_path)
        new_duration = round(self.clip.duration, 2)

        return {
            'fps': self.fps,
            'total_frames': new_total_frames,
            'duration': new_duration,
            'size': new_size
        }

    def reduce_fps_and_size(self, new_fps, scale_factor, output_path):
        if new_fps >= self.fps or scale_factor >= 1:
            raise ValueError("New FPS must be lower and scale factor must be less than 1.")

        new_size = (
            int(self.size[0] * scale_factor),
            int(self.size[1] * scale_factor)
        )
        new_total_frames = self._process_frames(new_fps, new_size, output_path)
        new_duration = round(self.clip.duration, 2)

        return {
            'fps': new_fps,
            'total_frames': new_total_frames,
            'duration': new_duration,
            'size': new_size
        }

In [None]:
class Detector:
    def __init__(self, config_path, checkpoint_path, use_slicing=True, confidence_threshold=0.4):
        self.config_path = config_path
        self.checkpoint_path = checkpoint_path
        self.use_slicing = use_slicing
        self.confidence_threshold = confidence_threshold
        if config_path and checkpoint_path:
            self._initialize_model()

    def _initialize_model(self):
        self.model = AutoDetectionModel.from_pretrained(
            model_type='mmdet',
            model_path=self.checkpoint_path,
            config_path=self.config_path,
            confidence_threshold=self.confidence_threshold,
            category_mapping=None,
            image_size=None,
            device='cuda:0'
        )
        self.class_labels = self.model.model.model.dataset_meta.get("classes")
        self.class_colors = self.model.model.model.dataset_meta.get("palette")
        self.class_colors = {i: tuple(color) for i, color in enumerate(self.class_colors)}
        self.class_labels = {i: label for i, label in enumerate(self.class_labels)}

    def detect(self, bgr_frame, slice_height=720, slice_width=720, overlap_height_ratio=0.1, overlap_width_ratio=0.1):
        # Convert BGR frame to RGB
        rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
        
        if self.use_slicing:
            result = get_sliced_prediction(
                rgb_frame,
                self.model,
                slice_height=slice_height,
                slice_width=slice_width,
                postprocess_match_threshold=0.7,
                postprocess_type="NMM",
                postprocess_class_agnostic=True,
                overlap_height_ratio=overlap_height_ratio,
                overlap_width_ratio=overlap_width_ratio,
                verbose=0
            )
        else:
            result = get_prediction(rgb_frame, self.model)
        
        object_prediction_list = result.object_prediction_list
        detections = []
        for obj in object_prediction_list:
            bbox = obj.bbox.to_voc_bbox()
            score = obj.score.value
            label = obj.category.id
            detections.append((np.array([bbox[0], bbox[1], bbox[2], bbox[3]]), score, label))
        return detections

    def get_class_labels(self):
        return self.class_labels

    def get_class_colors(self):
        return self.class_colors

In [None]:
class Tracker:
    def __init__(self,
                 track_high_thresh=0.6,
                 track_low_thresh=0.1,
                 new_track_thresh=0.7, 
                 track_buffer=15,
                 proximity_thresh=0.4,
                 cmc_method='sparseOptFlow',
                 match_thresh=0.7):
        self.track_high_thresh = track_high_thresh
        self.track_low_thresh = track_low_thresh
        self.new_track_thresh = new_track_thresh
        self.track_buffer = track_buffer
        self.proximity_thresh = proximity_thresh
        self.cmc_method = cmc_method
        self.match_thresh = match_thresh
        self.tracker = BoTSORT(self)

    def update(self, detections, frame):
        return self.tracker.update(detections, frame)

In [None]:
class Counter:
    def __init__(self, class_labels, class_colors, line_offset=0, draw_line=True):
        self.class_labels = class_labels
        self.class_colors = class_colors
        self.line_offset = line_offset
        self.crossed_line_ids = {cls_id: set() for cls_id in class_labels}
        self.track_class_mapping = {}
        self.track_trails = {}
        self.draw_line = draw_line

    def count_objects(self, tracked_objects, frame_width, frame):
        vertical_line_x = (frame_width // 2) + self.line_offset

        for t in tracked_objects:
            track_id = int(t.track_id)
            cls = t.cls

            if track_id not in self.track_class_mapping:
                self.track_class_mapping[track_id] = cls
                self.track_trails[track_id] = []

            bbox = t.tlwh
            center = (int(bbox[0] + bbox[2] // 2), int(bbox[1] + bbox[3] // 2))
            self.track_trails[track_id].append(center)

            if len(self.track_trails[track_id]) > 1:
                prev_center = self.track_trails[track_id][-2]
                if (prev_center[0] < vertical_line_x and center[0] >= vertical_line_x) or (prev_center[0] > vertical_line_x and center[0] <= vertical_line_x):
                    self.crossed_line_ids[cls].add(track_id)

        if self.draw_line:
            cv2.line(frame, (vertical_line_x, 0), (vertical_line_x, frame.shape[0]), (0, 255, 0), 2)

        return self.track_trails, self.track_class_mapping

    def get_counts(self):
        return {self.class_labels[cls]: len(self.crossed_line_ids[cls]) for cls in self.class_labels}

In [None]:
class Plotting:
    def __init__(self, logo_path):
        self.logo = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED)
    
    def add_logo(self, frame, logo_padding=25):
        logo_width = 235
        scale = logo_width / self.logo.shape[1]
        logo_resized = cv2.resize(self.logo, (logo_width, int(self.logo.shape[0] * scale)), interpolation=cv2.INTER_AREA)
        
        # Change logo color to white while keeping transparency
        logo_rgb = logo_resized[..., :3]
        alpha_mask = logo_resized[..., 3] / 255.0

        white_logo = np.ones_like(logo_rgb) * 255
        logo_white = cv2.addWeighted(logo_rgb, 0, white_logo, 1, 0)

        h, w = logo_resized.shape[:2]
        x = frame.shape[1] - w - logo_padding
        y = frame.shape[0] - h - logo_padding
        
        roi = frame[y:y+h, x:x+w].copy()  # Create a writable copy of the ROI
        for c in range(3):
            roi[:, :, c] = roi[:, :, c] * (1 - alpha_mask) + logo_white[:, :, c] * alpha_mask
        frame[y:y+h, x:x+w] = roi
        return frame

    def add_class_count(self, frame, text_lines, alpha=0.3, padding=15, bottom_space=15, side_space=15):
        x_offset = padding + side_space
        y_offset = padding + bottom_space
        max_text_width = max([cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0][0] for line in text_lines])
        text_height = cv2.getTextSize(text_lines[0], cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0][1]
        total_text_height = len(text_lines) * text_height + (len(text_lines) - 1) * 10  # 10 is the line spacing
        overlay = frame.copy()

        rect_x1 = x_offset - padding
        rect_y1 = y_offset - padding
        rect_x2 = x_offset + max_text_width + 50 + padding
        rect_y2 = y_offset + total_text_height + 2 * padding
        radius = 15

        # Create a larger rectangle and subtract radius for circles
        cv2.rectangle(overlay, (rect_x1 + radius, rect_y1), (rect_x2 - radius, rect_y2), (255, 255, 255), -1)
        cv2.rectangle(overlay, (rect_x1, rect_y1 + radius), (rect_x2, rect_y2 - radius), (255, 255, 255), -1)
        
        # Draw the rounded corners
        cv2.circle(overlay, (rect_x1 + radius, rect_y1 + radius), radius, (255, 255, 255), -1)
        cv2.circle(overlay, (rect_x2 - radius, rect_y1 + radius), radius, (255, 255, 255), -1)
        cv2.circle(overlay, (rect_x1 + radius, rect_y2 - radius), radius, (255, 255, 255), -1)
        cv2.circle(overlay, (rect_x2 - radius, rect_y2 - radius), radius, (255, 255, 255), -1)
        
        cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
        
        # Adjust the current_y to start from the correct offset to maintain padding consistency
        current_y = y_offset + padding + text_height
        for line in text_lines:
            cv2.putText(frame, line, (x_offset + 30, current_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
            current_y += text_height + 10  # 10 is the line spacing
        return frame

    def add_circles(self, frame, text_lines, class_labels, class_colors, padding=15, bottom_space=15, side_space=15):
        x_offset = padding + side_space
        y_offset = padding + bottom_space + 22  # Adjusted to better align with text baseline
        for line in text_lines:
            class_name = line.split(':')[0].strip()
            color = class_colors[list(class_labels.values()).index(class_name)]
            circle_radius = 8
            # Calculate text size to adjust the circle position
            text_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0]
            text_height = text_size[1]
            circle_y_offset = text_height // 2 - circle_radius // 2  # Adjusted to better align with text baseline
            cv2.circle(frame, (x_offset + 15, y_offset + circle_y_offset), circle_radius, color, -1)
            y_offset += text_height + 10  # 10 is the line spacing
        return frame

    def add_bbox_and_track_id(self, frame, bbox, track_id, color):
        x1, y1, w, h = map(int, bbox)
        x2, y2 = x1 + w, y1 + h
        label = f'{track_id}'
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        (text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
        cv2.rectangle(frame, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
        cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)

    def add_trails(self, frame, track_trails, track_id, color):
        for point in track_trails[track_id]:
            cv2.circle(frame, point, 2, color, -1)

In [None]:
class Processor:
    def __init__(self):
        current_dir = os.getcwd()
        self.output_dir = os.path.join(current_dir, 'output')
        os.makedirs(self.output_dir, exist_ok=True)
        self.logo_path = os.path.join(current_dir, 'logo.png')
        self.detector = None
        self.modified_video_path = None

    def print_video_info(self, info, title):
        print(f"{title} Video:")
        print(f"fps: {info['fps']}")
        print(f"total frames: {info['total_frames']}")
        print(f"duration: {info['duration']}")
        print(f"size: {info['size']}")
        print()

    def modify_video(self, video, reduce_fps=None, reduce_size=None):
        modified = False
        if reduce_fps and reduce_size:
            output_path = os.path.join(self.output_dir, 'reduced_fps_and_size.mp4')
            new_info = video.reduce_fps_and_size(reduce_fps, reduce_size, output_path)
            modified = True
        elif reduce_fps:
            output_path = os.path.join(self.output_dir, 'reduced_fps.mp4')
            new_info = video.reduce_fps(reduce_fps, output_path)
            modified = True
        elif reduce_size:
            output_path = os.path.join(self.output_dir, 'reduced_size.mp4')
            new_info = video.reduce_size(reduce_size, output_path)
            modified = True
        else:
            output_path = video.video_path
            new_info = video.get_video_info()
        
        self.modified_video_path = output_path
        return new_info, modified

    def run_detector(self, video_path, config_path, checkpoint_path, use_slicing=True, confidence_threshold=0.4, reduce_fps=None, reduce_size=None):
        # Initialize Video class
        video = Video(video_path)
        info = video.get_video_info()
        self.print_video_info(info, "Original")
        
        # Modify the video if needed
        new_info, modified = self.modify_video(video, reduce_fps, reduce_size)
        if modified:
            self.print_video_info(new_info, "Modified")
        
        csv_output_path = os.path.join(self.output_dir, 'detections.csv')
        self.detector = Detector(config_path, checkpoint_path, use_slicing, confidence_threshold)
        reader = VideoFileClip(self.modified_video_path)
        frame_width = int(reader.size[0])
        total_frames = new_info['total_frames']
        
        frame_id = 1
        
        with open(csv_output_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["frame_id", "x1", "y1", "x2", "y2", "score", "label"])  # Write the header
            
            with tqdm(total=total_frames, desc="Detecting", unit="frame", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} frames [{elapsed}<{remaining}, {rate_fmt}]") as pbar:
                for frame in reader.iter_frames():
                    bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                    detections = self.detector.detect(bgr_frame)
                    for bbox, score, label in detections:
                        writer.writerow([frame_id, *bbox, score, label])  # Write each detection result immediately
                    pbar.update(1)
                    frame_id += 1
        
        print(f"Detections saved to {csv_output_path}")

    def run_tracker(self,
                    line_offset=0,
                    draw_line=True,
                    track_high_thresh=0.6,
                    track_low_thresh=0.1,
                    new_track_thresh=0.7,
                    track_buffer=15,
                    proximity_thresh=0.4,
                    cmc_method='sparseOptFlow',
                    match_thresh=0.7,
                    use_tracker_only=False,
                    detections_path=None, video_path=None, config_path=None, checkpoint_path=None):
        if use_tracker_only:
            if not detections_path or not video_path or not config_path or not checkpoint_path:
                raise ValueError("When use_tracker_only is True, detections_path, video_path, config_path, and checkpoint_path must be provided.")
            
            # Initialize the detector to access class labels and colors
            self.detector = Detector(config_path, checkpoint_path, use_slicing=False, confidence_threshold=0.0)
            self.modified_video_path = video_path
            csv_input_path = detections_path
        else:
            if not self.modified_video_path:
                raise ValueError("Run the detector first to set the modified video path.")
            csv_input_path = os.path.join(self.output_dir, 'detections.csv')

        video_output_path = os.path.join(self.output_dir, 'counting_video.mp4')
        tracker_csv_output_path = os.path.join(self.output_dir, 'tracker_output.csv')
        reader = VideoFileClip(self.modified_video_path)
        frame_width = int(reader.size[0])
        fps = reader.fps
        total_frames = Video(self.modified_video_path).get_video_info()['total_frames']

        detections = pd.read_csv(csv_input_path)
        frame_groups = detections.groupby('frame_id')

        tracker = Tracker(track_high_thresh, track_low_thresh, new_track_thresh, track_buffer, proximity_thresh, cmc_method, match_thresh)
        counter = Counter(self.detector.get_class_labels(), self.detector.get_class_colors(), line_offset, draw_line)
        plotting = Plotting(self.logo_path)

        # Open video writer
        writer = imageio.get_writer(video_output_path, fps=fps)

        # List to hold tracking data for CSV
        tracking_data = []

        with tqdm(total=total_frames, desc="Tracking", unit="frame", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} frames [{elapsed}<{remaining}, {rate_fmt}]") as pbar:
            for frame_id, frame in enumerate(reader.iter_frames(), start=1):
                frame = frame.copy()  # Create a writable copy of the frame
                if frame_id in frame_groups.groups:
                    frame_detections = frame_groups.get_group(frame_id)
                    detection_list = frame_detections[['x1', 'y1', 'x2', 'y2', 'score', 'label']].values.tolist()
                    detection_array = np.array(detection_list)
                else:
                    detection_array = np.empty((0, 6))

                tracked_objects = tracker.update(detection_array, frame)
                track_trails, track_class_mapping = counter.count_objects(tracked_objects, frame_width, frame)

                for t in tracked_objects:
                    bbox = t.tlwh
                    track_id = int(t.track_id)
                    cls = track_class_mapping[track_id]
                    color = counter.class_colors[cls]
                    plotting.add_bbox_and_track_id(frame, bbox, track_id, color)
                    plotting.add_trails(frame, track_trails, track_id, color)

                    # Add tracking data to list
                    x, y, w, h = bbox
                    tracking_data.append([frame_id, track_id, x, y, w, h, cls])

                counts = counter.get_counts()
                text_lines = [f"{cls}: {count}" for cls, count in counts.items()]
                frame = plotting.add_class_count(frame, text_lines)
                frame = plotting.add_circles(frame, text_lines, counter.class_labels, counter.class_colors)
                frame = plotting.add_logo(frame)

                writer.append_data(frame)
                pbar.update(1)

        writer.close()

        # Save tracking data to CSV
        tracking_df = pd.DataFrame(tracking_data, columns=["frame_id", "track_id", "x", "y", "w", "h", "class_id"])
        tracking_df.to_csv(tracker_csv_output_path, index=False)
        print(f"Tracked video saved to {video_output_path}")
        print(f"Tracker output saved to {tracker_csv_output_path}")

In [None]:
# Instantiate Processor
processor = Processor()

# Define paths
video_path = '/home/jupyter/videos/GX010067.MP4'
config_path = '/home/jupyter/weights/tomato/ddod_r152_Solanumlycopersicum13.py'
checkpoint_path = '/home/jupyter/weights/tomato/best_coco_bbox_mAP_50_epoch_1176.pth'

In [None]:
# Run detector
processor.run_detector(video_path, config_path, checkpoint_path, use_slicing=True, confidence_threshold=0.35, reduce_fps=None, reduce_size=None)

In [None]:
# Run tracker directly
processor.run_tracker(line_offset=0,
                      draw_line=True,
                      track_high_thresh=0.6,
                      track_low_thresh=0.1,
                      new_track_thresh=0.7,
                      track_buffer=15,
                      proximity_thresh=0.4,
                      cmc_method='sparseOptFlow',
                      match_thresh=0.7,
                      use_tracker_only=False,
                      detections_path=None,
                      video_path=video_path,
                      config_path=config_path,
                      checkpoint_path=checkpoint_path)