# Image classification


In [1]:
# Imports et chemins
import sys
from pathlib import Path
sys.path.append(str(Path('../..').resolve()))
import random
from collections import Counter
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score, classification_report
from PIL import Image
from model.image.dataset import SimpleImageFolder


ParamÃ¨tres


In [None]:
# ParamÃ¨tres
classes = ['Chao', 'Milho', 'Ervas']
image_size = 128  # resize
per_class_limit = 800  
batch_size = 32
noise_std = 0.0  # 0 pour clean, 0.02-0.05 pour bruit blanc
model_name = 'resnet18'  # resnet18/resnet34/vgg16/vgg19/mobilenet_v2
optimizer_name = 'adam'  # adam/rmsprop/adagrad
lr = 1e-3
dropout = 0.3
use_pretrained = True
epochs = 2
base_dir = Path('../data/ImagensTCCRotuladas').resolve()
split_dirs = {
    'train': base_dir / 'Treino',
    'val': next((p for p in base_dir.iterdir() if p.name.lower().startswith('valid')), None),
    'test': base_dir / 'Teste',
}
print('Classes:', classes)
print('per_class_limit:', per_class_limit, 'image_size:', image_size)


Explorations des splits


In [None]:
# Exploration des splits
from typing import Dict, List

def collect_files(split: str, classes):
    root = split_dirs[split]
    per_class = {}
    for cls in classes:
        per_class[cls] = sorted([p for p in (root / cls).glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png"}])
    return per_class

train_files = collect_files("train", classes)
val_files = collect_files("val", classes) if split_dirs["val"] else {}
test_files = collect_files("test", classes)

train_counts = {cls: len(paths) for cls, paths in train_files.items()}
val_counts = {cls: len(paths) for cls, paths in val_files.items()} if val_files else {}
test_counts = {cls: len(paths) for cls, paths in test_files.items()}
print(train_counts)
print(val_counts)
print(test_counts)
min_train = min(train_counts.values()) if train_counts else 0
print("Min train class size:", min_train)


chartBar par split


In [None]:

import numpy as np

def plot_counts(counts: dict, title: str, color: str):
    labels = list(counts.keys())
    values = [counts[k] for k in labels]
    plt.figure(figsize=(6, 3))
    plt.bar(labels, values, color=color)
    plt.title(title)
    plt.xticks(rotation=25)
    for i, v in enumerate(values):
        plt.text(i, v, str(v), ha='center', va='bottom', fontsize=8)
    plt.tight_layout(); plt.show()

plot_counts(train_counts, 'Train', 'tab:blue')
if val_counts:
    plot_counts(val_counts, 'Val', 'tab:orange')
if test_counts:
    plot_counts(test_counts, 'Test', 'tab:green')


In [None]:
# Ã©chantillons visuels
n_per_class = 3
fig, axes = plt.subplots(len(classes), n_per_class, figsize=(3 * n_per_class, 3 * len(classes)))
for row, cls in enumerate(classes):
    files = train_files.get(cls, [])[:n_per_class]
    for col, path in enumerate(files):
        ax = axes[row, col] if len(classes) > 1 else axes[col]
        ax.imshow(Image.open(path))
        ax.axis("off")
        ax.set_title(cls)
plt.tight_layout(); plt.show()


In [None]:
# Loaders (balance + bruit blanc optionnel)
class AddWhiteNoise:
    def __init__(self, std: float = noise_std):
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn_like(tensor) * self.std
        return torch.clamp(tensor + noise, 0.0, 1.0)

def build_loader(split: str, balance: bool):
    split_path = split_dirs[split]
    rng = random.Random(42)
    paths, labels = [], []
    for cls in classes:
        files = sorted([p for p in (split_path / cls).glob("*") if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
        if per_class_limit:
            files = files[:per_class_limit]
        if balance:
            target = per_class_limit or min_train
            rng.shuffle(files)
            files = files[:target]
        for p in files:
            paths.append(p)
            labels.append(classes.index(cls))
    tfms = [transforms.Resize((image_size, image_size)), transforms.ToTensor()]
    if split == 'train' and noise_std:
        tfms.append(AddWhiteNoise(noise_std))
    transform = transforms.Compose(tfms)
    ds = SimpleImageFolder(paths, labels, image_size=image_size)
    ds.transform = transform
    return DataLoader(ds, batch_size=batch_size, shuffle=(split=='train'))

train_loader = build_loader('train', balance=True)
val_loader = build_loader('val', balance=False)
print("Train size", len(train_loader.dataset))
print("Val size", len(val_loader.dataset))


In [None]:
# Backbones et optimisateurs
available_models = ['resnet18', 'resnet34', 'vgg16', 'vgg19', 'mobilenet_v2']
assert model_name in available_models

def create_backbone(name: str, num_classes: int, dropout: float, use_pretrained: bool):
    if name == 'resnet18':
        m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if use_pretrained else None)
        in_f = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, num_classes))
        return m
    if name == 'resnet34':
        m = models.resnet34(weights=models.ResNet34_Weights.DEFAULT if use_pretrained else None)
        in_f = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, num_classes))
        return m
    if name == 'vgg16':
        m = models.vgg16(weights=models.VGG16_Weights.DEFAULT if use_pretrained else None)
        in_f = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_f, num_classes)
        return m
    if name == 'vgg19':
        m = models.vgg19(weights=models.VGG19_Weights.DEFAULT if use_pretrained else None)
        in_f = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_f, num_classes)
        return m
    if name == 'mobilenet_v2':
        m = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT if use_pretrained else None)
        in_f = m.classifier[-1].in_features
        m.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, num_classes))
        return m
    raise ValueError('backbone not handled')

