# Pallet + Box Detection/Segmentation/Tracking Pipeline (Colab-Ready)

This notebook is a **single production architecture** optimized for Google Colab (Free/Pro) using:
- **Grounding DINO** (`IDEA-Research/grounding-dino-base`) for text-conditioned detection.
- **SAM ViT-B** for segmentation.
- **Custom centroid tracker** for persistent IDs across frames.

It produces:
1. Annotated output video.
2. `counts.json` with per-frame tracked IDs + counts.
3. Interactive analytics section with Matplotlib + Plotly.

## BLOCK 1 — Environment Setup

In [None]:
# =========================
# BLOCK 1 — Environment Setup
# =========================
import subprocess
from pathlib import Path

def run_cmd(cmd: str):
    print(f"[CMD] {cmd}")
    subprocess.run(cmd, shell=True, check=True)

run_cmd("pip -q install --upgrade pip")
run_cmd("pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
run_cmd("pip -q install transformers accelerate opencv-python tqdm matplotlib plotly imageio pandas")
run_cmd("pip -q install git+https://github.com/facebookresearch/segment-anything.git")

if not Path('/content/segment-anything').exists():
    run_cmd('git clone -q https://github.com/facebookresearch/segment-anything.git /content/segment-anything')

sam_ckpt = Path('/content/sam_vit_b_01ec64.pth')
if not sam_ckpt.exists():
    run_cmd('wget -q -O /content/sam_vit_b_01ec64.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth')

from google.colab import drive
drive.mount('/content/drive', force_remount=True)
print('[INFO] Setup complete.')

## BLOCK 2 — Global Configuration

In [None]:
# =========================
# BLOCK 2 — Global Configuration
# =========================
import json
import gc
import torch
from pathlib import Path

PROJECT_ROOT = Path('/content/drive/MyDrive/pallet_cv_project')
INPUT_VIDEO = PROJECT_ROOT / 'input' / 'drone_video.mp4'
OUTPUT_DIR = PROJECT_ROOT / 'results'
SAM_CHECKPOINT = Path('/content/sam_vit_b_01ec64.pth')

RESIZE_FACTOR = 1.0
PROCESS_EVERY_N_FRAMES = 1
LOG_EVERY_N_FRAMES = 25
DRAW_MASK_OVERLAY = False
BOX_THRESHOLD = 0.30
TEXT_THRESHOLD = 0.25

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
GPU_NAME = torch.cuda.get_device_name(0) if DEVICE == 'cuda' else 'CPU'
GPU_MEM_GB = (torch.cuda.get_device_properties(0).total_memory / (1024**3)) if DEVICE == 'cuda' else 0

ALLOW_FULL_RES = GPU_MEM_GB > 16
USE_AMP = GPU_MEM_GB > 16 and DEVICE == 'cuda'

if DEVICE == 'cuda':
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision('high')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f'[INFO] DEVICE={DEVICE} | GPU={GPU_NAME} | VRAM={GPU_MEM_GB:.1f} GB')
print(f'[INFO] Input video: {INPUT_VIDEO}')
print(f'[INFO] Output dir : {OUTPUT_DIR}')
print(f'[INFO] ALLOW_FULL_RES={ALLOW_FULL_RES} | USE_AMP={USE_AMP}')

## BLOCK 3 — Model Loading

In [None]:
# =========================
# BLOCK 3 — Model Loading
# =========================
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from segment_anything import sam_model_registry, SamPredictor

DINO_MODEL_ID = 'IDEA-Research/grounding-dino-base'

try:
    print('[INFO] Loading Grounding DINO...')
    dino_processor = AutoProcessor.from_pretrained(DINO_MODEL_ID)
    dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL_ID).to(DEVICE).eval()

    print('[INFO] Loading SAM ViT-B...')
    sam = sam_model_registry['vit_b'](checkpoint=str(SAM_CHECKPOINT))
    sam.to(device=DEVICE)
    sam.eval()
    sam_predictor = SamPredictor(sam)
    print('[INFO] Models loaded successfully.')
