In [40]:
import os
import numpy as np
import h5py
import json
import torch
from scipy.misc import imread, imresize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample
import pandas as pd
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
from random import randint

import warnings
warnings.filterwarnings('ignore')

[nltk_data] Downloading package punkt to /home/as3ek/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [57]:
dataset = pd.read_csv('data/caption_data_0_100.csv') 
image_folder = 'data/images/' 
captions_per_image = 7000
min_word_freq = 5
output_folder = 'data/proc_data_files/'
max_len = 20

train_image_paths = []
train_image_captions = []
val_image_paths = []
val_image_captions = []
test_image_paths = []
test_image_captions = []
num_images_to_train = 20
word_freq = Counter()

In [58]:
for img in dataset['ImageName'][:num_images_to_train]:
    captions = []
    for c in dataset[dataset['ImageName'] == img]['Caption']:
        # Updating word freq
        c = str(c)
        
        tokens = word_tokenize(c)
        tokens = [token.lower() for token in tokens]
        
        word_freq.update(tokens)
        if len(tokens) <= max_len:
            captions.append(tokens)
    
    if len(captions) == 0:
        continue
    
    path = os.path.join(img)
    
    if randint(0, 10) < 9:
        train_image_paths.append(path)
        train_image_captions.append(captions)
    
    else:
        val_image_paths.append(path)
        val_image_captions.append(captions)
        
    if randint(0, 10) < 5:
        test_image_paths.append(path)
        test_image_captions.append(captions)
        
# Sanity check
assert len(train_image_paths) == len(train_image_captions)
assert len(val_image_paths) == len(val_image_captions)
assert len(test_image_paths) == len(test_image_captions)

In [59]:
# Create word map
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [None]:
# Create a base/root name for all output files
base_filename = 'meme_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'

# Save word map to a JSON
with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
    json.dump(word_map, j)
    
# Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files
seed(123)
for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'),
                               (val_image_paths, val_image_captions, 'VAL'),
                               (test_image_paths, test_image_captions, 'TEST')]:

    with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
        # Make a note of the number of captions we are sampling per image
        h.attrs['captions_per_image'] = captions_per_image

        # Create dataset inside HDF5 file to store images
        images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8')

        print("\nReading %s images and captions, storing to file...\n" % split)

        enc_captions = []
        caplens = []

        for i, path in enumerate(tqdm(impaths)):

            # Sample captions
            if len(imcaps[i]) < captions_per_image:
                captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))]
            else:
                captions = sample(imcaps[i], k=captions_per_image)

            # Sanity check
            assert len(captions) == captions_per_image

            # Read images
            img = imread(impaths[i])
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
                img = np.concatenate([img, img, img], axis=2)
            img = imresize(img, (256, 256))
            img = img.transpose(2, 0, 1)
            assert img.shape == (3, 256, 256)
            assert np.max(img) <= 255

            # Save image to HDF5 file
            images[i] = img

            for j, c in enumerate(captions):
                # Encode captions
                enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [
                    word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))

                # Find caption lengths
                c_len = len(c) + 2

                enc_captions.append(enc_c)
                caplens.append(c_len)

        # Sanity check
        assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens)

        # Save encoded captions and their lengths to JSON files
        with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
            json.dump(enc_captions, j)

        with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
            json.dump(caplens, j)

  6%|▌         | 1/17 [00:00<00:02,  5.94it/s]


Reading TRAIN images and captions, storing to file...



100%|██████████| 17/17 [00:01<00:00, 14.14it/s]
100%|██████████| 3/3 [00:00<00:00, 26.92it/s]


Reading VAL images and captions, storing to file...




 27%|██▋       | 3/11 [00:00<00:00, 28.88it/s]


Reading TEST images and captions, storing to file...



100%|██████████| 11/11 [00:00<00:00, 27.78it/s]
