In [1]:
import os

import h5py
import json
import random

import imageio
from PIL import Image
from skimage.transform import resize

import torch
import numpy as np

from tqdm import tqdm
from collections import Counter

# HDF5, and JSON with Encoded Captions, Caption Lengths

In [2]:
class PrepareCOCOData():
    def __init__(self, json_path, image_dir, output_dir, word_count_thresh=5, max_length=16, crop=False):
        """
        NOTE: This class creates COCO input data for image captioning.
        json_path: str, path to json file which has data splits and captions. 
            NOTE: Andrej Karpathy created this file for COCO 2014 dataset. Source: 
            http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip
        image_dir: str, directory path to train/val/test images
        output_dir: str, directory path to save model input files
        word_count_thresh: int, words occuring less frequently than this threshold are mapped 
            as <UNK>
        max_length: int, captions longer than this are clipped.
        """
        self.json_path = json_path
        self.image_dir = image_dir
        self.output_dir = output_dir
        self.word_count_thresh = word_count_thresh
        self.max_length = max_length
        self.crop = crop
        self.num_captions = 5
        self.base_filename = 'COCO_' + str(self.word_count_thresh) + '_WordCountThresh'
        
        # For storing word frequency
        self.word_frequency = Counter()
    
    def read_json(self):
        with open(self.json_path, 'r') as f:
            self.json_data = json.load(f)
            
    def read_reshape_image(self, img_path):
        # Read images
        img = imageio.imread(img_path)

        # If image is gray scale then add channels
        if len(img.shape) == 2:
            img = img[:, :, np.newaxis]
            img = np.concatenate([img, img, img], axis=2)
            
        # Resize image and return it
        img = resize(img, (224, 224), mode='constant', anti_aliasing=True)
        img = img.transpose(2, 0, 1)  # PyTorch: [C, W, H]
        return img
    
    def read_reshape_crop_image(self, img_path):
        # Read images
        img = Image.open(img_path)
        
        # Get cropping dimensions
        width, height = img.size
        if width > height:
            left = (width - height) / 2
            right = width - left
            top = 0
            bottom = height
        else:
            top = (height - width) / 2
            bottom = height - top
            left = 0
            right = width
            
        # Crop, resize and normalize image
        img = img.crop((left, top, right, bottom))
        img = img.resize([224, 224], Image.ANTIALIAS)
        img = np.array(img)/255.0
        
        # If image is gray scale then add channels
        if len(img.shape) == 2:
            img = img[:, :, np.newaxis]
            img = np.concatenate([img, img, img], axis=2)
        
        img = img.transpose(2, 0, 1)  # PyTorch: [C, W, H]
        return img
    
    def write_word_to_idx(self):
        words = [w for w, n in self.word_frequency.items() if n > self.word_count_thresh]
        self.word2idx = {w: idx + 1 for idx, w in enumerate(words)}
        self.word2idx['<UNK>'] = len(self.word2idx) + 1
        self.word2idx['<START>'] = len(self.word2idx) + 1
        self.word2idx['<END>'] = len(self.word2idx) + 1
        self.word2idx['<PAD>'] = 0
        
        # Write word to index mapping to a json
        with open(os.path.join(self.output_dir, 'WORD2IDX_' + self.base_filename + '.json'), 'w') as f:
            json.dump(self.word2idx, f)
    
    def process_img_captions(self, captions):
        temp_caps = []
        temp_lens = []
        
        # Encode captions and compute lengths (used by RNN in forward pass)
        for caption in captions:
            start = [self.word2idx['<START>']]
            middle = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in caption]
            end = [self.word2idx['<END>']]
            pad = [self.word2idx['<PAD>']] * (self.max_length - len(caption))
            encoded_caption = start + middle + end + pad
            temp_caps.append(encoded_caption)
            temp_lens.append(len(caption) + 2) # +2 for <START> and <END> and discard <PAD> counts
        return temp_caps, temp_lens
     
    def get_image_paths_and_captions(self):
        # Lists to store image paths and captions
        self.train_image_paths = []
        self.train_image_captions = []
        self.val_image_paths = []
        self.val_image_captions = []
        
        # Read json
        self.read_json()
        
        # Extract information from json and populate lists
        for img in self.json_data['images']:
            captions = []
            for s in img['sentences']:
                
                # Update word frequency
                self.word_frequency.update(s['tokens'])
                
                # Select captions if their length is within max length threshold
                if len(s['tokens']) <= self.max_length:  # 96% of captions have max length 16
                    captions.append(s['tokens'])
                    
            if len(captions) == 0:
                continue
                
            # Generate image path
            img_path = os.path.join(self.image_dir, img['filepath'], img['filename'])
            
            # Populate lists
            if img['split'] in ['train', 'restval']:
                self.train_image_paths.append(img_path)
                self.train_image_captions.append(captions)
            elif img['split'] in ['val']:
                self.val_image_paths.append(img_path)
                self.val_image_captions.append(captions)
            elif img['split'] in ['test']:
                continue  # Not interested in test for now
                
    def process_and_write(self):
        """
        Sample captions for each image, resize image and save images to HDF5 file, and captions
        and their lengths to JSON files.
        """
        split_sets = [(self.train_image_paths, self.train_image_captions, 'TRAIN'),
                      (self.val_image_paths, self.val_image_captions, 'VAL')]
        
        for img_paths, img_caps, split in split_sets:
            PATH = os.path.join(self.output_dir, split + '_IMAGES_' + self.base_filename + '.hdf5')
            print('Processing: {} data'.format(split))
            
            with h5py.File(PATH, 'a') as hf:
                
                # Create dataset inside HDF5 file to store images
                images = hf.create_dataset('images', (len(img_paths), 3, 224, 224), dtype='float')
                
                # List to store encoded captions (Word to index) and caption lengths
                encoded_captions = []
                
                for i, path in enumerate(tqdm(img_paths)):
                    
                    # Sample captions ()
                    if len(img_caps[i]) < self.num_captions:
                        captions = img_caps[i] + [random.choice(img_caps[i]) for _ in 
                                                  range(self.num_captions - len(img_caps[i]))]
                    else:
                        captions = random.sample(img_caps[i], k=self.num_captions)
                    
                    # Save images to HDF5
                    if self.crop:
                        images[i] = self.read_reshape_crop_image(path)
                    else:
                        images[i] = self.read_reshape_image(path)
                    
                    # Process captions
                    temp = []
                    temp_caps, temp_lens = self.process_img_captions(captions)
                    temp = [temp_caps, temp_lens]
                    encoded_captions.append(temp)
                    
                # Save encoded captions and their lengths to JSON files
                PATH = os.path.join(self.output_dir, split + '_CAPTIONS_' + self.base_filename + '.json')
                with open(PATH, 'w') as cf:
                    json.dump(encoded_captions, cf)
                
    def prepare(self):
        
        # Read json
        self.read_json()
        print('--- Done: Read JSON File ---')
        
        # Get image paths and captions
        self.get_image_paths_and_captions()
        print('--- Done: Extracted Image Paths and Captions ---')
        
        # Write word to index mapping
        self.write_word_to_idx()
        print('--- Done: Wrote Word-to-Index JSON ---')
            
        # Process data and write HDF5 and other files
        self.process_and_write()
        print('--- Done: Wrote HDF5 and JSON ---')