except Exception as e:
    raise RuntimeError(f'Model loading failed: {e}')

## BLOCK 4 — Utility Functions

In [None]:
# =========================
# BLOCK 4 — Utility Functions
# =========================
import cv2
import numpy as np
import inspect
from dataclasses import dataclass
from typing import List, Dict, Tuple


def cleanup_cuda():
    gc.collect()
    if DEVICE == 'cuda':
        torch.cuda.empty_cache()


def detect_objects(image_bgr: np.ndarray, prompt: str,
                   box_threshold: float = BOX_THRESHOLD,
                   text_threshold: float = TEXT_THRESHOLD) -> Tuple[np.ndarray, List[str], np.ndarray]:
    """Grounding DINO inference with compatibility across transformers versions."""
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    with torch.no_grad():
        inputs = dino_processor(images=image_rgb, text=prompt, return_tensors='pt').to(DEVICE)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            outputs = dino_model(**inputs)

        target_sizes = torch.tensor([image_rgb.shape[:2]], device=DEVICE)

        # Compatibility: different transformers versions expose different argument names.
        pp_fn = dino_processor.post_process_grounded_object_detection
        pp_sig = inspect.signature(pp_fn)
        kwargs = {
            'outputs': outputs,
            'input_ids': inputs.input_ids,
            'target_sizes': target_sizes
        }
        if 'box_threshold' in pp_sig.parameters:
            kwargs['box_threshold'] = box_threshold
        elif 'threshold' in pp_sig.parameters:
            kwargs['threshold'] = box_threshold

        if 'text_threshold' in pp_sig.parameters:
            kwargs['text_threshold'] = text_threshold

        results = pp_fn(**kwargs)[0]

    boxes = results['boxes'].detach().cpu().numpy() if len(results['boxes']) else np.empty((0, 4), dtype=np.float32)
    labels = list(results['labels']) if 'labels' in results else []
    scores = results['scores'].detach().cpu().numpy() if len(results['scores']) else np.empty((0,), dtype=np.float32)
    del inputs, outputs, results, target_sizes
    cleanup_cuda()
    return boxes, labels, scores


def segment_boxes(image_bgr: np.ndarray, boxes_xyxy: np.ndarray) -> List[np.ndarray]:
    if len(boxes_xyxy) == 0:
        return []
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    sam_predictor.set_image(image_rgb)
    masks_out = []
    for box in boxes_xyxy:
        masks, scores, _ = sam_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=np.array(box, dtype=np.float32)[None, :],
            multimask_output=False
        )
        masks_out.append(masks[0].astype(np.uint8))
    del image_rgb
    cleanup_cuda()
    return masks_out


def filter_masks(masks: List[np.ndarray], min_area: int = 120, max_area_ratio: float = 0.8) -> List[np.ndarray]:
    if not masks:
        return []
    h, w = masks[0].shape[:2]
    max_area = int(h * w * max_area_ratio)
    return [m for m in masks if min_area <= int(m.sum()) <= max_area]


def get_centroid(mask: np.ndarray):
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        return None
    return int(xs.mean()), int(ys.mean())


@dataclass
class TrackState:
    centroid: Tuple[int, int]
    disappeared: int = 0


