## Imports

In [69]:
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models, losses
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import Sequence
from pycocotools.coco import COCO
from tensorflow.keras.preprocessing.sequence import pad_sequences
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from gensim.models import Word2Vec

### Constantes et Variables Globals

In [None]:
ANNOTDIR = 'annotations_trainval2014'
DATADIR = 'train2014'
CAPFILE = '{}/annotations/captions_{}.json'.format(ANNOTDIR, DATADIR)
INSTANCEFILE = '{}/annotations/instances_{}.json'.format(ANNOTDIR, DATADIR)
ALLOW_STOPWORD = True
WORD2VEC_PATH = 'word2vec_captions.model'
#MAX_LEN_SEQUENCE = 80 # 57 Obtenu via trainement des données
TEXT_VECTOR_SIZE = 100 #VOCAB_SIZE = 24918
IMAGE_VECTOR_SIZE = 128
START_TOKEN = '<sos>'
END_TOKEN = '<eos>'
coco_captions = COCO(CAPFILE)
coco_instances = COCO(INSTANCEFILE)
BATCH_SIZE = 32
EPOCHS = 200
RATIO_TRAIN = 0.8
RATIO_VAL = 0.15
RATIO_TEST = 0.05

if os.path.exists(WORD2VEC_PATH):
    wordvec = Word2Vec.load(WORD2VEC_PATH)
    print(f"Model loaded from {WORD2VEC_PATH}")
else:
    print(f"No model found at {WORD2VEC_PATH}")

assert RATIO_TRAIN + RATIO_VAL + RATIO_TEST == 1 # Vérification de la somme des ratios


## Dataset

### Visualisation des données

In [None]:

# Sélectionner un ID d'image au hasard
imgIds = coco_instances.getImgIds()
print(f' Number of images found in instances :',len(imgIds))
randomImgId = np.random.choice(imgIds)
found_img = coco_instances.imgs[randomImgId]
file_name = found_img['file_name']
print(f" Filename : {file_name}")

image = Image.open(f'{DATADIR}/{file_name}')

plt.imshow(image)
plt.axis('off')  # Désactiver les axes, qui ne sont pas nécessaires pour l'affichage d'image
plt.show()

# Récupérer les IDs des annotations de légendes pour l'image sélectionnée
annIds = coco_captions.getAnnIds(imgIds=randomImgId)
# Charger les annotations
anns = coco_captions.loadAnns(annIds)
# Afficher les légendes
print("Captions for the selected image:")
for ann in anns:
    print(f"- {ann['caption']}")

### Word2Vec

#### Pré-Traitement pour Word2Vec

In [None]:
# Fonction de nettoyage de texte
def process_text(text):
    # Retirer les caractères non-alphabétiques et convertir en minuscules
    tokens = word_tokenize(text.lower())
    # Retirer les stop words si besoin
    if not ALLOW_STOPWORD:
        tokens = [w for w in tokens if w not in stopwords.words('english')]
    # Ajouter les tokens de début et de fin
    tokens.insert(0, START_TOKEN)  # Insérer le token de début en première position
    tokens.append(END_TOKEN)  # Ajouter le token de fin
    return tokens

count_captions = 0
count_invidual_captions = 0
raw_captions = []
for id in imgIds :
    caption_ids = coco_captions.getAnnIds(imgIds=id)
    captions_data = coco_captions.loadAnns(caption_ids)
    captions = [process_text(caption['caption']) for caption in captions_data]
    count_invidual_captions += len(captions)
    count_captions += 1
    raw_captions += captions # On aura donc raw_captions une liste de listes
max_captions = max([len(raw_captions[i]) for i in range(len(raw_captions))])
max_len_captions = max([len(raw_captions[i][j]) for i in range(len(raw_captions)) for j in range(len(raw_captions[i]))])
print('Attention, statistique avec captions altérés (ajout des tokens de début et de fin)')
print('count_captions :',count_captions)
print('count_invidual_captions :',count_invidual_captions)
print(' mean number of caption per image :',count_invidual_captions/count_captions)
print(' max number of captions :',max_captions)