In [3]:
# Prepare data for training (NOTE: This takes 45 mins to 1 hour)
json_path = './dataset_coco.json' # Andrej Karpathy Splits
image_dir = './dataset/'
output_dir = './data/'
coco = PrepareCOCOData(json_path, image_dir, output_dir, crop=False)

coco.prepare()

--- Done: Read JSON File ---


  0%|          | 3/113287 [00:00<1:20:09, 23.56it/s]

--- Done: Extracted Image Paths and Captions ---
--- Done: Wrote Word-to-Index JSON ---
Processing: TRAIN data


100%|██████████| 5000/5000 [02:49<00:00, 29.54it/s]0it/s]


--- Done: Wrote HDF5 and JSON ---


# JSON with Image Path, Encoded Captions, Caption Lengths (Takes Forever)

In [None]:
class PrepareDataCOCO(object):
    def __init__(self, json_path, image_dir, output_dir, word_count_thresh=5, max_length=16):
        """
        NOTE: This class creates COCO input data for image captioning.
        json_path: str, path to json file which has data splits and captions. 
            NOTE: Andrej Karpathy created this file for COCO 2014 dataset. Source: 
            http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip
        image_dir: str, directory path to train/val/test images
        output_dir: str, directory path to save model input files
        word_count_thresh: int, words occuring less frequently than this threshold are mapped 
            as <UNK>
        max_length: int, captions longer than this are clipped.
        """
        self.json_path = json_path
        self.image_dir = image_dir
        self.output_dir = output_dir
        self.word_count_thresh = word_count_thresh
        self.max_length = max_length
        self.num_captions = 5
        self.base_filename = 'COCO_' + str(self.word_count_thresh) + '_WordCountThresh'
        
        # For storing word frequency
        self.word_frequency = Counter()
        
    def read_json(self):
        with open(self.json_path, 'r') as f:
            self.json_data = json.load(f)
            
    def get_image_paths_and_captions(self):
        # Lists to store image paths and captions
        self.train_image_paths = []
        self.train_image_captions = []
        self.val_image_paths = []
        self.val_image_captions = []
        
        # Read json
        self.read_json()
        
        # Extract information from json and populate lists
        for img in self.json_data['images']:
            captions = []
            for s in img['sentences']:
                # Update word frequency
                self.word_frequency.update(s['tokens'])
                
                # Select captions if their length is within max length threshold
                if len(s['tokens']) <= self.max_length:  # 96% of captions have max length 16
                    captions.append(s['tokens'])
                    
            if len(captions) == 0:
                continue
                
            # Generate image path
            img_path = os.path.join(self.image_dir, img['filepath'], img['filename'])
            
            # Populate lists
            if img['split'] in ['train', 'restval']:
                self.train_image_paths.append(img_path)
                self.train_image_captions.append(captions)
            elif img['split'] in ['val']:
                self.val_image_paths.append(img_path)
                self.val_image_captions.append(captions)
            elif img['split'] in ['test']:
                continue  # Not interested in test for now
                
    def write_word_to_idx(self):
        words = [w for w, n in self.word_frequency.items() if n > self.word_count_thresh]
        self.word2idx = {w: idx + 1 for idx, w in enumerate(words)}  # idx is 1-indexed
        self.word2idx['<UNK>'] = len(self.word2idx) + 1
        self.word2idx['<START>'] = len(self.word2idx) + 1
        self.word2idx['<END>'] = len(self.word2idx) + 1
        self.word2idx['<PAD>'] = len(self.word2idx) + 1
        
        # Write word to index mapping to a json
        with open(os.path.join(self.output_dir, 'WORD2IDX_' + self.base_filename + '.json'), 'w') as f:
            json.dump(self.word2idx, f)
            
    def process_img_captions(self, captions):
        temp_caps = []
        temp_lens = []
        # Encode captions and compute lengths (used by RNN in forward pass)
        for caption in captions:
            start = [self.word2idx['<START>']]
            middle = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in caption]
            end = [self.word2idx['<END>']]
            pad = [self.word2idx['<PAD>']] * (self.max_length - len(caption))
            encoded_caption = start + middle + end + pad
            temp_caps.append(encoded_caption)
            temp_lens.append(len(caption) + 2) # +2 for <START> and <END> and discard <PAD> counts
        return temp_caps, temp_lens
            
    def process_and_write(self):
        """
        Sample captions for each image, resize image and save images to HDF5 file, and captions
        and their lengths to JSON files.
        """
        split_sets = [(self.train_image_paths, self.train_image_captions, 'TRAIN'),
                      (self.val_image_paths, self.val_image_captions, 'VAL')]
        
        for img_paths, img_caps, split in split_sets:
            data = []
            
            for i, img_path in enumerate(tqdm(img_paths)):
                
                img_cap_data = {}
                img_cap_data['image_path'] = img_path
                
                # Number of captions per image is in range [5, 7], after removing more than threshold needs to sample
                if len(img_caps[i]) < self.num_captions:
                    captions = img_caps[i] + [random.choice(img_caps[i]) for _ in range(self.num_captions - len(img_caps[i]))]        
                else:
                    captions = random.sample(img_caps[i], k=self.num_captions)
                
                # Encode captions for the current image
                temp_caps, temp_lens = self.process_img_captions(captions)
                        
                img_cap_data['captions'] = temp_caps
                img_cap_data['caption_lengths'] = temp_lens
                
                data.append(img_cap_data)
                
                # Save encoded captions and their lengths to JSON files
                PATH = os.path.join(self.output_dir, split + '_DATA_' + self.base_filename + '.json')
                with open(PATH, 'w') as cf:
                    json.dump(data, cf)
                    
    def prepare(self):
        
        # Read json
        self.read_json()
        print('--- Done: Read JSON File ---')
        
        # Get image paths and captions
        self.get_image_paths_and_captions()
        print('--- Done: Extracted Image Paths and Captions ---')
        
        # Write word to index mapping
        self.write_word_to_idx()
        print('--- Done: Wrote Word-to-Index JSON ---')
            
        # Process data and write HDF5 and other files
        self.process_and_write()
        print('--- Done: Wrote JSON ---')