class CentroidTracker:
    def __init__(self, max_disappeared: int = 15, max_distance: float = 80.0):
        self.next_id = 0
        self.objects: Dict[int, TrackState] = {}
        self.max_disappeared = max_disappeared
        self.max_distance = max_distance

    def _register(self, centroid):
        self.objects[self.next_id] = TrackState(centroid=centroid, disappeared=0)
        self.next_id += 1

    def _deregister(self, object_id):
        self.objects.pop(object_id, None)

    def update(self, input_centroids: List[Tuple[int, int]]):
        if len(input_centroids) == 0:
            for object_id in list(self.objects.keys()):
                self.objects[object_id].disappeared += 1
                if self.objects[object_id].disappeared > self.max_disappeared:
                    self._deregister(object_id)
            return self.objects

        if len(self.objects) == 0:
            for c in input_centroids:
                self._register(c)
            return self.objects

        object_ids = list(self.objects.keys())
        object_centroids = np.array([self.objects[i].centroid for i in object_ids])
        new_centroids = np.array(input_centroids)
        D = np.linalg.norm(object_centroids[:, None] - new_centroids[None, :], axis=2)

        rows = D.min(axis=1).argsort()
        cols = D.argmin(axis=1)[rows]
        used_rows, used_cols = set(), set()

        for row, col in zip(rows, cols):
            if row in used_rows or col in used_cols or D[row, col] > self.max_distance:
                continue
            object_id = object_ids[row]
            self.objects[object_id].centroid = tuple(new_centroids[col])
            self.objects[object_id].disappeared = 0
            used_rows.add(row)
            used_cols.add(col)

        for row in set(range(D.shape[0])) - used_rows:
            object_id = object_ids[row]
            self.objects[object_id].disappeared += 1
            if self.objects[object_id].disappeared > self.max_disappeared:
                self._deregister(object_id)

        for col in set(range(D.shape[1])) - used_cols:
            self._register(tuple(new_centroids[col]))

        return self.objects

## BLOCK 5 — Optimized Video Processing Pipeline

In [None]:
# =========================
# BLOCK 5 — Optimized Video Processing Pipeline
# =========================
from tqdm import tqdm

video_in = str(INPUT_VIDEO)
video_out = str(OUTPUT_DIR / 'annotated_output.mp4')
json_out = str(OUTPUT_DIR / 'counts.json')

cap = cv2.VideoCapture(video_in)
if not cap.isOpened():
    raise FileNotFoundError(f'Cannot open video: {video_in}')

fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

proc_w = int(orig_w * RESIZE_FACTOR)
proc_h = int(orig_h * RESIZE_FACTOR)
if proc_w <= 0 or proc_h <= 0:
    raise ValueError('Invalid RESIZE_FACTOR generated non-positive dimensions.')

writer = cv2.VideoWriter(video_out, cv2.VideoWriter_fourcc(*'mp4v'), fps, (proc_w, proc_h))
tracker = CentroidTracker(max_disappeared=20, max_distance=100.0)
frame_records = []

if num_frames <= 0:
    print('[WARN] CAP_PROP_FRAME_COUNT returned 0/unknown. Falling back to streaming loop.')

print(f'[INFO] Processing video @ {fps:.2f} FPS | size {orig_w}x{orig_h} -> {proc_w}x{proc_h}')

frame_idx = 0
pbar = tqdm(total=num_frames if num_frames > 0 else None, desc='Video Processing')

