<a href="https://colab.research.google.com/github/LastChanceKatze/image-caption-gen/blob/main/img_caption_gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ***Imports***

In [24]:
from os import listdir
import string
from pickle import dump, load
import tensorflow.keras.applications.vgg16 as vgg16
import tensorflow.keras.applications.inception_v3 as inception_v3
from tensorflow.keras.models import Model
from keras.preprocessing.image import load_img, img_to_array
from keras_preprocessing.text import Tokenizer
from keras.utils.np_utils import to_categorical
from keras_preprocessing.sequence import pad_sequences
import numpy as np
import random

In [5]:
drive_folder = "/content/drive/MyDrive/DL"
img_features_path = f"{drive_folder}/training_files/img_features.pkl"
img_train_path = f"{drive_folder}/Dataset/Flickr8k_text/Flickr_8k.trainImages.txt"
img_test_path = f"{drive_folder}/Dataset/Flickr8k_text/Flickr_8k.devImages.txt"
captions_filename = f"{drive_folder}/training_files/captions.txt"

# ***Preprocessing***

### *Preprocess captions*

In [28]:
def load_captions(filename):
    """
    Load captions from file and create a per image caption dictionary
    :param filename:
    :return:
    """
    # read from the captions file
    file = open(filename, "r")
    text = file.read()
    file.close()

    mapping = dict()

    # process each line
    # line is in form: image_name.jpg#no caption
    for line in text.split("\n"):
        token = line.split("\t")

        if len(line) < 2:
            continue

        # first token: image id
        # rest: image caption
        img_id, img_capt = token[0], token[1:]
        # extract image id: before the .jpg part
        img_id = img_id.split('.')[0]
        # convert caption list back to string
        img_capt = ' '.join(img_capt)

        # add all the captions od the same image to image_id key
        if img_id not in mapping:
            mapping[img_id] = list()
        mapping[img_id].append(img_capt)

    return mapping


def clean_captions(captions):
    """
    Remove punctuation, hanging s and a, and tokens with numbers
    from the captions
    :param captions:
    :return:
    """
    # Prepare translation table for removing punctuation
    table = str.maketrans('', '', string.punctuation)
    for _, caption_list in captions.items():
        for i in range(len(caption_list)):
            caption = caption_list[i]
            # Tokenize i.e. split on white spaces
            caption = caption.split()
            # Convert to lowercase
            caption = [word.lower() for word in caption]
            # Remove punctuation from each token
            caption = [w.translate(table) for w in caption]
            # Remove hanging 's' and 'a'
            caption = [word for word in caption if len(word)>1]
            # Remove tokens with numbers in them
            caption = [word for word in caption if word.isalpha()]
            # Store as string
            caption_list[i] = ' '.join(caption)


def save_captions(captions_dict, to_file):
    """
    Save the captions_dict to a file,
    file: image_id caption_list per line
    :param captions_dict:
    :param to_file:
    :return:
    """
    # convert captions dictionary to string of lines
    lines = list()
    for key, caption_list in captions_dict.items():
        for caption in caption_list:
            lines.append(key + ' ' + caption)
    data = '\n'.join(lines)

    # save captions string to a file
    file = open(to_file, 'w')
    file.write(data)
    file.close()


def preprocess_captions(capt_filename=f"{drive_folder}/Dataset/Flickr8k_text/Flickr8k.token.txt",
                        clean_capt_to_file=f"{drive_folder}/training_files/captions.txt"):
    captions_dict = load_captions(capt_filename)
    clean_captions(captions_dict)
    save_captions(captions_dict, clean_capt_to_file)

In [29]:
preprocess_captions()

### *Extract image features*

In [None]:
def create_cnn_model_dict():
  cnn_model_dict = dict()

  cnn_model_dict['vgg16'] = {
      'model': vgg16.VGG16(),
      'target_size': (224, 224),
      'preprocess_input': vgg16.preprocess_input
  }

  cnn_model_dict['inception_v3'] = {
      'model': inception_v3.InceptionV3(),
      'target_size': (299, 299),
      'preprocess_input': inception_v3.preprocess_input
  }
  return cnn_model_dict

