# Image Captioning

Observing that people who are blind have relied on (human-based) image captioning services to learn about images they take for nearly a decade, we introduce the first image captioning dataset to represent this real use case. This new dataset, which we call VizWiz-Captions, consists of 39,181 images originating from people who are blind that are each paired with 5 captions. Our proposed challenge addresses the task of predicting a suitable caption given an image. Ultimately, we hope this work will educate more people about the technological needs of blind people while providing an exciting new opportunity for researchers to develop assistive technologies that eliminate accessibility barriers for blind people (https://vizwiz.org/tasks-and-datasets/image-captioning/).

The goal of this Challenge is to create a single model similar to https://arxiv.org/pdf/1411.4555.pdf to get reasonable results on this task.

In [None]:
from vizwiz_api.vizwiz import VizWiz
from vizwiz_eval_cap.eval import VizWizEvalCap
import matplotlib.pyplot as plt
from PIL import Image
import skimage.io as io
import pylab
import numpy as np

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import json
from jsonpath_ng import jsonpath, parse
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.3f')

In [None]:
# install spacy: !pip install spacy
# install 'en_core_web_sm': !python -m spacy download en_core_web_sm

In [None]:
 def get_alloc_dicts(set_name, vizwiz=None):
    # be sure if `vizwiz` is set, that it contains the `set_name` dataset
    if (set_name != 'train') and (set_name != 'val') and (set_name != 'test'):
        raise Exception('only "train", "val" or "test" is a valid `set_name`')
    
    if not isinstance(vizwiz, VizWiz):
        ann_path = './annotations/'+set_name+'.json'
        vizwiz = VizWiz(ann_path, ignore_rejected=True, ignore_precanned=True)
    
    img_path_prefix = './images/'+set_name+'/'
    img_ids_anns = np.unique([vizwiz.anns[i]['image_id'] for i in vizwiz.anns])
    img_ids_imgs = np.unique([vizwiz.imgs[i]['id'] for i in vizwiz.imgs])
    img_ids_with_capitions = np.array([_id for _id in img_ids_imgs if _id in img_ids_anns])
    imgIdx_enumIdx = {imgIdx:idx for idx, imgIdx in enumerate(img_ids_with_capitions)}
    
    imgIdx_imgPath = {vizwiz.imgs[i]['id']:img_path_prefix+vizwiz.imgs[i]['file_name'] for i in vizwiz.imgs if vizwiz.imgs[i]['id'] in img_ids_with_capitions}
    capIdx_imgIdx = {vizwiz.anns[i]['id']:vizwiz.anns[i]['image_id'] for i in vizwiz.anns}
    enumIdx_capIdx = {idx:vizwiz.anns[i]['id'] for idx, i in enumerate(vizwiz.anns)}
    capIdx_cap = {vizwiz.anns[i]['id']:vizwiz.anns[i]['caption'] for i in vizwiz.anns}
    
    def get_X_idx(idx):
        capIdx = enumIdx_capIdx[idx]
        imgIdx = capIdx_imgIdx[capIdx]
        X_idx = imgIdx_enumIdx[imgIdx]
        return X_idx
        
    return imgIdx_imgPath, capIdx_imgIdx, enumIdx_capIdx, capIdx_cap, get_X_idx

In [None]:
ann_train = './annotations/train.json'
vizwiz_train = VizWiz(ann_train, ignore_rejected=True, ignore_precanned=True)
imgIdx_imgPath_train, capIdx_imgIdx_train, enumIdx_capIdx_train, capIdx_cap_train, get_X_idx_train = get_alloc_dicts('train', vizwiz_train)

ann_val = './annotations/val.json'
vizwiz_val = VizWiz(ann_val, ignore_rejected=True, ignore_precanned=True)
imgIdx_imgPath_val, capIdx_imgIdx_val, enumIdx_capIdx_val, capIdx_cap_val, get_X_idx_val = get_alloc_dicts('val', vizwiz_val)

In [None]:
cap1_idx = enumIdx_capIdx_train[0]
cap1 = capIdx_cap_train[cap1_idx]
print('caption:', cap1)
img1_idx = capIdx_imgIdx_train[cap1_idx]
img1 = Image.open(imgIdx_imgPath_train[img1_idx])
_ = plt.figure(figsize=(7,5))
_ = plt.imshow(img1)

In [None]:
resize_shape = (128,128)
resizer = transforms.Compose([transforms.Resize(resize_shape)])

img1_resized = resizer(img1)
_ = plt.figure(figsize=(7,5))
_ = plt.imshow(img1_resized)

In [None]:
def load_train_channel_means_and_sigmas():
    with open('./images/train_means_'+str(resize_shape[0])+'x'+str(resize_shape[1])+'.npy', 'rb') as f:
        train_channel_means = np.load(f)
    with open('./images/train_sigmas_'+str(resize_shape[0])+'x'+str(resize_shape[1])+'.npy', 'rb') as f:
        train_channels_sigmas = np.load(f)
    return train_channel_means, train_channels_sigmas

def load_imgs(imgIdx_imgPath, resizer, standardized=False):
    # loading all images will take a while & some RAM
    if standardized:
        train_channel_means, train_channels_sigmas = load_train_channel_means_and_sigmas()
    imgs = {}
    for i in imgIdx_imgPath:
        img = Image.open(imgIdx_imgPath[i])
        img_resized = np.asarray(resizer(img))
        if standardized: # element wise standardization to avoid RAM issues
            img_resized = ((img_resized - train_channel_means) / train_channels_sigmas).astype(np.float32)
        imgs[i] = img_resized
        del img, img_resized
        
    if standardized:
        imgs_tensor = np.array(list(imgs.values()), dtype=np.float32)
    else:
        imgs_tensor = np.array(list(imgs.values()))
    del imgs
    
    return imgs_tensor

X_train = load_imgs(imgIdx_imgPath_train, resizer, standardized=True)
#X_val = load_imgs(imgIdx_imgPath_val, resizer, standardized=True)

In [None]:
def calc_train_channel_means_and_sigmas(X_train):
    """Calculates the mean and stds of the 3 RGB channels and stores it in a .npy file
    Performs batch-wise calculation of sum to avoid RAM issues"""
    train_channel_means = X_train.mean(axis=(0,1,2))

    std_batch_size = 1000
    std_sum = 0
    std_n = X_train.shape[0]*X_train.shape[1]*X_train.shape[2]
    std_idx = np.arange(0,X_train.shape[0]+std_batch_size, std_batch_size)
    for i in range(std_idx.shape[0]-1):
        start_idx, end_idx = std_idx[i], std_idx[i+1]
        std_batch = X_train[start_idx:end_idx]
        batch_sum = np.sum((std_batch - train_channel_means)**2, axis=(0,1,2))
        std_sum += batch_sum
    train_channels_sigmas = np.sqrt(std_sum / std_n)
    
    with open('./images/train_means_'+str(resize_shape[0])+'x'+str(resize_shape[1])+'.npy', 'wb') as f:
        np.save(f, train_channel_means)
    with open('./images/train_sigmas_'+str(resize_shape[0])+'x'+str(resize_shape[1])+'.npy', 'wb') as f:
        np.save(f, train_channels_sigmas)
    return train_channel_means, train_channels_sigmas
        
# uncomment for mean and std calculation porpuse
#X_train = load_imgs(imgIdx_imgPath_train, resizer, standardized=False)
#calc_train_channel_means_and_sigmas(X_train)

In [None]:
# FastText word embedding!!!

In [None]:
from torchtext.data import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

def build_vocab(capIdx_cap_train, size=10000):
    tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
    counter = Counter()
    for i in capIdx_cap_train:
        sentence = capIdx_cap_train[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        counter.update(tokenizer(sentence.lower()))

    top_tokens = list(dict(counter.most_common(size)).keys())
    drop_tokens = list(set(counter.keys()) - set(top_tokens))
    for drop_token in drop_tokens: # drops all `drop_tokens` from `corpus_counter`
        counter.pop(drop_token)

    vocab = Vocab(counter, specials=['<eos>'])
    
    return tokenizer, vocab

tokenizer, vocab = build_vocab(capIdx_cap_train, size=10000)

In [None]:
def get_sentence_lengths_quantiles(capIdx_cap, tokenizer, vocab):
    lengths = []
    for i in capIdx_cap:
        sentence = capIdx_cap[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        sentence_tokenized = [vocab[token] for token in tokenizer(sentence.lower())] # filters the words not included in the vocabulary
        sentence_tokenized = list(filter(None, sentence_tokenized))
        lengths.append(len(sentence_tokenized))
    return np.quantile(lengths, [.5,.75,.9,.95,.99,.995,1])

# uncomment if u wanna show quantiles of the sentence lengths
#get_sentence_lengths_quantiles(capIdx_cap_train, tokenizer, vocab)

In [None]:
max_sentence_length = 29 # perform get_sentence_lengths_quantiles to evaluate this number

In [None]:
def process_sentences(capIdx_cap, tokenizer, vocab, max_sentence_length):
    y = []
    for i in capIdx_cap:
        sentence = capIdx_cap[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        sentence_tokenized = [vocab[token] for token in tokenizer(sentence.lower())] # tokenizes the sentence and put words not in vocab to None
        sentence_tokenized = list(filter(None, sentence_tokenized))[:max_sentence_length] # filter None (not in vocab) from sentence and slice to max_sentence_length
        max_length_diff = max_sentence_length - len(sentence_tokenized) + 1 # calculates the numbers of '<eos>', +1 is to ensure that every sentence has at least 1 '<eos>'
        sentence_tokenized = sentence_tokenized+[vocab['<eos>']]*max_length_diff
        sentence_tensor = torch.tensor(sentence_tokenized)
        y.append(sentence_tensor)
        
    return y
    
y_train = process_sentences(capIdx_cap_train, tokenizer, vocab, max_sentence_length)
y_val = process_sentences(capIdx_cap_val, tokenizer, vocab, max_sentence_length)

## Model Setup

In [None]:
class ImageCaptioning(Dataset):
    def __init__(self, X, y, get_X_idx):
        self.X = torch.from_numpy(X)
        self.y = y
        self.get_X_idx = get_X_idx

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        imgIdx = self.get_X_idx(idx)
        img = self.X[imgIdx]
        cap = self.y[idx]
        return img, cap

In [None]:
data_train = ImageCaptioning(X_train, y_train, get_X_idx_train)
#data_val = ImageCaptioning(X_val, y_val, get_X_idx_val)

train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
#val_loader = DataLoader(data_val, batch_size=64, shuffle=True)

out_size = lambda n, p, f, s: (n+2*p-f)/s+1 #n:img_size, p:padding, f:filter/kernel, s:stride

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
class ImgCapNet(nn.Module):
    def __init__(self):
        super().__init__() # case 128 x 128
        self.conv1 = nn.Conv2d(3, 8, 5) # 124 x 124 x 8
        self.pool = nn.MaxPool2d(2) # height & width / 2, depth is the same
        self.conv2 = nn.Conv2d(8, 12, 4) # 59 x 59 x 12
        self.conv3 = nn.Conv2d(12, 20, 3) # 28 x 28 x 20
        self.linear = nn.Linear(14 * 14 * 20, 128)
        
        self.embedding = nn.Embedding(10000, 128) # input is given by a list of indices
        self.lstm = nn.LSTM(300, 300-128)
        self.output = nn.Linear(300-128, 10000)
        self.softmax = nn.Softmax(dim=1)
        self.argmax = nn.argmax
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        # still toDo here
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

26.05.2021: 2h --> Sprechstunde, Challenge & Beschreibung anschauen<br>
28.05.2021: 1h --> Download und Integration der Daten und API<br>
01.06.2021: 5h 30min --> API vertraut machen, Daten laden<br>
02.06.2021: 6h 45min --> Standardisierung der Channels, Laden der Bilder<br>
03.06.2021: 7h 15min --> Erstellen des Vocabulars<br>
04.06.2021: 5h 30min --> Vocabular und Preprocessing von captions fertiggestellt<br>
05.06.2021: 4h --> bugfixing & `get_X_idx` validierung, ImgCapNet \_\_init\_\_ erstellt<br>

Sum: 32h 00min