In [None]:
from typing import Tuple
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pathlib import Path
from torch.utils.data import Dataset, DataLoader


import torch
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
import pandas as pd
import torch.nn as nn
import matplotlib as plt

from sklearn.metrics import classification_report

import os

AVANT TOUTE CHOSE, IL FAUT VOUS ASSURER D'AVOIR UN DATASET DANS LE FORMAT RENDU PAR Prepare_dataset.ipynb

In [None]:
data_dir = Path("../data/final_dataset")
data_dir_noaug = Path("../data/final_dataset_noaug2")

### UTILS METHODS

In [None]:
def get_label(filename: str):
    return filename.split("_")[0]


def get_uuid(filename: str):
    name = Path(filename).stem          
    parts = name.split("_")
    return "_".join(parts[:2])          


def build_augmented_path(img_path: Path, base_dir: Path):
    img_path = Path(img_path)
    filename = img_path.name
    label = get_label(filename)
    uuid = get_uuid(filename)
    pres = "_".join(filename.split(".")[0].split("_")[1:3])
    return base_dir / uuid / pres / filename


def make_class_names(dataset):
    # dataset.class_to_idx: {label_str: idx}
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    return [idx_to_class[i] for i in range(len(idx_to_class))]


def count_oov_pct(sequences, oov_id):
    total = sum(1 for seq in sequences for tid in seq if tid != 0)
    oov = sum(1 for seq in sequences for tid in seq if tid == oov_id)
    return oov/total*100 if total > 0 else 0

class_to_idx = {
    "ball":  0,
    "bike":  1,
    "dog":   2,
    "water": 3,
}

class_names = ["ball", "bike", "dog", "water"]



### TRANSFORMS

In [None]:
from torchvision import transforms

transform = transforms.Compose(
    [transforms.Resize((300, 500)),
        transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform_resnet = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


### IMAGE DATASET MODEL

In [None]:
class ImageCLIPDataset(Dataset):
    def __init__(self, imgs, labels, captions,  base_dir: Path, transform):
        
        self.img_paths = [Path(build_augmented_path(img, base_dir)) for img in imgs]
        self.labels = list(labels)
        self.captions = list(captions)
        self.transform = transform
        self.classes = sorted(set(self.labels))                  
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}  


    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")

        label_str = self.labels[idx]
        label = self.class_to_idx[label_str]
        caption = self.captions[idx] if self.captions is not None else "-"

        if self.transform:
            img = self.transform(img)

        return idx, img, label, caption
    
    def __len__(self) -> int:
        return len(self.img_paths)
    
    def _get_img_path_from_idx(self, idx: int) -> Path:
        return self.img_paths[idx]

    def _get_caption_from_idx(self, idx: int) -> Path:
        return self.captions[idx]
    
    def _get_label_from_idx(self, idx: int) -> str :
        return self.labels[idx]
    
    def _get_img_size(self, idx: int) -> Tuple[int, int]:
        print(self.img_paths[idx])

        img = Image.open(self.img_paths[idx]).convert("RGB")
        if self.transform:
            for t in self.transform.transforms:
                if isinstance(t, transforms.Resize):
                    img = t(img)
        return img.height, img.width

    
    
    

        


### DATASET LOADING

In [None]:
metadata_path = Path("../data/final_dataset/metadata.csv")
metadata_path_noaug = Path("../data/final_dataset_noaug2/metadata.csv")

df = pd.read_csv(metadata_path)
df_noaug= pd.read_csv(metadata_path_noaug)

print(df.columns)
print(df.iloc[1])

X =  df["image_path"]
X_noaug = df_noaug["image_path"]
print(X)

In [None]:

df_train, df_temp = train_test_split(df, test_size=0.3, random_state=11)
df_test, df_val = train_test_split(df_temp, test_size=0.5, random_state=11)

df_train_noaug, df_temp_noaug = train_test_split(df_noaug, test_size=0.3, random_state=11)
df_test_noaug, df_val_noaug = train_test_split(df_temp_noaug, test_size=0.5, random_state=11)

print(len(df_train), len(df_test), len(df_val))
print(df_train["label"].value_counts(normalize=True) * 100)
print(df_val["label"].value_counts(normalize=True) * 100)
print(df_test["label"].value_counts(normalize=True) * 100)
print("-----\n")

print(len(df_train_noaug), len(df_test_noaug), len(df_val_noaug))
print(df_train["label"].value_counts(normalize=True) * 100)
print(df_val["label"].value_counts(normalize=True) * 100)
print(df_test["label"].value_counts(normalize=True) * 100)



In [None]:
X_train, y_train, train_caption = df_train["image_path"], df_train["label"], df_train["caption"]
X_val, y_val, val_caption = df_val["image_path"], df_val["label"], df_val["caption"]
X_test, y_test, test_caption   = df_test["image_path"], df_test["label"], df_test["caption"]

X_train_noaug, y_train_noaug, train_caption_noaug = df_train_noaug["image_path"], df_train_noaug["label"], df_train_noaug["caption"]
X_val_noaug, y_val_noaug, val_caption_noaug = df_val_noaug["image_path"], df_val_noaug["label"], df_val_noaug["caption"]
X_test_noaug, y_test_noaug, test_caption_noaug   = df_test_noaug["image_path"], df_test_noaug["label"], df_test_noaug["caption"]


In [None]:
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_val = X_val.reset_index(drop=True)
y_val = y_val.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)
train_caption = train_caption.reset_index(drop=True)
test_caption = test_caption.reset_index(drop=True)
val_caption = val_caption.reset_index(drop=True)

