# üìò Plan du Notebook - `04_dataset_and_dataloader.ipynb`

| √âtape | Objectif                                                                                  |
| ----- | ----------------------------------------------------------------------------------------- |
| 1     | üìÅ Chargement des features `.pt` et des captions align√©es (`.token.txt`)                  |
| 2     | üîÑ Construction de la liste enrichie `(feature_path, caption)`                            |
| 3     | üß± D√©finition d‚Äôune classe `ImageCaptionDataset` (h√©rite de `torch.utils.data.Dataset`)   |
| 4     | üß© Cr√©ation d‚Äôun `collate_fn` personnalis√© (padding dynamique des captions)               |
| 5     | üì¶ Cr√©ation des `DataLoader` entra√Ænement/test/val (avec `torch.utils.data.random_split`) |
| 6     | üß™ Visualisation d‚Äôun batch (image\_id, l√©gende tokenis√©e, shape tensors)                 |
| 7     | ‚úÖ V√©rification des dimensions, vocab size, etc.                                           |


### üß† Remarques int√©gr√©es :
- ‚úÖ Toutes les captions (5 par image) sont associ√©es √† toutes les images (originale + augment√©e).

- ‚úÖ Cela cr√©e plus d‚Äô√©chantillons tout en gardant le dataset simple √† manipuler.

- ‚úÖ Le `collate_fn` utilisera le tokenizer (charg√© depuis `tokenizer.pkl`) pour encoder + padd les captions.

In [1]:
# üì¶ Imports
from pathlib import Path
from collections import defaultdict
import torch
import pickle

# üìÅ Dossiers
features_dir = Path("../data/processed/features_resnet_global")
captions_file = "../data/raw/Flickr8k_text/Flickr8k.token.txt"
tokenizer_path = "../data/vocab/tokenizer.pkl"

# ‚úÖ Chargement des features disponibles
feature_files = list(features_dir.glob("*.pt"))
feature_ids = [f.stem.split("_aug")[0] for f in feature_files]

print(f"üß† Nombre de features charg√©es : {len(feature_files)}")

# üìñ Construction du dictionnaire {image_id: [captions]}
captions_dict = defaultdict(list)
with open(captions_file, "r") as f:
    for line in f:
        try:
            image_tag, caption = line.strip().split("\t")
            image_id = image_tag.split("#")[0].split(".")[0]
            if image_id in feature_ids:
                captions_dict[image_id].append(caption.strip())
        except Exception as e:
            print(f"‚õî Ligne corrompue : {line}")

print(f"‚úÖ Captions align√©es pour {len(captions_dict)} images")


üß† Nombre de features charg√©es : 32364
‚úÖ Captions align√©es pour 8091 images


## üß© √âtape 2 ‚Äì Construction des paires (features, caption)

### üéØ Objectif :
Associer chaque feature `.pt` √† **une seule caption** (al√©atoire parmi les 5 possibles), pour former un dataset `(feature_path, caption)` utilisable dans le `Dataset`.

### üîß Pourquoi ce choix :
- Chaque image (et ses versions augment√©es) aura **1 seule caption associ√©e** pour √©viter les doublons (sinon on duplique inutilement les features).

- On pourra √©ventuellement en faire plus tard (ex: `dataset.expand_with_5_captions()`), mais une seule suffit pour d√©marrer.

In [2]:
import random

# üìÑ Cr√©ation de la liste (feature_path, caption)
dataset_pairs = []

for feature_file in feature_files:
    image_id = feature_file.stem.split("_aug")[0]
    
    if image_id in captions_dict:
        # üéØ Caption al√©atoire parmi les 5 possibles
        caption = random.choice(captions_dict[image_id])
        dataset_pairs.append((feature_file, caption))

print(f"‚úÖ Paires (feature, caption) construites : {len(dataset_pairs)}")


‚úÖ Paires (feature, caption) construites : 32364


## üß† √âtape 3 ‚Äì Classe `ImageCaptionDataset`

### üéØ Objectif :
Cr√©er une classe h√©ritant de `torch.utils.data.Dataset` pour :

- charger les features `.pt` (d√©j√† extraites),

- encoder la caption associ√©e (via ton tokenizer),

- retourner une paire `(features_tensor, encoded_caption_tensor)`.

In [3]:
import torch
from torch.utils.data import Dataset
import re

