In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image
#import imageio, skimage

import torch

from collections import Counter

## Example from VirTex

code from `arch-pre-training/virtex/data/datasets/coco_captions.py`

In [None]:
# code from virtex/data/datasets/coco_captions.py

from collections import defaultdict
import json
import os
from typing import Dict, List

import cv2
from torch.utils.data import Dataset

In [None]:
USE_CUSTOM = True

if not use_custom:
    from virtex.data.datasets.coco_captions import CocoCaptionsDataset
else:
    class CocoCaptionsDataset(Dataset):
        r"""
        A PyTorch dataset to read COCO Captions dataset and provide it completely
        unprocessed. This dataset is used by various task-specific datasets
        in :mod:`~virtex.data.datasets` module.

        Args:
            data_root: Path to the COCO dataset root directory.
            split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
        """

        def __init__(self, data_root: str, split: str):

            # Get paths to image directory and annotation file.
            image_dir = os.path.join(data_root, f"{split}2017")
            captions = json.load(
                open(os.path.join(data_root, "annotations", f"captions_{split}2017.json"))
            )
            # Collect list of captions for each image.
            captions_per_image: Dict[int, List[str]] = defaultdict(list)
            for ann in captions["annotations"]:
                captions_per_image[ann["image_id"]].append(ann["caption"])

            # Collect image file for each image (by its ID).
            image_filepaths: Dict[int, str] = {
                im["id"]: os.path.join(image_dir, im["file_name"])
                for im in captions["images"]
            }
            # Keep all annotations in memory. Make a list of tuples, each tuple
            # is ``(image_id, file_path, list[captions])``.
            self.instances = [
                (im_id, image_filepaths[im_id], captions_per_image[im_id])
                for im_id in captions_per_image.keys()
            ]

        def __len__(self):
            return len(self.instances)

        def __getitem__(self, idx: int):
            image_id, image_path, captions = self.instances[idx]

            # shape: (height, width, channels), dtype: uint8
            try:
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            except Exception as e:
                print(e)

            return {"image_id": image_id, "image": image, "captions": captions}

Check how the dataset object behaves

In [None]:
coco_dataset = CocoCaptionsDataset('../datasets/coco', split='train')

In [None]:
coco_dataset.instances

In [None]:
coco_dataset.__getitem__(0)

## PubMed Dataset

In [None]:
pubmed_set_dir = '../datasets/ARCH/pubmed_set'

In [None]:
os.listdir(pubmed_set_dir)

In [None]:
os.listdir(pubmed_set_dir +'/images')[:5]

In [None]:
with open(f'{pubmed_set_dir}/captions.json', 'r') as f:
    pubmed_captions = json.load(f)

In [None]:
pubmed_captions

Indices appear to be integers from 0 to 3308.

In [None]:
# Check for duplicates (slightly pointless since keys have to be unique anyway)
assert len(pubmed_captions.keys()) == len(set(pubmed_captions.keys()))

# check for the indices to be without gaps ['0', '1', '2', ..., '3308']
assert sorted(list(pubmed_captions.keys())) == sorted([str(i) for i in range(len(pubmed_captions.keys()))])

A simple check confirms that everything is in order.

In [None]:
class ArchPubmedCaptionsDataset(Dataset):
    r"""
    A PyTorch dataset to read ARCH Pubmed dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        split: Name of ARCH split to read. One of ``{"train", "val"}``.
    """

    def __init__(self, data_root: str, split: str=''):