def make_optimizer(name: str, params, lr: float):
    name = name.lower()
    if name == 'adam':
        return torch.optim.Adam(params, lr=lr)
    if name == 'rmsprop':
        return torch.optim.RMSprop(params, lr=lr)
    if name == 'adagrad':
        return torch.optim.Adagrad(params, lr=lr)
    raise ValueError('optimizer not handled')

device = globals().get('device', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model = create_backbone(model_name, num_classes=len(classes), dropout=dropout, use_pretrained=use_pretrained).to(device)
opt = make_optimizer(optimizer_name, model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()


In [None]:
# Entra?nement
from tqdm.auto import tqdm
history = {'train_loss': [], 'val_acc': []}

def evaluate(loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            preds.extend(pred.cpu().tolist())
            labels.extend(y.cpu().tolist())
    return labels, preds

for epoch in range(epochs):
    model.train()
    running = 0.0
    for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        opt.step()
        opt.zero_grad()
        running += loss.item()
    labels, preds = evaluate(val_loader)
    val_acc = accuracy_score(labels, preds)
    history['train_loss'].append(running / max(1, len(train_loader)))
    history['val_acc'].append(val_acc)
    print(f"Epoch {epoch+1}/{epochs} - loss={history['train_loss'][-1]:.4f} val_acc={val_acc:.3f}")


In [None]:
# Courbes + confusion
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].plot(history['train_loss'], label='train loss')
axes[0].set_title('Loss')
axes[1].plot(history['val_acc'], label='val acc')
axes[1].set_title('Val acc')
for ax in axes: ax.legend(); ax.grid(True)
plt.tight_layout(); plt.show()

labels, preds = evaluate(val_loader)
cm = confusion_matrix(labels, preds)
ConfusionMatrixDisplay(cm, display_labels=classes).plot(xticks_rotation=45)
plt.title('Val confusion')
plt.tight_layout(); plt.show()


In [None]:
# Rapport val + test
val_labels, val_preds = evaluate(val_loader)
print('VAL:', classification_report(val_labels, val_preds, target_names=classes, digits=3))

test_loader = build_loader('test', balance=False)
test_labels, test_preds = evaluate(test_loader)
print('TEST:', classification_report(test_labels, test_preds, target_names=classes, digits=3))
cm_test = confusion_matrix(test_labels, test_preds)
ConfusionMatrixDisplay(cm_test, display_labels=classes).plot(xticks_rotation=45)
plt.title('Test confusion')
plt.tight_layout(); plt.show()


Commentaire: sauvegarde des artefacts dans model/registry


In [None]:
# Sauvegarde des artefacts dans model/registry
from datetime import datetime
import json
run_id = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
run_dir = Path('../model/registry') / f'image_{run_id}_notebook'
run_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), run_dir / 'model.pt')
summary = {
    'classes': classes,
    'model_name': model_name,
    'optimizer': optimizer_name,
    'lr': lr,
    'dropout': dropout,
    'use_pretrained': use_pretrained,
    'image_size': image_size,
    'per_class_limit': per_class_limit,
    'noise_std': noise_std,
    'epochs': epochs,
    'val_samples': len(val_labels),
    'test_samples': len(test_labels),
}
(run_dir / 'summary.json').write_text(json.dumps(summary, indent=2), encoding='utf-8')
(run_dir / 'val_report.txt').write_text(classification_report(val_labels, val_preds, target_names=classes, digits=3), encoding='utf-8')
(run_dir / 'test_report.txt').write_text(classification_report(test_labels, test_preds, target_names=classes, digits=3), encoding='utf-8')
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].plot(history['train_loss']); axes[0].set_title('Loss')
axes[1].plot(history['val_acc']); axes[1].set_title('Val acc')
for ax in axes: ax.grid(True)
fig.savefig(run_dir / 'curves.png', dpi=160, bbox_inches='tight'); plt.close(fig)
ConfusionMatrixDisplay(cm, display_labels=classes).plot(xticks_rotation=45)
plt.title('Val confusion'); plt.savefig(run_dir / 'val_confusion.png', dpi=160, bbox_inches='tight'); plt.close()
ConfusionMatrixDisplay(cm_test, display_labels=classes).plot(xticks_rotation=45)
plt.title('Test confusion'); plt.savefig(run_dir / 'test_confusion.png', dpi=160, bbox_inches='tight'); plt.close()
print('Artifacts saved to', run_dir)


