In [1]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image
#import imageio, skimage

import torch

from collections import Counter

## Example from VirTex

code from `arch-pre-training/virtex/data/datasets/coco_captions.py`

In [2]:
# code from virtex/data/datasets/coco_captions.py

from collections import defaultdict
import json
import os
from typing import Dict, List

import cv2
from torch.utils.data import Dataset

## Unified Dataset Class for ARCH Dataset

In [3]:
class ArchCaptionsDatasetRaw(Dataset):
    r"""
    A PyTorch dataset to read ARCH dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        source: Name of ARCH source to read. One of ``{"pubmed", "books", "both"}``. Default value: "both".
        split:  Name of ARCH split to read. One of ``{"train", "val", "all"}``.
    """

    def __init__(self, data_root: str, source: str='both', split: str=''):
        allowed_source_values = ['pubmed', 'books', 'both']
        assert source in allowed_source_values, f"source should be one of {allowed_source_values}"
        allowed_split_values = ['train', 'val', 'all']
        assert split in allowed_split_values, f"split should be one of {allowed_split_values}"

        # Get path to the annotation file
        captions = json.load(
            open(os.path.join(data_root, "annotations", f"captions_{split}.json"))
        )
        
        # Collect list of uuids and file paths for each caption
        captions_to_uuids: Dict[str, List[str]] = defaultdict(list)
        captions_to_image_filepaths: Dict[str, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            if (source == "both") or (source == ann['source']):
                # if source="both", then no filtering needed
                # if source is one of the ["books", "pubmed"], LHS=False, RHS will filter the needed captions

                # make a check that the image exist before adding its `uuid` or `path`
                assert os.path.exists(ann['path']), f"{ann['path']} does not exist!"
                
                captions_to_uuids[ann['caption']].append(ann['uuid'])
                captions_to_image_filepaths[ann['caption']].append(ann['path'])
        #print(captions_per_image)

        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(list[image_id], list[file_path], captions)``.
        self.instances = [
            (captions_to_uuids[caption], captions_to_image_filepaths[caption], caption)
            for caption in captions_to_image_filepaths.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_ids, image_paths, caption = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        images = [cv2.imread(image_path) for image_path in image_paths]
        # cv2.imread loads images in BGR (blue, green, red) order
        images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images]

        return {"image_ids": image_ids, "images": images, "caption": caption}

In [4]:
arch_books_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                            source="books",
                                            split='all')
len(arch_books_dataset.instances)

3210

In [5]:
arch_pubmed_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                             source="pubmed",
                                             split="all")
len(arch_pubmed_dataset.instances)

3285

In [6]:
arch_dataset_raw = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                          source="both",
                                          split="train")
len(arch_dataset_raw.instances)

5196

## Unified Dataset Class + augmentations and collate function

In [7]:
import random
from typing import Callable, Dict, List

import albumentations as alb
import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
from virtex.data.datasets.coco_captions import CocoCaptionsDataset


class ArchCaptioningDatasetExtended(Dataset):
    r"""
    A dataset which provides image-caption (forward and backward) pairs from
    a ARCH Captions annotation file. This is used for pretraining tasks which
    use captions - bicaptioning, forward captioning and token classification.

    Args:
        data_root: Path to dataset directory containing images and annotations.
        source: Name of ARCH source to read. One of ``{"pubmed", "books", "both"}``.
            "both" option results in a concatenation of the datasets from "pubmed" and "books"
        split: Name of ARCH split to read. One of ``{"train", "val", "all"}``.
        tokenizer: Tokenizer which maps word tokens to their integer IDs.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
        max_caption_length: Maximum number of tokens to keep in caption tokens.
            Extra tokens will be trimmed from the right end of the token list.
    """

    def __init__(
        self,
        data_root: str,
        source: str,
        split: str,
        tokenizer: SentencePieceBPETokenizer,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        max_caption_length: int = 30,
    ):
        self._dset = ArchCaptionsDatasetRaw(data_root=data_root, source=source, split=split)
        self.image_transform = image_transform
        self.caption_transform = alb.Compose(
            [
                T.NormalizeCaption(),
                T.TokenizeCaption(tokenizer),
                T.TruncateCaptionTokens(max_caption_length),
            ]
        )
        self.padding_idx = tokenizer.token_to_id("<unk>")

    def __len__(self):
        return len(self._dset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        # keys: {"image_id", "image", "captions"}
        instance = self._dset[idx]
        image_ids, images, caption = (
            instance["image_ids"],
            instance["images"],
            instance["caption"],
        )
        
        # list of np.arrays -> torch tensor
        # transformation needs to be done before putting into tensor
        # (need the same size for each image to put into tensor)
        # 
        # TODO: think how to do apply image_transform to each image before putting into tensor
        images = torch.stack([torch.from_numpy(image) for image in images], dim=0)
        
        print(image_ids)
        print(images)
        print(caption)

        # Transform image-caption pair and convert image from HWC to CHW format.
        # Pass in caption to image_transform due to paired horizontal flip.
        # Caption won't be tokenized/processed here.
        images_caption = self.image_transform(image=images, caption=caption)
        images, caption = images_caption["images"], image_caption["caption"]
        image = np.transpose(image, (2, 0, 1))

        caption_tokens = self.caption_transform(caption=caption)["caption"]
        return {
            "images_ids": torch.tensor(image_ids, dtype=torch.str),
            "images": torch.tensor(image, dtype=torch.float),
            "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long),
            "noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0),
            "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        # Pad `caption_tokens` and `masked_labels` up to this length.
        caption_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["caption_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["noitpac_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_ids"] for d in data], dim=0),
            "image": torch.stack([d["images"] for d in data], dim=0),
            "caption_tokens": caption_tokens,
            "noitpac_tokens": noitpac_tokens,
            "caption_lengths": torch.stack([d["caption_lengths"] for d in data]),
        }


In [8]:
arch_tokenizer = SentencePieceBPETokenizer("../datasets/vocab/arch_10k.model")

arch_dataset_extended = ArchCaptioningDatasetExtended(data_root='../datasets/ARCH',
                                                      source="both", split="train",
                                                      tokenizer=arch_tokenizer)
len(arch_dataset_extended)

5196

In [9]:
arch_dataset_extended.__getitem__(0)

RuntimeError: stack expects each tensor to be equal size, but got [671, 907, 3] at entry 0 and [674, 910, 3] at entry 1