#         TODO: change after splitting the caption files into train and validation
#         
#         # Get paths to image directory and annotation file.
#         image_dir = os.path.join(data_root, "pubmed_set/images", f"{split}")
#         captions = json.load(
#             open(os.path.join(data_root, "pubmed_set", f"captions_{split}.json"))
#         )
        
        # Get paths to image directory and annotation file.
        image_dir = os.path.join(data_root, "pubmed_set/images")
        captions = json.load(
            open(os.path.join(data_root, "pubmed_set", "captions.json"))
        )
        
        # Collect list of captions for each image.
        captions_per_image: Dict[int, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            captions_per_image[ann['uuid']].append(ann['caption'])
        #print(captions_per_image)

        # Collect image file for each image (by its ID).
        image_filepaths: Dict[int, str] = {
            ann["uuid"]: os.path.join(image_dir, f"{ann['uuid']}.jpg")
            for idx, ann in captions.items()
        }
        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(image_id, file_path, list[captions])``.
        self.instances = [
            (im_id, image_filepaths[im_id], captions_per_image[im_id])
            for im_id in captions_per_image.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_id, image_path, captions = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # cv2.imread loads images in BGR (blue, green, red) order

        return {"image_id": image_id, "image": image, "captions": captions}

In [None]:
arch_pubmed_dataset = ArchPubmedCaptionsDataset('../datasets/ARCH')

In [None]:
arch_pubmed_dataset.instances

In [None]:
test_instance = arch_pubmed_dataset.__getitem__(0)

print(test_instance['image_id'])
print(test_instance['image'].shape)


plt.imshow(test_instance['image'])

## Books Dataset

In [None]:
books_set_dir = '../datasets/ARCH/books_set'

In [None]:
os.listdir(books_set_dir)

In [None]:
os.listdir(books_set_dir +'/images')[:5]

In [None]:
with open(f'{books_set_dir}/captions.json', 'r') as f:
    books_captions = json.load(f)
    
books_captions

In [None]:
class ArchBooksCaptionsDataset(Dataset):
    r"""
    A PyTorch dataset to read ARCH Books dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        split: Name of ARCH split to read. One of ``{"train", "val"}``.
    """

    def __init__(self, data_root: str, split: str=''):

#         TODO: change after splitting the caption files into train and validation
#         
#         # Get paths to image directory and annotation file.
#         image_dir = os.path.join(data_root, "pubmed_set/images", f"{split}")
#         captions = json.load(
#             open(os.path.join(data_root, "pubmed_set", f"captions_{split}.json"))
#         )
        
        # Get paths to image directory and annotation file.
        image_dir = os.path.join(data_root, "books_set/images")
        captions = json.load(
            open(os.path.join(data_root, "books_set", "captions.json"))
        )
                
        # Collect list of captions for each figure.
        captions_per_figure: Dict[int, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            captions_per_figure[ann['figure_id']].append(ann['caption'])
        #print(captions_per_image)
        
        # Collect image file for each image (by its ID).
        image_filepaths: Dict[int, str] = {
            ann["uuid"]: os.path.join(image_dir, f"{ann['uuid']}.png")
            for idx, ann in captions.items()
        }
            
        # Collect list of images and image paths for each figure.
        images_per_figure: Dict[int, List[str]] = defaultdict(list)
        image_filepaths_per_figure: Dict[int, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            images_per_figure[ann['figure_id']].append(ann['uuid'])
            image_filepaths_per_figure[ann['figure_id']].append(image_filepaths[ann["uuid"]])
        #print(captions_per_image)
            
        
        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(figure_id, list[img_ids], list[img_file_paths], list[captions])``.
        self.instances = [
            (figure_id, images_per_figure[figure_id],
             image_filepaths_per_figure[figure_id], captions_per_figure[figure_id])
            for figure_id in captions_per_figure.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        figure_id, image_ids, image_paths, captions = self.instances[idx]
        
        images = []
        for image_path in image_paths:
            # shape: (height, width, channels), dtype: uint8
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images.append(image)
        
        return {"figure_id": figure_id, "image_ids": image_ids, "images": images, "captions": captions}

In [None]:
arch_books_dataset = ArchBooksCaptionsDataset('../datasets/ARCH')
arch_books_dataset.instances

In [None]:
test_instance = arch_books_dataset.__getitem__(9)
print(test_instance.keys())

print('figure_id:', test_instance['figure_id'])
print('image_ids:', test_instance['image_ids'])

for image in test_instance['images']:
    plt.imshow(image)
    plt.show()
    
print('captions:\n', '-'*80, '\n', '\n\n'.join(test_instance['captions']))