# APGCC Inference and Visualization Notebook

This notebook is adapted from `predict__v2.py`. It allows for running inference on images or videos, visualizing the detected points, and projecting them onto a 2D map.

## 1. Imports and Environment Setup

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
import sys
import torch
import numpy as np
import cv2
import scipy.ndimage
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as standard_transforms
from glob import glob
from typing import List

# --- Project-specific imports ---
# Add the project root to the Python path to allow for absolute imports
if os.path.abspath(os.getcwd()) not in sys.path:
    sys.path.append(os.path.abspath(os.getcwd()))

from apgcc.tracker import Track, PointTracker
from config import cfg, merge_from_file
from models import build_model

## 2. Configuration

Set the paths to your configuration file, model weights, and input data here. You can also adjust the detection threshold and other parameters.

In [None]:
class Args:
    # Path to the model config file
    config_file = './configs/SHHA_test.yml'
    # Path to the trained model weights
    weights = './output/SHHA_best.pth'
    # Path to a single image, a directory of images, or a video file
    input = "C:/Users/abdulna/OneDrive - KAUST/codes/APGCC/apgcc/test_vid/different_lighting_output/input_1.mp4"
    # Directory to save visualized results. If None, results are shown on screen.
    output_dir = './pred_output/video_results'
    # Confidence threshold for point detection
    threshold = 0.5

args = Args()

## 3. 3D Projection Setup (Camera Calibration)

This section sets up a more accurate projection using camera intrinsics and extrinsics. It back-projects each detected pixel onto a 3D plane at a specified height (e.g., average person height) to get its real-world (X, Y) coordinates.

In [None]:
# --- Get frame dimensions for camera center (cx, cy) ---
VIDEO_PATH = args.input
cap_calib = cv2.VideoCapture(VIDEO_PATH)
ret_calib, frame_calib = cap_calib.read()
cap_calib.release()

if not ret_calib:
    raise ValueError("Failed to read video frame for camera setup")

h, w = frame_calib.shape[:2]

# --- Step 1: Camera intrinsics (example values, replace with your calibration results) ---
K = np.array([
    [(50*w)/36, 0, w / 2],   # fx, 0, cx
    [0, (50*h)/24, h / 2],   # 0, fy, cy
    [0, 0, 1]
], dtype=np.float32)
K_inv = np.linalg.inv(K)

# --- Step 2: Camera extrinsics (rotation + translation, from your calibration) ---
# Example: camera is 5m high, tilted down, looking 10m ahead on the ground plane
R_vec = np.array([0.3, 0, 0]) # Rotation vector (e.g., from solvePnP)
R = cv2.Rodrigues(R_vec)[0]   # Convert to rotation matrix
t = np.array([[0], [5], [10]], dtype=np.float32)  # Translation vector in world coords
R_T = R.T
cam_center = -R_T @ t

# --- Step 3: Function to project pixel to a given plane (Z = const) ---
def backproject_to_plane(u, v, plane_height, K_inv, R_T, cam_center):
    # Pixel in normalized camera coordinates
    pixel_h = np.array([u, v, 1.0])
    ray_cam = K_inv @ pixel_h

    # Transform ray to world coordinates
    ray_world = R_T @ ray_cam

    # Solve intersection with plane Z = plane_height
    s = (plane_height - cam_center[2]) / ray_world[2]
    world_point = cam_center + s * ray_world
    return world_point.flatten()[:2] # Return only (X, Y)

def project_to_map(points, plane_height=1.7):
    """Projects a list of image points to the world plane."""
    if len(points) == 0:
        return np.array([])
    
    world_points = []
    for p in points:
        u, v = p[:2] # Handle points that might have tracking IDs
        wp = backproject_to_plane(u, v, plane_height, K_inv, R_T, cam_center)
        wp = np.array([wp[0]+1, wp[1]])* 4 # Scale factor for better visualization
        world_points.append(wp)
        
    return np.array(world_points)

def draw_on_map(map_points, size=(w, h), scale=20):
    map_img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
    for x, y in map_points:
        px, py = int((x+10) * scale), int(y * scale)
        if 0 <= px < size[0] and 0 <= py < size[1]:
            cv2.circle(map_img, (px, py), 5, (0, 0, 255), -1)  # Draw circles in red

    # Use a named window to allow resizing
    cv2.imshow("Projected 2D Map", map_img)
    cv2.waitKey(1)
    return map_img

def draw_heatmap(map_points, size=(w, h), scale=20, writer=None, sigma=30):
    heatmap = np.zeros((size[1], size[0]), dtype=np.float32)
    for x, y in map_points:
        px, py = int((x+10) * scale), int(y * scale)
        if 0 <= px < size[0] and 0 <= py < size[1]:
            heatmap[size[1] - py, px] += 1.0

    # Smooth the heatmap
    if heatmap.max() > 0:
        heatmap = scipy.ndimage.gaussian_filter(heatmap, sigma=sigma)
        heatmap = np.clip(heatmap / heatmap.max(), 0, 1)
    
    heatmap_color = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)

    if writer is not None:
        writer.write(heatmap_color)

    # Use a named window to allow resizing
    cv2.imshow("Projected Heatmap", heatmap_color)
    cv2.waitKey(1)
    return heatmap_color

