# Image Captioning

Observing that people who are blind have relied on (human-based) image captioning services to learn about images they take for nearly a decade, we introduce the first image captioning dataset to represent this real use case. This new dataset, which we call VizWiz-Captions, consists of 39,181 images originating from people who are blind that are each paired with 5 captions. Our proposed challenge addresses the task of predicting a suitable caption given an image. Ultimately, we hope this work will educate more people about the technological needs of blind people while providing an exciting new opportunity for researchers to develop assistive technologies that eliminate accessibility barriers for blind people (https://vizwiz.org/tasks-and-datasets/image-captioning/).

The goal of this Challenge is to create a single model similar to https://arxiv.org/pdf/1411.4555.pdf to get reasonable results on this task.

In [None]:
#import zipfile
#def extract_zip(filename):
#    with zipfile.ZipFile(filename, 'r') as zip_ref:
#        zip_ref.extractall('./')
#            
#extract_zip('vizwiz_eval_cap.zip')

#!pip3 install torch torchvision torchaudio
#!pip install tensorboard
#!pip install torchtext
#!pip install -U spacy
#!python -m spacy download en_core_web_sm

In [None]:
from vizwiz_api.vizwiz import VizWiz
from vizwiz_eval_cap.eval import VizWizEvalCap
import matplotlib.pyplot as plt
from PIL import Image
import skimage.io as io
import pylab
import numpy as np

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from torchtext.data import get_tokenizer
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.tensorboard import SummaryWriter
from collections import Counter

import json, os
from jsonpath_ng import jsonpath, parse
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.3f')

In [None]:
 def get_alloc_dicts(set_name, vizwiz=None):
    # be sure if `vizwiz` is set, that it contains the `set_name` dataset
    if (set_name != 'train') and (set_name != 'val') and (set_name != 'test'):
        raise Exception('only "train", "val" or "test" is a valid `set_name`')
    
    if not isinstance(vizwiz, VizWiz):
        ann_path = './annotations/'+set_name+'.json'
        vizwiz = VizWiz(ann_path, ignore_rejected=True, ignore_precanned=True)
    
    img_path_prefix = './images/'+set_name+'/'
    img_ids_anns = np.unique([vizwiz.anns[i]['image_id'] for i in vizwiz.anns])
    img_ids_imgs = np.unique([vizwiz.imgs[i]['id'] for i in vizwiz.imgs])
    img_ids_with_capitions = np.array([_id for _id in img_ids_imgs if _id in img_ids_anns])
    imgIdx_enumIdx = {imgIdx:idx for idx, imgIdx in enumerate(img_ids_with_capitions)}
    
    imgIdx_imgPath = {vizwiz.imgs[i]['id']:img_path_prefix+vizwiz.imgs[i]['file_name'] for i in vizwiz.imgs if vizwiz.imgs[i]['id'] in img_ids_with_capitions}
    capIdx_imgIdx = {vizwiz.anns[i]['id']:vizwiz.anns[i]['image_id'] for i in vizwiz.anns}
    enumIdx_capIdx = {idx:vizwiz.anns[i]['id'] for idx, i in enumerate(vizwiz.anns)}
    capIdx_cap = {vizwiz.anns[i]['id']:vizwiz.anns[i]['caption'] for i in vizwiz.anns}
    
    def get_img_path(idx):
        capIdx = enumIdx_capIdx[idx]
        imgIdx = capIdx_imgIdx[capIdx]
        imgPath = imgIdx_imgPath[imgIdx]
        return imgPath
        
    return imgIdx_imgPath, capIdx_imgIdx, enumIdx_capIdx, capIdx_cap, get_img_path

In [None]:
ann_train = './annotations/train.json'
vizwiz_train = VizWiz(ann_train, ignore_rejected=True, ignore_precanned=True)
imgIdx_imgPath_train, capIdx_imgIdx_train, enumIdx_capIdx_train, capIdx_cap_train, get_img_path_train = get_alloc_dicts('train', vizwiz_train)

ann_val = './annotations/val.json'
vizwiz_val = VizWiz(ann_val, ignore_rejected=True, ignore_precanned=True)
imgIdx_imgPath_val, capIdx_imgIdx_val, enumIdx_capIdx_val, capIdx_cap_val, get_img_path_val = get_alloc_dicts('val', vizwiz_val)

In [None]:
cap1_idx = enumIdx_capIdx_train[0]
cap1 = capIdx_cap_train[cap1_idx]
print('caption:', cap1)
img1_idx = capIdx_imgIdx_train[cap1_idx]
img1 = Image.open(imgIdx_imgPath_train[img1_idx])
_ = plt.figure(figsize=(7,5))
_ = plt.imshow(img1)

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

to_pil = transforms.Compose([transforms.ToPILImage()])

In [None]:
img1_processed = preprocess(img1)
img1_processed_pil = to_pil(img1_processed)
_ = plt.figure(figsize=(7,5))
_ = plt.imshow(img1_processed_pil)

In [None]:
def build_vocab(capIdx_cap_train, size=10000):
    size -= 1 # -1 to address specials '<eos>'
    tokenizer = get_tokenizer('basic_english')
    counter = Counter()
    for i in capIdx_cap_train:
        sentence = capIdx_cap_train[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        counter.update(tokenizer(sentence.lower()))

    top_tokens = list(dict(counter.most_common(size)).keys())
    drop_tokens = list(set(counter.keys()) - set(top_tokens))
    for drop_token in drop_tokens: # drops all `drop_tokens` from `corpus_counter`
        counter.pop(drop_token)

    vocab = Vocab(counter, specials=['<eos>'])
    
    return tokenizer, vocab

tokenizer, vocab = build_vocab(capIdx_cap_train, size=10000)

In [None]:
def get_sentence_lengths_quantiles(capIdx_cap, tokenizer, vocab):
    lengths = []
    for i in capIdx_cap:
        sentence = capIdx_cap[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        sentence_tokenized = [vocab[token] for token in tokenizer(sentence.lower())] # filters the words not included in the vocabulary
        sentence_tokenized = list(filter(None, sentence_tokenized))
        lengths.append(len(sentence_tokenized))
    return np.quantile(lengths, [.5,.75,.9,.95,.99,.995,1])

# uncomment if u wanna show quantiles of the sentence lengths
#get_sentence_lengths_quantiles(capIdx_cap_train, tokenizer, vocab)

In [None]:
max_sentence_length = 29 # perform get_sentence_lengths_quantiles to evaluate this number

In [None]:
def process_sentences(capIdx_cap, tokenizer, vocab, max_sentence_length):
    y = []
    for i in capIdx_cap:
        sentence = capIdx_cap[i]
        if sentence[-1] == '.':
            sentence = sentence[:-1]
        sentence_tokenized = [vocab[token] for token in tokenizer(sentence.lower())] # tokenizes the sentence and put words not in vocab to None
        sentence_tokenized = list(filter(None, sentence_tokenized))[:max_sentence_length] # filter None (not in vocab) from sentence and slice to max_sentence_length
        max_length_diff = max_sentence_length - len(sentence_tokenized) + 1 # calculates the numbers of '<eos>', +1 is to ensure that every sentence has at least 1 '<eos>'
        sentence_tokenized = np.array(sentence_tokenized+[vocab['<eos>']]*max_length_diff)
        y.append(sentence_tokenized)
        
    y = np.array(y)
    
    return torch.from_numpy(y)
    
y_train = process_sentences(capIdx_cap_train, tokenizer, vocab, max_sentence_length)
y_val = process_sentences(capIdx_cap_val, tokenizer, vocab, max_sentence_length)

## Model Setup

In [None]:
class ImageCaptioning(Dataset):
    def __init__(self, get_img_path, preprocess, y):
        self.get_img_path = get_img_path
        self.preprocess = preprocess
        self.y = y.long()

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        img = self.preprocess(Image.open(self.get_img_path(idx)))
        cap = self.y[idx]
        return img, cap

In [None]:
data_train = ImageCaptioning(get_img_path_train, preprocess, y_train)
data_val = ImageCaptioning(get_img_path_val, preprocess, y_val)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
class ImgCapEncoderCNN(nn.Module):
    """Pretrained ResNet18 with removed last .fc layer"""
    def __init__(self, embedding_dim):
        super(ImgCapEncoderCNN, self).__init__()
        resnet = models.resnet18(pretrained=True)
        modules = list(resnet.children())[:-1] # deletes the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim, momentum=0.01)
        
    def forward(self, images):
        with torch.no_grad():
            out = self.resnet(images)
        out = out.reshape(out.size(0), -1)
        out = self.bn(self.linear(out))
        return out

class ImgCapDecoderLSTM(nn.Module):
    """LSTM Decoder with built-in Word Embedding"""
    def __init__(self, embedding_dim, vocab_size, max_sentence_length):
        # embedding_dim=128, vocab_size=10000, max_sentence_length=29
        super(ImgCapDecoderLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.inv_embedding = nn.Linear(embedding_dim, vocab_size)
        self.lstm = nn.LSTM(embedding_dim, embedding_dim)
        
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.max_sentence_length = max_sentence_length + 1 # +1 to cover the added <eos> at the end of every sentence
    
    def forward(self, x, y=None):
        """`y` is to provide the true label during training"""
        out = []
        hidden = None
        for i in range(self.max_sentence_length):
            output, hidden = self.lstm(x.unsqueeze(0), hidden)
            y_pred = self.inv_embedding(output)
            if y == None:
                _, predicted = y_pred.max(1)
                out.append(predicted)
            else:
                predicted = y[:,i] # takes the true word as input for training
                out.append(y_pred)
            x = self.embedding(predicted)
        out = torch.stack(out, 1)[0]
        return out
    
class ImgCapNet(nn.Module):
    def __init__(self, embedding_dim, vocab_size, max_sentence_length):
        super(ImgCapNet, self).__init__()
        # maybe implement factory pattern to get models with stored parameters
        self.ImgCapEncoderCNN = ImgCapEncoderCNN(embedding_dim)
        self.ImgCapDecoderLSTM = ImgCapDecoderLSTM(embedding_dim, vocab_size, max_sentence_length)
        
    def forward(self, x, y=None):
        x1 = self.ImgCapEncoderCNN(x)
        x2 = self.ImgCapDecoderLSTM(x1, y)
        return x2

In [None]:
def save_iter(json_):
    with open('iter.json', 'w') as fp:
        json.dump(json_, fp)
        
def get_iter():
    path = 'iter.json'
    if os.path.exists(path):
        with open(path, 'r') as fp:
            json_ = json.load(fp)
    else:
        json_ = dict()
    return json_

def get_iter_value(json_, key):
    if key not in json_:
        json_[key] = 0
    return json_[key]

def save_model(model, model_num):
    path = './models/model_'+str(model_num)+'.pth'
    torch.save(model.state_dict(), 'model_weights.pth')
    
def get_model(model, model_num):
    """Loads the stored weights if it exists"""
    path = './models/model_'+str(model_num)+'.pth'
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
def train_loop(data_class, batch_size, lr, n_epochs, model_num):
    icn = ImgCapNet(128, len(vocab), max_sentence_length)
    icn = get_model(icn, model_num).to(device)
    dataloader = DataLoader(data_class, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(icn.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    iter_name = 'iter_'+str(model_num)
    iter_name_batch = 'iter_batch_'+str(model_num)
    iter_json = get_iter()
    n_iter = get_iter_value(iter_json, iter_name)
    n_iter_batch = get_iter_value(iter_json, iter_name_batch)
    writer = SummaryWriter('./runs/model_'+str(model_num))
    try:
        for epoch in range(n_epochs):
            train_loss = 0
            for i, (imgs, caps) in enumerate(dataloader):
                raise NotImplementedError()
                icn.train()
                imgs = imgs.to(device)
                caps = caps.to(device)
                logits = icn(imgs, caps)
                logits = logits.view(logits.shape[0]*logits.shape[1], len(vocab)) # reorders the labels so, that each vocab char of first_sentence is in a row, followed by second_sentence, etc.
                targets = caps.ravel() # orders the captions to that first_sentence, second_sentence, etc... for `nn.CrossEntropyLoss`
                loss = loss_fn(logits, targets)
                
                with torch.no_grad():
                    minibatch_ratio = imgs.shape[0] / dataloader.batch_size
                    train_loss += loss.item() * minibatch_ratio

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                writer.add_scalar('batch_loss', loss, n_iter_batch) # track minibatch_loss
                n_iter_batch += 1
                iter_json[iter_name_batch] = n_iter_batch

            
            train_loss /= len(dataloader)
            writer.add_scalar('train_loss', train_loss, n_iter)
            n_iter += 1
            iter_json[iter_name] = n_iter
            
        # stores `iter_json` & `icn` model when either all epochs are through or a KeyboardInterrupt occured
        save_iter(iter_json)
        save_model(icn, model_num)
    except KeyboardInterrupt:
        save_iter(iter_json)
        save_model(icn, model_num)
            
train_loop(data_class=data_train, batch_size=64, lr=.001, n_epochs=3, model_num=1)

In [None]:
def train_loop(dataloader, ImgCapNet, device, writer):
    train_loss = 0
    
    for X, y in dataloader:
        model.train()
        # Compute prediction and loss
        X = X.to(device)
        y = y.to(device)
        logits = model(X, y)
        loss = nn.CrossEntropyLoss(logits, y)
            
        # Implements l1_regularization (l1 after Compute Accuracy to not have regularization term in loss observation on TensorBoard)
        if reg == 'l1':
            l1_reg = torch.tensor(0.).to(device)
            for name, param in model.named_parameters():
                if 'weight' in name:
                    l1_reg += torch.norm(param, 1)
            loss += (reg_param / X.shape[0]) * l1_reg
            
        with torch.no_grad():
            minibatch_ratio = X.shape[0] / dataloader.batch_size # is always 1, except for the last mini-batch to adjust loss weight
            train_loss += loss.item() * minibatch_ratio # * minibatch_ratio because the `loss_fn` is set to 'mean'
            pred_classes = logits.argmax(1)
            correct += (pred_classes == y).type(torch.float).sum().item()
            with warnings.catch_warnings(): # catches "y_pred contains classes not in y_true"
                warnings.filterwarnings('ignore', category=UserWarning)
                bas += balanced_accuracy_score(y.cpu(), pred_classes.cpu()) * minibatch_ratio # upscaling because the last mini-batch is likely to have a different size

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(dataloader) # scale test loss to adjust summation in mini-batch loop
    correct /= len(dataloader.dataset)
    bas /= len(dataloader)
    writer.add_scalar('model_'+str(n_model)+'/loss_train', train_loss, n_iter)
    writer.add_scalar('model_'+str(n_model)+'/accuracy_train', correct, n_iter)
    writer.add_scalar('model_'+str(n_model)+'/balanced_accuracy_train', bas, n_iter)
    return train_loss, correct, bas

26.05.2021: 2h --> Sprechstunde, Challenge & Beschreibung anschauen<br>
28.05.2021: 1h --> Download und Integration der Daten und API<br>
01.06.2021: 5h 30min --> API vertraut machen, Daten laden<br>
02.06.2021: 6h 45min --> Standardisierung der Channels, Laden der Bilder<br>
03.06.2021: 7h 15min --> Erstellen des Vocabulars<br>
04.06.2021: 5h 30min --> Vocabular und Preprocessing von captions fertiggestellt<br>
05.06.2021: 4h --> bugfixing & `get_X_idx` validierung, ImgCapNet \_\_init\_\_ erstellt<br>
06.06.2021: 2h --> update ImgCapNet Architektur<br>
14.06.2021: 8h --> update der Architektur<br>

Sum: 42h 00min