In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
from tqdm import tqdm
import json

from datasets import load_dataset, Image

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

import sys
sys.path.append("..")
import config

from tokenizer.tokenizer import ByteLevelBPE, TokenizerHF
from dataset.loader import DatasetLoader

In [None]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42)

In [None]:
batch_size_train = config.BATCH_SIZE_TRAIN
batch_size_test = config.BATCH_SIZE_TEST

In [None]:
data_loader = DatasetLoader(dataset_type=config.DATASET, batch_size_train=batch_size_train, batch_size_test=batch_size_test, shuffle_test=True)
data_loader.load_data()

train_dataloader = data_loader.get_train_dataloader()
test_dataloader = data_loader.get_test_dataloader()

In [None]:
# dataset stats
print(f"Number of training samples: {len(train_dataloader.dataset)}")
print(f"Number of test samples: {len(test_dataloader.dataset)}")

# plot some samples from the dataset
data_iter = iter(test_dataloader)
batch = next(data_iter)
image_tensor = batch['pixel_values']
captions = batch['description']
fig, axs = plt.subplots(1, batch_size_test, figsize=(15, 5))
for i in range(batch_size_test):
    print(image_tensor.shape)
    img = image_tensor[i].cpu().permute(1, 2, 0).numpy()
    print('Caption:', captions[i])
    axs[i].imshow(img)
    axs[i].axis('off')
plt.show()

data_iter = iter(train_dataloader)
batch = next(data_iter)
image_tensor = batch['pixel_values']
captions = batch['description']
fig, axs = plt.subplots(1, batch_size_train, figsize=(15, 5))
for i in range(batch_size_train):
    print(image_tensor.shape)
    img = image_tensor[i].cpu().permute(1, 2, 0).numpy()
    print('Caption:', captions[i])
    axs[i].imshow(img)
    axs[i].axis('off')
plt.show()

## Test Tokenizer Implementation on the Dataset

In [None]:
if config.TOKENIZER_TYPE == config.TokenizerType.HF:
    tokenizer = TokenizerHF() # will take special tokens from config
elif config.TOKENIZER_TYPE == config.TokenizerType.BPE:
    special_tokens = [config.SpecialTokens.PAD, config.SpecialTokens.BOS, config.SpecialTokens.EOS]
    tokenizer = ByteLevelBPE(special_tokens=special_tokens)
    tokenizer.load(folder=config.TOKENIZER_DATA_PATH, filename_prefix=config.TOKENIZER_FILENAME_PREFIX)
pad_idx = tokenizer.get_padding_token_id()
vocab_size = tokenizer.get_vocab_size()
print(f"Tokenizer vocab size: {vocab_size}")

## TODO
* Implement [special tokens](https://en.wikipedia.org/wiki/Byte-pair_encoding):
    - `<PAD>` for padding
    - `<BOS>` for beginning of sentence
    - `<EOS>` for end of sentence
* Tokenizer should firstly assign indices to special tokens, then to the rest of the vocabulary
* During encoding, tokenizer first tokenizes the text, then:
  * If sequence is longer than `MAX_TEXT_SEQUENCE_LENGTH - 2`, it truncates it to this length.
  * It adds `<BOS>` at the beginning of the token sequence.
  * It adds `<EOS>` tokens at the end of the token sequence.
  * If the resulting sequence is shorter than `MAX_TEXT_SEQUENCE_LENGTH`, it pads the token sequence with `<PAD>` tokens.
  
```python
# in tokenizer
tokens = tokens[:MAX_TEXT_SEQUENCE_LENGTH - 2]
tokens = [BOS_ID] + tokens + [EOS_ID]
tokens += [PAD_ID] * (MAX_TEXT_SEQUENCE_LENGTH - len(tokens))
```

## Test Updated Tokenizer

In [None]:
samples = 10
max_seq_length = 5

bos_id = tokenizer.token_to_id(config.SpecialTokens.BOS.value)
eos_id = tokenizer.token_to_id(config.SpecialTokens.EOS.value)
print(f'BOS ID: {bos_id}, EOS ID: {eos_id}, PAD ID: {pad_idx}')

for batch in train_dataloader:
    if samples < 0:
        break
    desc_batch = batch['description']
    for desc in desc_batch:
        print('Description: ', desc)
        print('Tokenized: ', tokenizer.tokenize(desc))
        encoded = tokenizer.encode(desc, max_seq_length=max_seq_length, verbose=True)['input_ids']
        if encoded.dim() > 1:
            encoded = encoded.squeeze(0)
        print(f'Encoded: {encoded}')
        print(encoded.shape)
        decoded = tokenizer.decode(encoded)
        print(f'Decoded: {decoded}\n')
        if isinstance(decoded, list):
            decoded = decoded[0]
        decoded_stripped = tokenizer.strip(decoded)
        # cut desc to be the same length as decoded_stripped
        desc_cut = desc[:len(decoded_stripped)]
        assert desc_cut == decoded_stripped, "Decoded text does not match original!"
    samples -= 1

In [None]:
# test batched encoding and decoding
samples = 10
for batch in train_dataloader:
    if samples < 0:
        break
    desc = batch['description']
    print('Description: ', desc)
    encoded = tokenizer.encode_batched(desc, max_seq_length=max_seq_length, verbose=False)['input_ids']
    print(f'Encoded: {encoded}')
    print(encoded.shape)
    decoded = tokenizer.decode_batched(encoded)
    assert not isinstance(decoded, torch.Tensor)
    print(f'Decoded: {decoded}\n')
    for i in range(len(decoded)):
        decoded_i = decoded[i]
        decoded_stripped = tokenizer.strip(decoded_i)
        # cut desc to be the same length as decoded_stripped
        desc_cut = desc[i][:len(decoded_stripped)]
        assert desc_cut == decoded_stripped, "Decoded text does not match original!"
    samples -= 1

In [None]:
# test all train and test samples
max_seq_length = config.MAX_TEXT_SEQUENCE_LENGTH

for batch in train_dataloader:
    desc = batch['description']
    encoded = tokenizer.encode_batched(desc, max_seq_length=max_seq_length, verbose=False)['input_ids']
    decoded = tokenizer.decode_batched(encoded)
    assert not isinstance(decoded, torch.Tensor)
    for i in range(len(decoded)):
        decoded_i = decoded[i]
        decoded_stripped = tokenizer.strip(decoded_i)
        desc_cut = desc[i][:len(decoded_stripped)]
        assert desc_cut == decoded_stripped, "Decoded (train) text does not match original!"
for batch in test_dataloader:
    desc = batch['description']
    encoded = tokenizer.encode_batched(desc, max_seq_length=max_seq_length, verbose=False)['input_ids']
    decoded = tokenizer.decode_batched(encoded)
    assert not isinstance(decoded, torch.Tensor)
    for i in range(len(decoded)):
        decoded_i = decoded[i]
        decoded_stripped = tokenizer.strip(decoded_i)
        desc_cut = desc[i][:len(decoded_stripped)]
        assert desc_cut == decoded_stripped, "Decoded (test) text does not match original!"