In [None]:
# -*- coding:utf-8 -*-
"""
Description: image caption generation model training
Author: allocator
"""

In [None]:
from keras.layers import Dense
import string

In [None]:
def load_txt(filename):
    """Load the text of the picture."""
    with open(filename, 'r') as f:
        text = f.read()
    return text

In [None]:
def load_description(text):
    """Load description from the src text photo id and the description."""
    for line in text.split('\n'):
        # get all the tokens
        img_dict = {}
        tokens = line.split()
        if len(tokens) < 2:
            continue
        # if the tokens just contains the img id so continue
        img_id, img_des_token = tokens[0], tokens[1:]
        img_des = ' '.join(img_des_token)
        if img_id not in img_dict:
            img_dict[img_id] = img_des
        return img_dict

In [None]:
def clean_descriptions(description):
    """Clean the description of each picture."""
    # generate the translate table to remove the punctuation of the word
    trans_table = str.maketrans('', '', string.punctuation)
    for key, desc in description.items():
        # split the description
        desc = desc.split()
        # change the word to lower
        desc = [word.lower() for word in desc]
        # use the table to remove the punctuation
        desc = [word.translate(trans_table) for word in desc]
        # remove the 'a' and 's'
        desc = [word for word in desc in len(word) > 1]
        description[key] = ' '.join(desc)
        return description

In [None]:
# test current result
filename = 'img_des_tokens.txt'
doc = load_txt(filename)
descriptions = load_description(doc)
# length of the description dict
print(' length of the description is %d ' % len(descriptions))
clean_desc = clean_descriptions(descriptions)
vocabulary = set(' '.join(clean_desc.values()).split())
print(' the vocabulary size is %d ' % len(vocabulary))

In [None]:
# save the description
def save_descriptions(filename, descriptions):
    """Save the descriptions of the file."""
    lines = []
    for key, desc in descriptions.items():
        lines.append(key + ' ' + desc)
    data = '\n'.join(lines)
    with open(filename, 'w') as f:
        f.write(data)