In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import string

import torch


In [11]:
df = pd.read_csv('captions.txt')

In [12]:
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [13]:
def preprocess(caption):
    # prepare translation table for removing punctuation
    table = str.maketrans('', '', string.punctuation)

    # tokenize
    caption = caption.split()
    # convert to lower case
    caption = [word.lower() for word in caption]
    # remove punctuation from each token
    caption = [w.translate(table) for w in caption]
    # remove hanging 's' and 'a'
    caption = [word for word in caption if len(word)>1]
    # remove tokens with numbers in them
    caption = [word for word in caption if word.isalpha()]

    caption = ['<start>'] + caption + ['<end>']

    return ' '.join(caption)


In [14]:
df['caption_preprocessed'] = [preprocess(caption) for caption in df['caption']]

In [15]:
df.head()

Unnamed: 0,image,caption,caption_preprocessed
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,<start> child in pink dress is climbing up set...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,<start> girl going into wooden building <end>
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,<start> little girl climbing into wooden playh...
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,<start> little girl climbing the stairs to her...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,<start> little girl in pink dress going into w...


In [16]:
word_to_idx = {}

word_to_idx['<pad>'] = 0
word_to_idx['<start>'] = 1
word_to_idx['<end>'] = 2
word_to_idx['<unknown>'] = 3
i = 4
for caption in df['caption_preprocessed']:
    for word in caption.split():
        if word not in word_to_idx:
            word_to_idx[word] = i
            i += 1

In [17]:
max_len = 0
for caption in df['caption_preprocessed']:
    curr_len = len(caption.split())
    if curr_len > max_len:
        max_len = curr_len


print('max_len:', max_len)

max_len: 34


In [18]:
len(word_to_idx)

8767

In [19]:
image_list = []

for image in df['image']:
    if image not in image_list:
        image_list.append(image)

In [20]:
content = ""
for image in image_list:
    content = content + image + "\n"
f = open('image_list.txt', 'w')
f.write(content)
f.close()

In [21]:
X = []
for image in df['image']:
    X.append(image_list.index(image))

X = torch.tensor(X)

In [22]:
torch.save(X, 'X.pt')

In [23]:
y = []

for caption in df['caption_preprocessed']:
    _y = []
    for word in caption.split():
        _y.append(word_to_idx[word])
    
    y.append(_y)


In [24]:
for i in range(len(y)):
    y_len = len(y[i])
    pad_len = max_len - y_len
    y[i] = [word_to_idx['<pad>'] for j in range(pad_len)] + y[i]

In [25]:
y = torch.tensor(y)

In [26]:
torch.save(y, 'y.pt')

In [27]:
content = ""
for word, idx in word_to_idx.items():
    content = content + word + ',' + str(idx) + '\n'

f = open('word_to_idx.txt', 'w')
f.write(content)
f.close()