<a href="https://colab.research.google.com/github/Paul-locatelli/projet-detection-avions-paul-omar/blob/main/Model_detector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")

import os, glob, pathlib, random, shutil
import xml.etree.ElementTree as ET
from PIL import Image

ROOT = "/content/drive/MyDrive/Final_product"
DATASET_ROOT = os.path.join(ROOT, "DataSet")
RAW_DIR = os.path.join(DATASET_ROOT, "raw")

OUT_CROPS = os.path.join(DATASET_ROOT, "cls_crops")
OUT_SPLIT1N = os.path.join(DATASET_ROOT, "split1n")
OUT_KEYS = os.path.join(DATASET_ROOT, "splits_keys")

for d in [OUT_CROPS, OUT_SPLIT1N, OUT_KEYS]:
    os.makedirs(d, exist_ok=True)

def norm_stem(p):
    s = pathlib.Path(p).stem.lower()
    s = s.replace(" ", "").replace("-", "").replace("_", "")
    return s

img_files = sorted(
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpeg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.png"), recursive=True)
)
xml_files = sorted(glob.glob(os.path.join(RAW_DIR, "**", "*.xml"), recursive=True))

assert len(img_files) > 0, "No raw images found under DataSet/raw"
assert len(xml_files) > 0, "No raw xml found under DataSet/raw"

img_map = {}
for p in img_files:
    k = norm_stem(p)
    if k not in img_map:
        img_map[k] = p

xml_map = {}
for p in xml_files:
    k = norm_stem(p)
    if k not in xml_map:
        xml_map[k] = p

matched = sorted(set(img_map.keys()) & set(xml_map.keys()))
assert len(matched) > 0, "No matched image/xml pairs. Filenames don't align."

print("Raw images:", len(img_map))
print("Raw xml:", len(xml_map))
print("Matched pairs:", len(matched))

def parse_voc(xml_path):
    root = ET.parse(xml_path).getroot()
    objs = []
    for obj in root.findall("object"):
        name = obj.findtext("name")
        bnd = obj.find("bndbox")
        if bnd is None:
            continue
        x1 = int(float(bnd.findtext("xmin")))
        y1 = int(float(bnd.findtext("ymin")))
        x2 = int(float(bnd.findtext("xmax")))
        y2 = int(float(bnd.findtext("ymax")))
        objs.append((name, x1, y1, x2, y2))
    return objs

label_set = set()
empty_ann = 0
for k in matched:
    objs = parse_voc(xml_map[k])
    if len(objs) == 0:
        empty_ann += 1
    for name, *_ in objs:
        if name is not None:
            label_set.add(name)

classes = sorted(label_set)
class_to_idx = {c:i for i,c in enumerate(classes)}

print("Detected classes:", len(classes))
print("Empty-annotation xml:", empty_ann)
print("Classes:", classes)

random.seed(42)
random.shuffle(matched)

n = len(matched)
n_train = int(0.8 * n)
n_val = int(0.1 * n)

train_keys = matched[:n_train]
val_keys = matched[n_train:n_train+n_val]
test_keys = matched[n_train+n_val:]

print("Split sizes:", len(train_keys), len(val_keys), len(test_keys))

def write_keys(keys, path):
    with open(path, "w") as f:
        for k in keys:
            f.write(k + "\n")

os.makedirs(OUT_KEYS, exist_ok=True)
write_keys(train_keys, os.path.join(OUT_KEYS, "train.txt"))
write_keys(val_keys, os.path.join(OUT_KEYS, "val.txt"))
write_keys(test_keys, os.path.join(OUT_KEYS, "test.txt"))

for split in ["train", "val", "test"]:
    for c in classes:
        os.makedirs(os.path.join(OUT_CROPS, split, c), exist_ok=True)
    os.makedirs(os.path.join(OUT_SPLIT1N, split, "images"), exist_ok=True)

def safe_crop(img, x1, y1, x2, y2):
    w, h = img.size
    x1 = max(0, min(int(x1), w-1))
    y1 = max(0, min(int(y1), h-1))
    x2 = max(0, min(int(x2), w))
    y2 = max(0, min(int(y2), h))
    if x2 <= x1 or y2 <= y1:
        x2 = min(w, x1+1)
        y2 = min(h, y1+1)
    return img.crop((x1, y1, x2, y2))

