In [22]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [14]:
!python -m spacy download en

[38;5;3m⚠ As of spaCy v3.0, shortcuts like 'en' are deprecated. Please use the
full pipeline package name 'en_core_web_sm' instead.[0m
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.3.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl (12.8 MB)
[K     |████████████████████████████████| 12.8 MB 7.0 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [16]:
spacy_en = spacy.load("en_core_web_sm")

In [12]:
pd.read_csv('/content/data/flickr8k/captions.txt')

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...
...,...,...
40450,997722733_0cb5439472.jpg,A man in a pink shirt climbs a rock face
40451,997722733_0cb5439472.jpg,A man is rock climbing high in the air .
40452,997722733_0cb5439472.jpg,A person in a red shirt climbing up a rock fac...
40453,997722733_0cb5439472.jpg,A rock climber in a red shirt .


In [19]:
class Vocabulary:
  def __init__(self, frequency_threshold):
    self.itos = {
        0: '<PAD>',
        1: '<SOS>',
        2: '<EOS>',
        3: '<UNK>'
    }

    self.stoi = {
        '<PAD>': 0,
        '<SOS>': 1,
        '<EOS>': 2,
        '<UNK>': 3
    }

    self.frequency_threshold = frequency_threshold

  def __len__(self):
    return len(self.itos)
  
  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4
    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1
        if frequencies[word] == self.frequency_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text)
    return [
            self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
            for token in tokenized_text
    ]


In [20]:
class FlickerDataset(Dataset):
  def __init__(self, root_dir, caption_file, transform=None, frequency_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(caption_file)

    self.imgs = self.df["image"] 
    self.captions = self.df['caption']
    self.transform = transform

    # initialize vocab
    self.vocab = Vocabulary(frequency_threshold)
    self.vocab.build_vocabulary(self.captions.to_list())

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert('RGB')

    if self.transform is not None:
      img = self.transform(img)

    numericalized_caption = [self.vocab.stoi['<SOS>']] # SOS is the first start representation of caption
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi['<EOS>']) # End Of Sentence
    numericalized_caption = torch.tensor(numericalized_caption)
    
    return img, numericalized_caption


In [21]:
# with collate_fn in DataLoader we can modify the details for every batch:
# like givig each batach a variable shape of input
class MyCollate: 
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, padding_value=self.pad_idx)

    return imgs, targets

In [24]:
transform = transforms.Compose(
    [
     transforms.Resize((224,224)),
     transforms.ToTensor(),
    ]
)
dataset = FlickerDataset(root_dir="/content/data/flickr8k/images", caption_file='/content/data/flickr8k/captions.txt', transform=transform)

In [50]:
pad_idx = dataset.vocab.stoi['<PAD>']
print(pad_idx)
loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=MyCollate(pad_idx))

0


In [49]:
for img, caption in loader:
  for c in range(3):
    print(caption.shape)
    text = [dataset.vocab.itos[i.data.item()] for i in caption[:,c]]
    print(text)
  break

torch.Size([13, 3])
['<SOS>', 'a', 'man', 'surfing', 'a', 'wave', '.', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
torch.Size([13, 3])
['<SOS>', 'the', 'lady', 'on', 'the', 'porch', 'is', 'wearing', 'a', 'brown', 'jacket', '.', '<EOS>']
torch.Size([13, 3])
['<SOS>', 'a', 'man', 'on', 'a', 'bmx', 'bike', 'in', 'midair', '<EOS>', '<PAD>', '<PAD>', '<PAD>']