X_train_noaug = X_train_noaug.reset_index(drop=True)
y_train_noaug = y_train_noaug.reset_index(drop=True)
X_val_noaug = X_val_noaug.reset_index(drop=True)
y_val_noaug = y_val_noaug.reset_index(drop=True)
X_test_noaug = X_test_noaug.reset_index(drop=True)
y_test_noaug = y_test_noaug.reset_index(drop=True)
train_caption_noaug = train_caption_noaug.reset_index(drop=True)
test_caption_noaug = test_caption_noaug.reset_index(drop=True)
val_caption_noaug = val_caption_noaug.reset_index(drop=True)

In [None]:
train_dataset_resnet = ImageCLIPDataset(X_train, y_train, train_caption, data_dir, transform_resnet)
val_dataset_resnet   = ImageCLIPDataset(X_val, y_val, val_caption, data_dir, transform_resnet)
test_dataset_resnet  = ImageCLIPDataset(X_test, y_test, test_caption,  data_dir, transform_resnet)

train_dataset_custom = ImageCLIPDataset(X_train, y_train, train_caption, data_dir, transform)
val_dataset_custom   = ImageCLIPDataset(X_val, y_val, val_caption, data_dir, transform)
test_dataset_custom  = ImageCLIPDataset(X_test, y_test, test_caption, data_dir, transform)

train_dataset_resnet_noaug = ImageCLIPDataset(X_train_noaug, y_train_noaug, train_caption_noaug, data_dir_noaug, transform_resnet)
val_dataset_resnet_noaug   = ImageCLIPDataset(X_val_noaug, y_val_noaug, val_caption_noaug, data_dir_noaug, transform_resnet)
test_dataset_resnet_noaug  = ImageCLIPDataset(X_test_noaug, y_test_noaug, test_caption_noaug,  data_dir_noaug, transform_resnet)

train_dataset_custom_noaug = ImageCLIPDataset(X_train_noaug, y_train_noaug, train_caption_noaug, data_dir, transform)
val_dataset_custom_noaug   = ImageCLIPDataset(X_val_noaug, y_val_noaug, val_caption, data_dir_noaug, transform)
test_dataset_custom_noaug  = ImageCLIPDataset(X_test_noaug, y_test_noaug, test_caption, data_dir_noaug, transform)



In [None]:
print(train_dataset_resnet._get_img_size(3))
print(train_dataset_resnet._get_img_size(13))
img = train_dataset_resnet._get_img_path_from_idx(3)
lg = Image.open(build_augmented_path(img, data_dir))
display(lg)
print(train_dataset_resnet.__getitem__(3))

In [None]:
train_loader_resnet = DataLoader(train_dataset_resnet, batch_size=32, shuffle=True)
val_loader_resnet   = DataLoader(val_dataset_resnet, batch_size=32, shuffle=False)
test_loader_resnet  = DataLoader(test_dataset_resnet, batch_size=32, shuffle=False)

train_loader_custom = DataLoader(train_dataset_custom, batch_size=32, shuffle=True)
val_loader_custom   = DataLoader(val_dataset_custom, batch_size=32, shuffle=False)
test_loader_custom  = DataLoader(test_dataset_custom, batch_size=32, shuffle=False)

train_loader_resnet_noaug = DataLoader(train_dataset_resnet_noaug, batch_size=32, shuffle=True)
val_loader_resnet_noaug   = DataLoader(val_dataset_resnet_noaug, batch_size=32, shuffle=False)
test_loader_resnet_noaug  = DataLoader(test_dataset_resnet_noaug, batch_size=32, shuffle=False)

train_loader_custom_noaug = DataLoader(train_dataset_custom_noaug, batch_size=32, shuffle=True)
val_loader_custom_noaug   = DataLoader(val_dataset_custom_noaug, batch_size=32, shuffle=False)
test_loader_custom_noaug  = DataLoader(test_dataset_custom_noaug, batch_size=32, shuffle=False)




### CNN MODEL

In [None]:
class CNNBasic(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2,2)
        )

        self.flattened_size = self._get_flattened_size()

        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def _get_flattened_size(self):
        with torch.no_grad():
            x = torch.randn(1, 3, 300, 500)
            x = self.features(x)
            return x.view(1, -1).shape[1]

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


