# 0 - Imports/Constants

In [20]:
import os
from dataclasses import dataclass
from collections import Counter

from xml.etree import ElementTree
from xml.etree.ElementTree import ParseError
from glob import glob
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import transforms

import nltk

from tqdm import tqdm

In [2]:
data_root_dir = '../data/iaprtc12/'
annotation_dir = os.path.join(data_root_dir, 'annotations_complete_eng/')
image_dir = os.path.join(data_root_dir, 'images/')

UNKNOWN_TOKEN = '<unk>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'
PADDING_TOKEN = '<pad>'

In [3]:
hyperparameters = {
    'batch_size': 32
}

# 1 - Data loading 

In [58]:
@dataclass(slots=True, kw_only=True)
class CLEFSample:
    # by Dominik
    image_id: str
    caption: str
    caption_length: torch.CharTensor
    image_path: str
    encoded_caption: torch.IntTensor = None
    image: torch.FloatTensor = None


class CLEFDataset(Dataset):
    # by Dominik, individual contributions by Maria marked with in-line comments or comments under specific methods
    def __init__(
        self, 
        annotation_directory: str, 
        image_directory: str, 
        number_images=100, 
        word_map: dict = None, 
        min_frequency=10, 
        concat_captions: bool = False  # added by Maria to allow the optional concatenation of multiple captions into one
    ) -> None:
        super(CLEFDataset, self).__init__()
        captions = self._load_captions(annotation_directory, number_images, concat_captions)
        samples = self._load_images(image_directory, captions)

        if word_map == None:
            word_map = self._create_word_map(samples, min_frequency)
        self.word_map = word_map

        self.samples = self._encode_captions(samples)

    def _load_captions(self, directory: str, number_images: int, concat_captions: bool) -> list[CLEFSample]:
        captions: list[CLEFSample] = []

        file_pattern = directory + '**/*.eng'
        for file in glob(file_pattern, recursive=True):
            if len(captions) == number_images:
                break
            try:
                root = ElementTree.parse(file).getroot()
                description = root.find('./DESCRIPTION').text
                # multiple captions option by Maria
                all_captions = description.split(';')
                if concat_captions == True:
                    first_caption = ' and '.join(all_captions[:-1])  # if not -1, then there is a trailing 'and'
                else:
                    first_caption = all_captions[0]
                    
                tokenized_caption = nltk.word_tokenize(first_caption)
                
                image_path = root.find('./IMAGE').text.removeprefix('images/')
                image_id = image_path.removesuffix('.jpg')
                
                # selecting only the captions that include verbs or prepositions (relation words) by Maria
                annotated_caption = nltk.pos_tag(tokenized_caption, tagset='universal')

                va_counter = 0  # for seeing if there is a verb or an adposition in the description
                for tagged_word in annotated_caption:
                    if tagged_word[1] == 'VERB':
                        va_counter += 1
                    elif tagged_word[1] == 'ADP':
                        va_counter += 1
                    else:
                        continue
                
                if va_counter > 0:
                    captions.append(CLEFSample(
                        image_id=image_id,
                        caption=tokenized_caption,
                        # +2 for start and end token
                        caption_length=torch.CharTensor([len(tokenized_caption) + 2]),
                        image_path=image_path
                    ))
                else:
                    continue
                    
            except ParseError:
                continue
        
        print('Captions loaded!')  # added for clarity by Maria

        return captions

    def _load_images(self, directory: str, captions: list[CLEFSample]) -> list[CLEFSample]:
        transform = transforms.ToTensor()

        samples: list[CLEFSample] = []
        for sample in tqdm(captions, desc='Loading images...'):  # tqdm added because Maria is impatient
            image_path = os.path.join(directory, sample.image_path)

            # TODO correct conversion?
            # error-handling added by Maria
            try:
                image = Image.open(image_path).resize((256, 256)).convert('RGB')
                sample.image = transform(image)
                samples.append(sample)
            except FileNotFoundError:
                continue

        print('Images loaded!')  # added for clarity by Maria
        
        return samples

    def _create_word_map(self, samples: list[CLEFSample], min_frequency: int) -> dict:
        word_frequency = Counter()
        for sample in samples:
            word_frequency.update(sample.caption)

        words = [word for word in word_frequency.keys() if word_frequency[word] >= min_frequency]

        word_map = {word: index for index, word in enumerate(words, start=1)}
        word_map[UNKNOWN_TOKEN] = len(word_map) + 1
        word_map[START_TOKEN] = len(word_map) + 1
        word_map[END_TOKEN] = len(word_map) + 1
        word_map[PADDING_TOKEN] = 0

        return word_map

    def _encode_captions(self, samples: list[CLEFSample]) -> list[CLEFSample]:
        encoded_samples: list[CLEFSample] = []
        for sample in samples:
            encoding = [self.get_encoded_token(START_TOKEN), *[self.get_encoded_token(token)
                                                               for token in sample.caption], self.get_encoded_token(END_TOKEN)]
            sample.encoded_caption = torch.IntTensor(encoding)
            encoded_samples.append(sample)
        return encoded_samples

    def get_encoded_token(self, token: str) -> int:
        if token in self.word_map:
            return self.word_map[token]
        else:
            return self.word_map[UNKNOWN_TOKEN]

    def __getitem__(self, index: int) -> CLEFSample:
        return self.samples[index]

    def __len__(self) -> int:
        return len(self.samples)


