<a href="https://colab.research.google.com/github/Paul-locatelli/projet-detection-avions-paul-omar/blob/main/website.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# =========================
# 0) Install deps
# =========================
!pip -q install gradio pillow numpy torch torchvision

# =========================
# 1) Imports + Drive mount
# =========================
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torchvision.ops import nms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import gradio as gr
from google.colab import drive

drive.mount("/content/drive")

# =========================
# 2) Paths on Google Drive
# =========================
BASE_DIR = "/content/drive/MyDrive/Final_product"  # <- folder in your screenshot

CLASSIFIER_PATH = os.path.join(BASE_DIR, "best_crop_classifier_resnet50.pth")
DETECTOR_PATH   = os.path.join(BASE_DIR, "best_faster_rcnn_raw.pth")

if not os.path.exists(CLASSIFIER_PATH):
    raise FileNotFoundError(f"Missing classifier at: {CLASSIFIER_PATH}")
if not os.path.exists(DETECTOR_PATH):
    raise FileNotFoundError(f"Missing detector at: {DETECTOR_PATH}")

# =========================
# 3) Device
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# =========================
# 4) Helper: checkpoint parsing
# =========================
def extract_state_dict(ckpt):
    if isinstance(ckpt, dict):
        for k in ["model_state_dict", "state_dict", "model", "weights"]:
            if k in ckpt and isinstance(ckpt[k], dict):
                return ckpt[k]
        dict_candidates = [v for v in ckpt.values() if isinstance(v, dict)]
        if dict_candidates:
            return max(dict_candidates, key=lambda d: len(d))
        return ckpt
    raise ValueError("Checkpoint format not supported.")

def strip_module(sd):
    if any(k.startswith("module.") for k in sd.keys()):
        return {k.replace("module.", "", 1): v for k, v in sd.items()}
    return sd

def infer_detector_num_classes(sd):
    k = "roi_heads.box_predictor.cls_score.weight"
    if k in sd and hasattr(sd[k], "shape"):
        return int(sd[k].shape[0])
    return None

# =========================
# 5) Load classifier (ResNet50)
#   Expected: ckpt contains 'class_to_idx' and 'model_state_dict'
# =========================
clf_ckpt = torch.load(CLASSIFIER_PATH, map_location=DEVICE)
if not isinstance(clf_ckpt, dict) or "class_to_idx" not in clf_ckpt:
    raise KeyError("Classifier checkpoint must contain 'class_to_idx'.")

class_to_idx = clf_ckpt["class_to_idx"]
idx_to_class = {v: k for k, v in class_to_idx.items()}
num_classes = len(idx_to_class)

clf = torchvision.models.resnet50(weights=None)
clf.fc = nn.Linear(clf.fc.in_features, num_classes)

clf_sd = clf_ckpt.get("model_state_dict", None)
if clf_sd is None:
    clf_sd = extract_state_dict(clf_ckpt)
clf.load_state_dict(clf_sd, strict=False)

clf.to(DEVICE).eval()

tf_clf = T.Compose([T.Resize((224, 224)), T.ToTensor()])

@torch.no_grad()
def classify_crop(crop_pil):
    x = tf_clf(crop_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
    logits = clf(x)[0]
    probs = torch.softmax(logits, dim=0).detach().cpu().numpy()
    i = int(np.argmax(probs))
    return idx_to_class[i], float(probs[i])

# =========================
# 6) Load detector (Faster R-CNN ResNet50 FPN)
# =========================
det_ckpt = torch.load(DETECTOR_PATH, map_location=DEVICE)
det_sd = strip_module(extract_state_dict(det_ckpt))

num_det_classes = infer_detector_num_classes(det_sd)
if num_det_classes is None:
    raise ValueError(
        "Detector checkpoint not recognized as torchvision Faster R-CNN state_dict.\n"
        "If your detector is not FasterRCNN, tell me what model you trained."
    )

detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
in_features = detector.roi_heads.box_predictor.cls_score.in_features
detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_det_classes)

detector.load_state_dict(det_sd, strict=False)
detector.to(DEVICE).eval()