def process_split(keys, split_name, copy_images=True):
    saved_crops = 0
    saved_imgs = 0
    for k in keys:
        img_path = img_map[k]
        xml_path = xml_map[k]

        if copy_images:
            dst_img = os.path.join(OUT_SPLIT1N, split_name, "images", os.path.basename(img_path))
            if not os.path.exists(dst_img):
                shutil.copy2(img_path, dst_img)
                saved_imgs += 1

        img = Image.open(img_path).convert("RGB")
        objs = parse_voc(xml_path)
        for j, (name, x1, y1, x2, y2) in enumerate(objs):
            if name not in class_to_idx:
                continue
            crop = safe_crop(img, x1, y1, x2, y2)
            out_name = f"{pathlib.Path(img_path).stem}__{j}.jpg"
            out_path = os.path.join(OUT_CROPS, split_name, name, out_name)
            crop.save(out_path, quality=95)
            saved_crops += 1
    return saved_crops, saved_imgs

for split_name, keys in [("train", train_keys), ("val", val_keys), ("test", test_keys)]:
    crops_n, imgs_n = process_split(keys, split_name, copy_images=True)
    print(split_name, "| crops:", crops_n, "| images copied:", imgs_n)

meta_path = os.path.join(DATASET_ROOT, "dataset_meta.txt")
with open(meta_path, "w") as f:
    f.write("classes=" + ",".join(classes) + "\n")
    f.write("num_classes=" + str(len(classes)) + "\n")
    f.write("matched_pairs=" + str(len(matched)) + "\n")

print("Crops folder:", OUT_CROPS)
print("Split images folder:", OUT_SPLIT1N)
print("Keys folder:", OUT_KEYS)
print("Meta:", meta_path)



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Raw images: 1331
Raw xml: 1331
Matched pairs: 1331
Detected classes: 20
Empty-annotation xml: 0
Classes: ['A1', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A2', 'A20', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9']
Split sizes: 1064 133 134
train | crops: 6220 | images copied: 1064
val | crops: 876 | images copied: 133
test | crops: 774 | images copied: 134
Crops folder: /content/drive/MyDrive/Final_product/DataSet/cls_crops
Split images folder: /content/drive/MyDrive/Final_product/DataSet/split1n
Keys folder: /content/drive/MyDrive/Final_product/DataSet/splits_keys
Meta: /content/drive/MyDrive/Final_product/DataSet/dataset_meta.txt


In [None]:
import os, glob, pathlib
import xml.etree.ElementTree as ET
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as T

ROOT = "/content/drive/MyDrive/Final_product"
DATASET_ROOT = os.path.join(ROOT, "DataSet")
RAW_DIR = os.path.join(DATASET_ROOT, "raw")
KEYS_DIR = os.path.join(DATASET_ROOT, "splits_keys")
META_PATH = os.path.join(DATASET_ROOT, "dataset_meta.txt")

MODELS_DIR = os.path.join(ROOT, "models")
os.makedirs(MODELS_DIR, exist_ok=True)
BEST_PATH = os.path.join(MODELS_DIR, "best_faster_rcnn_raw.pth")

TRAIN_TXT = os.path.join(KEYS_DIR, "train.txt")
VAL_TXT   = os.path.join(KEYS_DIR, "val.txt")
TEST_TXT  = os.path.join(KEYS_DIR, "test.txt")

assert os.path.isdir(RAW_DIR), f"Missing {RAW_DIR}"
assert os.path.exists(TRAIN_TXT) and os.path.exists(VAL_TXT) and os.path.exists(TEST_TXT), "Missing splits_keys/*.txt"
assert os.path.exists(META_PATH), f"Missing {META_PATH}"

print("RAW_DIR:", RAW_DIR)
print("KEYS_DIR:", KEYS_DIR)
print("META:", META_PATH)
print("SAVE:", BEST_PATH)


RAW_DIR: /content/drive/MyDrive/Final_product/DataSet/raw
KEYS_DIR: /content/drive/MyDrive/Final_product/DataSet/splits_keys
META: /content/drive/MyDrive/Final_product/DataSet/dataset_meta.txt
SAVE: /content/drive/MyDrive/Final_product/models/best_faster_rcnn_raw.pth


In [None]:
with open(META_PATH, "r") as f:
    lines = f.read().splitlines()

classes_line = [l for l in lines if l.startswith("classes=")][0]
classes = classes_line.split("=", 1)[1].split(",")
classes = [c for c in classes if c]

class_to_idx = {c:i for i,c in enumerate(classes)}

print("Num classes:", len(classes))
print("Classes:", classes)


Num classes: 20
Classes: ['A1', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A2', 'A20', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9']


In [None]:
def norm_stem_from_path(p):
    s = pathlib.Path(p).stem.lower()
    s = s.replace(" ", "").replace("-", "").replace("_", "")
    return s

img_files = sorted(
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpeg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.png"), recursive=True)
)
xml_files = sorted(glob.glob(os.path.join(RAW_DIR, "**", "*.xml"), recursive=True))

assert len(img_files) > 0, "No images in raw/"
assert len(xml_files) > 0, "No xml in raw/"

