<a href="https://colab.research.google.com/github/Paul-locatelli/projet-detection-avions-paul-omar/blob/main/Model_classificateur.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")

import os, glob, pathlib, random, shutil
import xml.etree.ElementTree as ET
from PIL import Image

ROOT = "/content/drive/MyDrive/Final_product"
DATASET_ROOT = os.path.join(ROOT, "DataSet")
RAW_DIR = os.path.join(DATASET_ROOT, "raw")
OUT_KEYS = os.path.join(DATASET_ROOT, "splits_keys")
META_PATH = os.path.join(DATASET_ROOT, "dataset_meta.txt")

# Write crops locally first (FAST), then copy to Drive
LOCAL_CROPS = "/content/cls_crops"
DRIVE_CROPS = os.path.join(DATASET_ROOT, "cls_crops")

# Clean local output
if os.path.exists(LOCAL_CROPS):
    shutil.rmtree(LOCAL_CROPS)
os.makedirs(LOCAL_CROPS, exist_ok=True)
os.makedirs(OUT_KEYS, exist_ok=True)

def norm_stem(p):
    s = pathlib.Path(p).stem.lower()
    s = s.replace(" ", "").replace("-", "").replace("_", "")
    return s

img_files = sorted(
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.jpeg"), recursive=True) +
    glob.glob(os.path.join(RAW_DIR, "**", "*.png"), recursive=True)
)
xml_files = sorted(glob.glob(os.path.join(RAW_DIR, "**", "*.xml"), recursive=True))

assert len(img_files) > 0, "No images in raw/"
assert len(xml_files) > 0, "No xml in raw/"

img_map = {}
for p in img_files:
    k = norm_stem(p)
    if k not in img_map:
        img_map[k] = p

xml_map = {}
for p in xml_files:
    k = norm_stem(p)
    if k not in xml_map:
        xml_map[k] = p

matched = sorted(set(img_map.keys()) & set(xml_map.keys()))
assert len(matched) > 0, "No matched image/xml pairs."

print("Matched pairs:", len(matched))

def parse_voc(xml_path):
    root = ET.parse(xml_path).getroot()
    objs = []
    for obj in root.findall("object"):
        name = obj.findtext("name")
        bnd = obj.find("bndbox")
        if bnd is None:
            continue
        x1 = int(float(bnd.findtext("xmin")))
        y1 = int(float(bnd.findtext("ymin")))
        x2 = int(float(bnd.findtext("xmax")))
        y2 = int(float(bnd.findtext("ymax")))
        if x2 > x1 and y2 > y1:
            objs.append((name, x1, y1, x2, y2))
    return objs

label_set = set()
for k in matched:
    for name, *_ in parse_voc(xml_map[k]):
        label_set.add(name)

classes = sorted(label_set)
print("Num classes:", len(classes))
print("Classes:", classes)

random.seed(42)
random.shuffle(matched)

n = len(matched)
n_train = int(0.8 * n)
n_val = int(0.1 * n)

train_keys = matched[:n_train]
val_keys = matched[n_train:n_train+n_val]
test_keys = matched[n_train+n_val:]

print("Split sizes:", len(train_keys), len(val_keys), len(test_keys))

def write_keys(keys, path):
    with open(path, "w") as f:
        for k in keys:
            f.write(k + "\n")

write_keys(train_keys, os.path.join(OUT_KEYS, "train.txt"))
write_keys(val_keys, os.path.join(OUT_KEYS, "val.txt"))
write_keys(test_keys, os.path.join(OUT_KEYS, "test.txt"))

for split in ["train", "val", "test"]:
    for c in classes:
        os.makedirs(os.path.join(LOCAL_CROPS, split, c), exist_ok=True)

def safe_crop(img, x1, y1, x2, y2):
    w, h = img.size
    x1 = max(0, min(int(x1), w-1))
    y1 = max(0, min(int(y1), h-1))
    x2 = max(0, min(int(x2), w))
    y2 = max(0, min(int(y2), h))
    if x2 <= x1 or y2 <= y1:
        x2 = min(w, x1 + 1)
        y2 = min(h, y1 + 1)
    return img.crop((x1, y1, x2, y2))

def make_crops(keys, split_name):
    count = 0
    for k in keys:
        img = Image.open(img_map[k]).convert("RGB")
        for j, (name, x1, y1, x2, y2) in enumerate(parse_voc(xml_map[k])):
            crop = safe_crop(img, x1, y1, x2, y2)
            out_path = os.path.join(LOCAL_CROPS, split_name, name, f"{k}__{j}.jpg")
            crop.save(out_path, quality=90)
            count += 1
    return count