In [None]:
def extract_features(images_dir, model_type, cnn_model_dict):
    model = cnn_model_dict[model_type]['model']
    target_size = cnn_model_dict[model_type]['target_size']
    model = Model(inputs=model.inputs, outputs=model.layers[-2].output)
    model.summary()

    features_dict = dict()

    img_count = 0

    for name in listdir(images_dir):
      filename = f"{images_dir}/{name}"
      image = load_img(filename, target_size=target_size)
      image = img_to_array(image)
      image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
      image = cnn_model_dict[model_type]['preprocess_input'](image)
      features = model.predict(image, verbose=0)
      image_id = name.split('.')[0]
      features_dict[image_id] = features
      
      img_count += 1

      if img_count % 200 == 0:
        print("No. images", img_count)
        print()

      print(".", end="")

    return features_dict

def save_img_features(img_features, to_file):
  dump(img_features, open(to_file, "wb"))

def preprocess_img_features(images_dir=f"{drive_folder}/Dataset/Flickr8k_Dataset/Flicker8k_Dataset",
                            to_file=f"{drive_folder}/training_files/img_features.pkl",
                            model_type):
  cnn_model_dict = create_cnn_model_dict()
  features = extract_features(images_dir, model_type, cnn_model_dict)
  print("No. features", len(features))
  save_img_features(features, to_file)

In [None]:
preprocess_img_features(to_file=f"{drive_folder}/training_files/img_features_inc_v3.pkl")

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv2d_188 (Conv2D)             (None, 149, 149, 32) 864         input_13[0][0]                   
__________________________________________________________________________________________________
batch_normalization_188 (BatchN (None, 149, 149, 32) 96          conv2d_188[0][0]                 
__________________________________________________________________________________________________
activation_188 (Activation)     (None, 149, 149, 32) 0           batch_normalization_188[0][0]    
____________________________________________________________________________________________

# ***Load preprocessed data***

In [6]:
def load_img_ids(filename):
    """
    Load image ids from a file
    """
    file = open(filename, "r")
    text = file.read()
    file.close()

    img_ids = list()
    for line in text.split("\n"):

        if len(line) < 1:
            continue

        img_id = line.split('.')[0]
        img_ids.append(img_id)

    return img_ids

In [7]:
def load_img_features(img_features, train_ids, test_ids):
    """
    Load train and test features from a file
    :param img_features:
    :param train_ids:
    :param test_ids:
    :return:
    """
    features = load(open(img_features, "rb"))

    train_features = {train_id: features[train_id] for train_id in train_ids}
    test_features = {test_id: features[test_id] for test_id in test_ids}

    return train_features, test_features

def load_clean_captions(filename, dataset):
    """
    load captions from file and create entry for each imgId from dataset
    """
    file = open(filename, 'r')
    text = file.read()
    file.close()

    captions = dict()

    for line in text.split('\n'):

        tokens = line.split()
        img_id, img_caption = tokens[0], tokens[1:]

        if img_id in dataset:
            if img_id not in captions:
                captions[img_id] = list()

            # add startseq at the begining and endseq at the end of each caption
            caption = 'startseq ' + ' '.join(img_caption) + ' endseq'
            captions[img_id].append(caption)

    return captions

In [13]:
def load_train_test(img_features_path, captions_path, train_ids_path, test_ids_path):
    """
    Load train image features and captions, load test image features and captions
    :param img_features_path:
    :param captions_path:
    :param train_ids_path:
    :param test_ids_path:
    :return:
    """
    img_train_ids = load_img_ids(img_train_path)
    img_test_ids = load_img_ids(img_test_path)

    train_features, test_features = load_img_features(img_features_path, img_train_ids, img_test_ids)
   
    train_captions = load_clean_captions(captions_filename, img_train_ids)
    test_captions = load_clean_captions(captions_filename, img_test_ids)

    print("Train images: ", len(train_features))
    print("Train captions: ", len(train_captions))
    print("Test images: ", len(test_features))
    print("Test captions: ", len(test_captions))

    return train_features, train_captions, test_features, test_captions


In [30]:
train_features, train_captions, test_features, test_captions = load_train_test(img_features_path, captions_filename, img_train_path, img_test_path)

Train images:  6000
Train captions:  6000
Test images:  1000
Test captions:  1000


# ***Prepare data for model fitting***

In [18]:
def to_lines(captions):
    """
    Extract values from captions dictionary
    """
    all_captions = list()
    for key in captions.keys():
        [all_captions.append(d) for d in captions[key]]
    return all_captions


def create_tokenizer(captions):
    lines = to_lines(captions)
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(lines)
    return tokenizer


def calc_max_length(captions):
    lines = to_lines(captions)
    return max(len(line.split()) for line in lines)

In [31]:
tokenizer = create_tokenizer(train_captions)
vocab_size = len(tokenizer.word_index) + 1
print("Vocabulary size: ", vocab_size)
max_length = calc_max_length(train_captions)
print("Max caption length: ", max_length)

Vocabulary size:  7579
Max caption length:  34


In [21]:
def create_sequences(image, caption_list, tokenizer, max_length, vocab_size):
    """
    Generate sequences from a caption, containing just the first word, first two words etc.
    For word i in sequence, separate the caption into input=caption[:i] and next_word=caption[i];
    encode each word as a categorical value.
    :param image:
    :param caption_list:
    :param tokenizer:
    :param max_length:
    :param vocab_size:
    :return:
    """
    in_img_list, in_word_list, out_word_list = list(), list(), list()
    for capt in caption_list:
        # tokenize each caption
        seq = tokenizer.texts_to_sequences([capt])[0]
        for i in range(1, len(seq)):
            in_seq, out_seq = seq[:i], seq[i]
            in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
            # encode word to a categorical value
            out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]

            in_img_list.append(image)
            in_word_list.append(in_seq)
            out_word_list.append(out_seq)
    return in_img_list, in_word_list, out_word_list

