In [5]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image
#import imageio, skimage

import torch

from collections import Counter

## Example from VirTex

code from `arch-pre-training/virtex/data/datasets/coco_captions.py`

In [6]:
# code from virtex/data/datasets/coco_captions.py

from collections import defaultdict
import json
import os
from typing import Dict, List

import cv2
from torch.utils.data import Dataset

## Unified Dataset Class for ARCH Dataset

In [7]:
class ArchCaptionsDatasetRaw(Dataset):
    r"""
    A PyTorch dataset to read ARCH dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        source: Name of ARCH source to read. One of ``{"pubmed", "books"}``.
        split:  Name of ARCH split to read. One of ``{"train", "val"}``.
    """

    def __init__(self, data_root: str, source: str, split: str=''):
        assert source in ['pubmed', 'books'], "source should be one of ['pubmed', 'books']"
        assert split in ['train', 'val', ''], "split should be one of ['train', 'val', '']"

#         TODO: change after splitting the caption files into train and validation
#         
#         # Get paths to image directory and annotation file.
#         image_dir = os.path.join(data_root, "pubmed_set/images", f"{split}")
#         captions = json.load(
#             open(os.path.join(data_root, "pubmed_set", f"captions_{split}.json"))
#         )
        
        # Get path to image directory and record the extensions
        image_dir = os.path.join(data_root, f"{source}_set", "images")
        uuids_to_extensions = {
            file_name.split('.')[0]: file_name.split('.')[1]
            for file_name in os.listdir(image_dir)
        }
        
        # Get path to the annotation file
        captions = json.load(
            open(os.path.join(data_root, f"{source}_set", "captions.json"))
        )
        
        # Collect list of uuids and file paths for each caption
        captions_to_uuids: Dict[str, List[str]] = defaultdict(list)
        captions_to_image_filepaths: Dict[str, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            if ann['uuid'] in uuids_to_extensions.keys():
                # uuids_to_extensions contains only image uuids from the image dir
                # this means that only uuids with exisitng images will be added
                captions_to_uuids[ann['caption']].append(ann['uuid'])
                captions_to_image_filepaths[ann['caption']].append(
                    os.path.join(image_dir, f"{ann['uuid']}.{uuids_to_extensions[ann['uuid']]}")
                )
        #print(captions_per_image)

        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(list[image_id], list[file_path], captions)``.
        self.instances = [
            (captions_to_uuids[caption], captions_to_image_filepaths[caption], caption)
            for caption in captions_to_image_filepaths.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_ids, image_paths, caption = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        images = [cv2.imread(image_path) for image_path in image_paths]
        # cv2.imread loads images in BGR (blue, green, red) order
        images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images]

        return {"image_ids": image_ids, "images": images, "caption": caption}

In [11]:
arch_books_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                         source="books")
arch_books_dataset.instances

