In [None]:

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torchvision import models
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

class CSRNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16_bn(weights=models.VGG16_BN_Weights.IMAGENET1K_V1)

        self.frontend = nn.Sequential(*list(vgg.features.children())[:33])
        self.backend = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=2, dilation=2), nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1,   1)
        )

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        return x

# ⭐ Use your real checkpoint path here
WEIGHTS_PATH = Path("/content/drive/MyDrive/deepvision/checkpoints_partB/partB_best.pth")
assert WEIGHTS_PATH.exists(), f"Weights not found: {WEIGHTS_PATH}"

model = CSRNet().to(device)
state_dict = torch.load(WEIGHTS_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()

print("✅ Loaded weights from:", WEIGHTS_PATH)


Using device: cuda
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth


100%|██████████| 528M/528M [00:02<00:00, 199MB/s]


✅ Loaded weights from: /content/drive/MyDrive/deepvision/checkpoints_partB/partB_best.pth


In [None]:
# Colab-ready: YOLOv8 video annotator with dashboard-style summary (Gradio 5.50.0 compatible)
!pip install -q ultralytics==8.1.20
!pip install -q matplotlib
!pip install -q gradio

import os, cv2, tempfile, time, numpy as np, matplotlib.pyplot as plt
import gradio as gr
import torch

# ---------------- CONFIG ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
YOLO_MODEL = "yolov8n.pt"
OUT_W, OUT_H = 640, 480

print("Device:", DEVICE, "| Gradio:", gr.__version__)
print("Security: forcing torch.load(weights_only=False) to load YOLO (trusted checkpoint)")

# ---------------- SAFE LOAD YOLO ----------------
orig_load = torch.load
def _forced_load(*a, **k):
    k["weights_only"] = False
    return orig_load(*a, **k)
torch.load = _forced_load
from ultralytics import YOLO
yolo = YOLO(YOLO_MODEL)
torch.load = orig_load
print("YOLO loaded ✔")

# ------------- helper functions -------------
def run_yolo(frame_rgb, conf=0.35):
    res = yolo.predict(source=frame_rgb, imgsz=(OUT_W, OUT_H), conf=conf, verbose=False, device=DEVICE)
    preds = res[0]
    boxes = []
    if getattr(preds, "boxes", None) is not None:
        for box, cls, confv in zip(preds.boxes.xyxy.cpu().numpy(),
                                   preds.boxes.cls.cpu().numpy(),
                                   preds.boxes.conf.cpu().numpy()):
            if preds.names[int(cls)] != "person":
                continue
            x1, y1, x2, y2 = map(int, box[:4])
            boxes.append((x1, y1, x2, y2, float(confv)))
    return boxes

def annotate(frame_rgb, boxes, count):
    img = frame_rgb.copy()
    for (x1, y1, x2, y2, c) in boxes:
        cv2.rectangle(img, (x1,y1), (x2,y2), (0,255,0), 2)
        cv2.putText(img, f"{c:.2f}", (x1, max(12,y1-6)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 1)
    cv2.putText(img, f"Count: {count}", (10,35),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2)
    return img

def process_video_file(video_file, conf, skip, alert_thr):
    cap = cv2.VideoCapture(video_file)
    fps = cap.get(cv2.CAP_PROP_FPS) or 20.0
    out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
    writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (OUT_W, OUT_H))

    counts = []
    frame_idx = 0

    while True:
        ok, frame_bgr = cap.read()
        if not ok:
            break
        frame_idx += 1
        if frame_idx % skip != 0:
            writer.write(cv2.resize(frame_bgr, (OUT_W, OUT_H)))
            continue

        rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        rgb = cv2.resize(rgb, (OUT_W, OUT_H))

        boxes = run_yolo(rgb, conf)
        num = len(boxes)
        counts.append(num)

        annotated_rgb = annotate(rgb, boxes, num)
        annotated_bgr = cv2.cvtColor(annotated_rgb, cv2.COLOR_RGB2BGR)

        if num >= alert_thr:
            cv2.putText(annotated_bgr, "!!! OVERCROWDED !!!", (10,75),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 3)

        writer.write(annotated_bgr)

    cap.release()
    writer.release()

    # save plot
    plot_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
    plt.figure(figsize=(10,3))
    if counts:
        plt.plot(counts, marker='o', linewidth=1)
    plt.title("People Count per Processed Frame")
    plt.xlabel("Processed frame index")
    plt.ylabel("Count")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(plot_path, dpi=150)
    plt.close()

    # compute summary values safely
    frames_total = frame_idx
    frames_processed = len(counts)
    avg_count = float(np.mean(counts)) if counts else 0.0
    max_count = int(np.max(counts)) if counts else 0
    time_s = 0.0  # we can measure if desired

    summary = {
        "frames_total": frames_total,
        "frames_processed": frames_processed,
        "avg_count": avg_count,
        "max_count": max_count,
        "time_s": time_s
    }

    return out_path, plot_path, summary

# -------------- UI styling HTML for stat tiles --------------
def stat_html(num, label, accent="#FF7A1A"):
    return f"""
    <div style="display:inline-block; background:#111827; padding:14px 18px; margin:6px; border-radius:8px; min-width:140px;">
      <div style="font-size:22px; font-weight:700; color:{accent};">{num}</div>
      <div style="color:#9CA3AF; font-size:12px;">{label}</div>
    </div>
    """

# -------------- Wrapper for UI processing --------------
def ui_process(video_file, conf, skip, alert_thr):
    if not video_file:
        return None, None, "No file uploaded", stat_html(0, "Frames processed"), stat_html(0.00, "Avg count"), stat_html(0, "Max count")
    out_vid, plot, summary = process_video_file(video_file, float(conf), int(skip), int(alert_thr))
    # markdown summary (nice formatting)
    md = f"""### Summary

- **Frames total:** {summary['frames_total']}
- **Frames processed:** {summary['frames_processed']}
- **Average count:** **{summary['avg_count']:.2f}**
- **Max count:** **{summary['max_count']}**

*Processing produced an annotated MP4 and counts-over-time plot.*
"""
    # stat tiles
    stat_frames = stat_html(summary['frames_processed'], "Frames processed")
    stat_avg = stat_html(f"{summary['avg_count']:.2f}", "Average count")
    stat_max = stat_html(summary['max_count'], "Max count")
    return out_vid, plot, md, stat_frames, stat_avg, stat_max

# ---------------- Build Gradio UI ----------------
with gr.Blocks() as demo:
    gr.HTML("<div style='font-size:26px; font-weight:700; color:#fff; margin-bottom:4px;'>Crowd Counting — Annotated Video</div>")
    gr.HTML("<div style='color:#9CA3AF; margin-bottom:12px;'>Upload a video and click Process. Stats update in the dashboard on the right.</div>")

    with gr.Row():
        with gr.Column(scale=2):
            vid_in = gr.Video(label="Upload video (mp4/mov)")
            conf = gr.Slider(0.1, 0.9, value=0.35, label="YOLO Confidence")
            skip = gr.Slider(1, 10, value=1, step=1, label="Process every N frames")
            alert_thr = gr.Slider(1, 50, value=8, step=1, label="Overcrowd Threshold")
            process_btn = gr.Button("Process", variant="primary")

        with gr.Column(scale=1):
            # stat tiles area (HTML components that will be updated)
            st_frames = gr.HTML(stat_html(0, "Frames processed"), label="Frames")
            st_avg = gr.HTML(stat_html(0.00, "Avg count"), label="Avg")
            st_max = gr.HTML(stat_html(0, "Max count"), label="Max")
            gr.Markdown("### Results")
            out_plot = gr.Image(label="Counts over time")
            out_summary = gr.Markdown("### Summary\n\n_No results yet._")

    out_video = gr.Video(label="Annotated video (MP4)")

    process_btn.click(fn=ui_process, inputs=[vid_in, conf, skip, alert_thr],
                      outputs=[out_video, out_plot, out_summary, st_frames, st_avg, st_max])

demo.launch(share=True)


Device: cuda | Gradio: 5.50.0
Security: forcing torch.load(weights_only=False) to load YOLO (trusted checkpoint)
YOLO loaded ✔
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://3a78e26b62b67b4b24.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