c1 = make_crops(train_keys, "train")
c2 = make_crops(val_keys, "val")
c3 = make_crops(test_keys, "test")
print("Crops counts:", c1, c2, c3, "| total:", c1+c2+c3)

with open(META_PATH, "w") as f:
    f.write("classes=" + ",".join(classes) + "\n")
    f.write("num_classes=" + str(len(classes)) + "\n")
    f.write("matched_pairs=" + str(len(matched)) + "\n")

print("Meta saved:", META_PATH)

# Copy local crops -> Drive in one go (FASTER than writing directly to Drive)
if os.path.exists(DRIVE_CROPS):
    shutil.rmtree(DRIVE_CROPS)
shutil.copytree(LOCAL_CROPS, DRIVE_CROPS)

print("‚úÖ Crops copied to Drive:", DRIVE_CROPS)
print("‚úÖ Keys:", OUT_KEYS)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Matched pairs: 1331
Num classes: 20
Classes: ['A1', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A2', 'A20', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9']
Split sizes: 1064 133 134
Crops counts: 6220 876 774 | total: 7870
Meta saved: /content/drive/MyDrive/Final_product/DataSet/dataset_meta.txt
‚úÖ Crops copied to Drive: /content/drive/MyDrive/Final_product/DataSet/cls_crops
‚úÖ Keys: /content/drive/MyDrive/Final_product/DataSet/splits_keys


In [None]:
import os

def remove_empty_class_dirs(root_dir):
    removed = []
    for cls in sorted(os.listdir(root_dir)):
        cls_dir = os.path.join(root_dir, cls)
        if not os.path.isdir(cls_dir):
            continue
        files = [f for f in os.listdir(cls_dir)
                 if f.lower().endswith((".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"))]
        if len(files) == 0:
            os.rmdir(cls_dir)
            removed.append(cls)
    return removed

val_removed  = remove_empty_class_dirs(VAL_DIR)
test_removed = remove_empty_class_dirs(TEST_DIR)

print("Removed empty classes in val :", val_removed)
print("Removed empty classes in test:", test_removed)


Removed empty classes in val : []
Removed empty classes in test: []


In [None]:
import os
import torch
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

ROOT = "/content/drive/MyDrive/Final_product"
CROPS_DIR = os.path.join(ROOT, "DataSet", "cls_crops")

TRAIN_DIR = os.path.join(CROPS_DIR, "train")
VAL_DIR   = os.path.join(CROPS_DIR, "val")
TEST_DIR  = os.path.join(CROPS_DIR, "test")

assert os.path.isdir(TRAIN_DIR), f"Missing {TRAIN_DIR}"
assert os.path.isdir(VAL_DIR), f"Missing {VAL_DIR}"
assert os.path.isdir(TEST_DIR), f"Missing {TEST_DIR}"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = (device.type == "cuda")
print("device:", device)

train_tf = T.Compose([
    T.Resize((256,256)),
    T.RandomResizedCrop(224, scale=(0.75, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomApply([T.ColorJitter(0.2,0.2,0.15,0.02)], p=0.7),
    T.RandomGrayscale(p=0.05),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

eval_tf = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tf)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=eval_tf)
test_ds  = datasets.ImageFolder(TEST_DIR,  transform=eval_tf)

classes = train_ds.classes
class_to_idx = train_ds.class_to_idx
num_classes = len(classes)

print("Num classes:", num_classes)
print("Classes:", classes)

BATCH = 64 if use_cuda else 32
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=0, pin_memory=use_cuda)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=use_cuda)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=use_cuda)

print("train/val/test crops:", len(train_ds), len(val_ds), len(test_ds))


device: cuda
Num classes: 20
Classes: ['A1', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A2', 'A20', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9']
train/val/test crops: 6220 876 774


In [None]:
import os

def remove_empty_class_dirs(root_dir):
    removed = []
    for cls in sorted(os.listdir(root_dir)):
        cls_dir = os.path.join(root_dir, cls)
        if not os.path.isdir(cls_dir):
            continue
        files = [f for f in os.listdir(cls_dir)
                 if f.lower().endswith((".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"))]
        if len(files) == 0:
            os.rmdir(cls_dir)
            removed.append(cls)
    return removed

val_removed  = remove_empty_class_dirs(VAL_DIR)
test_removed = remove_empty_class_dirs(TEST_DIR)

print("Removed empty classes in val :", val_removed)
print("Removed empty classes in test:", test_removed)


In [None]:
from google.colab import drive
drive.mount("/content/drive")

import os
import torch
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

