In [27]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms

In [21]:
spacy_eng=spacy.load("en")

class Vocabulary:
  def __init__(self, freq_threshold):
    self.itos= {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
    self.stoi= {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
    self.freq_threshold= freq_threshold

  def __len__(self):
    return len(self.itos)

  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies={}
    idx=4

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word]=1
        else:
          frequencies[word]+=1

        if frequencies[word]==self.freq_threshold:
          self.stoi[word]=idx
          self.itos[idx]=word
          idx+=1

  def numericalize(self,text):
    tokenized_text=self.tokenizer_eng(text)

    return [
          self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
          for token in tokenized_text
    ]

In [19]:
class FlickrDataset(Dataset):
  def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
    self.root_dir=root_dir
    self.df=pd.read_csv(captions_file)
    self.transform=transform

    self.imgs=self.df["image"]
    self.captions= self.df["caption"]

    self.vocab=Vocabulary(freq_threshold)
    self.vocab.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.df)

  def __getitem__(self,index):
    caption=self.captions[index]
    img_id=self.imgs[index]
    img=Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

    if self.transform is not None:
      img=self.transform(img)

    numericalized_caption= [self.vocab.stoi["<SOS>"]]
    numericalized_caption+= self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)

In [23]:
class MyCollate:
  def __init__(self, pad_idx):
    self.pad_idx=pad_idx

  def __call__(self,batch):
    imgs=[item[0].unsqueeze(0) for item in batch]
    imgs=torch.cat(imgs, dim=0)
    targets=[item[1] for item in batch]
    targets=pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return imgs, targets

def get_loader(root_folder, annotation_file, transform, batch_size=32, num_workers=8, shuffle=True, pin_memory=True):
  dataset= FlickrDataset(root_folder, annotation_file, transform=transform)

  pad_idx=dataset.vocab.stoi["<PAD>"]
  loader=DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, 
                    collate_fn=MyCollate(pad_idx=pad_idx),)
    
  return loader

In [6]:
%cd "/content/drive/MyDrive/Pytorch"

/content/drive/MyDrive/Pytorch


In [7]:
ls

[0m[01;34mflickr8k[0m/


In [37]:
def main():
  transform=transforms.Compose(
      [
       transforms.Resize((224,224)),
       transforms.ToTensor(),
      ]
  )

  dataloader=get_loader("flickr8k/images/", annotation_file="flickr8k/captions.txt", transform=transform)

  for idx, (imgs, captions) in enumerate(dataloader):
    print(imgs.shape)
    print(captions.shape)

In [38]:
if __name__=="__main__":
  main()

  cpuset_checked))
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3ca8a059e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3ca8a059e0>
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3ca8a059e0>
Traceback (most recent call last):
AssertionError: can only test a child process
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-pa

torch.Size([32, 3, 224, 224])
torch.Size([25, 32])
torch.Size([32, 3, 224, 224])
torch.Size([24, 32])
torch.Size([32, 3, 224, 224])
torch.Size([28, 32])
torch.Size([32, 3, 224, 224])
torch.Size([36, 32])
torch.Size([32, 3, 224, 224])
torch.Size([31, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([27, 32])
torch.Size([32, 3, 224, 224])
torch.Size([28, 32])
torch.Size([32, 3, 224, 224])
torch.Size([26, 32])
torch.Size([32, 3, 224, 224])
torch.Size([28, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])


KeyboardInterrupt: ignored