In [14]:
import cv2
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from ai_edge_litert.interpreter import Interpreter

# ==== Config ====
TFLITE_MODEL_PATH = "build/csrnet_mobile_B_float16.tflite"
VIDEO_PATH = "inference/test_videos/crowd_video_test1.mp4"
QUEUE_SIZE = 10  # number of frames for temporal smoothing
FRAME_SKIP = 1   # skip every N frames
VISUALIZE = False

# ==== Load TFLite Model ====
interpreter = Interpreter(model_path=TFLITE_MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
is_channels_first = input_shape[1] == 3

# ==== Preprocessing ====
def preprocess(frame):
    img = cv2.resize(frame, (512, 512))
    img = img.astype(np.float32) / 255.0
    img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
    if not is_channels_first:
        img = img  # NHWC
    else:
        img = np.transpose(img, (2, 0, 1))  # CHW
    return np.expand_dims(img, axis=0).astype(np.float32)

# ==== Spatio-temporal smoothing ====
density_queue = deque(maxlen=QUEUE_SIZE)

cap = cv2.VideoCapture(VIDEO_PATH)
frame_count = 0

while True:
    ret, frame = cap.read()
    if not ret:
        break

    frame_count += 1
    if frame_count % FRAME_SKIP != 0:
        continue

    input_tensor = preprocess(frame)
    interpreter.set_tensor(input_details[0]['index'], input_tensor)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    density_map = output.squeeze()  # shape: (H, W)
    density_queue.append(density_map)

    # Temporal smoothing
    smoothed_density = np.mean(np.stack(density_queue), axis=0)
    predicted_count = smoothed_density.sum() - 90

    print(f"Frame {frame_count}: Estimated Count = {predicted_count:.2f}")

    if VISUALIZE:
        vis = (smoothed_density / smoothed_density.max() * 255).astype(np.uint8)
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        vis = cv2.resize(vis, (frame.shape[1], frame.shape[0]))
        overlay = cv2.addWeighted(frame, 0.6, vis, 0.4, 0)
        overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

        plt.figure(figsize=(8, 6))
        plt.imshow(overlay_rgb)
        plt.title(f"Frame {frame_count}: Count = {predicted_count:.1f}")
        plt.axis('off')
        plt.show()

cap.release()

Frame 1: Estimated Count = 31.53
Frame 2: Estimated Count = 33.67
Frame 3: Estimated Count = 33.18
Frame 4: Estimated Count = 34.69
Frame 5: Estimated Count = 34.21
Frame 6: Estimated Count = 33.93
Frame 7: Estimated Count = 33.54
Frame 8: Estimated Count = 33.48
Frame 9: Estimated Count = 33.14
Frame 10: Estimated Count = 32.55
Frame 11: Estimated Count = 32.73
Frame 12: Estimated Count = 32.12
Frame 13: Estimated Count = 31.53
Frame 14: Estimated Count = 30.00
Frame 15: Estimated Count = 28.79
Frame 16: Estimated Count = 27.61
Frame 17: Estimated Count = 26.72
Frame 18: Estimated Count = 26.60
Frame 19: Estimated Count = 26.08
Frame 20: Estimated Count = 26.54
Frame 21: Estimated Count = 27.20
Frame 22: Estimated Count = 27.73
Frame 23: Estimated Count = 29.26
Frame 24: Estimated Count = 30.80
Frame 25: Estimated Count = 32.67
Frame 26: Estimated Count = 34.18
Frame 27: Estimated Count = 35.36
Frame 28: Estimated Count = 35.14
Frame 29: Estimated Count = 35.65
Frame 30: Estimated Cou

KeyboardInterrupt: 