In [None]:
class CNNBasicV2(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        def block(in_ch, out_ch, pool=True):
            layers = [
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            ]
            if pool:
                layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)

        self.features = nn.Sequential(
            block(3, 32, pool=True),    #300x500 -> 150x250
            block(32, 64, pool=True),   #-> 75x125
            block(64, 128, pool=True),  #-> 37x62
            block(128, 256, pool=False) #pas de pool ici (garde un peu de spatial)
        )

        #Dropout sur feature maps (souvent meilleur que dropout sur FC pour CNN)
        self.dropout = nn.Dropout2d(p=0.2)

        #Global Average Pooling => [B, 256, 1, 1]
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.dropout(x)
        x = self.gap(x).flatten(1)  #[B, 256]
        return self.classifier(x)


### RESNET MODEL

In [None]:
from torchvision.models import resnet18, ResNet18_Weights
class ResNet18EarlyExit(nn.Module):
    def __init__(self, num_classes=4, threshold=0.9):
        super().__init__()
        self.threshold = threshold

        base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

        self.stem = nn.Sequential(
            base.conv1, base.bn1, base.relu, base.maxpool
        )
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4
        
        self.exit1 = self._make_exit(64, num_classes)
        self.exit2 = self._make_exit(128, num_classes)
        self.exit3 = self._make_exit(256, num_classes)
        self.exit4 = nn.Linear(base.fc.in_features, num_classes)

    def _make_exit(self, channels, num_classes):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(channels, num_classes)
        )
    def extract_features(self, x):
        features = {}

        x = self.stem(x)
        x = self.layer1(x)
        f1 = torch.flatten(nn.AdaptiveAvgPool2d((1,1))(x), 1)
        features["exit1"] = f1

        x = self.layer2(x)
        f2 = torch.flatten(nn.AdaptiveAvgPool2d((1,1))(x), 1)
        features["exit2"] = f2

        x = self.layer3(x)
        f3 = torch.flatten(nn.AdaptiveAvgPool2d((1,1))(x), 1)
        features["exit3"] = f3

        x = self.layer4(x)
        f4 = torch.flatten(nn.AdaptiveAvgPool2d((1,1))(x), 1)
        features["exit4"] = f4

        return features
    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        out1 = self.exit1(x)
        if self._confident(out1):
            return out1
        
        x = self.layer2(x)
        out2 = self.exit2(x)
        if self._confident(out2):
            return out2
        
        x = self.layer3(x)
        out3 = self.exit3(x)
        if self._confident(out3):
            return out3
        
        x = self.layer4(x)
        x = nn.AdaptiveAvgPool2d((1,1))(x)
        x = torch.flatten(x, 1)
        out4 = self.exit4(x)
        return out4, 4

    def _confident(self, logits):
        probs = logits.softmax(dim=1)
        max_conf = probs.max(dim=1).values
        return (max_conf > self.threshold).any()


### GRAD CAM MODEL

