In [1]:

from torchvision.datasets import VisionDataset, Flickr30k
import torch
from torch.utils.data import DataLoader, Dataset


import os
from PIL import Image
import numpy as np
import nltk
from collections import Counter
import pickle
from tqdm import tqdm
import re
import torchvision.transforms as transforms
import json
from torchvision.datasets.utils import download_url


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(), 
            transforms.Normalize((0.444, 0.421, 0.385), 
                                 (0.285, 0.277, 0.286))])

In [4]:
def pre_caption(caption,max_words=128):
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
            
    return caption


In [6]:
class flickr30k(Dataset):
    def __init__(self, transform, image_root, ann_root, split, max_words=128, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. data/)
        ann_root (string): directory to store the annotation file
        split (string): one of "train" or "test"
        '''        
        train = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
        test = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'
        filename = 'flickr30k_train.json'

        self.split = split
        assert self.split in ("train","test")

        if self.split == "train":
            url = train
        else:
            url = test

        download_url(url,ann_root)
        
        self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.prompt = prompt
        
        self.img_ids = {}  
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(self.image_root,ann['image'])        
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words)
        
        #return image, caption, self.img_ids[ann['image_id']] 
        return image, caption

In [7]:
ds = flickr30k(transform=transform,image_root="",ann_root="./flickr30k",split="train")

Using downloaded and verified file: ./flickr30k/flickr30k_train.json


In [8]:
image,caption = ds.__getitem__(0)

In [12]:
print(type(caption))
print(len(caption))
print(caption)

<class 'str'>
81
two young guys with shaggy hair look at their hands while hanging out in the yard


In [14]:
print(type(image))
print(image.shape)
print(image)

<class 'torch.Tensor'>
torch.Size([3, 224, 224])
tensor([[[-1.5029, -1.4341, -1.5029,  ..., -1.5579, -1.4478, -1.3102],
         [-1.4478, -1.4065, -1.4891,  ..., -1.2277, -1.0488, -1.1451],
         [-1.4065, -1.4753, -1.4616,  ..., -0.7736, -1.0625, -1.3515],
         ...,
         [-1.0763, -0.4158, -0.4158,  ..., -1.3928, -0.5397, -1.4616],
         [-0.8699, -1.0901, -0.0581,  ..., -1.4478, -0.8837, -1.3928],
         [-1.0488, -1.2277, -1.0075,  ..., -1.3653, -1.0213, -0.2645]],

        [[-1.2933, -1.1801, -1.1518,  ..., -1.1235, -1.1942, -0.6279],
         [-1.2933, -1.1942, -1.2226,  ..., -0.8828, -0.1041, -0.0900],
         [-1.2792, -1.1518, -1.3358,  ...,  0.1790,  0.0233, -0.5288],
         ...,
         [-0.9819, -0.6704, -0.6279,  ..., -0.9252, -0.0050, -1.2650],
         [-0.7129, -1.2509, -0.2882,  ..., -0.9252, -0.5430, -0.8545],
         [-0.9960, -1.3358, -1.1093,  ..., -1.0951, -0.8403, -0.1749]],

        [[-1.2639, -1.2776, -1.2776,  ..., -1.3462, -1.2913, -1.071

In [38]:
def collate_test(batch):
    imgs = torch.stack([item[0] for item in batch])
    caps = [item[1] for item in batch]
    #for img, cap in batch:
    
    #    x.append(x_)
    #    caps.append(cap)
    return imgs, caps

In [20]:
batch_size = 2
shuffle_img = False
num_workers = 2


In [33]:
def collate_fn(batch):
    imgs = torch.stack([item[0] for item in batch])
    caps = [item[0] for item in batch]
    imgs = [item[1] for item in batch]
    #for img, cap in batch:
    
    #    x.append(x_)

    #    caps.append(cap)
    return caps, imgs

In [39]:
dataloader = DataLoader(ds,batch_size=batch_size,shuffle=shuffle_img,collate_fn=collate_test)
#dataloader = DataLoader(ds,batch_size=batch_size,shuffle=shuffle_img)

In [40]:
img, cap = next(iter(dataloader))


In [41]:
print(type(cap))
print(len(cap))
print(cap)

<class 'list'>
2
['two young guys with shaggy hair look at their hands while hanging out in the yard', 'two young, white males are outside near many bushes']


In [42]:
print(type(img))
print(img.shape)
#print(img)

<class 'torch.Tensor'>
torch.Size([2, 3, 224, 224])


In [28]:
from transformers import BertTokenizerFast

In [43]:
text_model_name = "prajjwal1/bert-medium"

tokenizer = BertTokenizerFast.from_pretrained(text_model_name)

In [45]:
encoded = tokenizer.batch_encode_plus(cap,add_special_tokens=True,max_length=128,padding="max_length",return_attention_mask=True,return_tensors="pt")

In [46]:
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']

In [47]:
input_ids

tensor([[  101,  2048,  2402,  4364,  2007, 25741,  2606,  2298,  2012,  2037,
          2398,  2096,  5689,  2041,  1999,  1996,  4220,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  