# 0 - Imports/Constants

In [1]:
import os
from dataclasses import dataclass
from collections import Counter

from xml.etree import ElementTree
from xml.etree.ElementTree import ParseError
from glob import glob
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import transforms

import nltk

from tqdm import tqdm

# adding imports to match the ones from test.ipynb, where we will be sourcing some of the code from
import torch.nn.functional as F
import numpy as np
import json
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.pyplot import figure
import skimage.transform
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
data_root_dir = '../data/iaprtc12/'
annotation_dir = os.path.join(data_root_dir, 'annotations_complete_eng/')
image_dir = os.path.join(data_root_dir, 'images/')

UNKNOWN_TOKEN = '<unk>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'
PADDING_TOKEN = '<pad>'

In [3]:
hyperparameters = {
    'batch_size': 32
}

# 1 - Data loading 
This part is mostly done by Dominik, with individual contributions by Maria marked in the code.

In [4]:
@dataclass(slots=True, kw_only=True)
class CLEFSample:
    # by Dominik
    image_id: str
    caption: str
    caption_length: torch.CharTensor
    image_path: str
    encoded_caption: torch.IntTensor = None
    image: torch.FloatTensor = None


class CLEFDataset(Dataset):
    # by Dominik, individual contributions by Maria marked with in-line comments or comments under specific methods
    def __init__(
        self, 
        annotation_directory: str, 
        image_directory: str, 
        number_images=100, 
        word_map: dict = None, 
        min_frequency=10, 
        concat_captions: bool = False  # added by Maria to allow the optional concatenation of multiple captions into one
    ) -> None:
        super(CLEFDataset, self).__init__()
        captions = self._load_captions(annotation_directory, number_images, concat_captions)
        samples = self._load_images(image_directory, captions)

        if word_map == None:
            word_map = self._create_word_map(samples, min_frequency)
        self.word_map = word_map

        self.samples = self._encode_captions(samples)

    def _load_captions(self, directory: str, number_images: int, concat_captions: bool) -> list[CLEFSample]:
        captions: list[CLEFSample] = []

        file_pattern = directory + '**/*.eng'
        for file in glob(file_pattern, recursive=True):
            if len(captions) == number_images:
                break
            try:
                root = ElementTree.parse(file).getroot()
                description = root.find('./DESCRIPTION').text
                # multiple captions option by Maria
                all_captions = description.split(';')
                if concat_captions == True:
                    first_caption = ' and '.join(all_captions[:-1])  # if not -1, then there is a trailing 'and'
                else:
                    first_caption = all_captions[0]
                    
                tokenized_caption = nltk.word_tokenize(first_caption)
                
                image_path = root.find('./IMAGE').text.removeprefix('images/')
                image_id = image_path.removesuffix('.jpg')
                
                # selecting only the captions that include verbs or prepositions (relation words) by Maria
                annotated_caption = nltk.pos_tag(tokenized_caption, tagset='universal')

                va_counter = 0  # for seeing if there is a verb or an adposition in the description
                for tagged_word in annotated_caption:
                    if tagged_word[1] == 'VERB':
                        va_counter += 1
                    elif tagged_word[1] == 'ADP':
                        va_counter += 1
                    else:
                        continue
                
                if va_counter > 0:
                    captions.append(CLEFSample(
                        image_id=image_id,
                        caption=tokenized_caption,
                        # +2 for start and end token
                        caption_length=torch.CharTensor([len(tokenized_caption) + 2]),
                        image_path=image_path
                    ))
                else:
                    continue
                    
            except ParseError:
                continue
        
        print('Captions loaded!')  # added for clarity by Maria

        return captions

    def _load_images(self, directory: str, captions: list[CLEFSample]) -> list[CLEFSample]:
        transform = transforms.ToTensor()

        samples: list[CLEFSample] = []
        for sample in tqdm(captions, desc='Loading images...'):  # tqdm added because Maria is impatient
            image_path = os.path.join(directory, sample.image_path)

            # TODO correct conversion?
            # error-handling added by Maria
            try:
                image = Image.open(image_path).resize((256, 256)).convert('RGB')
                sample.image = transform(image)
                samples.append(sample)
            except FileNotFoundError:
                continue

        print('Images loaded!')  # added for clarity by Maria
        
        return samples

    def _create_word_map(self, samples: list[CLEFSample], min_frequency: int) -> dict:
        word_frequency = Counter()
        for sample in samples:
            word_frequency.update(sample.caption)

        words = [word for word in word_frequency.keys() if word_frequency[word] >= min_frequency]

        word_map = {word: index for index, word in enumerate(words, start=1)}
        word_map[UNKNOWN_TOKEN] = len(word_map) + 1
        word_map[START_TOKEN] = len(word_map) + 1
        word_map[END_TOKEN] = len(word_map) + 1
        word_map[PADDING_TOKEN] = 0

        return word_map

    def _encode_captions(self, samples: list[CLEFSample]) -> list[CLEFSample]:
        encoded_samples: list[CLEFSample] = []
        for sample in samples:
            encoding = [self.get_encoded_token(START_TOKEN), *[self.get_encoded_token(token)
                                                               for token in sample.caption], self.get_encoded_token(END_TOKEN)]
            sample.encoded_caption = torch.IntTensor(encoding)
            encoded_samples.append(sample)
        return encoded_samples

    def get_encoded_token(self, token: str) -> int:
        if token in self.word_map:
            return self.word_map[token]
        else:
            return self.word_map[UNKNOWN_TOKEN]

    def __getitem__(self, index: int) -> CLEFSample:
        return self.samples[index]

    def __len__(self) -> int:
        return len(self.samples)