ROOT = "/content/drive/MyDrive/Final_product"
CROPS_DIR = os.path.join(ROOT, "DataSet", "cls_crops")

TRAIN_DIR = os.path.join(CROPS_DIR, "train")
VAL_DIR   = os.path.join(CROPS_DIR, "val")
TEST_DIR  = os.path.join(CROPS_DIR, "test")

assert os.path.isdir(TRAIN_DIR)
assert os.path.isdir(VAL_DIR)
assert os.path.isdir(TEST_DIR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = device.type == "cuda"
print("device:", device)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
device: cuda


In [None]:
def remove_empty_class_dirs(root_dir):
    removed = []
    for cls in os.listdir(root_dir):
        cls_dir = os.path.join(root_dir, cls)
        if not os.path.isdir(cls_dir):
            continue
        files = [f for f in os.listdir(cls_dir)
                 if f.lower().endswith((".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"))]
        if len(files) == 0:
            os.rmdir(cls_dir)
            removed.append(cls)
    return removed

print("Removed empty classes in val :", remove_empty_class_dirs(VAL_DIR))
print("Removed empty classes in test:", remove_empty_class_dirs(TEST_DIR))


Removed empty classes in val : []
Removed empty classes in test: []


In [None]:
train_tf = T.Compose([
    T.Resize((256,256)),
    T.RandomResizedCrop(224, scale=(0.75, 1.0)),
    T.RandomHorizontalFlip(0.5),
    T.RandomApply([T.ColorJitter(0.2,0.2,0.15,0.02)], p=0.7),
    T.RandomGrayscale(0.05),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

eval_tf = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tf)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=eval_tf)
test_ds  = datasets.ImageFolder(TEST_DIR,  transform=eval_tf)

classes = train_ds.classes
class_to_idx = train_ds.class_to_idx
num_classes = len(classes)

print("Num classes:", num_classes)
print("Classes:", classes)

BATCH = 64 if use_cuda else 32
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=0, pin_memory=use_cuda)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=use_cuda)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=use_cuda)

print("Crops train/val/test:", len(train_ds), len(val_ds), len(test_ds))


Num classes: 20
Classes: ['A1', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A2', 'A20', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9']
Crops train/val/test: 6220 876 774


In [None]:
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.amp import autocast, GradScaler

clf = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

for p in clf.parameters():
    p.requires_grad = False

clf.fc = nn.Linear(clf.fc.in_features, num_classes)
clf.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.10)
optimizer = optim.AdamW(clf.fc.parameters(), lr=3e-4, weight_decay=1e-4)

EPOCHS = 20
FREEZE_EPOCHS = 2
PATIENCE = 4

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler("cuda") if use_cuda else None

@torch.no_grad()
def eval_acc(loader):
    clf.eval()
    correct, total = 0, 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        out = clf(x)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return correct / total

best_val_acc = 0.0
bad_epochs = 0

print("\nüöÄ CLASSIFIER TRAINING START\n")

for epoch in range(1, EPOCHS+1):

    if epoch == FREEZE_EPOCHS + 1:
        for p in clf.parameters():
            p.requires_grad = True
        optimizer = optim.AdamW(clf.parameters(), lr=1e-4, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS - epoch + 1)
        print("üîì backbone unfrozen")

    clf.train()
    correct, total = 0, 0

    for x,y in train_loader:
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler:
            with autocast("cuda"):
                out = clf(x)
                loss = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = clf(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)

    train_acc = correct / total
    val_acc = eval_acc(val_loader)
    scheduler.step()

    print(f"Epoch {epoch:02d} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        bad_epochs = 0
        torch.save({
            "model_state_dict": clf.state_dict(),
            "classes": classes,
            "class_to_idx": class_to_idx
        }, os.path.join(ROOT, "models", "best_crop_classifier_resnet50.pth"))
        print("‚úÖ saved best")
    else:
        bad_epochs += 1

    if bad_epochs >= PATIENCE:
        print("‚èπ early stopping")
        break

print("Best val acc:", best_val_acc)



üöÄ CLASSIFIER TRAINING START



ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [None]:
import torch
import torchvision.models as models
import torch.nn as nn

MODEL_PATH = os.path.join(ROOT, "models", "best_crop_classifier_resnet50.pth")
ckpt = torch.load(MODEL_PATH, map_location=device)

model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, len(ckpt["classes"]))
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()

@torch.no_grad()
def test_acc(loader):
    correct, total = 0, 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        out = model(x)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return correct / total

acc = test_acc(test_loader)
print("‚úÖ Test accuracy:", acc)