class ImageCaptionDataset(Dataset):
    def __init__(self, features_dir, captions_dict, tokenizer, max_length=30):
        self.features_dir = Path(features_dir)
        self.captions = captions_dict  # dict[image_id] = caption
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_ids = list(self.captions.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        feature_path = self.features_dir / f"{image_id}.pt"
        features = torch.load(feature_path)

        caption = self.captions[image_id]

        # ‚ú® Nettoyage de la l√©gende
        caption = caption.lower()
        caption = re.sub(r"[^a-z ]", "", caption)

        # ‚ú® Encodage
        encoded = self.tokenizer.encode(caption)
        encoded = encoded[:self.max_length]  # üî™ Troncature si trop long

        return features, torch.tensor(encoded, dtype=torch.long)



### üìå Remarques :
- On ajoute un `max_len=37` par d√©faut (valeur obtenue lors de l'utilisation `03_vocab_building.ipynb`).

- Le `padding` √† droite permet une gestion plus simple dans le `collate_fn`.

## üß± √âtape 4 ‚Äì Fonction `collate_fn` personnalis√©e

### üéØ Objectif :
Cr√©er une fonction `collate_fn` qui sera utilis√©e par le `DataLoader` pour :

- empiler proprement les features (`[batch_size, 2048]`)

- empiler les s√©quences texte d√©j√† pad√©es (`[batch_size, max_len]`)

- retourner un batch `(X, y)` pr√™t √† √™tre trait√© par le mod√®le

In [4]:
def collate_fn(batch):
    """
    Batch = liste de tuples : (features, encoded_caption)
    """
    features_batch = torch.stack([item[0] for item in batch])          # (B, 2048)
    captions_batch = torch.stack([item[1] for item in batch])          # (B, max_len)

    return features_batch, captions_batch


### üßæ √âtape 5 ‚Äì Construction du `Dataset` et du `DataLoader`

#### üì¶ 1. Nouveau `Dataset` `ImageCaptionDataset`

In [5]:
from torch.utils.data import Dataset
from pathlib import Path
import torch
import re

def clean_caption(caption):
    caption = caption.lower()
    # On garde les apostrophes dans les contractions (it's, don't)
    caption = re.sub(r"[^a-zA-Z0-9'\s]", "", caption)  # supprime tout sauf lettres, chiffres, apostrophes, espaces
    caption = re.sub(r"\s+", " ", caption)
    return caption.strip()

class ImageCaptionDataset(Dataset):
    def __init__(self, features_dir, captions_dict, tokenizer, max_length=37):
        self.features_dir = Path(features_dir)
        self.captions_dict = captions_dict
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.samples = []
        for image_id, captions in captions_dict.items():
            for caption in captions:
                self.samples.append((image_id, caption))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_id, caption = self.samples[idx]

        # Nettoyage de la l√©gende avant encodage
        caption = clean_caption(caption)

        # Chargement des features (Tensor de taille [2048])
        feature_path = self.features_dir / f"{image_id}.pt"
        features = torch.load(feature_path)

        # Encodage de la l√©gende
        encoded = self.tokenizer.encode(caption)
        encoded = encoded[:self.max_length]  # üî™ Troncature si trop long

        return features, torch.tensor(encoded, dtype=torch.long)


#### üîß 2. `collate_fn` pour padding dynamique

In [6]:
def collate_fn(batch):
    features, captions = zip(*batch)

    # Stack les features [batch, 2048]
    features = torch.stack(features)

    # Padding des s√©quences
    lengths = [len(c) for c in captions]
    max_len = max(lengths)
    padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long)

    for i, cap in enumerate(captions):
        padded_captions[i, :len(cap)] = cap

    return features, padded_captions, torch.tensor(lengths)


#### üöÄ 3. Cr√©ation du DataLoader

In [7]:
from collections import defaultdict
from pathlib import Path

captions_file = "../data/raw/Flickr8k_text/Flickr8k.token.txt"
features_dir = Path("../data/processed/features_resnet_global")

extracted_ids = {f.stem.split("_aug")[0] for f in features_dir.glob("*.pt")}

aligned_captions = defaultdict(list)

with open(captions_file, "r") as f:
    for line in f:
        try:
            image_tag, caption = line.strip().split('\t')
            image_id = image_tag.split('#')[0].split('.')[0]
            if image_id in extracted_ids:
                aligned_captions[image_id].append(caption.strip())
        except:
            continue


In [8]:
# üß† Classe Tokenizer √† copier (m√™me que dans 03_)
class Tokenizer:
    def __init__(self, word2idx):
        self.word2idx = word2idx
        self.idx2word = {idx: word for word, idx in word2idx.items()}
        self.pad_token = "<pad>"
        self.start_token = "<start>"
        self.end_token = "<end>"
        self.unk_token = "<unk>"
        self.pad_token_id = self.word2idx[self.pad_token]
        self.start_token_id = self.word2idx[self.start_token]
        self.end_token_id = self.word2idx[self.end_token]
        self.unk_token_id = self.word2idx[self.unk_token]
        self.vocab_size = len(self.word2idx)

    def encode(self, caption, add_special_tokens=True):
        tokens = caption.strip().split()
        token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
        if add_special_tokens:
            return [self.start_token_id] + token_ids + [self.end_token_id]
        return token_ids

    def decode(self, token_ids, remove_special_tokens=True):
        words = [self.idx2word.get(idx, self.unk_token) for idx in token_ids]
        if remove_special_tokens:
            words = [w for w in words if w not in [self.pad_token, self.start_token, self.end_token]]
        return " ".join(words)

# ‚úÖ Chargement de l'objet
import pickle

with open("../data/vocab/tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

print("üß† Tokenizer charg√© :", type(tokenizer))

üß† Tokenizer charg√© : <class '__main__.Tokenizer'>


In [9]:
from torch.utils.data import DataLoader

dataset = ImageCaptionDataset(
    features_dir="../data/processed/features_resnet_global",
    captions_dict=aligned_captions,
    tokenizer=tokenizer,
    max_length=37
)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

print(f"‚úÖ DataLoader pr√™t avec {len(dataloader)} batchs")


‚úÖ DataLoader pr√™t avec 1265 batchs


#### üëÄ 4. (Optionnel) Visualisation d‚Äôun batch

In [10]:
test_caption = "a man is climbing"
print("Encoded :", tokenizer.encode(test_caption))
print("Decoded :", tokenizer.decode(tokenizer.encode(test_caption)))


Encoded : [1, 4, 12, 8, 120, 2]
Decoded : a man is climbing


In [11]:
import matplotlib.pyplot as plt
from PIL import Image
import random

# üß† Chargement d'un batch
features_batch, captions_batch, lengths = next(iter(dataloader))

print("‚úÖ Batch charg√©")
print("Features shape :", features_batch.shape)  # [batch_size, 2048]
print("Captions shape :", captions_batch.shape)  # [batch_size, max_seq_len]

# üîÅ Affichage al√©atoire de 5 exemples
for i in range(5):
    idx = random.randint(0, len(captions_batch) - 1)
    caption_ids = captions_batch[idx][:lengths[idx]].tolist()
    decoded = tokenizer.decode(caption_ids)

    print(f"\nüñºÔ∏è Extrait {i+1} :")
    print(f"  ‚û§ Vector shape : {features_batch[idx].shape}")
    print(f"  ‚û§ Caption (decoded) : {decoded}")


‚úÖ Batch charg√©
Features shape : torch.Size([32, 2048])
Captions shape : torch.Size([32, 24])

üñºÔ∏è Extrait 1 :
  ‚û§ Vector shape : torch.Size([2048])
  ‚û§ Caption (decoded) : a white and brown spotted dog runs along the snow to catch a ball

üñºÔ∏è Extrait 2 :
  ‚û§ Vector shape : torch.Size([2048])
  ‚û§ Caption (decoded) : two small dogs run across the green grass

üñºÔ∏è Extrait 3 :
  ‚û§ Vector shape : torch.Size([2048])
  ‚û§ Caption (decoded) : a brown dog runs in the grass with one ear up

üñºÔ∏è Extrait 4 :
  ‚û§ Vector shape : torch.Size([2048])
  ‚û§ Caption (decoded) : a tan dog is standing in front of some plants

üñºÔ∏è Extrait 5 :
  ‚û§ Vector shape : torch.Size([2048])
  ‚û§ Caption (decoded) : man with no shirt and <unk> on back airborne with skateboard in hand


In [12]:
original = "It's a beautiful day, isn't it?"
cleaned = clean_caption(original)
print(cleaned)


it's a beautiful day isn't it