img_map = {}
for p in img_files:
    k = norm_stem_from_path(p)
    if k not in img_map:
        img_map[k] = p

xml_map = {}
for p in xml_files:
    k = norm_stem_from_path(p)
    if k not in xml_map:
        xml_map[k] = p

print("Raw images:", len(img_map))
print("Raw xml:", len(xml_map))

# juste pour v√©rifier que √ßa colle √† tes 1331
inter = len(set(img_map.keys()) & set(xml_map.keys()))
print("Intersection img/xml:", inter)
assert inter > 0


Raw images: 1331
Raw xml: 1331
Intersection img/xml: 1331


In [None]:
def read_keys(p):
    with open(p, "r") as f:
        return [line.strip() for line in f if line.strip()]

train_keys = read_keys(TRAIN_TXT)
val_keys   = read_keys(VAL_TXT)
test_keys  = read_keys(TEST_TXT)

print("Split sizes:", len(train_keys), len(val_keys), len(test_keys))


Split sizes: 1064 133 134


In [None]:
def parse_voc(xml_path):
    root = ET.parse(xml_path).getroot()
    objs = []
    for obj in root.findall("object"):
        name = obj.findtext("name")
        bnd = obj.find("bndbox")
        if bnd is None:
            continue
        x1 = int(float(bnd.findtext("xmin")))
        y1 = int(float(bnd.findtext("ymin")))
        x2 = int(float(bnd.findtext("xmax")))
        y2 = int(float(bnd.findtext("ymax")))
        objs.append((name, x1, y1, x2, y2))
    return objs


In [None]:
class VOCDatasetRAW(Dataset):
    def __init__(self, keys, img_map, xml_map, class_to_idx):
        self.keys = keys
        self.img_map = img_map
        self.xml_map = xml_map
        self.class_to_idx = class_to_idx
        self.tf = T.ToTensor()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        k = self.keys[idx]

        img_path = self.img_map[k]
        xml_path = self.xml_map[k]

        img = Image.open(img_path).convert("RGB")

        boxes, labels = [], []
        for name, x1, y1, x2, y2 in parse_voc(xml_path):
            if name not in self.class_to_idx:
                continue
            if x2 <= x1 or y2 <= y1:
                continue
            boxes.append([x1, y1, x2, y2])
            labels.append(self.class_to_idx[name] + 1)

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        if boxes.numel() == 0:
            boxes = boxes.reshape(0, 4)
            labels = labels.reshape(0,)

        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if boxes.shape[0] else torch.zeros((0,), dtype=torch.float32)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "iscrowd": torch.zeros((labels.shape[0],), dtype=torch.int64),
            "area": area
        }

        return self.tf(img), target

def collate_fn(batch):
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)


In [None]:
train_ds = VOCDatasetRAW(train_keys, img_map, xml_map, class_to_idx)
val_ds   = VOCDatasetRAW(val_keys,   img_map, xml_map, class_to_idx)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)

print("Train batches:", len(train_loader))
print("Val batches  :", len(val_loader))


Train batches: 532
Val batches  : 67


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

num_classes = len(classes) + 1  # + background

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
in_feat = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_feat, num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

print("RCNN ready with", num_classes, "classes (incl background)")


device: cuda
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160M/160M [00:00<00:00, 209MB/s]


RCNN ready with 21 classes (incl background)


