In [23]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO


class CocoDataset(data.Dataset):
    """
    COCO Custom Dataset compatible with torch.utils.data.DataLoader.
    自定义COCO dataset
    """
    def __init__(self, root, json, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """
        Returns one data pair (image and caption).  
        Return:
             image: 经过transform处理 [3, 224, 224]
             target: 经过序列化的caption
        """
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]                            # 获取每一条annotation的id
        caption = coco.anns[ann_id]['caption']       # 找到该annotation的caption
        img_id = coco.anns[ann_id]['image_id']      # 找到该annotation的image_id
        path = coco.loadImgs(img_id)[0]['file_name']   # 根据img_id找到文件名
        
        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())   # 提取每个单词并小写
        caption = []
        caption.append(vocab('<start>'))                               # 添加开始和结束的标识
        caption.extend([vocab(token) for token in tokens])      # 序列化caption
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    因为每个图片的caption长度不一致, 所以无法打包成mini-batch, 这里我们要手动增加pad补齐
    在有多个标签的时候, 也需要自定义该方法
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).  padded_length是该batch中最长的caption长度
        lengths: list; valid length for each padded caption. 每个caption的有效长度
    """
    # Sort a data list by caption length (descending order).  根据caption长度, 由大到小排列
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)                # 使image和caption分成两个list, 

    # Merge images (from tuple of 3D tensor to 4D tensor). 增加一个维度, 为了制造mini-batch
    images = torch.stack(images, 0)      #  [batch_size, 3, 224, 224]

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]                          # list, 记录每个caption的长度
    targets = torch.zeros(len(captions), max(lengths)).long()    # 生成 [128, 18] 全零矩阵, 用于填充
    for i, cap in enumerate(captions):
        end = lengths[i]            # 每个caption的结尾位置
        targets[i, :end] = cap[:end]        # 覆盖每一行的前面部分     
    return images, targets, lengths



# 定义参数

In [24]:
dataset_root = '/Volumes/SD/Dataset/'
vocab_path = dataset_root+'coco/vocab.pkl'   # path for vocabulary wrapper'
image_dir = dataset_root+'coco/resized2014'  # directory for resized images'
caption_path = dataset_root+'coco/annotations/captions_train2014.json'     # path for train annotation json file'

batch_size = 128
num_workers = 2
# Load vocabulary wrapper
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])

# get_loader

In [25]:
coco = CocoDataset(root=image_dir,
                   json=caption_path,
                   vocab=vocab,
                   transform=transform)

# Data loader for COCO dataset
# This will return (images, captions, lengths) for each iteration.
# images: a tensor of shape (batch_size, 3, 224, 224).
# captions: a tensor of shape (batch_size, padded_length).
# lengths: a list indicating valid length for each caption. length is (batch_size).

data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=2,
                                          collate_fn=collate_fn)

loading annotations into memory...
Done (t=1.71s)
creating index...
index created!


# 测试 collate-fn方法

In [33]:
coco = COCO(caption_path)
data = []
for index in range(10):
    ann_id = list(coco.anns.keys())[index]                            # 获取每一条annotation的id
    caption = coco.anns[ann_id]['caption']       # 找到该annotation的caption
    img_id = coco.anns[ann_id]['image_id']      # 找到该annotation的image_id
    path = coco.loadImgs(img_id)[0]['file_name']   # 根据img_id找到文件名

    image = Image.open(os.path.join(image_dir, path)).convert('RGB')
    image = transform(image)

    # Convert caption (string) to word ids.
    tokens = nltk.tokenize.word_tokenize(str(caption).lower())   # 提取每个单词并小写
    caption = []
    caption.append(vocab('<start>'))                               # 添加开始和结束的标识
    caption.extend([vocab(token) for token in tokens])      # 序列化caption
    caption.append(vocab('<end>'))
    target = torch.Tensor(caption)
    
    data.append([image, target])

loading annotations into memory...
Done (t=0.83s)
creating index...
index created!


In [41]:
# Sort a data list by caption length (descending order).  根据caption长度, 由大到小排列
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions = zip(*data)                # 使image和caption分成两个list, 

# Merge images (from tuple of 3D tensor to 4D tensor). 增加一个维度, 为了制造mini-batch
images = torch.stack(images, 0)      #  [batch_size, 3, 224, 224]

# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in captions]                          # list, 记录每个caption的长度
targets = torch.zeros(len(captions), max(lengths)).long()    # 生成 [128, 18] 全零矩阵, 用于填充
for i, cap in enumerate(captions):
    end = lengths[i]            # 每个caption的结尾位置
    targets[i, :end] = cap[:end]        # 覆盖每一行的前面部分

torch.Size([10, 3, 224, 224])
