In [29]:
dataset='DAMICA'
karpathy_json_path='caption_data/dataset_DAMICA.json'
image_folder='media/images/'
captions_per_image=1
min_word_freq=5
output_folder='media/'
max_len=50

In [30]:
import os
import numpy as np
import h5py
import json
import torch
from scipy.misc import imread, imresize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample

In [31]:
with open(karpathy_json_path, 'r') as j:
    data = json.load(j)

In [32]:
len(data['images'])

14586

In [33]:
train_image_paths = []
train_image_captions = []
val_image_paths = []
val_image_captions = []
test_image_paths = []
test_image_captions = []
word_freq = Counter()

In [34]:
for img in data['images']:
    captions = []
    for c in img['sentences']:
        # Update word frequency
        word_freq.update(c['tokens'])
        if len(c['tokens']) <= max_len:
            captions.append(c['tokens'])

    if len(captions) == 0:
        continue

    path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join(
        image_folder, img['filename'])

    if img['split'] in {'train', 'restval'}:
        train_image_paths.append(path)
        train_image_captions.append(captions)
    elif img['split'] in {'val'}:
        val_image_paths.append(path)
        val_image_captions.append(captions)
    elif img['split'] in {'test'}:
        test_image_paths.append(path)
        test_image_captions.append(captions)

In [36]:
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}

In [37]:
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [39]:
base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'

# Save word map to a JSON
with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
    json.dump(word_map, j)

In [41]:
enc_captions = []
caplens = []
for j, c in enumerate(captions):
    # Encode captions
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [
        word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))

    # Find caption lengths
    c_len = len(c) + 2

    enc_captions.append(enc_c)
    caplens.append(c_len)

In [57]:
len(enc_captions[1])

52

In [63]:
word_map.get("d", word_map['<unk>'])

2630

['a',
 'big',
 'dog',
 'stands',
 'on',
 'his',
 'hand',
 'leg',
 'as',
 'tennis',
 'balls',
 'are',
 'thrown',
 'his',
 'direction']