while True:
    ok, frame = cap.read()
    if not ok:
        break

    if RESIZE_FACTOR != 1.0:
        frame = cv2.resize(frame, (proc_w, proc_h), interpolation=cv2.INTER_AREA)

    active_ids = sorted(list(tracker.objects.keys()))

    # Skip-frame logic for speed, but still export a per-frame record.
    if frame_idx % PROCESS_EVERY_N_FRAMES != 0:
        frame_records.append({'frame_idx': frame_idx, 'unique_ids': active_ids, 'count': len(active_ids)})
        cv2.putText(frame, f'Count: {len(active_ids)}', (15, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)
        writer.write(frame)
        frame_idx += 1
        pbar.update(1)
        continue

    try:
        pallet_boxes, _, _ = detect_objects(frame, prompt='pallet')
        box_boxes, _, _ = detect_objects(frame, prompt='box')

        masks = filter_masks(segment_boxes(frame, box_boxes))
        centroids = [c for c in (get_centroid(m) for m in masks) if c is not None]
        tracked = tracker.update(centroids)
        active_ids = sorted(list(tracked.keys()))

        if DRAW_MASK_OVERLAY and len(masks) > 0:
            overlay = frame.copy()
            for m in masks:
                overlay[m.astype(bool)] = (0.5 * overlay[m.astype(bool)] + 0.5 * np.array([0, 255, 0], dtype=np.uint8)).astype(np.uint8)
            frame = cv2.addWeighted(overlay, 0.35, frame, 0.65, 0)

        for bx in pallet_boxes:
            x1, y1, x2, y2 = map(int, bx)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 200, 0), 2)
            cv2.putText(frame, 'Pallet', (x1, max(0, y1 - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 200, 0), 2)

        for bx in box_boxes:
            x1, y1, x2, y2 = map(int, bx)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 220, 0), 1)

        for object_id, state in tracked.items():
            cx, cy = state.centroid
            cv2.circle(frame, (cx, cy), 4, (0, 0, 255), -1)
            cv2.putText(frame, f'ID {object_id}', (cx + 5, cy - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

    except Exception as e:
        # Keep pipeline alive and keep exporting a row for this frame.
        print(f'[WARN] Frame {frame_idx}: {e}')
        active_ids = sorted(list(tracker.objects.keys()))

    frame_records.append({'frame_idx': frame_idx, 'unique_ids': active_ids, 'count': len(active_ids)})
    cv2.putText(frame, f'Count: {len(active_ids)}', (15, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)

    writer.write(frame)

    if frame_idx % LOG_EVERY_N_FRAMES == 0:
        if num_frames > 0:
            print(f'[INFO] frame={frame_idx}/{num_frames} | tracked={len(tracker.objects)}')
        else:
            print(f'[INFO] frame={frame_idx} | tracked={len(tracker.objects)}')

    cleanup_cuda()
    frame_idx += 1
    pbar.update(1)

pbar.close()
cap.release()
writer.release()
cleanup_cuda()

print(f'[INFO] Annotated video saved to: {video_out}')
print(f'[INFO] Records exported from pipeline: {len(frame_records)}')


## BLOCK 6 — JSON Export

In [None]:
# =========================
# BLOCK 6 — JSON Export
# =========================
with open(json_out, 'w', encoding='utf-8') as f:
    json.dump(frame_records, f, indent=2)

print(f'[INFO] JSON saved to: {json_out}')
print(f'[INFO] Total exported frames: {len(frame_records)}')
if len(frame_records) == 0:
    print('[WARN] No frames were exported. Verify video path/codec and inspect frame read logs.')


## BLOCK 7 — Interactive Analytics Playground

In [None]:
# =========================
# BLOCK 7 — Interactive Analytics Playground
# =========================
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

with open(json_out, 'r', encoding='utf-8') as f:
    data = json.load(f)

df = pd.DataFrame(data)
if df.empty:
    raise ValueError('No records found in counts.json')

plt.figure(figsize=(14, 4))
plt.plot(df['frame_idx'], df['count'], color='tab:blue', linewidth=1.5)
plt.title('Tracked Object Count per Frame')
plt.xlabel('Frame Index')
plt.ylabel('Count')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

rows = []
for _, row in df.iterrows():
    for obj_id in row['unique_ids']:
        rows.append({'frame_idx': int(row['frame_idx']), 'object_id': int(obj_id), 'presence': 1})

id_df = pd.DataFrame(rows)
if not id_df.empty:
    fig = px.scatter(id_df, x='frame_idx', y='object_id', color='object_id',
                     title='Object ID Presence Over Time', opacity=0.8, height=520)
    fig.update_traces(marker=dict(size=6))
    fig.update_layout(xaxis_title='Frame Index', yaxis_title='Tracked ID', showlegend=False)
    fig.show()
else:
    print('[INFO] No active IDs to visualize.')