In [None]:

import os
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image  # Load img
import torchvision.transforms as transforms
import spacy  # for tokenizer

In [None]:
# 下載英文的字典庫
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    # 這邊是先建立自己的字典，後面才會繼續增加
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    # 把一文字先做 tokenizer 切成token ，再把token換成小寫
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        # 每個句子
        for sentence in sentence_list:
            # 把句子透過tokenizer 轉換成words
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1
                # words 要出現夠多次才會被加入到 vocabulary
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    # 這個就是文字轉數值的地方，簡單來說文字先看 stoi 裡面有沒有，如果沒有的話就回傳<UNK>的數值
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7133e58440>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1474, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


In [None]:

# 這邊是建立圖片跟文字之間的關係
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # 載入圖片跟敘述
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # 然後這邊就是設定 Vocab 的頻率threshold
        self.vocab = Vocabulary(freq_threshold)
        # 然後這邊就是設定把文字的部分丟進去建立字典
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)
        # SOS => start of sentence
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        # 這邊就是轉換成向量
        numericalized_caption += self.vocab.numericalize(caption)
        # EOS => end of sentence
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        # 因此這裡就是一個影像跟 一排vactor 的輸出（vactor 是文字轉出來的)
        return img, torch.tensor(numericalized_caption)



In [None]:

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        # 把他填充到等長
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets

In [None]:

def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.dataset.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset


In [None]:
%%time
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
CPU times: user 43.1 ms, sys: 7.34 ms, total: 50.4 ms
Wall time: 2.74 s


<__main__.Vocabulary object at 0x7f70b7136110>


In [None]:

if __name__ == "__main__":
    transform = transforms.Compose(
        [transforms.Resize((224, 224)), transforms.ToTensor(),]
    )

    loader, dataset = get_loader(
        "/content/drive/MyDrive/Colab Notebooks/ithome/Flickr8k/Images/", 
        "/content/drive/MyDrive/Colab Notebooks/ithome/Flickr8k/captions.txt", 
        transform=transform
    )
    print(len(dataset.vocab))

    for idx, (imgs, captions) in enumerate(loader.head(0)):
        print(imgs.shape)
        print(captions.shape)

2994


  cpuset_checked))


AttributeError: ignored