In [22]:
def data_generator(images, captions, tokenizer, max_length, batch_size, random_seed, vocab_size):
    """
    Extract images, input word sequences and output word in batches. To be used while fitting the model.
    :param images:
    :param captions:
    :param tokenizer:
    :param max_length:
    :param batch_size:
    :param random_seed:
    :param vocab_size:
    :return:
    """
    random.seed(random_seed)

    img_ids = list(captions.keys())

    count = 0
    while True:
        if count >= len(img_ids):
            count = 0

        in_img_batch, in_seq_batch, out_word_batch = list(), list(), list()

        # get current batch indexes
        for i in range(count, min(len(img_ids), count+batch_size)):
            # current image_id
            img_id = img_ids[i]
            # current image
            img = images[img_id][0]
            # current image caption list
            captions_list = captions[img_id]
            # shuffle the captions
            random.shuffle(captions_list)
            # get word sequences and output word
            in_img, in_seq, out_word = create_sequences(img, captions_list, tokenizer, max_length, vocab_size)

            # append to batch list
            for j in range(len(in_img)):
                in_img_batch.append(in_img[j])
                in_seq_batch.append(in_seq[j])
                out_word_batch.append(out_word[j])

        count = count + batch_size
        yield [[np.array(in_img_batch), np.array(in_seq_batch), np.array(out_word_batch)]]

In [32]:
generator = data_generator(train_features, train_captions, tokenizer, max_length, 1, 10, vocab_size)
input = next(generator)
print(input[0][0].shape)
print(input[0][1].shape)
print(input[0][2].shape)

(47, 4096)
(47, 34)
(47, 7579)
