# Projet : Génération de légendes d'images automatique

Etudiant : CHIBAN Amine

Cours : PGM 

# Objectif :  

Créer un modèle capable de générer une légende d'une image quelconque 

# Vu d'ensemble :

- Utiliser un réseaux de neurones à convolution (CNN) pour extraire les features des images
- Utiliser un réseaux de neurones récurents (RNN) pour générer les mots de la description de manière séquentielle
- Ajouter un mécanisme d'attention au réseau de neurones récurents pour améliorer les prédictions 

![title](notebook-img/over.png)

# CNN

Le choix du CNN que nous allons utiliser : VGG16, 16 représente le nombre de couche de notre réseau de neurones à convolution. Une série de 5 bloc (couches de convolution et max pooling) pour l'extraction de features puis 3 couches pour la classification. Comme le montre l'image suivante :
![title](notebook-img/picture.png)
- 3x3 Conv 64 : convolution par 64 filtres de taille 3x3.
- pool/2 : fonction max pooling, pour une fenetre de 2x2. Remplacer une fenetre 2x2 par la valeur du max et donc devient 1x1.
- FC 4096 : Couche de réseau de neurones simple avec 4096 noeuds.

Le temps d'apprentissage du VGG étant très grands pour une base de données suffisante au bon apprentissage des features (plus de 2 jours avec des performances élevées), nous allons utilisée un modéle pré entrainé du VGG disponible sur pytorch. 

Vu que nous n'avons pas besoin de la partie classification, nous allons devoir enveler la dernière brique contenant le réseau de neurones simple (la brique contenant les FC 4096). Nous aurons donc un éxtracteur de features que nous pourons utiliser comme entrées dans la deuxième partie, qui est dédiée à la génération de descriptions à partir des features de l'image.

In [8]:
import torch
from torch import nn
import torchvision
from skimage import io, transform
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from skimage.transform import resize

class Encoder(nn.Module):
    """
    Encoder class to transform images into features. We're using the features from the last conv layer.
    So the output will be 14x14x512
    """

    def __init__(self):

        super(Encoder, self).__init__()

        # Imgs must be normalised before being fed to the CNN as per de docs
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transform=transforms.Compose([normalize])
        
        # Using the official VGG16 model pretrained
        vgg = torchvision.models.vgg16(pretrained=True)  
        
        # Remove the classification bloc
        modules = list(vgg.children())[:-1]

        # Remove the last maxpool layer
        modules[0]=nn.Sequential(*list(modules[0].children())[:-1])

        # Save the new model 
        self.vgg = nn.Sequential(*modules)


    def forward(self, images):
        """
        Forward propagation.

        Params:
            - images : tensor of input images (batch_size, 3, 224, 224)
        Return:
            - output : Tensor of features (14,14,512)
        """
        output = self.transform(images.squeeze())
        output = self.vgg(images)  
        output = output.permute(0, 2, 3, 1)  
        return output

# Pre-processing des images

In [1]:
from collections import Counter
import os
import json
import string
from skimage.transform import resize
import skimage.io as io
import numpy as np
import torch