## 4. Helper Functions

These are utility functions for preprocessing, inference, visualization, and point merging.

In [22]:
def max_by_axis_pad(the_list: List[List[int]]) -> List[int]:
    """Helper function to find max dimensions for padding."""
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    # Pad to a multiple of 128, mimicking training behavior
    block = 128
    for i in range(2):
        maxes[i+1] = ((maxes[i+1] - 1) // block + 1) * block
    return maxes

def nested_tensor_from_tensor_list(tensor_list: List[torch.Tensor]):
    """
    Pads a list of tensors to the same size to create a single batch tensor.
    """
    if tensor_list[0].ndim == 3:
        max_size = max_by_axis_pad([list(img.shape) for img in tensor_list])
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        for img, pad_img in zip(tensor_list, tensor):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    else:
        raise ValueError('not supported')
    return tensor

def preprocess_image(image_input, cfg):
    """
    Loads and preprocesses a single image for model inference.
    - image_input: can be a file path (str) or a PIL Image object.
    - Converts to RGB
    - Scales if the image is too large (mimicking validation logic)
    - Converts to a Tensor and normalizes
    """
    if isinstance(image_input, str):
        img = Image.open(image_input).convert('RGB')
    else:
        img = image_input.convert('RGB')
    
    # --- Mimic validation scaling from dataset.py ---
    temp_tensor = standard_transforms.ToTensor()(img)
    max_size = max(temp_tensor.shape[1:])
    scale = 1.0
    upper_bound = cfg.DATALOADER.UPPER_BOUNDER

    if upper_bound != -1 and max_size > upper_bound:
        scale = upper_bound / max_size
    elif max_size > 2560:  # A reasonable default from the original codebase
        scale = 2560 / max_size

    # --- Define transforms ---
    transform_list = []
    if scale != 1.0:
        new_w = int(img.width * scale)
        new_h = int(img.height * scale)
        transform_list.append(standard_transforms.Resize((new_h, new_w)))

    transform_list.extend([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    transform = standard_transforms.Compose(transform_list)
    
    display_img = img
    if scale != 1.0:
        display_img = display_img.resize((new_w, new_h))

    img_tensor = transform(img)
    return display_img, img_tensor

@torch.no_grad()
def run_inference(model, image_tensor, threshold):
    """
    Runs the model on a preprocessed image tensor and returns the predicted points.
    """
    model.eval()
    device = next(model.parameters()).device
    
    samples = nested_tensor_from_tensor_list([image_tensor.to(device)])
    outputs = model(samples)
    
    outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
    outputs_points = outputs['pred_points'][0]
    
    points = outputs_points[outputs_scores > threshold].detach().cpu().numpy()
    predict_cnt = len(points)
    
    return points, predict_cnt

def visualize_results(image_to_display, points, count, image_path, output_dir=None):
    """
    Visualizes the results for a single image using matplotlib.
    """
    plt.figure(figsize=(12, 8))
    plt.imshow(image_to_display)
    plt.title(f'Predicted Count: {count}')
    
    if len(points) > 0:
        # Handle tracked points which may have an ID
        if isinstance(points[0], (list, tuple)) and len(points[0]) > 2:
             plot_points = np.array([p[0] for p in points])
        else:
             plot_points = np.array(points)
        plt.scatter(plot_points[:, 0], plot_points[:, 1], c='red', s=15, marker='o', alpha=0.8, edgecolors='none')
            
    plt.axis('off')
    
    if output_dir:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        filename = os.path.basename(image_path)
        output_path = os.path.join(output_dir, f"pred_{filename}")
        plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
        print(f"Result saved to {output_path}")
        plt.close() # Close the figure to free memory
    else:
        plt.show()

def merge_close_points(points, threshold):
    points = np.array(points)
    if len(points) == 0:
        return []

    merged = []
    visited = set()

    for i in range(len(points)):
        if i in visited:
            continue
        cluster = [points[i]]
        visited.add(i)

        for j in range(i + 1, len(points)):
            if j in visited:
                continue
            dist = np.linalg.norm(points[i] - points[j])
            if dist < threshold:
                cluster.append(points[j])
                visited.add(j)

        # Use the first point in the cluster as the representative
        merged.append(cluster[0])

    return merged

## 5. Model Loading

This cell sets up the device (GPU or CPU), builds the model architecture based on the config, and loads the pre-trained weights.

In [23]:
# --- Load Configuration ---
if args.config_file != "":
    cfg_ = merge_from_file(cfg, args.config_file)
else:
    cfg_ = cfg

# --- Setup Model ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = build_model(cfg=cfg_, training=False)
model.to(device)

# --- Load Weights ---
if not os.path.exists(args.weights):
    raise FileNotFoundError(f"Weight file not found: {args.weights}")

state_dict = torch.load(args.weights, map_location='cpu')
model.load_state_dict(state_dict)
print("Model weights loaded successfully.")

Using device: cuda
### VGG16: last_pool= False
ultra_pe
auxiliary anchors: (pos, neg, lambda, kwargs)  [2, 2] [2, 8] {'pos_coef': 1.0, 'neg_coef': 1.0, 'pos_loc': 0.0002, 'neg_loc': 0.0002}
Model weights loaded successfully.


  state_dict = torch.load(args.weights, map_location='cpu')


## 6. Main Processing

This is the main execution block. It checks if the input is a video or image(s) and processes them accordingly.

In [24]:
# --- Find Input Files ---
input_path = args.input
is_video = os.path.isfile(input_path) and any(input_path.lower().endswith(ext) for ext in ['.mp4', '.avi', '.mov'])

if is_video:
    # --- Process Video File ---
    video_path = input_path
    print(f"Processing video: {os.path.basename(video_path)}")
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Could not open video: {video_path}")

    out = None
    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        output_video_path = os.path.join(args.output_dir, os.path.splitext(os.path.basename(video_path))[0] + "_output.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use a widely supported codec
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width * 2, frame_height))

    frame_count = 0
    tracker = PointTracker(distance_threshold=100) # Simple distance-based tracker
    collected_points = []
    displayed_points = np.empty((0, 2))
    window = 2  # Number of frames to collect points before merging
    
    # Create named windows that can be resized
    cv2.namedWindow("Combined Output", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Projected Heatmap", cv2.WINDOW_NORMAL)

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        print(f"Processing frame {frame_count}")

        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        _, img_tensor = preprocess_image(pil_img, cfg_)
        points, count = run_inference(model, img_tensor, args.threshold)

        # Collect points over a window of frames and then merge to stabilize
        if len(displayed_points) == 0:
            displayed_points = np.array(points)
        if frame_count % window:
            collected_points.extend(points.tolist())
        else:
            displayed_points = merge_close_points(collected_points, threshold=5)
            collected_points = []

        # Track points (optional, using simple tracker for now)
        # tracked_points = tracker.update(points)

        # --- Visualization ---
        map_coords = project_to_map(displayed_points, plane_height=1.7) # Assume avg person height is 1.7m
        points_map = draw_on_map(map_coords, size=(w,h))
        heatmap = draw_heatmap(map_coords, size=(w,h))

        # Draw points on the original frame
        for point in displayed_points:
            x, y = int(point[0]), int(point[1])
            cv2.circle(frame, (x, y), 3, (0, 0, 255), -1)  # Red circles

        cv2.putText(frame, f"Count: {len(displayed_points)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Combine frame and heatmap for display and saving
        # combined_frame = np.concatenate((frame, heatmap), axis=1)
        combined_frame = np.concatenate((frame, points_map), axis=1)
        
        cv2.imshow('Combined Output', combined_frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        if out:
            out.write(combined_frame)

    cap.release()
    if out:
        out.release()
        print(f"Video with predictions saved to {output_video_path}")
    cv2.destroyAllWindows()

else:
    # --- Process Image(s) ---
    if os.path.isdir(input_path):
        image_files = []
        supported_extensions = ['*.jpg', '*.jpeg', '*.png']
        for ext in supported_extensions:
            image_files.extend(glob(os.path.join(input_path, ext)))
        image_files = sorted(image_files)
    elif os.path.isfile(input_path):
        image_files = [input_path]
    else:
        raise FileNotFoundError(f"Input not found: {input_path}")

    print(f"Found {len(image_files)} image(s) to process.")

    for image_path in image_files:
        print(f"\nProcessing: {os.path.basename(image_path)}")
        display_img, img_tensor = preprocess_image(image_path, cfg_)
        points, count = run_inference(model, img_tensor, args.threshold)
        visualize_results(display_img, points, count, image_path, args.output_dir)


Processing video: input_1.mp4
Processing frame 1
Processing frame 2
Processing frame 3
Processing frame 4
Processing frame 5
Processing frame 6
Processing frame 7
Processing frame 8
Processing frame 9
Processing frame 10
Processing frame 11
Processing frame 12
Processing frame 13
Processing frame 14
Processing frame 15
Processing frame 16
Processing frame 17
Processing frame 18
Processing frame 19
Processing frame 20
Processing frame 21
Processing frame 22
Processing frame 23
Processing frame 24
Processing frame 25
Processing frame 26
Processing frame 27
Processing frame 28
Processing frame 29
Processing frame 30
Processing frame 31
Processing frame 32
Processing frame 33
Processing frame 34
Processing frame 35
Processing frame 36
Processing frame 37
Processing frame 38
Processing frame 39
Processing frame 40
Processing frame 41
Processing frame 42


KeyboardInterrupt: 