In [None]:
class GradCAM:
    """
    Grad-CAM for a given model + target layer.
    Works for CNNs that output logits [B, num_classes].A
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

        self.activations = None
        self.gradients = None

        self._fwd_handle = target_layer.register_forward_hook(self._forward_hook)
        self._bwd_handle = target_layer.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module, inp, out):
        self.activations = out.detach()

    def _backward_hook(self, module, grad_in, grad_out):
        # grad_out is a tuple; grad_out[0] shape == activations shape
        self.gradients = grad_out[0].detach()

    def remove(self):
        self._fwd_handle.remove()
        self._bwd_handle.remove()

    def __call__(self, x, class_idx=None):
        """
        x: Tensor [1, C, H, W]
        class_idx: int or None -> if None, uses predicted class
        returns: cam (H', W') in [0,1], pred_class, probs
        """
        self.model.zero_grad(set_to_none=True)
        logits = self.model(x)  # [1, num_classes]
        probs = torch.softmax(logits, dim=1)

        pred_class = int(probs.argmax(dim=1).item())
        target_class = pred_class if class_idx is None else int(class_idx)

        score = logits[0, target_class]
        score.backward(retain_graph=False)

        # activations: [1, K, H', W'], gradients: [1, K, H', W']
        grads = self.gradients[0]      # [K, H', W']
        acts = self.activations[0]     # [K, H', W']

        # Global-average-pool gradients over spatial dims -> weights [K]
        weights = grads.mean(dim=(1, 2))  # [K]

        # Weighted sum of activations
        cam = (weights[:, None, None] * acts).sum(dim=0)  # [H', W']
        cam = F.relu(cam)

        # Normalize to [0,1]
        cam -= cam.min()
        cam /= (cam.max() + 1e-8)

        return cam.cpu().numpy(), pred_class, probs.detach().cpu().numpy()[0]

def get_last_conv_layer(model):
    last_conv = model.features[8]  # dernière Conv2d
    return last_conv


### TRAINING METHODS

In [None]:
def ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)

def make_class_names(dataset):
    # dataset.class_to_idx: {label_str: idx}
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    return [idx_to_class[i] for i in range(len(idx_to_class))]


def to_device(x, device):
    return x.to(device, non_blocking=True)


def train_one_epoch(model, loader, optimizer, criterion, device, epoch, epochs):
    model.train()
    running_loss = 0.0

    loop = tqdm(
        enumerate(loader, 0),
        total=len(loader),
        desc=f"Epoch {epoch+1}/{epochs} [TRAIN]"
    )

    for i, data in loop:
        idx, inputs, labels, caption = data
        inputs = to_device(inputs, device)
        labels = to_device(labels, device)


        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=running_loss / (i + 1))

    avg_loss = running_loss / len(loader)
    return avg_loss


@torch.no_grad()
def evaluate(
    model,
    loader,
    criterion,
    device,
    epoch,
    epochs,
    num_classes,
    dataset=None,                 
    class_names=None,             
    caption_fn=None,              
    n_mistakes_to_print=5
):
    model.eval()
    running_loss = 0.0

    all_labels = []
    all_preds = []
    all_probs = []

    mistakes_printed = 0

    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs} [EVAL]")

    for batch in loop:
        idxs, inputs, labels, captions = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

      
        logits = model(inputs)
        loss = criterion(logits, labels)
        running_loss += loss.item()

        probs = torch.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)

        all_labels.append(labels)
        all_preds.append(preds)
        all_probs.append(probs)


        can_print_mistakes = (dataset is not None) and (class_names is not None)
        if can_print_mistakes and mistakes_printed < n_mistakes_to_print:
            mism_mask = (preds != labels).detach().cpu()
            if mism_mask.any():
                mism_positions = torch.where(mism_mask)[0].tolist()
                for pos in mism_positions:
                    if mistakes_printed >= n_mistakes_to_print:
                        break

                    sample_idx = int(idxs[pos].item())  # idx du dataset
                    true_i = int(labels[pos].item())
                    pred_i = int(preds[pos].item())

                    true_name = class_names[true_i]
                    pred_name = class_names[pred_i]

                    img_path = dataset._get_img_path_from_idx(sample_idx)
                    display(Image.open(img_path).convert('RGB'))
                    caption = dataset._get_caption_from_idx(sample_idx)

                    print("\n--- Mauvaise prédiction ---")
                    print(f"dataset_idx : {sample_idx}")
                    print(f"image      : {img_path}")
                    print(f"vrai label : {true_name} ({true_i})")
                    print(f"prédit     : {pred_name} ({pred_i})")
                    print(f"caption    : {caption}")

                    mistakes_printed += 1

    avg_loss = running_loss / len(loader)

    all_labels = torch.cat(all_labels)
    all_preds = torch.cat(all_preds)
    all_probs = torch.cat(all_probs)

    accuracy = (all_preds == all_labels).float().mean().item()

    return (
        avg_loss,
        all_labels.cpu().numpy(),
        all_preds.cpu().numpy(),
        all_probs.cpu().numpy(),
        accuracy
    )

@torch.no_grad()
def evaluate_while_training(
    model,
    loader,
    criterion,
    device,
    epoch,
    epochs,
):
    model.eval()
    running_loss = 0.0

    all_labels = []
    all_preds = []
    all_probs = []

    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs} [EVAL]")

    for batch in loop:
        idxs, inputs, labels, captions = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

      
        logits = model(inputs)
        loss = criterion(logits, labels)
        running_loss += loss.item()

        probs = torch.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)

        all_labels.append(labels)
        all_preds.append(preds)
        all_probs.append(probs)

    avg_loss = running_loss / len(loader)

    all_labels = torch.cat(all_labels)
    all_preds = torch.cat(all_preds)
    all_probs = torch.cat(all_probs)

    accuracy = (all_preds == all_labels).float().mean().item()

    return (
        avg_loss,
        all_labels.cpu().numpy(),
        all_preds.cpu().numpy(),
        all_probs.cpu().numpy(),
        accuracy
    )


In [None]:
def fit(
    model,
    train_loader,
    test_loader,
    num_classes=4,
    epochs=50,
    lr=1e-3,
    momentum=0.9,
    patience=5,
    models_dir=Path("../models"),
    model_saving_name="best_model.pth",
    criterion=None,
    class_names=None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if criterion is None:
        criterion = nn.CrossEntropyLoss()


    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    ensure_dir(models_dir)
    best_model_path = models_dir / model_saving_name

    best_test_loss = float("inf")
    patience_counter = 0

    train_losses = []
    test_losses = []

    for epoch in range(epochs):
        train_loss = train_one_epoch(
            model, train_loader, optimizer, criterion, device, epoch, epochs
        )
        train_losses.append(train_loss)
        print(f"\nEpoch {epoch+1} - Average TRAIN loss: {train_loss:.4f}")

        test_loss, y_true, y_pred, y_prob, acc = evaluate_while_training(
            model, test_loader, criterion, device, epoch, epochs
        )
        test_losses.append(test_loss)

        print(f"Epoch {epoch+1} - Average TEST loss: {test_loss:.4f}")
        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, digits=3))

        # Early stopping + save best
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"Nouveau meilleur modèle sauvegardé (test loss: {best_test_loss:.4f})")
        else:
            patience_counter += 1
            print(f"Pas d'amélioration ({patience_counter}/{patience})")
            if patience_counter >= patience:
                print(f"\nEarly stopping déclenché après {epoch+1} époques")
                print(f"Meilleur test loss: {best_test_loss:.4f}")
                break

    print(f"\nChargement du meilleur modèle (test loss: {best_test_loss:.4f})")
    model.load_state_dict(torch.load(best_model_path, map_location=device))

    final_test_loss, y_true, y_pred, y_prob, acc = evaluate(
        model, test_loader, criterion, device, epoch=0, epochs=1, num_classes=num_classes
    )
    print(f"Best model - TEST loss: {final_test_loss:.4f}")

    plot_losses(train_losses, test_losses)
    plot_confusion_matrix(y_true, y_pred, class_names=class_names, title="Confusion Matrix (Best Model)")
    plot_multiclass_roc(y_true, y_prob, num_classes=num_classes, class_names=class_names, title="ROC (Best Model, OvR)")
    return model, {"train_losses": train_losses, "test_losses": test_losses, "best_test_loss": best_test_loss}


### RESNET TRAINING METHODS

In [None]:
def configure_finetuning(
    model,
    mode: str,
    lr_backbone=1e-4,
    lr_head=1e-3,
    weight_decay=1e-5,
):
    """
    Configure le fine-tuning RESNET

    mode:
        - "freeze_all"
        - "freeze_until_l3"
        - "unfreeze_all"
        - "freeze_semantic
    """

    for p in model.parameters():
        p.requires_grad = False

    params = []
    if mode == "freeze_all":
        for p in model.fc.parameters():
            p.requires_grad = True

        params = [
            {"params": model.fc.parameters(), "lr": lr_head}
        ]

    elif mode == "freeze_until_l3":
        # défreeze layer4 + fc
        for p in model.layer4.parameters():
            p.requires_grad = True
        for p in model.fc.parameters():
            p.requires_grad = True

        params = [
            {"params": model.layer4.parameters(), "lr": lr_backbone},
            {"params": model.fc.parameters(),     "lr": lr_head},
        ]
    elif mode == "freeze_semantic":
        for p in model.parameters():
            p.requires_grad = True

        for p in model.layer4.parameters():
            p.requires_grad = False
        backbone_params = []
        head_params = []

        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if name.startswith("fc."):
                head_params.append(p)
            else:
                backbone_params.append(p)

        params = [
            {"params": backbone_params, "lr": lr_backbone},
            {"params": head_params,     "lr": lr_head},
        ]

    elif mode == "unfreeze_all":
        # tout défreeze
        for p in model.parameters():
            p.requires_grad = True

        params = [
            {"params": model.parameters(), "lr": lr_backbone}
        ]

    else:
        raise ValueError(f"Mode inconnu : {mode}")

    optimizer = optim.Adam(params, weight_decay=weight_decay)

    return optimizer



In [None]:
def train_one_experiment(
    model,
    train_loader,
    optimizer,
    criterion,
    device,
    epochs=5
):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for idx, inputs, labels, caption in tqdm(
            train_loader,
            desc=f"TRAIN epoch {epoch+1}"
        ):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(
            f"Epoch {epoch+1}: "
            f"Train loss = {running_loss / len(train_loader):.4f}"
        )


In [None]:
@torch.no_grad()
def extract_resnet_layer_embeddings(model, loader, device, max_batches=None):
    """
    Retourne:
      feats: dict { "layer1": [N, C1], "layer2": [N, C2], "layer3": [N, C3], "layer4": [N, C4] }
      labels: [N]
    """
    model.eval()
    feats = {"layer1": [], "layer2": [], "layer3": [], "layer4": []}
    all_labels = []

    for b, batch in enumerate(loader):
        if max_batches is not None and b >= max_batches:
            break

        _, x, y, _ = batch
        x = x.to(device)
        y = y.to(device)

        x = model.conv1(x)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)

        x1 = model.layer1(x)
        x2 = model.layer2(x1)
        x3 = model.layer3(x2)
        x4 = model.layer4(x3)

        f1 = F.adaptive_avg_pool2d(x1, (1, 1)).flatten(1)
        f2 = F.adaptive_avg_pool2d(x2, (1, 1)).flatten(1)
        f3 = F.adaptive_avg_pool2d(x3, (1, 1)).flatten(1)
        f4 = F.adaptive_avg_pool2d(x4, (1, 1)).flatten(1)

        feats["layer1"].append(f1.cpu())
        feats["layer2"].append(f2.cpu())
        feats["layer3"].append(f3.cpu())
        feats["layer4"].append(f4.cpu())
        all_labels.append(y.cpu())

    for k in feats:
        feats[k] = torch.cat(feats[k], dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    return feats, all_labels


### PLOT METHODS

In [None]:
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    roc_curve,
    auc
)
from sklearn.preprocessing import label_binarize
def plot_losses(train_losses, test_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train loss")
    plt.plot(test_losses, label="Test loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Test Loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()
def plot_losses(train_losses, test_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train loss")
    plt.plot(test_losses, label="Test loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Test Loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()


def plot_confusion_matrix(y_true, y_pred, class_names=None, title="Confusion Matrix"):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap="Blues", values_format="d", xticks_rotation=45)
    plt.title(title)
    plt.tight_layout()
    plt.show()


def plot_multiclass_roc(y_true, y_prob, num_classes, class_names=None, title="ROC (One-vs-Rest)"):
    """
    y_true: [N] int labels
    y_prob: [N, C] probabilities
    """
    y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))  # [N, C]

    plt.figure(figsize=(8, 6))
    for c in range(num_classes):
        fpr, tpr, _ = roc_curve(y_true_bin[:, c], y_prob[:, c])
        roc_auc = auc(fpr, tpr)
        name = class_names[c] if class_names is not None else f"Class {c}"
        plt.plot(fpr, tpr, label=f"{name} (AUC={roc_auc:.3f})")

    # Diagonal
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
from sklearn.manifold import TSNE

def plot_tsne(model, loader, class_names, max_samples=500):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    feats = { "exit1": [], "exit2": [], "exit3": [], "exit4": [] }
    labels = []

    with torch.no_grad():
        for _, (idx, imgs, y, caption) in enumerate(loader):
            imgs = imgs.to(device)

            features = model.extract_features(imgs)
            for k in feats.keys():
                feats[k].append(features[k].cpu().numpy())

            labels.append(y.numpy())
            if len(labels) * imgs.size(0) > max_samples:
                break

    for k in feats.keys():
        X = np.concatenate(feats[k], axis=0)
        Y = np.concatenate(labels, axis=0)

        X_2d = TSNE(n_components=2, learning_rate="auto", init="pca").fit_transform(X)

        plt.figure(figsize=(6,5))
        for i, class_name in enumerate(class_names):
            pts = X_2d[Y == i]
            plt.scatter(pts[:,0], pts[:,1], s=12, label=class_name)

        plt.title(f"t-SNE — {k}")
        plt.legend()
        plt.show()

In [None]:
def denorm_05(t):
    #[3,H,W] in [-1,1]
    return (t * 0.5 + 0.5).clamp(0, 1)

def show_cam_overlay(img_tensor, cam, title="Grad-CAM"):
    """
    img_tensor: [3,H,W] (normalized), cam: [H',W'] from GradCAM
    """
    img = denorm_05(img_tensor).permute(1,2,0).cpu().numpy()  # [H,W,3]

    # Resize cam to image size
    cam_t = torch.tensor(cam)[None, None, ...]  # [1,1,H',W']
    cam_resized = F.interpolate(cam_t, size=img.shape[:2], mode="bilinear", align_corners=False)
    cam_resized = cam_resized[0,0].cpu().numpy()

    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.axis("off")
    plt.title("Image")

    plt.subplot(1,2,2)
    plt.imshow(img)
    plt.imshow(cam_resized, alpha=0.45) 
    plt.axis("off")
    plt.title(title)

    plt.tight_layout()
    plt.show()


In [None]:
def plot_tsne_one(feat_2d, labels, class_names, title):
    plt.figure(figsize=(6, 5))
    for i, name in enumerate(class_names):
        mask = labels == i
        plt.scatter(feat_2d[mask, 0], feat_2d[mask, 1], s=10, label=name, alpha=0.7)
    plt.title(title)
    plt.legend(markerscale=2, fontsize=8)
    plt.tight_layout()
    plt.show()

def plot_tsne_layers(model, loader, device, class_names, mode_name="", max_batches=None,
                     perplexity=30, n_iter=1000, random_state=0):
    feats, labels = extract_resnet_layer_embeddings(model, loader, device, max_batches=max_batches)

    for layer_name in ["layer1", "layer2", "layer3", "layer4"]:
        X = feats[layer_name]

        # t-SNE
        tsne = TSNE(
            n_components=2,
            perplexity=perplexity,
            init="pca",
            learning_rate="auto",
            random_state=random_state
        )
        X2 = tsne.fit_transform(X)

        plot_tsne_one(
            X2, labels, class_names,
            title=f"t-SNE {mode_name} — {layer_name} (N={len(labels)})"
        )

### ERRORS VISU

In [None]:

def show_5_mistakes_with_gradcam(
    model,
    loader,
    device,
    class_names,
    target_layer,          #POUR CNN BASIC model.features[8]
    dataset=None,
    n=5
):
    model.eval()
    gc = GradCAM(model, target_layer)
    shown = 0

    for idxs, inputs, labels, captions in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        logits = model(inputs)
        probs = torch.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)

        mism = (preds != labels).detach().cpu()
        if mism.any():
            for pos in torch.where(mism)[0].tolist():
                if shown >= n:
                    gc.remove()
                    return

                sample_idx = int(idxs[pos].item())
                true_i = int(labels[pos].item())
                pred_i = int(preds[pos].item())

                true_name = class_names[true_i]
                pred_name = class_names[pred_i]
                caption = captions[pos]

                img_path = dataset._get_img_path_from_idx(sample_idx) if dataset is not None else "N/A"

                print("\n--- Mauvaise prédiction ---")
                print(f"dataset_idx : {sample_idx}")
                print(f"image      : {img_path}")
                print(f"vrai label : {true_name} ({true_i})")
                print(f"prédit     : {pred_name} ({pred_i})")
                print(f"caption    : {caption}")

                x1 = inputs[pos:pos+1].detach()
                x1.requires_grad_(True)

                cam_map, _, _ = gc(x1, class_idx=pred_i)

                show_cam_overlay(
                    inputs[pos].detach().cpu(),
                    cam_map,
                    title=f"Grad-CAM (true={true_name} pred={pred_name})"
                )

                shown += 1

    gc.remove()


### CNN TRAINING

In [None]:
class_names = make_class_names(train_dataset_custom)
class_names

model = CNNBasic(4)  
trained_model, history = fit(
    model=model,
    train_loader=train_loader_custom,
    test_loader=test_loader_custom,
    num_classes=4,
    epochs=50,
    lr=1e-3,
    momentum=0.9,
    patience=5,
    models_dir=Path("../models"),
    model_saving_name="best_model_2.pth",
    criterion=nn.CrossEntropyLoss(),
    class_names=class_names,
)


### CNN EVAL ON NO AUGMENTED DATASET

In [None]:
best_model_path = Path("../models/best_model_2.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loaded = CNNBasic(num_classes=4)
model_loaded = model_loaded.to(device)
model_loaded.load_state_dict(torch.load(best_model_path, map_location=device))
criterion = nn.CrossEntropyLoss()

final_test_loss, y_true, y_pred, y_prob, acc = evaluate(
    model_loaded, train_loader_custom_noaug, criterion, device, class_names=class_names, epoch=0, epochs=1, num_classes=4,dataset=train_dataset_custom_noaug
)
print(f"Best model - TEST loss: {final_test_loss:.4f}")
print(f"Accuracy : {acc}\n")
plot_confusion_matrix(y_true, y_pred, class_names=class_names, title="Confusion Matrix (Best Model)")
plot_multiclass_roc(y_true, y_prob, num_classes=4, class_names=class_names, title="ROC (Best Model, OvR)")

In [None]:
target_layer = model_loaded.features[8]
show_5_mistakes_with_gradcam(
    model=model_loaded,
    loader=train_loader_custom_noaug,
    device=device,
    class_names=class_names,
    target_layer=target_layer,
    dataset=train_dataset_custom_noaug,
    n=5
)

### CNN BASIC SECOND TRAINING

In [None]:
best_model_path = Path("../models/best_model_2.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loaded = CNNBasic(num_classes=4)
model_loaded = model_loaded.to(device)
model_loaded.load_state_dict(torch.load(best_model_path, map_location=device))

trained_model, history = fit(
    model=model_loaded,
    train_loader=train_loader_custom_noaug,
    test_loader=test_loader_custom_noaug,
    num_classes=4,
    epochs=50,
    lr=1e-3,
    momentum=0.9,
    patience=5,
    models_dir=Path("../models"),
    model_saving_name="best_model_v2_2train.pth",
    criterion=nn.CrossEntropyLoss(),
    class_names=class_names,
)



### CNN EVAL SECON TRAINING

In [None]:
best_model_path = Path("../models/best_model_v2_2train.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loaded = CNNBasic(num_classes=4)
model_loaded = model_loaded.to(device)
model_loaded.load_state_dict(torch.load(best_model_path, map_location=device))
criterion = nn.CrossEntropyLoss()

final_test_loss, y_true, y_pred, y_prob, acc = evaluate(
    model_loaded, val_loader_custom_noaug, criterion, device, class_names=class_names, epoch=0, epochs=1, num_classes=4,dataset=val_dataset_custom_noaug
)
print(f"Best model - TEST loss: {final_test_loss:.4f}")
print(f"Accuracy : {acc}\n")
plot_confusion_matrix(y_true, y_pred, class_names=class_names, title="Confusion Matrix (Best Model)")
plot_multiclass_roc(y_true, y_prob, num_classes=4, class_names=class_names, title="ROC (Best Model, OvR)")

In [None]:
target_layer = model_loaded.features[8]
show_5_mistakes_with_gradcam(
    model=model_loaded,
    loader=val_loader_custom_noaug,
    device=device,
    class_names=class_names,
    target_layer=target_layer,
    dataset=val_dataset_custom_noaug,
    n=5
)

### RESNET TRAINING

In [None]:
from torchvision.models import resnet18, ResNet18_Weights
num_classes = 4

model_resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model_resnet.fc = nn.Linear(model_resnet.fc.in_features, num_classes)
model_resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_resnet = model_resnet.to(device)
optimizer = optim.Adam(model_resnet.parameters(), lr=1e-4, weight_decay=1e-5)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_resnet.parameters(), lr=1e-4, weight_decay=1e-5)

best_val_loss = float("inf")
best_epoch = -1
save_path = "best-model-resnet.pth"

for epoch in range(11):
    model_resnet.train()
    running_loss = 0.0

    for idx, inputs, labels, caption in tqdm(train_loader_resnet, desc=f"TRAIN {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model_resnet(inputs)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader_resnet)
    print(f"Epoch {epoch+1}: Train loss = {train_loss:.4f}")


    model_resnet.eval()
    val_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for idx, inputs, labels, caption in tqdm(val_loader_resnet, desc=f"VAL {epoch+1}"):
            inputs, labels = inputs.to(device), labels.to(device)

            logits = model_resnet(inputs)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss /= len(val_loader_resnet)
    val_acc = correct / total

    print(f"Epoch {epoch+1}: Val loss = {val_loss:.4f} | Val acc = {val_acc:.4f}")

    # ===== SAVE BEST =====
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1

        torch.save({
            "epoch": best_epoch,
            "model_state_dict": model_resnet.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_loss": best_val_loss,
        }, save_path)

        print(f"Best model saved at epoch {best_epoch} with val_loss={best_val_loss:.4f}")

print(f"\nTraining finished. Best epoch = {best_epoch}, best val_loss = {best_val_loss:.4f}")


### RESNET EVAL

In [None]:

final_test_loss, y_true, y_pred, y_prob, acc = evaluate(
    model_resnet, test_loader_resnet_noaug, criterion, device, epoch=0, epochs=1, num_classes=num_classes
)
print(f"Best model - TEST loss: {final_test_loss:.4f}")
print(f"Accuracy : {acc}\n")
plot_confusion_matrix(y_true, y_pred, class_names=class_names, title="Confusion Matrix (Best Model)")
plot_multiclass_roc(y_true, y_prob, num_classes=4, class_names=class_names, title="ROC (Best Model, OvR)")

### RENETS LAYER FINETUNE

In [None]:
def make_class_names_resnet(dataset):
    # dataset.class_to_idx: {label_str: idx}
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    print(idx_to_class)
    return [idx_to_class[i] for i in range(len(idx_to_class))]


class_name_resnet = make_class_names_resnet(train_dataset_resnet_noaug)

modes = [
    "freeze_all",
    "freeze_until_l3",
    "freeze_semantic",
    "unfreeze_all",
]

for mode in modes:
    print(f"Training mode: {mode}")


    model_resnet_freeze = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model_resnet_freeze.fc = nn.Linear(model_resnet_freeze.fc.in_features, num_classes)

    optimizer = configure_finetuning(model_resnet_freeze, mode=mode)

    train_one_experiment(
        model=model_resnet_freeze,
        train_loader=train_loader_resnet,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        epochs=5
    )

    plot_tsne_layers(
        model=model_resnet_freeze,
        loader=val_loader_resnet,
        device=device,
        class_names=class_names,
        mode_name=mode,
        max_batches=10,       
        perplexity=30,
        random_state=11
    )
    target_layer = model_resnet_freeze.layer4[-1]

    show_5_mistakes_with_gradcam(
        model=model_resnet_freeze,
        loader=train_loader_custom_noaug,
        device=device,
        class_names=class_name_resnet,
        target_layer=target_layer,
        dataset=train_dataset_custom_noaug,
        n=5
    )


### RESNET LAYER 3 CLASSIFICATION

In [None]:
class ResNet18Layer3Classifier(nn.Module):
    def __init__(self, num_classes=4, weights=ResNet18_Weights.IMAGENET1K_V1):
        super().__init__()
        base = resnet18(weights=weights)

        #backbone layer3
        self.conv1 = base.conv1
        self.bn1 = base.bn1
        self.relu = base.relu
        self.maxpool = base.maxpool
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.pool(x).flatten(1)
        logits = self.fc(x)
        return logits


In [None]:
model_l3 = ResNet18Layer3Classifier(num_classes=4).to(device)

optimizer = optim.Adam(model_l3.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(11):
    model_l3.train()
    running_loss = 0.0

    for idx, inputs, labels, caption in tqdm(train_loader_resnet, desc=f"TRAIN {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model_l3(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}: Train loss = {running_loss/len(train_loader_resnet):.4f}")


In [None]:
final_test_loss, y_true, y_pred, y_prob, acc = evaluate(
    model_l3, train_loader_resnet_noaug, criterion, device, epoch=0, epochs=1, num_classes=num_classes
)
print(f"Best model - TEST loss: {final_test_loss:.4f}")
print(f"Accuracy : {acc}\n")
plot_confusion_matrix(y_true, y_pred, class_names=class_names, title="Confusion Matrix (Best Model)")
plot_multiclass_roc(y_true, y_prob, num_classes=4, class_names=class_names, title="ROC (Best Model, OvR)")