to_tensor = T.ToTensor()

@torch.no_grad()
def detect(img_pil, det_conf=0.30, iou_nms=0.50, max_boxes=50):
    x = to_tensor(img_pil).to(DEVICE)
    pred = detector([x])[0]
    boxes = pred["boxes"].detach().cpu()
    scores = pred["scores"].detach().cpu()

    keep = scores >= float(det_conf)
    boxes = boxes[keep]
    scores = scores[keep]

    if len(boxes) == 0:
        return np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.float32)

    keep_idx = nms(boxes, scores, float(iou_nms))[: int(max_boxes)]
    return boxes[keep_idx].numpy().astype(np.float32), scores[keep_idx].numpy().astype(np.float32)

# =========================
# 7) Drawing utils
# =========================
ORANGE = (255, 130, 0)

def get_font():
    try:
        return ImageFont.truetype("DejaVuSans.ttf", 18)
    except:
        return None

def draw_tag(draw, font, x1, y1, text):
    pad_x, pad_y = 8, 4
    if font:
        bb = draw.textbbox((0, 0), text, font=font)
        tw, th = bb[2] - bb[0], bb[3] - bb[1]
    else:
        tw, th = max(10, 9 * len(text)), 18
    y_top = max(0, y1 - (th + 2 * pad_y))
    draw.rectangle([x1, y_top, x1 + tw + 2 * pad_x, y1], fill=ORANGE)
    draw.text((x1 + pad_x, y_top + pad_y), text, fill=(255, 255, 255), font=font)

# =========================
# 8) Main pipeline: upload image -> output image with bbox + labels
# =========================
def run_pipeline(img, det_conf, iou_nms, max_boxes):
    if img is None:
        return None, "Upload an image."

    img = img.convert("RGB")
    W, H = img.size
    font = get_font()

    out = img.copy()
    d = ImageDraw.Draw(out)

    boxes, dscores = detect(img, det_conf=det_conf, iou_nms=iou_nms, max_boxes=max_boxes)
    if len(boxes) == 0:
        return out, "No detections. Try lowering detector confidence."

    lines = []
    shown = 0

    for (x1, y1, x2, y2), ds in zip(boxes, dscores):
        x1 = int(max(0, min(W - 1, x1)))
        y1 = int(max(0, min(H - 1, y1)))
        x2 = int(max(0, min(W - 1, x2)))
        y2 = int(max(0, min(H - 1, y2)))
        if x2 - x1 < 2 or y2 - y1 < 2:
            continue

        crop = img.crop((x1, y1, x2, y2))
        cls, cls_p = classify_crop(crop)

        d.rectangle([x1, y1, x2, y2], outline=ORANGE, width=4)
        tag = f"{cls} {cls_p:.2f} | det {float(ds):.2f}"
        draw_tag(d, font, x1, y1, tag)

        shown += 1
        lines.append(f"Box {shown}: {tag}  (x1={x1}, y1={y1}, x2={x2}, y2={y2})")

    return out, "\n".join(lines) if lines else "Detections existed but none were drawable."

# =========================
# 9) Website (Gradio public URL)
# =========================
with gr.Blocks() as demo:
    gr.Markdown("# Final Product â€” Detector + BBox + Classifier (Public Website)")

    with gr.Row():
        inp = gr.Image(type="pil", label="Upload image")
        out = gr.Image(type="pil", label="Result (BBox + labels)")

    with gr.Row():
        det_conf = gr.Slider(0.01, 0.99, value=0.30, step=0.01, label="Detector confidence")
        iou_nms  = gr.Slider(0.10, 0.95, value=0.50, step=0.05, label="NMS IoU")
        max_boxes = gr.Slider(1, 200, value=50, step=1, label="Max boxes")

    status = gr.Textbox(label="Details", lines=10)
    btn = gr.Button("Run", variant="primary")
    btn.click(run_pipeline, inputs=[inp, det_conf, iou_nms, max_boxes], outputs=[out, status])

demo.launch(share=True)


Mounted at /content/drive
Using device: cuda
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://eb47f515ecc9f4bac8.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


