# 🧠 Plan du notebook - `05_training.ipynb`

| Étape | Objectif                                                                                     |
| ----- | -------------------------------------------------------------------------------------------- |
| 1     | ⚙️ Imports des librairies, configuration du device, définition des hyperparamètres           |
| 2     | 📁 Chargement des `features .pt` et des captions nettoyées (`cleaned_captions_dict.pkl`)     |
| 3     | 🧠 Chargement du tokenizer (`tokenizer.pkl`) + affichage vocab                               |
| 4     | 🔀 Split des IDs en `train` / `val` / `test` (80/10/10)                                      |
| 5     | 🔢 Création des dictionnaires `train_captions_dict`, etc.                                    |
| 6     | 🧱 Instanciation des `ImageCaptionDataset` pour chaque split                                 |
| 7     | 📦 Création des `DataLoader` pour train / val / test                                         |
| 8     | 👁️ Visualisation d’un batch (features, caption tokenisée et décodée)                        |
| 9     | 🧠 Définition du `DecoderWithAttention` (embedding + LSTM + attention additive)              |
| 10    | 📉 Définition de la loss (`CrossEntropy` avec padding) et de l’optimizer (`Adam`, scheduler) |
| 11    | 🔁 Boucle d'entraînement avec évaluation sur validation à chaque epoch                       |
| 12    | 🧪 Calcul des métriques BLEU-3 et ROUGE-L après chaque epoch                                 |
| 13    | 📈 Visualisation des courbes (loss, BLEU, LR)                                                |
| 14    | 💾 Sauvegarde du meilleur modèle (`state_dict`) + export des métriques                       |
| 15    | 🎯 Évaluation finale sur le test set + génération de légendes + BLEU/ROUGE                   |


## | Étape 1 | 📁 Chargement des features .pt et des captions alignées |

In [1]:
# ✅ Classe Tokenizer (doit être définie avant pickle.load)
class Tokenizer:
    def __init__(self, word2idx):
        self.word2idx = word2idx
        self.idx2word = {idx: word for word, idx in word2idx.items()}

        self.pad_token = "<pad>"
        self.start_token = "<start>"
        self.end_token = "<end>"
        self.unk_token = "<unk>"

        self.pad_token_id = self.word2idx[self.pad_token]
        self.start_token_id = self.word2idx[self.start_token]
        self.end_token_id = self.word2idx[self.end_token]
        self.unk_token_id = self.word2idx[self.unk_token]

        self.vocab_size = len(self.word2idx)

    def encode(self, caption, add_special_tokens=True):
        tokens = caption.strip().split()
        token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]

        if add_special_tokens:
            return [self.start_token_id] + token_ids + [self.end_token_id]
        else:
            return token_ids

    def decode(self, token_ids, remove_special_tokens=True):
        words = [self.idx2word.get(idx, self.unk_token) for idx in token_ids]
        if remove_special_tokens:
            words = [w for w in words if w not in [self.pad_token, self.start_token, self.end_token]]
        return " ".join(words)


In [2]:
from pathlib import Path
import pickle
import torch
import json

# 📁 Dossiers
features_dir = Path("../data/processed/features_resnet_global")
captions_file = "../data/raw/Flickr8k_text/Flickr8k.token.txt"
tokenizer_path = "../data/vocab/tokenizer.pkl"

# 🧠 Chargement du tokenizer
with open(tokenizer_path, "rb") as f:
    tokenizer = pickle.load(f)

# 🔢 Taille du vocabulaire
print("🧠 Vocab size :", tokenizer.vocab_size)

# 📖 Chargement des captions alignées (output de 04_)
captions_dict_path = Path("../data/processed/aligned_captions.json")

with open(captions_dict_path, "r") as f:
    aligned_captions = json.load(f)

print(f"✅ Captions alignées pour {len(aligned_captions)} images")


🧠 Vocab size : 1204
✅ Captions alignées pour 8091 images


## Étape 2 – 🔄 Construction de la liste enrichie (feature_path, caption)

On vas créer la liste complète des couples `(image_id_augmenté, caption)` à partir du dictionnaire `aligned_captions.json`.

In [3]:
from pathlib import Path
import json

# 📖 Chargement des captions alignées (sortie du 04_)
captions_dict_path = Path("../data/processed/aligned_captions.json")

with open(captions_dict_path, "r") as f:
    captions_dict = json.load(f)