In [3]:
def preprocess_flickr_data():
    """
    This function is used to generate input files from raw data files. It generates 3 types of files:

    Images : we can either load them raw, vector h5py or generate the features using the encoder
    Captions : We sample a number of captions per image then store them as Json
    Caption lenghts : the lenghts of each captions, usefull to know when to stop the process ?!

    The data is also split into train evaluate and test
    """

    min_frequency = 2
    max_cap_len=20
    output_folder='./processed_data'
    caps_per_img=2

    # Loading split IDs
    train_ids=load_doc('./raw_data/Flickr8k_text/Flickr8k.trainImages.txt')
    train_ids=[x.split('.')[0] for x in train_ids.split('\n')]
    eval_ids=load_doc('./raw_data/Flickr8k_text/Flickr8k.devImages.txt')
    eval_ids=[x.split('.')[0] for x in eval_ids.split('\n')]
    test_ids=load_doc('./raw_data/Flickr8k_text/Flickr8k.testImages.txt')
    test_ids=[x.split('.')[0] for x in test_ids.split('\n')]

    # Generating proccessed images then storing them
    for ID in train_ids[:20]:
        image=proccess_image('./raw_data/Flickr8k_data/'+ID+'.jpg')
        torch.save(image, './processed_data/train_images/'+ID+'.pt')
        print('Train image',ID,'Generated')
    for ID in eval_ids[:20]:
        image=proccess_image('./raw_data/Flickr8k_data/'+ID+'.jpg')
        torch.save(image, './processed_data/eval_images/'+ID+'.pt')
        print('Validation Image',ID,'Generated')
    for ID in test_ids[:20]:
        image=proccess_image('./raw_data/Flickr8k_data/'+ID+'.jpg')
        torch.save(image, './processed_data/test_images/'+ID+'.pt')
        print('Test Image',ID,'Generated')

    # Loading captions
    data=load_doc('./raw_data/Flickr8k_text/Flickr8k.token.txt')
    train_captions,eval_captions,test_captions=load_captions(data,train_ids,eval_ids,test_ids,caps_per_img)
    
    # Generating the wordmap then saving it to file
    wordmap=generate_wordmap([train_captions,eval_captions,test_captions],min_frequency)
    with open(os.path.join(output_folder, 'WORDMAP.json'), 'w') as j:
        json.dump(wordmap, j)

    # Process captions then store in file
    train_captions,eval_captions,test_captions=process_captions([train_captions,eval_captions,test_captions],wordmap,max_cap_len)
    with open(os.path.join(output_folder, 'TRAIN_CAPTIONS.json'), 'w') as j:
        json.dump(train_captions, j)
    with open(os.path.join(output_folder, 'EVAL_CAPTIONS.json'), 'w') as j:
        json.dump(eval_captions, j)
    with open(os.path.join(output_folder, 'TEST_CAPTIONS.json'), 'w') as j:
        json.dump(train_captions, j)

In [4]:
def generate_wordmap(splits,min_frequency):
    word_counter=Counter()
    for split in splits:
        for line in split.values():
            for cap in line:
                word_counter.update(cap.split(' '))
    words= [ x for x  in word_counter.keys() if word_counter[x]>min_frequency]
    wordmap = {k: v + 1 for v, k in enumerate(words)}
    wordmap['<unk>'] = len(wordmap) + 1
    wordmap['<start>'] = len(wordmap) + 1
    wordmap['<end>'] = len(wordmap) + 1
    wordmap['<pad>'] = 0    
    return wordmap

In [5]:
def load_captions(data,train_ids,eval_ids,test_ids,caps_per_img):
    table = str.maketrans('', '', string.punctuation)
    train_captions = {}
    eval_captions = {}
    test_captions = {}
    for line in data.split('\n'):
        tokens = line.split()
        image_id, image_cap = tokens[0], tokens[1:]
        image_id = image_id.split('.')[0]
        image_cap=[w.translate(table) for w in image_cap]
        #image_cap=[w for w in image_cap if len(w)>1]
        image_cap = ' '.join(image_cap).lower()
        if image_id in train_ids:
            if image_id not in train_captions.keys():
                train_captions[image_id] = []
            if (len(train_captions[image_id])<=caps_per_img):
                train_captions[image_id].append(image_cap)

        if image_id in eval_ids:
            if image_id not in eval_captions.keys():
                eval_captions[image_id] = []
            if (len(eval_captions[image_id])<=caps_per_img):
                eval_captions[image_id].append(image_cap)

        if image_id in test_ids:
            if image_id not in test_captions.keys():
                test_captions[image_id] = []
            if (len(test_captions[image_id])<=caps_per_img):
                test_captions[image_id].append(image_cap)

    return train_captions,eval_captions,test_captions

In [6]:
def load_doc(filename):
	"""
    Helper function to load a file as a string
    """
	file = open(filename, 'r')
	text = file.read()
	file.close()
	return text

In [None]:
def proccess_image(path):
    img = io.imread(path)
    if len(img.shape) == 2:
        img = img[:, :, np.newaxis]
        img = np.concatenate([img, img, img], axis=2)
    img = resize(img, (256, 256))
    img = img.transpose(2, 0, 1)
    img=torch.Tensor(img)
    img=img.reshape(1,3,img.shape[1],img.shape[2])
    enc=Encoder()
    return enc.forward(img)

