# 0 - Imports/Constants

In [122]:
import os
from dataclasses import dataclass
from collections import Counter

from xml.etree import ElementTree
from xml.etree.ElementTree import ParseError
from glob import glob
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import transforms

import nltk

In [123]:
data_root_dir = '../data/iaprtc12/'
annotation_dir = os.path.join(data_root_dir, 'annotations_complete_eng/')
image_dir = os.path.join(data_root_dir, 'images/')

UNKNOWN_TOKEN = '<unk>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'
PADDING_TOKEN = '<pad>'

In [124]:
hyperparameters = {
    'batch_size': 32
}

# 1 - Data loading 

In [125]:
@dataclass(slots=True, kw_only=True)
class CLEFSample:
    image_id: str
    caption: str
    caption_length: torch.Tensor
    image_path: str
    encoded_caption: torch.Tensor = None
    image: torch.FloatTensor = None


class CLEFDataset(Dataset):
    def __init__(self, annotation_directory: str, image_directory: str, number_images=100, word_map: dict = None, min_frequency=10) -> None:
        super(CLEFDataset, self).__init__()
        captions = self._load_captions(annotation_directory, number_images)
        samples = self._load_images(image_directory, captions)

        if word_map == None:
            word_map = self._create_word_map(samples, min_frequency)
        self.word_map = word_map

        self.samples = self._encode_captions(samples)

    def _load_captions(self, directory: str, number_images: int) -> list[CLEFSample]:
        captions: list[CLEFSample] = []

        file_pattern = directory + '**/*.eng'
        for file in glob(file_pattern, recursive=True):
            if len(captions) == number_images:
                break
            try:
                root = ElementTree.parse(file).getroot()
                description = root.find('./DESCRIPTION').text
                # TODO multiple captions
                first_caption = description.split(';')[0]
                tokenized_caption = nltk.word_tokenize(first_caption)

                image_path = root.find('./IMAGE').text.removeprefix('images/')
                image_id = image_path.removesuffix('.jpg')

                captions.append(CLEFSample(
                    image_id=image_id,
                    caption=tokenized_caption,
                    # +2 for start and end token
                    caption_length=torch.Tensor([len(tokenized_caption) + 2]),
                    image_path=image_path
                ))
            except ParseError:
                continue

        return captions

    def _load_images(self, directory: str, captions: list[CLEFSample]) -> list[CLEFSample]:
        transform = transforms.ToTensor()

        samples: list[CLEFSample] = []
        for sample in captions:
            image_path = os.path.join(directory, sample.image_path)

            # TODO correct conversion?
            image = Image.open(image_path).resize((256, 256)).convert('RGB')
            sample.image = transform(image)
            samples.append(sample)

        return samples

    def _create_word_map(self, samples: list[CLEFSample], min_frequency: int) -> dict:
        word_frequency = Counter()
        for sample in samples:
            word_frequency.update(sample.caption)

        words = [word for word in word_frequency.keys() if word_frequency[word] >= min_frequency]

        word_map = {word: index for index, word in enumerate(words, start=1)}
        word_map[UNKNOWN_TOKEN] = len(word_map) + 1
        word_map[START_TOKEN] = len(word_map) + 1
        word_map[END_TOKEN] = len(word_map) + 1
        word_map[PADDING_TOKEN] = 0

        return word_map

    def _encode_captions(self, samples: list[CLEFSample]):
        encoded_samples: list[CLEFSample] = []
        for sample in samples:
            encoding = [self.get_encoded_token(START_TOKEN), *[self.get_encoded_token(token)
                                                               for token in sample.caption], self.get_encoded_token(END_TOKEN)]
            sample.encoded_caption = torch.tensor(encoding)
            encoded_samples.append(sample)
        return encoded_samples

    def get_encoded_token(self, token: str) -> int:
        if token in self.word_map:
            return self.word_map[token]
        else:
            return self.word_map[UNKNOWN_TOKEN]

    def __getitem__(self, index):
        return self.samples[index]

    def __len__(self):
        return len(self.samples)


In [126]:
def custom_collate(samples: list[CLEFSample]) -> dict:
    image_ids = []
    captions = []
    caption_lengths = []
    encoded_captions = []
    image_paths = []
    images = []

    for sample in samples:
        image_ids.append(sample.image_id)
        captions.append(sample.caption)
        caption_lengths.append(sample.caption_length)
        encoded_captions.append(sample.encoded_caption)
        image_paths.append(sample.image_path)
        images.append(sample.image)
    
    return {
        'image_ids': image_ids,
        'captions': captions,
        'caption_lengths': caption_lengths,
        'encoded_captions': pad_sequence(encoded_captions, batch_first=True),
        'image_paths': image_paths,
        'images': images
    }

In [127]:
dataset = CLEFDataset(annotation_dir, image_dir, min_frequency=1)

In [128]:
dataloader = DataLoader(dataset, hyperparameters['batch_size'], shuffle=True, collate_fn=custom_collate)

In [None]:
for batch in dataloader:
    print(batch['encoded_captions'])