#### Entrainement de Word2Vec

In [11]:
# Entraîner un modèle Word2Vec
wordvec = Word2Vec(raw_captions, vector_size=TEXT_VECTOR_SIZE, window=4, min_count=1, workers=3, sg=1, epochs=100)

# Nombre de mots dans le vocabulaire
vocab_size = len(wordvec.wv.key_to_index)
print(f"Nombre de mots dans le vocabulaire : {vocab_size}")

#### Sauvegarde de Word2Vec

In [12]:
wordvec.save(WORD2VEC_PATH)

#### Chargement de Word2Vec

In [5]:
wordvec = Word2Vec.load(WORD2VEC_PATH)

#### Test unitaire de Word2Vec

In [None]:
word = input('Quel mot souhaitez-vous avoir de similaire ? :')
#START_TOKEN = '<sos>'
#END_TOKEN = '<eos>'

if word in wordvec.wv.key_to_index:
    similar_words = wordvec.wv.most_similar(word)
    print("Mots similaires à '{}':".format(word))
    for similar_word, similarity in similar_words:
        print(f"{similar_word}: {similarity:.4f}")
else:
    # Si le mot n'est pas dans le vocabulaire, afficher un message d'erreur
    print("Désolé, le mot '{}' n'est pas dans le vocabulaire.".format(word))


### Générateur de données

#### Création des générateurs

In [133]:
class DatasetGenerator(Sequence):
    def _getsplit(self, ensemble):
        if ensemble == 'train':
            start = 0
            stop = int(RATIO_TRAIN * len(self.imgIds))
        elif ensemble == 'val':
            start = int(RATIO_TRAIN * len(self.imgIds))
            stop = int((RATIO_TRAIN + RATIO_VAL) * len(self.imgIds))
        elif ensemble == 'test':
            start = int((RATIO_TRAIN + RATIO_VAL) * len(self.imgIds))
            stop = len(self.imgIds)
        return start, stop
    
    # Fonction de nettoyage de texte
    def _clean_text(self,text):
        # Retirer les caractères non-alphabétiques et convertir en minuscules
        tokens = word_tokenize(text.lower())
        # Retirer les stop words si besoin
        if not ALLOW_STOPWORD:
            tokens = [w for w in tokens if w not in stopwords.words('english')]
        # Ajouter les tokens de début et de fin
        tokens.insert(0, START_TOKEN)  # Insérer le token de début en première position
        tokens.append(END_TOKEN)  # Ajouter le token de fin
        return tokens
    
    def __init__(self, ensemble):
        self.ensemble = ensemble
        
        # Créer une liste de tous les IDs d'images
        self.imgIds = coco_instances.getImgIds()
        start, stop = self._getsplit(ensemble)
        self.ids = self.imgIds[start:stop]
        self.captions_ids = { id : coco_captions.getAnnIds(imgIds=id) for id in self.ids }

    def __len__(self):
        return int(np.ceil(len(self.ids) / BATCH_SIZE))
    
    def __getitem__(self, index):
        batch_ids = self.ids[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]
        batch_images = []
        batch_captions = []
        max_len_captions = 0
        for id in batch_ids:
            # Charger l'image
            file_name = coco_instances.imgs[id]['file_name']
            image = Image.open(f'{DATADIR}/{file_name}')
            image = image.resize((224, 224))
            image = image.convert('RGB')
            image = np.array(image)
            batch_images.append(image)
            # Charger une légende aléatoire
            caption_ids = self.captions_ids[id]
            chosen_id = np.random.choice(caption_ids)
            caption = coco_captions.anns[chosen_id]['caption'] # Accès directe car API buggée
            caption = self._clean_text(caption)
            caption_vector = [wordvec.wv.key_to_index[word] for word in caption if word in wordvec.wv.key_to_index]
            len_caption = len(caption_vector)
            if len_caption > max_len_captions:
                max_len_captions = len_caption
            batch_captions.append(caption_vector)
        batch_images = preprocess_input(np.array(batch_images).copy())
        batch_captions = pad_sequences(batch_captions, maxlen=max_len_captions, padding='post')
        return batch_images, batch_captions

    def on_epoch_end(self):
        self.ids = np.random.permutation(self.ids)