In [59]:
def custom_collate(samples: list[CLEFSample]) -> dict:
    # by Dominik
    image_ids = []
    captions = []
    caption_lengths = []
    encoded_captions = []
    image_paths = []
    images = []

    for sample in samples:
        image_ids.append(sample.image_id)
        captions.append(sample.caption)
        caption_lengths.append(sample.caption_length)
        encoded_captions.append(sample.encoded_caption)
        image_paths.append(sample.image_path)
        images.append(sample.image)
    
    return {
        'image_ids': image_ids,
        'captions': captions,
        'caption_lengths': caption_lengths,
        'encoded_captions': pad_sequence(encoded_captions, batch_first=True),
        'image_paths': image_paths,
        'images': images
    }

In [62]:
dataset = CLEFDataset(annotation_dir, image_dir, number_images=50, min_frequency=1, concat_captions=True)

Captions loaded!


Loading images...: 100%|███████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 217.74it/s]

Images loaded!





In [75]:
# splitting the dataset by Maria
# remove the last optional argument for random splits, this way the seed is fixed so results are reproducible
# QUESTION: does this need to be done any prettier?
train_set, test_set = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(25))

In [78]:
# by Dominik
dataloader = DataLoader(
    train_set, 
    hyperparameters['batch_size'], 
    shuffle=True, 
    collate_fn=custom_collate, 
    drop_last=True  # added by Maria since we were told it is good to do so when working with LSTMs in the Machine Learning 2 course
)

In [79]:
for batch in dataloader:
    print(batch['captions'])

[['Tourists', 'are', 'visiting', 'an', 'old', 'people', "'s", 'home', 'and', 'each', 'tourist', 'is', 'leading', 'one', 'old', 'lady', ',', 'arm', 'in', 'arm', 'and', 'the', 'old', 'ladies', 'are', 'wearing', 'traditional', 'dresses', 'and', 'a', 'hat', 'and', 'the', 'room', 'is', 'decorated', 'with', 'balloons', 'and', 'there', 'is', 'a', 'picture', 'on', 'the', 'wall', 'in', 'the', 'background', ',', 'and', 'also', 'an', 'old', 'lady', 'carrying', 'a', 'green', 'balloon'], ['Close', 'up', 'picture', 'of', 'the', 'Itaipu', 'Dam', 'and', 'backwater', 'is', 'flowing', 'off', 'in', 'the', 'foreground'], ['a', 'woman', 'is', 'sitting', 'at', 'table', 'in', 'the', 'middle', 'of', 'a', 'schoolyard', 'and', 'there', 'are', 'many', 'books', 'on', 'the', 'table', 'and', 'two', 'people', 'are', 'sitting', 'next', 'to', 'her', ',', 'six', 'others', 'are', 'standing', 'about', 'ten', 'metres', 'away', 'from', 'the', 'table'], ['A', 'tourist', 'group', 'is', 'visiting', 'a', 'school', 'and', 'the'