Afin de lancer la fonction suivante, il faut avoir le dataset flickr 8k avec les 5 Captions pour chacune des images, déposée dans un dossier row_data suivant le chemin indiqué dans la fonction preprocess_flickr_data().

In [None]:
preprocess_flickr_data()

En vue du temps nécessaire pour traiter toutes les images, et pour l'entrainement du LSTM et le mécanisme d'attention plus loin, j'ai décidé d'utiliser 20 images seulement.
Au niveau du code de la fonction preprocess_flickr_data(), j'ai mis : for ID in train_ids[:20]. Et donc on ne prend que 20 images du dataset d'apprentissage.

Nous aurons donc comme résultats un dossier nomé processed_data contenant : 
- WORDMAP.json : le vocabulaire que notre systeme connait
- TRAIN_CAPTIONS.json : les captions pour les images d'entrainement avec l'indice de chaque mot de le WORDMAP
- EVAL_CAPTIONS.json : les captions pour les images d'évalutation avec l'indice de chaque mot de le WORDMAP
- TEST_CAPTIONS.json : les captions pour les images de test avec l'indice de chaque mot de le WORDMAP
- eval_images : images du dataset d'évaluation encodées
- train_images : images du dataset d'entrainement encodées
- test_images : images du dataset de test encodées

# Mécanisme d'attention 

In [2]:
import torch
from torch import nn
import torchvision
from skimage import io, transform
import matplotlib.pyplot as plt
import numpy as np

class Attention(nn.Module):
    """
    Implements the attention model. 
    It is based on 1 linear layer folowed by tanh then a softmax to compute the weights,
    Then outputs the features*weights , weights.
    """

    def __init__(self, lstm_len, attention_len, features_len=512):
        """
        Params:
            - features_len : size of the images feature vector we're using 512
            - lstm_len : size of the lstm network (it's the size of the hidden and cell states also)
            - attention_len : size of the attention network (Not sure how it affects the learning yet)
        """
        super(Attention, self).__init__()

        # Simple linear layer to project the encoder features
        self.image_layer = nn.Linear(features_len, 1) 
        # Simple linear layer to project the hidden state 
        self.hidden_layer = nn.Linear(lstm_len, 1)  
        # A Tanh layer
        self.tanh = nn.Tanh()
        # Simple linear layer to merge the two vectors
        self.merge_layer = nn.Linear(attention_len, 1)  
        # Softmax layer to compure weights
        self.softmax = nn.Softmax(dim=1)  

    def forward(self, image_features, hidden_state): 
        """
        Forward propagation. Computes the attention weights using the equation

                        Si=tanh(Wc*C+Wx*Xi) , weights=softmax(si)

        Params:
            - image_features : vector of image features (batch_size,14,14,512)
            - hidden_state : vector Ht-1 previous hidden state (batch_size,lstm_len)

        Output:
            - attention_features : weights*features (batchsize,lstm_len)
            - attention_weights : weights computed with softmax (batchsize,nb_image_sections=14*14)
        """
        features = self.image_layer(image_features)  
        hidden = self.hidden_layer(hidden_state)  
        merged = features + hidden.unsqueeze(1)
        tanh = self.tanh(merged)
        attention_weights = self.softmax(tanh).squeeze(2) 
        attention_features = (image_features * attention_weights.unsqueeze(2)).sum(dim=1) 

        return attention_features, attention_weights

# LSTM

Fonction pour loader les données

In [3]:
import torch
from torch.utils.data import Dataset
import json
import os