In [None]:
def train_epoch(loader):
    model.train()
    total = 0.0
    for imgs, targets in loader:
        imgs = [i.to(device) for i in imgs]
        targets = [{k:v.to(device) for k,v in t.items()} for t in targets]

        loss_dict = model(imgs, targets)
        loss = sum(loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def val_epoch(loader):
    model.train()
    total = 0.0
    for imgs, targets in loader:
        imgs = [i.to(device) for i in imgs]
        targets = [{k:v.to(device) for k,v in t.items()} for t in targets]

        loss_dict = model(imgs, targets)
        loss = sum(loss_dict.values())
        total += loss.item()
    return total / max(1, len(loader))

EPOCHS = 10
best_val = 1e9

print("\nüöÄ DETECTOR TRAINING START\n")

for e in range(1, EPOCHS+1):
    tr = train_epoch(train_loader)
    va = val_epoch(val_loader)
    lr_scheduler.step()

    print(f"epoch {e}/{EPOCHS} | train_loss={tr:.4f} | val_loss={va:.4f}")

    if va < best_val:
        best_val = va
        torch.save({
            "model_state_dict": model.state_dict(),
            "classes": classes,
            "class_to_idx": class_to_idx
        }, BEST_PATH)
        print("‚úÖ saved best ->", BEST_PATH)

print("\nüèÅ DONE")
print("Best val loss:", best_val)
print("Detector saved:", BEST_PATH)


Objectif  
L‚Äôobjectif de cette partie est de mettre en place un mod√®le de d√©tection d‚Äôa√©ronefs capable de localiser et identifier les avions pr√©sents dans une image. Cette √©tape constitue la premi√®re brique du pipeline global, avant la phase de classification fine bas√©e sur des crops.

Donn√©es utilis√©es  
Le d√©tecteur est entra√Æn√© √† partir des donn√©es brutes non modifi√©es (RAW). Les images originales et leurs annotations au format Pascal VOC (XML) sont stock√©es dans le dossier Final_product/DataSet/raw/.  
Le jeu de donn√©es contient 1331 images annot√©es, chacune associ√©e √† un fichier XML d√©crivant les classes et les coordonn√©es des bounding boxes. Le nombre total de classes d√©tect√©es est de 20.

Pr√©paration des donn√©es  
Les images et les fichiers XML sont appari√©s √† l‚Äôaide du nom de fichier (stem), apr√®s normalisation afin d‚Äô√©viter les probl√®mes li√©s aux diff√©rences de casse ou de s√©parateurs.  
Seules les images disposant d‚Äôune annotation valide sont utilis√©es pour l‚Äôentra√Ænement du d√©tecteur, ce qui garantit un apprentissage enti√®rement supervis√©.

Un d√©coupage train / validation / test est effectu√© au niveau image afin d‚Äô√©viter toute fuite de donn√©es. La r√©partition est la suivante :  
‚Äì 80 % des images pour l‚Äôentra√Ænement (1064 images)  
‚Äì 10 % pour la validation (133 images)  
‚Äì 10 % pour le test (134 images)  

Les cl√©s correspondant √† chaque split sont enregistr√©es dans des fichiers texte s√©par√©s, ce qui permet de reproduire exactement le m√™me d√©coupage ult√©rieurement.

Dataset PyTorch pour la d√©tection  
Un dataset personnalis√© est impl√©ment√© afin de charger dynamiquement les images RAW et leurs annotations XML.  
Pour chaque image, les bounding boxes et les labels sont extraits √† partir du fichier XML et convertis dans un format compatible avec Faster R-CNN.  
La classe 0 est r√©serv√©e au fond (background), et les classes r√©elles sont index√©es √† partir de 1, conform√©ment aux conventions de torchvision. Les bounding boxes invalides sont filtr√©es automatiquement.

Mod√®le de d√©tection  
Le mod√®le utilis√© est Faster R-CNN avec un backbone ResNet-50 et un Feature Pyramid Network (FPN). Les poids du backbone sont pr√©-entra√Æn√©s sur le jeu de donn√©es COCO, ce qui permet d‚Äôacc√©l√©rer la convergence et d‚Äôam√©liorer les performances.  
La t√™te de classification du mod√®le est remplac√©e afin de correspondre exactement aux 20 classes du jeu de donn√©es, auxquelles s‚Äôajoute la classe background.

Entra√Ænement  
L‚Äôentra√Ænement est r√©alis√© √† l‚Äôaide de l‚Äôoptimiseur SGD, avec un learning rate initial de 0.005, un momentum de 0.9 et un weight decay de 5e-4.  
Un scheduler de type StepLR est utilis√© afin de r√©duire progressivement le learning rate tous les trois epochs.  
Le batch size est fix√© √† 2, ce qui est adapt√© √† l‚Äôarchitecture Faster R-CNN et aux contraintes de m√©moire GPU.

√Ä chaque epoch, le mod√®le est entra√Æn√© sur le jeu d‚Äôapprentissage puis √©valu√© sur le jeu de validation. Le mod√®le pr√©sentant la meilleure loss de validation est automatiquement sauvegard√©.

Sauvegarde et sortie du mod√®le  
Le meilleur mod√®le de d√©tection est sauvegard√© dans le fichier Final_product/models/best_faster_rcnn_raw.pth.  
Ce fichier contient les poids du mod√®le ainsi que les informations n√©cessaires √† l‚Äôinf√©rence, notamment la liste des classes et le dictionnaire de correspondance classe‚Äìindice.

R√¥le dans le pipeline global  
Le d√©tecteur constitue la premi√®re √©tape du pipeline. Il permet de localiser les a√©ronefs dans une image et de g√©n√©rer des bounding boxes pr√©cises.  
Ces bounding boxes sont ensuite utilis√©es pour extraire des crops, qui sont transmis √† un second mod√®le de classification bas√© sur ResNet-50. Cette s√©paration entre d√©tection et classification fine permet d‚Äôam√©liorer la robustesse et la pr√©cision globale du syst√®me, en particulier pour des classes visuellement proches.
