# Data Structure

We want the data as a dictionary with key-value pairs, where the key is the file path and the value are the captions. The file paths are obtained from the Flicker8k dataset, and the captions are retrieved accordingly from the Flicker8k.token.txt and Flicker8k.lemma.token.txt files.

## Folder Structure
```
Flicker8k_Dataset/
├── Images/
│   ├── image_files
├── Flicker8k.token.txt
├── Flicker8k.lemma.token.txt 
```


In [2]:
import torch
import cv2
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
import torchtext
from torch.utils.data import Dataset


In [5]:
class LoadFlickerData(Dataset):
    def __init__(self, dataset_path, max_seq_length, tokenizer):
        self.path = dataset_path  # Path to the dataset file
        self.data = None  # {image_id/image_path: [tokenized_captions1, tokenized_captions2, ..., lemmatized_captions]}
        self.tokenizer = tokenizer  # Tokenizer function
        self.vocab = None
        self.max_seq_length = max_seq_length
        self.load_data()

    def load_data(self):
        path_to_tokens = self.path + '/Flickr8k.token.txt'
        path_to_lemmas = self.path + '/Flickr8k.lemma.token.txt'
        self.data = defaultdict(list)

        # Read the data from the files
        with open(path_to_tokens, 'r') as f:
            for line in f:
                image_id, caption = line.strip().split('\t')
                image_id = image_id.split('.')[0]  # Get image id before .jpg
                self.data[image_id].append(caption)
        with open(path_to_lemmas, 'r') as f:
            for i, line in enumerate(f):
                image_id, caption = line.strip().split('\t')
                image_id = image_id.split('.')[0]  # Get image id before .jpg
                self.data[image_id].append(caption)

        # Create dataframes for the data
        dataframe = pd.DataFrame(self.data.items(), columns=['image_id', 'captions'])
        dataframe['image_path'] = dataframe['image_id'].apply(lambda x: self.path + '/Images/' + x + '.jpg')
        dataframe['captions'] = dataframe['captions'].apply(lambda x: x[0])
        self.data = dataframe.set_index('image_id').to_dict(orient='index')

    def __preprocess_image__(self, image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (224, 224))
        image = np.transpose(image, (2, 0, 1))
        image = torch.tensor(image, dtype=torch.float32)
        return image

    def __preprocess_caption__(self, image_data):
        counter = Counter()
        for image_id, captions in image_data.items():
            for caption in captions:
                counter.update(self.tokenizer(caption))

        vocab = torchtext.vocab.Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

        tokens_arrays = []
        attention_masks = []

        for image_id, captions in image_data.items():
            for caption in captions:
                tokens = [vocab[token] for token in self.tokenizer(caption)]
                tokens = [vocab['<bos>']] + tokens + [vocab['<eos>']]

                # Create attention mask
                mask = torch.ones(self.max_seq_length, dtype=torch.bool)

                # Padding or truncating the tokens to match max_seq_length
                if len(tokens) <= self.max_seq_length:
                    pad_starts = len(tokens)  # Record the length before padding
                    tokens = tokens + [vocab['<pad>']] * (self.max_seq_length - len(tokens))
                    mask[pad_starts:] = False  # Set the padded part to 0 in the attention mask
                else:
                    tokens = tokens[:self.max_seq_length - 1] + [vocab['<eos>']]

                # Append the processed tokens and attention mask
                tokens_arrays.append(tokens)
                attention_masks.append(mask)

        return vocab, tokens_arrays, attention_masks

    def __getitem__(self, item):
        # Tokenize and process the captions
        vocab, tokens_arrays, attention_masks = self.__preprocess_caption__(self.data)
        image = self.__preprocess_image__(self.data[item]['image_path'])

        # Get the relevant tokenized captions and attention masks for this item
        out_dict = {
            "image": image,
            "caption_tokens": torch.tensor(tokens_arrays[item], dtype=torch.long),
            "captions": self.data[item]['captions'],
            "attention_mask": torch.tensor(attention_masks[item], dtype=torch.bool)
        }

        return out_dict


In [6]:
data = LoadFlickerData('/Volumes/Aviral/Flicker8k_Dataset', lambda x: x.split())
data.load_data()
sample_data = data.get_data()