class FlickrDataset(Dataset):
    """
    Helper class to load up the data.
    """

    def __init__(self,data_folder,split,caps_per_img):

        self.split = split
        self.data_folder=data_folder
        self.caps_per_img=caps_per_img

        # Load captions
        with open(os.path.join(data_folder, self.split + '_CAPTIONS.json'), 'r') as j:
            self.captions = json.load(j)

        # Length of the dataset
        self.dataset_size = len(self.captions)*caps_per_img


    def __getitem__(self, i):

        if self.split=='DEV':
            self.split='TRAIN'
            
        ID=list(self.captions.keys())[i//self.caps_per_img]
        img_path=self.data_folder+'/'+self.split+'_images/'+ID+'.pt'
        img=torch.load(img_path)
        
        caption = torch.Tensor(self.captions[ID]["caps"][i%2]).to(dtype=torch.int16)
        caplen = torch.Tensor([self.captions[ID]["caplens"][i%2]]).to(dtype=torch.int16)

        return img, caption,caplen

    def __len__(self):
        return self.dataset_size

Decoder

In [4]:
import torch
from torch import nn
import torchvision
from skimage import io, transform
import matplotlib.pyplot as plt
import numpy as np

class Decoder(nn.Module):
    """
    Implements the decoder model. 
    it's based on an lstm and uses attention.
    """

    def __init__(self, attention_len, embedding_len, features_len, wordmap_len, lstm_len=512):
        """
        Params:
            - attention_len: size of attention network
            - embedding_len: size of the embedding vector
            - features_len: size of the images features (14,14,512)
            - wordmap_len: number of words in the vocabulary
            - lstm_len: size of the lstm network
        """
        super(Decoder, self).__init__()

        self.lstm_len = lstm_len
        self.attention_len = attention_len
        self.embedding_len = embedding_len
        self.features_len = features_len
        self.wordmap_len = wordmap_len

        # Our attention model
        self.attention = Attention(lstm_len, features_len, attention_len)
        # Embedding model, it transforms each word vector into an embedding vector
        self.embedding = nn.Embedding(wordmap_len, embedding_len)  
        # LSTM model. We use LSTMCell and implement the loop manualy to use attention
        self.lstm = nn.LSTMCell(embedding_len + lstm_len, features_len, bias=True) 
        # A simple linear layer to compute the vocabulary scores from the hidden state
        self.scoring_layer = nn.Linear(features_len, wordmap_len)

    def init_hidden_state(self, image_features):
        """
        Initialize the hidden state and cell state for each sentence.
        Using zero tensors for now, might change later.

        Params:
            - image_features : vector of image features (batch_size,14,14,512)

        Output:
            - Hidden state : vector for initial hidden state
            - Cell state : vector for initial cell state
        """
        h = torch.zeros((image_features.shape[0],image_features.shape[-1]))  # (batch_size, features_len)
        c = torch.zeros((image_features.shape[0],image_features.shape[-1]))
        return h, c

    def forward(self, image_features, caps, caplens):
        """
        Forward propagation.

        Params:
            - image_features : vector of image features (batch_size,14,14,512)
            - caps: image captions (batch_size, max_caption_length)
            - caplens: image caption lengths (batch_size, 1)
        Output:
            - scores : scores for each word (batch_size, wordmap_len)
            - caps_sorted : a sorted list of caps by lenghts.
            - decode lengths : caplens - 1
            - weights : attention weights
            - sort indices : can be used later
        """

        batch_size = image_features.size(0)
        lstm_len = image_features.size(-1)
        wordmap_len = self.wordmap_len

        # Flatten image 14*14 -> 196
        image_features = image_features.view(batch_size, -1, lstm_len)
        num_pixels = image_features.size(1)

        # Sort the sentences by decreasing lenght, so we can decode only the k first sentences that haven't
        # reached <end> yet. We can use index to select them, but this is cleaner
        caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True)
        image_features = image_features[sort_ind]
        caps = caps[sort_ind]
        
        # Transforms the captions into embedding vectors
        embeddings = self.embedding(caps)  
        # Initialize LSTM  hidden and cell state
        h, c = self.init_hidden_state(image_features)  

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caplens - 1).tolist()

        # Result tensors
        scores = torch.zeros(batch_size, max(decode_lengths), self.wordmap_len)
        weights = torch.zeros(batch_size, max(decode_lengths), num_pixels)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):

            # At the moment t we only decode the sentences that haven't reached <end>, so the K first sentences
            # Since they are sorted by lenght we can just use [:K]
            k = sum([l > t for l in decode_lengths])

            # We first generate the attention weighted images. Alpha is the weights of the attention model.
            attention_encoding, alpha = self.attention(image_features[:k],h[:k])

            # Concatenate Previous word + features
            decode_input=torch.cat([embeddings[:k, t, :],attention_encoding], dim=1)
            
            # We run the LSTM cell using the decode imput and (hidden,cell) states
            h, c = self.lstm(decode_input, (h[:k],c[:k]) ) 

            # The hidden state is transformed into vocabulary scores by a simple linear layer
            score = self.scoring_layer(h)  # (k, wordmap_len)

            # Finaly we store the scores and weights
            scores[:k, t, :] = score
            weights[:k, t, :] = alpha

        return scores, caps, decode_lengths, weights, sort_ind