# ✅ Vérif d’un ID d’image + 5 captions associées
example_id = next(iter(captions_dict))
print(f"📝 Captions associées à {example_id} :")
for cap in captions_dict[example_id]:
    print(" ➤", cap)

# 🔄 Génération des ID augmentés
augmentations = ["", "_aug0", "_aug1", "_aug2"]
full_pairs = []

for image_id, captions in captions_dict.items():
    for suffix in augmentations:
        full_id = image_id + suffix
        for caption in captions:
            full_pairs.append((full_id, caption))

print(f"\n✅ Nombre total de paires (features, captions) : {len(full_pairs):,}")


📝 Captions associées à 1000268201_693b08cb0e :
 ➤ A child in a pink dress is climbing up a set of stairs in an entry way .
 ➤ A girl going into a wooden building .
 ➤ A little girl climbing into a wooden playhouse .
 ➤ A little girl climbing the stairs to her playhouse .
 ➤ A little girl in a pink dress going into a wooden cabin .

✅ Nombre total de paires (features, captions) : 161,820


## Étape 3 – 🧱 Définition de la classe `ImageCaptionDataset`

Cette classe hérite de `torch.utils.data.Dataset` et va :

- charger les fichiers `.pt` (features),

- encoder les captions avec le `Tokenizer`,

- appliquer une troncature `max_length`.

In [4]:
from torch.utils.data import Dataset
import torch
import re

def clean_caption(caption):
    caption = caption.lower()
    caption = re.sub(r"[^a-zA-Z0-9'\s]", "", caption)  # garde lettres, chiffres, apostrophes, espaces
    caption = re.sub(r"\s+", " ", caption)
    return caption.strip()