train_generator = DatasetGenerator('train')
val_generator = DatasetGenerator('val')
test_generator = DatasetGenerator('test')


#### Test unitaire du générateur de données

In [None]:
generator = train_generator
# Récupérer un batch d'images et de légendes
r_index = np.random.randint(len(generator))
images, captions = generator.__getitem__(r_index-1)
print(f"Images shape: {images.shape}")
print(f"Captions shape: {captions.shape}")
# Plot d'une des images avec sa légende
r_index = np.random.randint(0, images.shape[0])
selected_image = images[r_index]
selected_caption = captions[r_index]

# On recentre les valeurs de l'image
selected_image = ( selected_image - np.min(selected_image) ) / ( np.max(selected_image) - np.min(selected_image) ) * 255
# On convertit l'image en RGB pour l'affichage
selected_image = np.array(selected_image)
selected_image = selected_image.astype('uint8')
selected_image = selected_image[...,::-1]

# Convertir les indices de la légende en mots
selected_caption_words = [wordvec.wv.index_to_key[index] for index in selected_caption if index in wordvec.wv.index_to_key]
selected_caption_str = ''.join(selected_caption_words).replace(START_TOKEN, '').replace(END_TOKEN, '')

# Affichage de l'image et de la légende
plt.figure(figsize=(8, 8))
plt.imshow(selected_image)
plt.title(f"Caption: {selected_caption_str}")
plt.axis('off')  # Désactiver les axes pour une meilleure visibilité
plt.show()


## Modèle

### Test de la layer d'embedding

In [None]:
def test_embedding(wordvec, word):
    # On test la layer d'embedding de tf
    if word not in wordvec.wv.key_to_index:
        print(f"Le mot '{word}' n'est pas dans le vocabulaire.")
        return
    word_index = wordvec.wv.key_to_index[word]
    embedding = wordvec.wv.get_vector(word)
    print(f"Index du mot '{word}' dans le vocabulaire : {word_index}")
    print(f"Embedding du mot '{word} (wordvec)' : {embedding[0:5]}..")
    
    model = models.Sequential()
    model.add(layers.Embedding(input_dim=vocab_size, output_dim=TEXT_VECTOR_SIZE, weights=[wordvec.wv.vectors], trainable=False))
    # Test de l'embedding
    embedded_word = model.predict(np.array([[word_index]]))
    print(f"Embedding du mot '{word}' (calculé) : {embedded_word[0][0][0:5]}..")
    
    print("Les deux embeddings sont-ils égaux ? :", np.allclose(embedding, embedded_word[0][0]))

test_embedding(wordvec, input('Quel mot souhaitez-vous tester ? :'))

In [None]:

def caption_model():
    '''
    inputs :
        image : (batch_size, 224, 224, 3)
        text : (batch_size, None, TEXT_VECTOR_SIZE)
    outputs :
        output : (batch_size, TEXT_VECTOR_SIZE)
    '''
    # Image processing
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    x = base_model.output
    x = layers.Flatten()(x)
    image_data = layers.Dense(IMAGE_VECTOR_SIZE)(x)  # Nouvelle couche dense pour les caractéristiques
    
    # Text processing
    text_input = layers.Input(shape=(None,TEXT_VECTOR_SIZE))
    x = layers.Masking(mask_value=0.0)(text_input) # Extrèmement important
    text_data = layers.LSTM(512)(x)
    context = layers.Concatenate()([image_data, text_data])
    output = layers.Dense(TEXT_VECTOR_SIZE, activation='relu')(context)
    
    model = Model(inputs=[base_model.input, text_input], outputs=[output], name='caption_model')
    cosinus_loss = losses.CosineSimilarity()
    model.compile(loss=cosinus_loss, optimizer='adam', metrics=['accuracy'])
    return model

model = caption_model()
model.summary()

In [None]:
plot_model(model, to_file='caption_model.png', show_shapes=True, show_layer_names=True)