In [5]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
import matplotlib.pyplot as plt

In [None]:
# Model parameters
data_folder = './processed_data'  
embedding_len = 512  
attention_len = 512  
lstm_len = 512  
caps_per_image=2
batch_size = 32
learning_rate = 4e-4  

# Read word map
word_map_file = os.path.join(data_folder, 'WORDMAP.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

# Decoder model
decoder = Decoder(attention_len=attention_len,
                    embedding_len=embedding_len,
                    features_len=lstm_len,
                    wordmap_len=len(word_map))
# We need an optimiser to update the model weights
grad_params=filter(lambda p: p.requires_grad, decoder.parameters())
decoder_optimizer = torch.optim.Adam(params=grad_params, lr=learning_rate)

# Loss function
criterion = nn.CrossEntropyLoss()

# Dataloaders are wrappers around datasets that help woth the learning.
# Its not mandatory but its usefull so we might as well use it
dataset=FlickrDataset(data_folder, 'DEV', caps_per_image)
data_loader = torch.utils.data.DataLoader(
            FlickrDataset(data_folder, 'DEV', caps_per_image),
            batch_size=5, )

losses=[]
weights=[]
predictions=[]
print('--------------------------------------------- LEARNING STARTED -------------------------------------------')
for epoch in range(1,15):
    print( '----------------------------------- Epoch',epoch,'----------------------------------')
    for i,(imgs,caps,caplens) in enumerate(data_loader):
        #print( '----------------------------------- Batch',i,'----------------------------------')

        scores, caps_sorted, decode_lengths, alphas, sort_ind=decoder.forward(imgs, caps.to(dtype=torch.int64), caplens)
        # Remove the <start> word
        targets = caps_sorted[:, 1:]

        prediction=''
        for line in scores[0]:
            idx=line.argmax()
            for key in word_map.keys():
                if word_map[key]==idx:
                    prediction+=' '+str(key)
        caption=''
        for line in caps[sort_ind[0], 1:caplens[sort_ind[0]]+1]:
            for key in word_map.keys():
                if word_map[key]==line:
                    caption+=' '+str(key)

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)
        #print('loss : ',loss.item())
        decoder_optimizer.zero_grad()
        losses.append(loss)
        loss.backward()
        decoder_optimizer.step()
    predictions.append(prediction)
    weights.append(alphas)    
    #print('--------------------------------------------------------------')
    print('Prediction : ',prediction)
    print('Truth : ',caption)
    #print('--------------------------------------------------------------')    

--------------------------------------------- LEARNING STARTED -------------------------------------------
----------------------------------- Epoch 1 ----------------------------------
Prediction :   a a dog a a a a a a a
Truth :   a brown dog wearing a black collar running across the beach
----------------------------------- Epoch 2 ----------------------------------
Prediction :   a a a a a a a a a a
Truth :   a brown dog wearing a black collar running across the beach
----------------------------------- Epoch 3 ----------------------------------
Prediction :   a a dog a a a a a a a
Truth :   a brown dog wearing a black collar running across the beach
----------------------------------- Epoch 4 ----------------------------------
Prediction :   a dog dog dog a dog dog dog a the
Truth :   a brown dog wearing a black collar running across the beach
----------------------------------- Epoch 5 ----------------------------------
Prediction :   a dog dog is a dog dog a a the
Truth :   a br