In [11]:

from torchvision.datasets import VisionDataset, Flickr30k
import torch
from torch.utils.data import DataLoader, Dataset


import os
from PIL import Image
import numpy as np
import nltk
from collections import Counter
import pickle
from tqdm import tqdm
import re
import torchvision.transforms as transforms
import json
from torchvision.datasets.utils import download_url


In [13]:
transform = transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(), 
            transforms.Normalize((0.444, 0.421, 0.385), 
                                 (0.285, 0.277, 0.286))])

In [5]:
batch_size = 32
shuffle_img = False
num_workers = 2




In [63]:
def pre_caption(caption,max_words=128):
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
            
    return caption

In [125]:
class flickr30k_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=128, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. data/)
        ann_root (string): directory to store the annotation file
        '''        
        url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
        filename = 'flickr30k_train.json'

        download_url(url,ann_root)
        
        self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words      
        self.prompt = prompt
        
        self.img_ids = {}  
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(self.image_root,ann['image'])        
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words)
        
        #return image, caption, self.img_ids[ann['image_id']] 
        return image, caption

In [126]:
ds = flickr30k_train(transform=transform,image_root="",ann_root="./flickr30k")

Using downloaded and verified file: ./flickr30k/flickr30k_train.json


In [127]:
image,caption = ds.__getitem__(0)

In [128]:
type(caption)

str

In [138]:
def collate_fn(batch):
    imgs = torch.stack([item[0] for item in batch])
    caps = [item[1] for item in batch]
    #for img, cap in batch:
    
    #    x.append(x_)
    #    caps.append(cap)
    return imgs, caps

In [139]:
dataloader = DataLoader(ds,batch_size=batch_size,shuffle=shuffle_img,collate_fn=collate_fn)

In [140]:
img, cap = next(iter(dataloader))


In [141]:
print(next(iter(dataloader)))

(tensor([[[[-1.2414, -1.3928, -1.3790,  ..., -0.7323, -0.8011, -0.4296],
          [-1.4203, -1.5579, -1.3377,  ..., -0.7185, -0.5809, -0.4709],
          [-1.4065, -1.3790, -1.3377,  ..., -0.7323, -0.7323, -0.6635],
          ...,
          [-0.6085, -0.6085,  0.6437,  ...,  1.4693,  1.7582,  1.3454],
          [ 0.3960, -0.0443, -0.2645,  ...,  1.5656,  1.4418,  1.7170],
          [-0.3333, -0.8561, -0.3470,  ...,  1.6894,  1.4142,  1.4280]],

         [[-0.4864, -0.7695, -0.8120,  ...,  0.8019,  0.9010,  1.0426],
          [-1.0385, -1.1376, -0.6563,  ...,  0.7878,  0.9010,  1.0709],
          [-1.2650, -1.0102, -0.7695,  ...,  0.8869,  0.8869,  1.0568],
          ...,
          [ 0.2923,  0.2781,  1.4673,  ...,  1.3824,  1.6231,  1.5381],
          [ 0.4905,  0.6462,  0.5188,  ...,  1.7080,  1.3682,  1.7929],
          [ 0.2923, -0.0900,  0.4480,  ...,  1.8071,  1.3541,  1.5240]],

         [[-1.0445, -1.2227, -1.2639,  ...,  0.9437,  1.1768,  1.2042],
          [-1.2365, -1.3187, 

In [142]:
type(cap)

list

In [143]:
type(img)

torch.Tensor

In [144]:
img

tensor([[[[-1.3653, -0.8699, -1.3377,  ...,  1.4555,  1.4005,  1.4968],
          [-1.3653, -1.5441, -1.3515,  ...,  1.5106,  1.5106,  1.3592],
          [-1.3240, -1.4203, -1.3790,  ...,  1.4418,  1.5106,  0.9326],
          ...,
          [ 1.3179,  0.3685, -1.5029,  ..., -1.5441, -1.5166, -1.5579],
          [-0.4296,  0.9464, -0.2507,  ..., -1.3240, -1.2552, -1.5304],
          [-0.2232,  0.9877,  1.1666,  ..., -1.3928, -1.4616, -1.5441]],

         [[-1.1235, -0.6138, -1.2650,  ...,  2.0761,  1.9912,  2.0478],
          [-1.1376, -1.1801, -1.2226,  ...,  2.0195,  2.0903,  2.0761],
          [-1.2367, -1.2226, -1.2084,  ...,  2.0195,  2.0903,  1.5523],
          ...,
          [ 1.6514,  0.8586, -1.3783,  ..., -1.5199, -1.5199, -1.5057],
          [-0.0900,  1.1417,  0.2640,  ..., -1.3075, -1.3358, -1.2792],
          [ 0.2498,  1.2691,  1.5947,  ..., -1.3075, -1.2367, -1.3075]],

         [[-1.1953, -0.7017, -1.2913,  ...,  2.1503,  1.9310,  1.9721],
          [-1.2090, -1.2913, -

In [145]:
cap

['two young guys with shaggy hair look at their hands while hanging out in the yard',
 'two young, white males are outside near many bushes',
 'two men in green shirts are standing in a yard',
 'a man in a blue shirt standing in a garden',
 'two friends enjoy time spent together',
 'several men in hard hats are operating a giant pulley system',
 'workers look down from up above on a piece of equipment',
 'two men working on a machine wearing hard hats',
 'four men on top of a tall structure',
 'three men on a large rig',
 'a child in a pink dress is climbing up a set of stairs in an entry way',
 'a little girl in a pink dress going into a wooden cabin',
 'a little girl climbing the stairs to her playhouse',
 'a little girl climbing into a wooden playhouse',
 'a girl going into a wooden building',
 'someone in a blue shirt and hat is standing on stair and leaning against a window',
 'a man in a blue shirt is standing on a ladder cleaning a window',
 'a man on a ladder cleans the window 