In [None]:
# # Prepare data for training (NOTE: This takes few hours. TOO Slow!!!
# json_path = '/home/ankoor/caption/dataset/captions/dataset_coco.json'
# image_dir = '/home/ankoor/caption/dataset/'
# output_dir = '/home/ankoor/caption/data/'
# coco = PrepareDataCOCO(json_path, image_dir, output_dir)
# coco.prepare()

# Scratch

### JSON Format

```
{'dataset': 'coco',
 'images': [{'filename': 'img_0.jpg',
             'imgid': 0,
             'sentences': [{'imgid': 0,
                            'raw': 'Two dogs playing in show.'
                            'sentid': 0,
                            'tokens': ['two', 'dogs', 'playing', 'in', 'snow']},
                            {'imgid': 0,
                            'raw': '...'
                            'sentid': ...,
                            'tokens': [...]},
                            {'imgid': 0,
                            'raw': 'Two dogs running in show.'
                            'sentid': 4,
                            'tokens': ['two', 'dogs', 'running', 'in', 'snow']}],
             'sentids': [0, 1, 2, 3, 4],
             'split': 'train'},
             {'filename': 'img_1.jpg',
             'imgid': 1,
             'sentences': [{...}, ...],
             'sentids': [0, 1, 2, 3, 4],
             'split': 'test'},                    
             {...}]
}  
```