In [None]:
# ProbabilitÃ©s sur quelques images test
sample_images = list((split_dirs['test'] / classes[0]).glob('*.jpg'))[:3]
softmax = nn.Softmax(dim=1)
model.eval()
for img_path in sample_images:
    tensor = transforms.Compose([transforms.Resize((image_size, image_size)), transforms.ToTensor()])(Image.open(img_path).convert('RGB')).unsqueeze(0).to(device)
    with torch.no_grad():
        probs = softmax(model(tensor)).squeeze(0)
    top_probs, top_idx = torch.topk(probs, k=len(classes))
    print(img_path.name, [(classes[i], float(p)) for p, i in zip(top_probs, top_idx)])


Commentaire: LIME


In [None]:
# LIME superpixels explicatifs
try:
    from lime import lime_image
    import numpy as np
    from skimage.segmentation import mark_boundaries
except ImportError:
    print('lime ou scikit-image non install?s; saute cette cellule')
else:
    explainer = lime_image.LimeImageExplainer()
    img_path = sample_images[0] if sample_images else None
    if img_path is None:
        print('Pas d\'image test dispo')
    else:
        img = Image.open(img_path).convert('RGB').resize((image_size, image_size))
        img_np = np.array(img)
        def predict_fn(xs):
            arr = torch.tensor(xs).permute(0,3,1,2).float() / 255.0
            arr = arr.to(device)
            with torch.no_grad():
                logits = model(arr)
                probs = torch.softmax(logits, dim=1)
            return probs.cpu().numpy()
        exp = explainer.explain_instance(img_np, predict_fn, top_labels=1, hide_color=0, num_samples=200)
        temp, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
        plt.figure(figsize=(4,4))
        plt.imshow(mark_boundaries(temp, mask))
        plt.title(f'LIME: {img_path.name}')
        plt.axis('off')
        plt.show()
