In [3]:
import os
import numpy as np
import h5py
import json
import torch
#from scipy.misc import imread, imresize # Not supported
from imageio import imread
from PIL import Image
from random import seed, choice, sample
import pickle
import yaml
import sys
from glob import glob

In [20]:
source_data = "../../data/processed/origin_5_captions_256_hubert"
limited_ratio = 10
target_dir = source_data + f"_{limited_ratio}%"
split = "TRAIN"
base_filename = "coco_5_cap_per_img_1_min_word_freq"

In [4]:
def load_images(data_path, split):
    image_hdf5 = glob(data_path+f"/{split}*.hdf5")[0]
    image_captions = glob(data_path+f"/{split}_CAPTIONS*.json")[0]
    image_caplens = glob(data_path+f"/{split}_CAPLENS*.json")[0]
    h = h5py.File(image_hdf5, 'r')
    # images = h['images']
    with open(image_captions, "r") as f:
        captions = json.load(f)
    with open(image_caplens, "r") as f:
        caplens = json.load(f)
    return h, captions, caplens

In [29]:
dataset, captions, caplens = load_images(source_data, split)

In [13]:
print(dataset.keys())
print(dataset.attrs.keys())

<KeysViewHDF5 ['images']>
<KeysViewHDF5 ['captions_per_image']>


In [53]:
# preprocess captions and caplens
packed_captions = []
packed_caplens = []
assert len(captions)==len(caplens)
temp_caps = []
temp_lens = []
for i in range(len(captions)):
    if (i+1) % 5 != 0:
        temp_caps.append(captions[i])
        temp_lens.append(caplens[i])
    else:
        temp_caps.append(captions[i])
        temp_lens.append(caplens[i])
        packed_captions.append(temp_caps)
        packed_caplens.append(temp_lens)
        temp_caps = []
        temp_lens = []

In [71]:
# sanity check
for i in range(len(captions)):
    cap_gt = captions[i]
    cap_ref = packed_captions[i//5][i%5]
    assert cap_gt == cap_ref, i


In [76]:
ids = np.random.choice(a=3420, size=342, replace=False, p=None)

In [80]:
origin_len = dataset["images"].shape[0]
resolution = dataset["images"].shape[-1]
captions_per_image = 5
current_len = int((limited_ratio/100) * origin_len)

In [82]:
if not os.path.isdir(target_dir):
    os.makedirs(target_dir)


In [86]:
with h5py.File(os.path.join(target_dir, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
    # Make a note of the number of captions we are sampling per image
    h.attrs['captions_per_image'] = captions_per_image

    # Create dataset inside HDF5 file to store images
    images = h.create_dataset('images', (current_len, 3, resolution, resolution), dtype='uint8')

    print("\nReading %s images and captions, storing to file...\n" % split)

    enc_captions = []
    caplens = []

    for i in range(current_len):

        # Sample captions
        # imcaps[i] means captions of image i, which can be a lot
        # if imcaps[i] doesn't have enough caps
        packed_caption = packed_captions[ids[i]]
        packed_caplen = packed_caplens[ids[i]]
        img = dataset['images'][ids[i]]

        # Sanity check
        assert len(captions) == captions_per_image

        # Save image to HDF5 file
        images[i] = img

        for j, c in enumerate(packed_caption):
            enc_c = packed_caption[j]
            c_len = packed_caplen[j]
            # print(enc_c)
            enc_captions.append(enc_c)
            caplens.append(c_len)

    # Sanity check
    # images数量 X 每个image的caption == enc_captions的总长度
    assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens)

    # Save encoded captions and their lengths to JSON files
    with open(os.path.join(target_dir, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
        json.dump(enc_captions, j)

    with open(os.path.join(target_dir, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
        json.dump(caplens, j)


Reading TRAIN images and captions, storing to file...

[102, 1, 99, 83, 50, 83, 23, 24, 84, 25, 41, 77, 71, 9, 10, 35, 10, 27, 18, 80, 28, 40, 66, 57, 58, 64, 65, 17, 18, 39, 28, 66, 82, 91, 63, 64, 65, 60, 4, 21, 66, 4, 5, 50, 13, 91, 17, 4, 5, 21, 25, 41, 33, 65, 4, 5, 21, 7, 62, 87, 93, 16, 86, 98, 80, 97, 68, 2, 79, 43, 44, 45, 84, 25, 3, 11, 30, 5, 13, 57, 14, 85, 15, 93, 11, 4, 5, 13, 25, 26, 27, 12, 6, 22, 12, 6, 34, 57, 58, 26, 71, 81, 70, 29, 48, 49, 50, 49, 50, 49, 50, 75, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[102, 1, 99, 50, 95, 99, 50, 90, 50, 90, 50, 83, 50, 23, 84, 25, 46, 57, 58, 41, 33, 64, 65, 17, 18, 39, 80, 39, 40, 66, 82, 93, 64, 65, 4, 12, 5, 13, 8, 58, 59, 4, 5, 21, 7, 41, 33, 65, 60, 85, 46, 87, 81, 87, 93, 86, 98, 80, 28, 29, 22, 68, 5, 2, 79, 43, 44, 85, 3, 11, 30, 5, 13, 14, 15, 93, 17, 4, 5, 13, 25, 26, 27, 4, 12, 34, 57, 58, 26, 42, 71, 81, 38, 70, 45, 5, 21, 48, 73, 75, 10