In [1]:
import os
import torch
import pandas as pd
from tqdm import tqdm
from torchvision.transforms.v2 import functional as v2f
from torchvision import io
from torch.nn.utils.rnn import pad_sequence
from collections import Counter

In [6]:
data = pd.read_csv("drive/MyDrive/data/flickr8k/captions.txt")
data.head(10)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...
5,1001773457_577c3a7d70.jpg,A black dog and a spotted dog are fighting
6,1001773457_577c3a7d70.jpg,A black dog and a tri-colored dog playing with...
7,1001773457_577c3a7d70.jpg,A black dog and a white dog with brown spots a...
8,1001773457_577c3a7d70.jpg,Two dogs of different breeds looking at each o...
9,1001773457_577c3a7d70.jpg,Two dogs on pavement moving toward each other .


Create Vocab

In [7]:
data = data["caption"].to_list()
data = " ".join(data).lower()
data = Counter(data.split(" "))
data = dict(data)

token2int = {
    "<unk>": 0,
    "<pad>": 1,
    "<start>": 2,
    "<end>": 3,
}
int2token = {
    0: "<unk>",
    1: "<pad>",
    2: "<start>",
    3: "<end>",
}

for key, value in data.items():
    if value >= 10:
        token2int[key] = len(token2int)
        int2token[len(int2token)] = key


print(f"Number tokens in vocab: {len(token2int)}")
torch.save({
    "token2int": token2int,
    "int2token": int2token,
}, "model/vocab.pt")
print("Vocab save!")

Number tokens in vocab: 1967
Vocab save!


Save dataset as bin

In [8]:
def txt2vec(txt: str) -> torch.Tensor:
    txt = txt.lower().split(" ")
    txt = map(lambda word: token2int.get(word, token2int["<unk>"]), txt)
    txt = [token2int["<start>"], *list(txt), token2int["<end>"]]
    txt = torch.tensor(txt, dtype=torch.long)

    return txt

In [13]:
captions = pd.read_csv("drive/MyDrive/data/flickr8k/captions.txt")
size = [255, 255] # size images
x, y = [], []

for key, value in tqdm(captions.groupby("image", 0), "Processing"):
    img_path = os.path.join("drive/MyDrive/data/flickr8k/Images", key)
    img = io.read_image(img_path)
    img = v2f.resize(img, size, antialias=True)
    x.append(img)

    captions = value["caption"].to_list()
    captions = list(map(txt2vec, captions))
    captions = torch.cat(captions)
    y.append(captions)


x = torch.stack(x)
y = pad_sequence(y, padding_value=token2int["<pad>"], batch_first=True)
torch.save({
    "x": x, "y": y,
}, "/content/drive/MyDrive/data/flickr8k/flickr8k.bin")

Processing: 100%|██████████| 8091/8091 [1:12:00<00:00,  1.87it/s]