class ImageCaptionDataset(Dataset):
    def __init__(self, pairs, features_dir, tokenizer, max_length=37):
        self.pairs = pairs
        self.features_dir = Path(features_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        image_id, caption = self.pairs[idx]

        # 📦 Chargement des features
        feature_path = self.features_dir / f"{image_id}.pt"
        features = torch.load(feature_path)

        # 🔡 Encodage de la légende
        caption = clean_caption(caption) 
        encoded = self.tokenizer.encode(caption)
        encoded = encoded[:self.max_length]  # 🔪 Troncature
        encoded_tensor = torch.tensor(encoded, dtype=torch.long)

        return features, encoded_tensor


In [5]:
# 🧪 Test rapide
dataset = ImageCaptionDataset(pairs=full_pairs,
                              features_dir="../data/processed/features_resnet_global",
                              tokenizer=tokenizer)

print("📦 Longueur du dataset :", len(dataset))

features, encoded_caption = dataset[0]
print("📐 Feature vector shape :", features.shape)
print("🧾 Caption encodée :", encoded_caption.tolist())
print("🧾 Caption décodée :", tokenizer.decode(encoded_caption.tolist()))


📦 Longueur du dataset : 161820
📐 Feature vector shape : torch.Size([2048])
🧾 Caption encodée : [1, 4, 43, 5, 4, 91, 171, 8, 120, 54, 4, 397, 13, 394, 5, 29, 3, 694, 2]
🧾 Caption décodée : a child in a pink dress is climbing up a set of stairs in an <unk> way


## 🧩 Étape 4 – Collate Function personnalisée

| Objectif | Permettre à PyTorch de batcher les légendes de longueur variable en ajoutant du padding |
| -------- | --------------------------------------------------------------------------------------- |


In [6]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    Permet d'empiler un batch de tuples (features, caption) avec padding sur les captions.
    """
    features, captions = zip(*batch)  # batch = liste de tuples

    features = torch.stack(features)  # [batch_size, 2048]
    lengths = [len(cap) for cap in captions]
    captions_padded = pad_sequence(captions, batch_first=True, padding_value=tokenizer.pad_token_id)

    return features, captions_padded, lengths


### 📌 Pourquoi c’est important :
- Les séquences textuelles (captions) n'ont pas la même longueur.

- Le modèle LSTM nécessite des entrées de taille uniforme → `pad_sequence` fait le job.

- On garde aussi les longueurs `lengths` pour utiliser `pack_padded_sequence` dans le modèle.

## 📦 Étape 5 – Création des `DataLoaders` entraînement / validation / test
| Objectif | Séparer le dataset en trois parties et créer les `DataLoaders` correspondants |
| -------- | ----------------------------------------------------------------------------- |


In [7]:
from torch.utils.data import random_split, DataLoader

# 📏 Proportions des splits
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1
total_size = len(dataset)

train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size  # le reste

# ✂️ Split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
print(f"📊 Split → Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# 🧪 DataLoaders
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


📊 Split → Train: 129456 | Val: 16182 | Test: 16182


### 🧠 Remarques
- `shuffle=True` pour le training uniquement

- `collate_fn` s’applique à tous les DataLoaders

- On garde `batch_size=32` mais tu pourras ajuster selon la mémoire GPU dispo

## 🧪 Étape 6 – Visualisation d’un batch complet
| Objectif | Afficher 5 exemples d’images (features), captions tokenisées et décodées, pour vérification |
| -------- | ----------------------------------------------------------------------------- |

In [8]:
import random

# 🔁 Récupération d’un batch du DataLoader
features_batch, captions_batch, lengths = next(iter(train_loader))

print("✅ Batch chargé")
print("📐 Feature vector shape :", features_batch.shape)   # [batch_size, 2048]
print("🧾 Captions shape :", captions_batch.shape)         # [batch_size, max_len]

# 🔍 Affichage de quelques exemples
print("\n🎯 Exemples du batch :\n")
for i in range(5):
    encoded_caption = captions_batch[i][:lengths[i]].tolist()  # on découpe selon la vraie longueur
    decoded_caption = tokenizer.decode(encoded_caption)

    print(f"🖼️ Exemple {i+1} :")
    print("  ➤ Caption encodée :", encoded_caption)
    print("  ➤ Caption décodée :", decoded_caption)
    print()


✅ Batch chargé
📐 Feature vector shape : torch.Size([64, 2048])
🧾 Captions shape : torch.Size([64, 19])

🎯 Exemples du batch :

🖼️ Exemple 1 :
  ➤ Caption encodée : [1, 4, 56, 59, 13, 32, 62, 5, 4, 262, 13, 25, 2]
  ➤ Caption décodée : a large group of dogs walking in a body of water

🖼️ Exemple 2 :
  ➤ Caption encodée : [1, 49, 232, 24, 116, 5, 48, 13, 29, 3, 223, 2]
  ➤ Caption décodée : three older people stand in front of an <unk> sign

🖼️ Exemple 3 :
  ➤ Caption encodée : [1, 4, 20, 5, 4, 297, 136, 3, 378, 2]
  ➤ Caption décodée : a girl in a forest carrying <unk> gear

🖼️ Exemple 4 :
  ➤ Caption encodée : [1, 4, 12, 22, 4, 31, 38, 972, 7, 4, 222, 257, 2]
  ➤ Caption décodée : a man wearing a blue shirt crouches on a rocky cliff

🖼️ Exemple 5 :
  ➤ Caption encodée : [1, 4, 27, 20, 5, 4, 91, 171, 11, 172, 5, 61, 173, 2]
  ➤ Caption décodée : a young girl in a pink dress with something in her hand



### 🧠 À vérifier
- Les captions ne doivent pas commencer/finir par `<unk>` (on les a corrigées précédemment).

- Les encodages doivent bien commencer par `<start>` et finir par `<end>` (IDs = 1 et 2).

- Vérifie que le padding n’est pas inclus dans `lengths[i]` (c’est bien le cas ici).

## ✅ Étape 7 – Vérification des dimensions et métadonnées

| Objectif | Valider les tailles du vocabulaire, batch, longueur max, etc. avant modélisation |
| -------- | ----------------------------------------------------------------------------- |


In [9]:
# 📏 Vérification des dimensions
print("🔢 Taille du vocabulaire :", tokenizer.vocab_size)
print("📐 Taille des features (par image) :", features_batch.shape[1])
print("📏 Longueur maximale des captions :", captions_batch.shape[1])
print("🧮 Taille du batch :", features_batch.shape[0])

# 🧾 Vérif rapide d’un vocab token
print("🔍 Exemple de token 1 (id):", tokenizer.idx2word.get(1))
print("🔍 Exemple de token 2 (id):", tokenizer.idx2word.get(2))


🔢 Taille du vocabulaire : 1204
📐 Taille des features (par image) : 2048
📏 Longueur maximale des captions : 19
🧮 Taille du batch : 64
🔍 Exemple de token 1 (id): <start>
🔍 Exemple de token 2 (id): <end>


### ✅ À ce stade, on a terminé la partie préparation des données et loaders :
On doit maintenant avoir :

- `train_loader`, `val_loader`, `test_loader` fonctionnels 🎯

- Un vocab propre, nettoyé et indexé 🧠

- Des captions bien formatées : `<start> ... <end>`

- Et des features d’images prêtes à être injectées dans un modèle

## | Étape 4 | 🧠 Définition du modèle CNN → LSTM avec Attention

| Module                 | Description                                                                                 |
| ---------------------- | ------------------------------------------------------------------------------------------- |
| `EncoderCNN`           | Projette les features extraits (2048-dim) dans un espace de plus petite dimension (ex: 256) |
| `Attention`            | Calcule les poids d’attention pour chaque position du vecteur d’image                       |
| `DecoderWithAttention` | LSTM conditionné à l’attention + embeddings + prédiction du token suivant                   |


### 🧠 Étape 4.1 – `Attention` : Mécanisme d’attention
**Objectif** :
L'attention permet au modèle de **focaliser sur différentes parties des features** à chaque étape de génération de mot.
Même si nos features sont globales (pas spatiales ici), on ajoute une attention **basée sur le contexte du LSTM** (Bahdanau-style).

In [10]:
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: [batch_size, encoder_dim]     → features encodées
        decoder_hidden: [batch_size, decoder_dim]  → état caché courant du LSTM

        returns:
            - attention_weighted_encoding: [batch_size, encoder_dim]
            - alpha: [batch_size, 1]
        """
        att1 = self.encoder_att(encoder_out)             # [batch_size, attention_dim]
        att2 = self.decoder_att(decoder_hidden)          # [batch_size, attention_dim]
        att = self.full_att(self.relu(att1 + att2))      # [batch_size, 1]
        alpha = self.softmax(att)                        # [batch_size, 1]
        attention_weighted_encoding = encoder_out * alpha  # [batch_size, encoder_dim]
        return attention_weighted_encoding, alpha


### 🧠 Étape 4.2 – `DecoderWithAttention` (LSTM + Attention)

**Objectif** :
Utiliser un LSTM pour générer des mots un par un, avec un mécanisme d’attention à chaque étape.

In [11]:
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.init_weights()

    def init_weights(self):
        """Initialisation des poids"""
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        """Initialise h et c à partir des features encodées"""
        h = self.init_h(encoder_out)  # [batch_size, decoder_dim]
        c = self.init_c(encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        encoder_out: [batch_size, encoder_dim]           → Features extraites
        encoded_captions: [batch_size, max_len]          → Captions target
        caption_lengths: [batch_size] ou [batch_size, 1] → Longueur réelle

        returns:
            - prédictions (logits)
        """
        batch_size = encoder_out.size(0)
        vocab_size = self.fc.out_features

        # 🔐 Sécurité : conversion caption_lengths en Tensor si nécessaire
        if isinstance(caption_lengths, list):
            caption_lengths = torch.tensor(caption_lengths, dtype=torch.long, device=encoder_out.device)
        if caption_lengths.dim() == 2:  # Si [batch_size, 1], on squeeze
            caption_lengths = caption_lengths.squeeze(1)

        embeddings = self.embedding(encoded_captions)  # [batch_size, max_len, embed_dim]
        h, c = self.init_hidden_state(encoder_out)     # h, c : [batch_size, decoder_dim]

        decode_lengths = (caption_lengths - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(encoder_out.device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attn_weighted_encoding, _ = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])

            input_lstm = torch.cat([embeddings[:batch_size_t, t, :], attn_weighted_encoding], dim=1)
            h, c = self.decode_step(input_lstm, (h[:batch_size_t], c[:batch_size_t]))  # LSTMCell

            preds = self.fc(self.dropout(h))  # [batch_size_t, vocab_size]
            predictions[:batch_size_t, t, :] = preds

        return predictions



In [12]:
# 📦 Imports nécessaires
import torch
import torch.nn.functional as F
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# ⚙️ BLEU-3 (trigrammes)
def compute_bleu(references, hypotheses, n=3):
    smoothie = SmoothingFunction().method4
    scores = []
    for ref, hyp in zip(references, hypotheses):
        ref_tokens = ref.lower().split()
        hyp_tokens = hyp.lower().split()
        weights = tuple((1. / n for _ in range(n)))
        score = sentence_bleu([ref_tokens], hyp_tokens, weights=weights, smoothing_function=smoothie)
        scores.append(score)
    return sum(scores) / len(scores)

# ⚙️ ROUGE-L
def compute_rouge(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = []
    for ref, hyp in zip(references, hypotheses):
        score = scorer.score(ref, hyp)['rougeL'].fmeasure
        scores.append(score)
    return sum(scores) / len(scores)

# 🧠 Initialisation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
decoder = DecoderWithAttention(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=tokenizer.vocab_size,
    dropout=0.5
).to(device)

# 🧪 Test sur un batch
features, captions, lengths = next(iter(train_loader))
features, captions = features.to(device), captions.to(device)

print("📐 Feature shape :", features.shape)
print("🧾 Captions shape :", captions.shape)

# 🔁 Forward (pas d'encoder, on a déjà les features)
outputs = decoder(features, captions, lengths)  # [B, max_len, vocab_size]
print("✅ Logits shape :", outputs.shape)

# 🎯 Calcul de BLEU/ROUGE sur 1 batch pour debug
predictions = outputs.argmax(-1).detach().cpu().tolist()
references = captions[:, 1:].detach().cpu().tolist()  # on ignore <start>

decoded_preds = [tokenizer.decode(p) for p in predictions]
decoded_refs = [tokenizer.decode(r) for r in references]

bleu = compute_bleu(decoded_refs, decoded_preds, n=3)
rouge = compute_rouge(decoded_refs, decoded_preds)

print(f"📊 BLEU-3 : {bleu:.4f} | ROUGE-L : {rouge:.4f}")


📐 Feature shape : torch.Size([64, 2048])
🧾 Captions shape : torch.Size([64, 25])
✅ Logits shape : torch.Size([64, 24, 1204])
📊 BLEU-3 : 0.0024 | ROUGE-L : 0.0081


In [13]:
from tqdm import tqdm
import os
import json
import torch
from datetime import datetime
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train_model(decoder, train_loader, val_loader, tokenizer, criterion, optimizer, num_epochs=10, patience=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    decoder.to(device)

    # 📁 Dossier de sortie daté
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"outputs/{timestamp}"
    os.makedirs(save_dir, exist_ok=True)
    best_model_path = os.path.join(save_dir, "decoder.pt")
    checkpoint_path = os.path.join(save_dir, "checkpoint.pt")
    metrics_path = os.path.join(save_dir, "metrics.json")
    plot_path = os.path.join(save_dir, "plot.png")

    train_losses, val_bleus, val_rouges = [], [], []
    best_bleu = 0
    patience_counter = 0
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=1)

    print(f"🚀 Training started on {device}")
    epoch_bar = tqdm(range(num_epochs), desc="📆 Epochs")

    for epoch in epoch_bar:
        decoder.train()
        total_loss = 0.0

        batch_bar = tqdm(train_loader, desc=f"🧠 Training Epoch {epoch+1}", leave=True, position=1)
        for features, captions, lengths in batch_bar:
            features, captions = features.to(device), captions.to(device)
            optimizer.zero_grad()

            outputs = decoder(features, captions, lengths)
            targets = captions[:, 1:]
            outputs = outputs.view(-1, outputs.shape[-1])
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            batch_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)

        # 🔍 Évaluation
        decoder.eval()
        all_preds, all_refs = [], []

        val_bar = tqdm(val_loader, desc=f"🔍 Evaluating Epoch {epoch+1}", leave=True, position=2)
        with torch.no_grad():
            for features, captions, lengths in val_bar:
                features, captions = features.to(device), captions.to(device)
                outputs = decoder(features, captions, lengths)
                preds = outputs.argmax(-1).detach().cpu().tolist()
                refs = captions[:, 1:].detach().cpu().tolist()

                decoded_preds = [tokenizer.decode(p) for p in preds]
                decoded_refs = [tokenizer.decode(r) for r in refs]
                all_preds.extend(decoded_preds)
                all_refs.extend(decoded_refs)

        bleu = compute_bleu(all_refs, all_preds, n=3)
        rouge = compute_rouge(all_refs, all_preds)

        val_bleus.append(bleu)
        val_rouges.append(rouge)
        scheduler.step(bleu)

        # ⏹️ Résumé console
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — Loss: {avg_loss:.4f} | BLEU-3: {bleu:.4f} | ROUGE-L: {rouge:.4f}\n")

        # 💾 Sauvegarde du modèle
        if bleu > best_bleu:
            best_bleu = bleu
            patience_counter = 0
            torch.save(decoder.state_dict(), best_model_path)
            torch.save({
                "epoch": epoch,
                "decoder_state_dict": decoder.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_losses": train_losses,
                "val_bleus": val_bleus,
                "val_rouges": val_rouges,
                "tokenizer": tokenizer.word2idx
            }, checkpoint_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⛔ Early stopping triggered.")
                break

    # 💾 Sauvegarde des métriques
    with open(metrics_path, "w") as f:
        json.dump({
            "train_loss": train_losses,
            "val_bleu": val_bleus,
            "val_rouge": val_rouges
        }, f)

    # 📊 Graphe matplotlib
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_bleus, label="BLEU-3")
    plt.plot(val_rouges, label="ROUGE-L")
    plt.xlabel("Epoch")
    plt.ylabel("Score")
    plt.title("Training Progress")
    plt.legend()
    plt.grid(True)
    plt.savefig(plot_path)
    plt.close()

    return train_losses, val_bleus, val_rouges


In [14]:
import matplotlib.pyplot as plt

def plot_training_curves(train_losses, val_bleus, val_rouges):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(16, 5))

    # 📉 Train Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_losses, label="Train Loss", marker='o')
    plt.title("🔧 Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)

    # 🎯 BLEU-3
    plt.subplot(1, 3, 2)
    plt.plot(epochs, val_bleus, label="BLEU-3", color="green", marker='o')
    plt.title("📊 BLEU-3 Score")
    plt.xlabel("Epoch")
    plt.ylabel("BLEU-3")
    plt.grid(True)

    # 📝 ROUGE-L
    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_rouges, label="ROUGE-L", color="orange", marker='o')
    plt.title("📘 ROUGE-L Score")
    plt.xlabel("Epoch")
    plt.ylabel("ROUGE-L")
    plt.grid(True)

    plt.tight_layout()
    plt.show()


In [15]:
import torch
import torch.nn as nn
import torch.optim as optim

# 🔧 Paramètres du modèle
embed_dim = 256
attention_dim = 256
decoder_dim = 512
vocab_size = tokenizer.vocab_size
dropout = 0.5

# 🧠 Initialisation du modèle
decoder = DecoderWithAttention(
    attention_dim=attention_dim,
    embed_dim=embed_dim,
    decoder_dim=decoder_dim,
    vocab_size=vocab_size,
    encoder_dim=2048,  # ← dépend du backbone utilisé en amont
    dropout=dropout
)

# ⚙️ Optimiseur et critère de perte
params = list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# 🚀 Lancement de l'entraînement
train_losses, val_bleus, val_rouges = train_model(
    decoder=decoder,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=tokenizer,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=5,      # ← tu peux augmenter
    patience=3          # ← early stopping si pas d’amélioration BLEU
)


🚀 Training started on cuda


📆 Epochs:   0%|          | 0/5 [00:00<?, ?it/s]
🧠 Training Epoch 1:   0%|          | 0/2023 [00:00<?, ?it/s][A
🧠 Training Epoch 1:   0%|          | 0/2023 [00:00<?, ?it/s, loss=7.12][A
🧠 Training Epoch 1:   0%|          | 1/2023 [00:00<14:11,  2.37it/s, loss=7.12][A
🧠 Training Epoch 1:   0%|          | 1/2023 [00:00<14:11,  2.37it/s, loss=6.94][A
🧠 Training Epoch 1:   0%|          | 2/2023 [00:00<12:36,  2.67it/s, loss=6.94][A
🧠 Training Epoch 1:   0%|          | 2/2023 [00:01<12:36,  2.67it/s, loss=6.79][A
🧠 Training Epoch 1:   0%|          | 3/2023 [00:01<12:07,  2.78it/s, loss=6.79][A
🧠 Training Epoch 1:   0%|          | 3/2023 [00:01<12:07,  2.78it/s, loss=6.66][A
🧠 Training Epoch 1:   0%|          | 4/2023 [00:01<11:58,  2.81it/s, loss=6.66][A
🧠 Training Epoch 1:   0%|          | 4/2023 [00:01<11:58,  2.81it/s, loss=6.46][A
🧠 Training Epoch 1:   0%|          | 5/2023 [00:01<11:59,  2.80it/s, loss=6.46][A
🧠 Training Epoch 1:   0%|          | 5/2023 [00:02<11:59,  2.80it/

KeyboardInterrupt: 