In [5]:
def custom_collate(samples: list[CLEFSample]) -> dict:
    # by Dominik
    image_ids = []
    captions = []
    caption_lengths = []
    encoded_captions = []
    image_paths = []
    images = []

    for sample in samples:
        image_ids.append(sample.image_id)
        captions.append(sample.caption)
        caption_lengths.append(sample.caption_length)
        encoded_captions.append(sample.encoded_caption)
        image_paths.append(sample.image_path)
        images.append(sample.image)
    
    return {
        'image_ids': image_ids,
        'captions': captions,
        'caption_lengths': caption_lengths,
        'encoded_captions': pad_sequence(encoded_captions, batch_first=True),
        'image_paths': image_paths,
        'images': images
    }

In [6]:
dataset = CLEFDataset(annotation_dir, image_dir, number_images=50, min_frequency=1, concat_captions=True)

Captions loaded!


Loading images...: 100%|███████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 170.65it/s]

Images loaded!





In [10]:
# splitting the dataset by Maria
# remove the last optional argument for random splits, this way the seed is fixed so results are reproducible
# QUESTION: does this need to be done any prettier?
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(25))

In [11]:
# by Dominik
dataloader = DataLoader(
    train_set, 
    hyperparameters['batch_size'], 
    shuffle=True, 
    collate_fn=custom_collate, 
    drop_last=True  # added by Maria since we were told it is good to do so when working with LSTMs in the Machine Learning 2 course
)

In [13]:
#for batch in dataloader:
    #print(batch['captions'])

# 2 - Testing the pretrained model (by Nikolai Ilinykh)
This part is mostly done by Nikolai, adapted for our use by Maria, with individual contributi

In [19]:
### THE FOLLOWING CODE IS TAKEN FROM TEST.IPYNB BY NIKOLAI ILINYKH

def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
    """
    Reads an image and captions it with beam search.

    :param encoder: encoder model
    :param decoder: decoder model
    :param image_path: path to image
    :param word_map: word map
    :param beam_size: number of sequences to consider at each decode-step
    :return: caption, weights for visualization
    """

    k = beam_size
    vocab_size = len(word_map)
    
    
    img = Image.open(image_path)
    img = img.resize((256, 256))
    img = np.transpose(img, (2, 0, 1))
    #img = img / 255.
    img = torch.FloatTensor(img).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([normalize])
    image = transform(img)  # (3, 256, 256)

    # Encode
    image = image.unsqueeze(0)  # (1, 3, 256, 256)
    encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)

    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # We'll treat the problem as having a batch size of k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

    # Tensor to store top k previous words at each step; now they're just 
    k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just 
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)

    # Lists to store completed sequences, their alphas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit 
    while True:

        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

        awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

        alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)

        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)

        
        
        
        # Add new words to sequences, alphas

        prev_word_inds = prev_word_inds.cpu()
        prev_word_inds = prev_word_inds.long()        

        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
        
        
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
                               dim=1)  # (s, step+1, enc_image_size, enc_image_size)

        # Which sequences are incomplete (didn't reach )?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != word_map[' ']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        seqs_alpha = seqs_alpha[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]
    alphas = complete_seqs_alpha[i]

    return seq, alphas


def visualize_att(image_path, seq, alphas, rev_word_map, smooth=False):
    """
    Visualizes caption with weights at every word.

    Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb

    :param image_path: path to image that has been captioned
    :param seq: caption
    :param alphas: weights
    :param rev_word_map: reverse word mapping, i.e. ix2word
    :param smooth: smooth weights?
    """
    
    figure(figsize=(10, 8), dpi=80)
    
    image = Image.open(image_path)
    image = image.resize([24 * 38, 24 * 38], Image.LANCZOS)

    words = [rev_word_map[ind] for ind in seq]

    for t in range(len(words)):
        if t > 50:
            break
        plt.subplot(np.ceil(len(words) / 5.), 5, t + 1)

        plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
        plt.imshow(image)
        current_alpha = alphas[t, :]
        if smooth:
            alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=12, sigma=8)
        else:
            alpha = skimage.transform.resize(current_alpha.numpy(), [24 * 38, 24 * 38])
        if t == 0:
            plt.imshow(alpha, alpha=0)
        else:
            plt.imshow(alpha, alpha=0.6)
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
    plt.show()

In [20]:
### THIS IS ALSO FROM NIKOLAI, ANY CHANGES THAT ARE INTRODUCED WILL BE MARKED BY COMMENTS

model_path = '../data/BEST_checkpoint_flickr8k_5_10.pth.tar'  # model path updated
img = '../data/iaprtc12/images/00/25.jpg'  # temporary image path

word_map = '../data/wordmap_flickr8k_5_10.json'  # wordmap path updated
beam_size = 3
smooth = False

# Load model
checkpoint = torch.load(model_path, map_location=str(device))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

# Load word map (word2ix)
with open(word_map, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

# Encode, decode with attention and beam search
seq, alphas = caption_image_beam_search(encoder, decoder, img, word_map, beam_size)
alphas = torch.FloatTensor(alphas)

# Visualize caption and attention of best sequence
visualize_att(img, seq, alphas, rev_word_map, smooth)

KeyError: ' '