In [4]:
# Read JSON file
json_path = './dataset_coco.json'
with open(json_path) as f:
    json_data = json.load(f)

In [5]:
filepaths = []
splits = []
sfreq = Counter()

for i, img in enumerate(json_data['images']):
    filepaths.append(img['filepath'])
    splits.append(img['split'])
    sfreq[img['split']] = 1
    if img['split'] in sfreq.keys():
        sfreq[img['split']] += 1
    
print(set(filepaths))
print(set(splits))

sfreq.update(splits)
print(sfreq)

{'val2014', 'train2014'}
{'val', 'test', 'restval', 'train'}
Counter({'train': 82785, 'restval': 30506, 'val': 5002, 'test': 5002})


In [6]:
# Check content
for i, img in enumerate(json_data['images']):
    print(img)
    if i == 0:
        break

{'sentences': [{'tokens': ['a', 'man', 'with', 'a', 'red', 'helmet', 'on', 'a', 'small', 'moped', 'on', 'a', 'dirt', 'road'], 'imgid': 0, 'raw': 'A man with a red helmet on a small moped on a dirt road. ', 'sentid': 770337}, {'tokens': ['man', 'riding', 'a', 'motor', 'bike', 'on', 'a', 'dirt', 'road', 'on', 'the', 'countryside'], 'imgid': 0, 'raw': 'Man riding a motor bike on a dirt road on the countryside.', 'sentid': 771687}, {'tokens': ['a', 'man', 'riding', 'on', 'the', 'back', 'of', 'a', 'motorcycle'], 'imgid': 0, 'raw': 'A man riding on the back of a motorcycle.', 'sentid': 772707}, {'tokens': ['a', 'dirt', 'path', 'with', 'a', 'young', 'person', 'on', 'a', 'motor', 'bike', 'rests', 'to', 'the', 'foreground', 'of', 'a', 'verdant', 'area', 'with', 'a', 'bridge', 'and', 'a', 'background', 'of', 'cloud', 'wreathed', 'mountains'], 'imgid': 0, 'raw': 'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreat

In [7]:
# Check split types
splits = []
for img in json_data['images']:
    split = img['split']
    splits.append(split)
print(set(splits))

{'val', 'test', 'restval', 'train'}


In [8]:
# Check number of captions per image
num_captions = []
for img in json_data['images']:
    num = len(img['sentids'])
    num_captions.append(num)
print(set(num_captions))

{5, 6, 7}


In [9]:
# Check minimum/maximum caption length
min_len = []
max_len = []

for img in json_data['images']:
    temp = []
    for ss in img['sentences']:
        temp.append(len(ss['tokens']))
    min_len.append(min(temp))
    max_len.append(max(temp))

print('Min caption length: ', min(max_len))
print('Max caption length: ', max(max_len))  # 49 is large number!

Min caption length:  8
Max caption length:  49


In [10]:
# Distribution of lengths
sent_lengths = {}
for img in json_data['images']:
    for ss in img['sentences']:
        tokens = ss['tokens']
        n_tokens = len(tokens)
        sent_lengths[n_tokens] = sent_lengths.get(n_tokens, 0) + 1
        
max_length = max(sent_lengths.keys())
print('Max length of sentence in raw data: ', max_length)

sum_length = sum(sent_lengths.values())
percents = []
for i in range(max_length):
    pct = sent_lengths.get(i, 0)*100.0/sum_length
    percents.append(pct)
    print('Length: {} \tCount: {} \tPercent: {}'.format(i, sent_lengths.get(i, 0), pct))

Max length of sentence in raw data:  49
Length: 0 	Count: 0 	Percent: 0.0
Length: 1 	Count: 0 	Percent: 0.0
Length: 2 	Count: 0 	Percent: 0.0
Length: 3 	Count: 0 	Percent: 0.0
Length: 4 	Count: 0 	Percent: 0.0
Length: 5 	Count: 1 	Percent: 0.00016213578223218817
Length: 6 	Count: 14 	Percent: 0.0022699009512506343
Length: 7 	Count: 4851 	Percent: 0.7865206796083448
Length: 8 	Count: 101387 	Percent: 16.438460553174863
Length: 9 	Count: 134531 	Percent: 21.812288919478508
Length: 10 	Count: 132558 	Percent: 21.4923950211344
Length: 11 	Count: 95206 	Percent: 15.436299283197707
Length: 12 	Count: 60590 	Percent: 9.823807045448282
Length: 13 	Count: 35233 	Percent: 5.712530015386686
Length: 14 	Count: 20016 	Percent: 3.2453098171594785
Length: 15 	Count: 11476 	Percent: 1.8606702368965915
Length: 16 	Count: 6922 	Percent: 1.1223038846112066
Length: 17 	Count: 4313 	Percent: 0.6992916287674276
Length: 18 	Count: 2755 	Percent: 0.4466840800496784
Length: 19 	Count: 1913 	Percent: 0.31016575

In [11]:
# Best length
n = 15
print(sum(percents[:n+1])) 
# 96.6 percet of captions have length 15. So using this to create data

96.61071360821835


In [12]:
# Check word frequency
word_freq = Counter()
for img in json_data['images']:
    temp = []
    for ss in img['sentences']:
        word_freq.update(ss['tokens'])

        
# Some stats
print('Most common 20 words: ', word_freq.most_common(20))
print('Total words: ', sum(word_freq.values()))

Most common 20 words:  [('a', 1019785), ('on', 224758), ('of', 212689), ('the', 206178), ('in', 191793), ('with', 161216), ('and', 146755), ('is', 102390), ('man', 75957), ('to', 71183), ('sitting', 55190), ('an', 51987), ('two', 50467), ('at', 44506), ('standing', 44297), ('people', 43707), ('are', 42776), ('next', 38867), ('white', 37898), ('woman', 35372)]
Total words:  6454115


In [13]:
# Map all words that occur <= 5 times to a special UNK token [ruotianluo/ImageCaptioning.pytorch]
count_thresh = 5
bad_words = [w for w, n in word_freq.items() if n <= count_thresh]
vocab = [w for w, n in word_freq.items() if n > count_thresh]

word2idx = {w: idx + 1 for idx, w in enumerate(vocab)}
word2idx['<UNK>'] = len(word2idx) + 1
word2idx['<BEG>'] = len(word2idx) + 1
word2idx['<END>'] = len(word2idx) + 1
word2idx['<PAD>'] = len(word2idx) + 1

print('Number of bad words: ', len(bad_words))
print('Percent of bad words: {}/{} = {}'.format(len(bad_words), len(word_freq), len(bad_words)*100.0/len(word_freq)))
print('Number of words in vocabulary: ', len(vocab))

Number of bad words:  18443
Percent of bad words: 18443/27929 = 66.0353038060797
Number of words in vocabulary:  9486


In [14]:
img_caps = [['a', 'woman', 'wearing', 'a', 'net', 'on', 'her', 'head', 'cutting', 'a', 'cake'], 
            ['a', 'woman', 'cutting', 'a', 'large', 'white', 'sheet', 'cake'], 
            ['a', 'woman', 'wearing', 'a', 'hair', 'net', 'cutting', 'a', 'large', 'sheet', 'cake'], 
            ['there', 'is', 'a', 'woman', 'that', 'is', 'cutting', 'a', 'white', 'cake'], 
            ['a', 'woman', 'marking', 'a', 'cake', 'with', 'the', 'back', 'of', 'a', 'chefs', 'knife']]

num_captions = 5
max_len = 16

if len(img_caps) < num_captions:
    caps = img_caps + [random.choice(img_caps) for _ in range(num_captions - len(img_caps))]
else:
    caps = random.sample(img_caps, k=num_captions)

for cap in caps:
    print(cap)
    
# Encode captions
temp_cap = []
temp_len = []
for c in cap:
    start = [word2idx['<BEG>']]
    middle = [word2idx.get(w, word2idx['<UNK>']) for w in c]
    end = [word2idx['<END>']]
    pad = [word2idx['<PAD>']] * (max_len - len(c))
    ecap = start + middle + end + pad
    temp_cap.append(ecap)
    temp_len.append(len(c) + 2)
    
print(temp_cap)
print(temp_len)

['a', 'woman', 'wearing', 'a', 'net', 'on', 'her', 'head', 'cutting', 'a', 'cake']
['there', 'is', 'a', 'woman', 'that', 'is', 'cutting', 'a', 'white', 'cake']
['a', 'woman', 'cutting', 'a', 'large', 'white', 'sheet', 'cake']
['a', 'woman', 'marking', 'a', 'cake', 'with', 'the', 'back', 'of', 'a', 'chefs', 'knife']
['a', 'woman', 'wearing', 'a', 'hair', 'net', 'cutting', 'a', 'large', 'sheet', 'cake']
[[9488, 7186, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 2633, 4674, 3962, 7186, 4342, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 2633, 6628, 7186, 2224, 72, 4342, 2834, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 7186, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 201, 7186, 72, 2224, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 4342, 6628, 1851, 9489, 9490, 9490, 9490, 

In [15]:
def process_img_captions(self, captions):
        temp_caps = []
        temp_lens = []
        # Encode captions and compute lengths (used by RNN in forward pass)
        for caption in enumerate(captions):
            start = [self.word2idx['<START>']]
            middle = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in caption]
            end = [self.word2idx['<END>']]
            pad = [self.word2idx['<PAD>']] * (self.max_length - len(caption))
            encoded_caption = start + middle + end + pad
            temp_caps.append(encoded_caption)
            temp_lens.append(len(caption) + 2) # +2 for <START> and <END> and discard <PAD> counts
        return temp_caps, temp_lens

In [16]:
captions = [['a', 'woman', 'cutting', 'a', 'large', 'white', 'sheet', 'cake'], 
            ['a', 'woman', 'wearing', 'a', 'hair', 'net', 'cutting', 'a', 'large', 'sheet', 'cake'], 
            ['a', 'woman', 'marking', 'a', 'cake', 'with', 'the', 'back', 'of', 'a', 'chefs', 'knife'], 
            ['there', 'is', 'a', 'woman', 'that', 'is', 'cutting', 'a', 'white', 'cake'], 
            ['a', 'woman', 'wearing', 'a', 'net', 'on', 'her', 'head', 'cutting', 'a', 'cake']]

temp_caps = []
temp_lens = []
for caption in captions:
    start = [word2idx['<BEG>']]
    middle = [word2idx.get(word, word2idx['<UNK>']) for word in caption]
    end = [word2idx['<END>']]
    pad = [word2idx['<PAD>']] * (16 - len(caption))
    encoded_caption = start + middle + end + pad
    temp_caps.append(encoded_caption)
    temp_lens.append(len(caption) + 2)
    
print(temp_caps)
print(temp_lens)

[[9488, 7186, 1381, 8178, 7186, 9171, 9338, 2523, 7274, 9489, 9490, 9490, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 7186, 1381, 8868, 7186, 7211, 643, 8178, 7186, 9171, 2523, 7274, 9489, 9490, 9490, 9490, 9490, 9490], [9488, 7186, 1381, 387, 7186, 7274, 1794, 534, 337, 6737, 7186, 1995, 1686, 9489, 9490, 9490, 9490, 9490], [9488, 7963, 7056, 7186, 1381, 4904, 7056, 8178, 7186, 9338, 7274, 9489, 9490, 9490, 9490, 9490, 9490, 9490], [9488, 7186, 1381, 8868, 7186, 643, 2251, 4183, 7196, 8178, 7186, 7274, 9489, 9490, 9490, 9490, 9490, 9490]]
[10, 13, 14, 12, 13]