[(['890e2e79-ab0a-4a2e-9d62-b0b6b3d43884',
   'f12c8088-05a5-41a6-80b8-aa4cfa461236'],
  ['../datasets/ARCH/books_set/images/890e2e79-ab0a-4a2e-9d62-b0b6b3d43884.png',
   '../datasets/ARCH/books_set/images/f12c8088-05a5-41a6-80b8-aa4cfa461236.png'],
  ' A, Spindle cell variant of embryonal rhabdomyosarcoma is characterized by fascicles of eosinophilic spindle cells (B), some of which can show prominent paranuclear vacuolisation, as seen in leiomyosarcoma.'),
 (['9a77b172-74e8-4e64-878f-d26b7c27239f'],
  ['../datasets/ARCH/books_set/images/9a77b172-74e8-4e64-878f-d26b7c27239f.png'],
  ' In the anaplastic variant of embryonal rhabdomyosarcoma, the tumor cells have enlarged hyperchromatic and atypical nuclei. Note the presence of a tripolar mitotic figure.'),
 (['c384e7fc-7b29-4a72-a8b2-0f4f8ff9d536',
   'd8f9e62e-e400-43c7-8202-b5c4033710ba'],
  ['../datasets/ARCH/books_set/images/c384e7fc-7b29-4a72-a8b2-0f4f8ff9d536.png',
   '../datasets/ARCH/books_set/images/d8f9e62e-e400-43c7-8202-b5c

In [12]:
arch_pubmed_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                         source="pubmed")
arch_pubmed_dataset.instances

[(['3f93c716-8fc9-42e9-bc29-bec52a51ab4b'],
  ['../datasets/ARCH/pubmed_set/images/3f93c716-8fc9-42e9-bc29-bec52a51ab4b.jpg'],
  'ER expression in tumor tissue. IHC staining, original'),
 (['9fcdf1e1-139c-4b63-bf1a-79d83c71f41a'],
  ['../datasets/ARCH/pubmed_set/images/9fcdf1e1-139c-4b63-bf1a-79d83c71f41a.jpg'],
  'Nuclear expression of TS (brown) in a colon carcinoma'),
 (['00f1ad7a-f4b0-4938-b874-089d40a123ce'],
  ['../datasets/ARCH/pubmed_set/images/00f1ad7a-f4b0-4938-b874-089d40a123ce.jpg'],
  'Nuclear expression of E2F1 (brown) in a colon carcinoma. This is higher magnification of the upper portion of a core shown in an inset (lower left corner)'),
 (['9d3aef30-7c8b-4b78-9acf-ec523f952650'],
  ['../datasets/ARCH/pubmed_set/images/9d3aef30-7c8b-4b78-9acf-ec523f952650.jpg'],
  'Cytoplasmic immunoexpression of PD-L1 in oral squamous cell carcinomas with poorer prognosis (OSCCPP). Immunohistochemistry. Total magnification x100'),
 (['b317d529-3626-49fc-9282-e4f28cf3d1cb'],
  ['../data

## Unified Dataset Class + augmentations and collate function

In [2]:
import random
from typing import Callable, Dict, List

import albumentations as alb
import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
from virtex.data.datasets.coco_captions import CocoCaptionsDataset


class ArchCaptioningDatasetExtended(Dataset):
    r"""
    A dataset which provides image-caption (forward and backward) pairs from
    a ARCH Captions annotation file. This is used for pretraining tasks which
    use captions - bicaptioning, forward captioning and token classification.

    Args:
        data_root: Path to dataset directory containing images and annotations.
        source: Name of ARCH source to read. One of ``{"pubmed", "books", "both"}``.
            "both" option results in a concatenation of the datasets from "pubmed" and "books"
        split: Name of ARCH split to read. One of ``{"train", "val"}``.
        tokenizer: Tokenizer which maps word tokens to their integer IDs.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
        max_caption_length: Maximum number of tokens to keep in caption tokens.
            Extra tokens will be trimmed from the right end of the token list.
    """

    def __init__(
        self,
        data_root: str,
        source: str,
        split: str,
        tokenizer: SentencePieceBPETokenizer,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        max_caption_length: int = 30,
    ):
        if source != "both":
            self._dset = ArchCaptionsDatasetRaw(data_root, source, split)
        else:
            d1 = ArchCaptionsDatasetRaw(data_root, "books", split)
            d2 = ArchCaptionsDatasetRaw(data_root, "pubmed", split)
            self._dset = d1.__add__(d2)
        self.image_transform = image_transform
        self.caption_transform = alb.Compose(
            [
                T.NormalizeCaption(),
                T.TokenizeCaption(tokenizer),
                T.TruncateCaptionTokens(max_caption_length),
            ]
        )
        self.padding_idx = tokenizer.token_to_id("<unk>")

    def __len__(self):
        return len(self._dset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        # keys: {"image_id", "image", "captions"}
        instance = self._dset[idx]
        image_id, image, captions = (
            instance["image_id"],
            instance["image"],
            instance["captions"],
        )
        caption = random.choice(captions)

        # Transform image-caption pair and convert image from HWC to CHW format.
        # Pass in caption to image_transform due to paired horizontal flip.
        # Caption won't be tokenized/processed here.
        image_caption = self.image_transform(image=image, caption=caption)
        image, caption = image_caption["image"], image_caption["caption"]
        image = np.transpose(image, (2, 0, 1))

        caption_tokens = self.caption_transform(caption=caption)["caption"]
        return {
            "image_id": torch.tensor(image_id, dtype=torch.long),
            "image": torch.tensor(image, dtype=torch.float),
            "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long),
            "noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0),
            "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        # Pad `caption_tokens` and `masked_labels` up to this length.
        caption_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["caption_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["noitpac_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_id"] for d in data], dim=0),
            "image": torch.stack([d["image"] for d in data], dim=0),
            "caption_tokens": caption_tokens,
            "noitpac_tokens": noitpac_tokens,
            "caption_lengths": torch.stack([d["caption_lengths"] for d in data]),
        }


In [4]:
arch_books_dataset_extended = ArchCaptioningDatasetExtended(data_root='../datasets/ARCH', source="books", split="")
len(arch_books_dataset_extended)

TypeError: __init__() missing 1 required positional argument: 'tokenizer'