In [None]:
import asyncio
import time
import cv2
import nats
import numpy as np
from nats.aio.msg import Msg
import nest_asyncio
import random
from asyncio import CancelledError
from dataclasses import dataclass

from project.generated.project.common.proto.Inference_pb2 import InferenceList
from project.common import profiler
from project.common.config_class.profiler import ProfilerConfig
from project.generated.project.common.proto.Image_pb2 import ImageMessage
from project.generated.project.common.proto.Inference_pb2 import Inference

nest_asyncio.apply()

In [None]:
@dataclass
class ImageWithParams:
    frame: np.ndarray
    camera_matrix: np.ndarray
    dist_coeff: np.ndarray

In [None]:
def process_image(frame_one: ImageWithParams, frame_two: ImageWithParams, inference: Inference):
    pass

In [None]:
def render_detections(frame: np.ndarray, inference: Inference):
    for i in range(0, len(inference.bounding_box), 4):
        box = inference.bounding_box[i:i+4]
        x1, y1, x2, y2 = map(int, box)
        
        cv2.rectangle(
            frame,
            (x1, y1),
            (x2, y2),
            color=(0, 255, 0),
            thickness=2
        )
        
        # Add label with confidence
        label = f"{inference.class_name}: {inference.confidence:.2f}"
        cv2.putText(
            frame,
            label,
            (x1, y1 - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 255, 0),
            2
        )

cap = cv2.VideoCapture(0)
nt = await nats.connect("nats://localhost:4222")

queue = asyncio.Queue()
image_id_map = {}

total_time_per_image = 0
total_images = 0
last_reset_time = 40

max_frame_age = 0.1
cur_frame_clean_time = 40

async def on_message(msg: Msg):
    await queue.put(InferenceList.FromString(msg.data))

await nt.subscribe("recognition/image_output", cb=on_message)


try:
    while True:
        if cur_frame_clean_time > 30:
            cur_frame_clean_time = 0
            for image_id in image_id_map:
                if time.time() - image_id_map[image_id]["timestamp"] > max_frame_age:
                    image_id_map.pop(image_id)
            
        if abs(time.time() - last_reset_time) > 5:
            total_time_per_image = 0
            total_images = 0
            last_reset_time = time.time()

        total_images += 1
            
        ret, frame = cap.read()
        if not ret:
            continue

        _, compressed_image = cv2.imencode(".jpg", frame)
        
        image_id = random.randint(0, 1000000)

        msg = ImageMessage(
            image=compressed_image.tobytes(),
            camera_name="camera0",
            is_gray=False,
            id=image_id,
            height=frame.shape[0],
            width=frame.shape[1],
            timestamp=int(time.time() * 1000),
        )

        image_id_map[image_id] = {"frame": frame, "timestamp": time.time()}

        await nt.publish("recognition/image_input", msg.SerializeToString())
        await nt.flush()

        if not queue.empty() and image_id in image_id_map:
            inference = await queue.get()
            for inference in inference.inferences:
                render_detections(image_id_map[image_id]["frame"], inference)
                
            cv2.imshow("frame", image_id_map[image_id]["frame"])
            cv2.waitKey(1)
            total_time_per_image += time.time() - image_id_map[image_id]["timestamp"]
            image_id_map.pop(image_id)
        
        time.sleep(total_time_per_image / total_images if total_images > 0 else 0.02)
except KeyboardInterrupt as e:
    print("Exiting...")
except CancelledError as e:
    print("Exiting...")
finally:
    cap.release()
    cv2.destroyAllWindows()