In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image
#import imageio, skimage

import torch

from collections import Counter

In [None]:
ANNOTATIONS_DIR = '../datasets/ARCH/annotations'
os.listdir(ANNOTATIONS_DIR)

## Example from VirTex

code from `arch-pre-training/virtex/data/datasets/coco_captions.py`

In [None]:
# code from virtex/data/datasets/coco_captions.py

from collections import defaultdict
import json
import os
from typing import Dict, List

import cv2
from torch.utils.data import Dataset

In [None]:
USE_CUSTOM = True

if not USE_CUSTOM:
    from virtex.data.datasets.coco_captions import CocoCaptionsDataset
else:
    class CocoCaptionsDataset(Dataset):
        r"""
        A PyTorch dataset to read COCO Captions dataset and provide it completely
        unprocessed. This dataset is used by various task-specific datasets
        in :mod:`~virtex.data.datasets` module.

        Args:
            data_root: Path to the COCO dataset root directory.
            split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
        """

        def __init__(self, data_root: str, split: str):

            # Get paths to image directory and annotation file.
            image_dir = os.path.join(data_root, f"{split}2017")
            captions = json.load(
                open(os.path.join(data_root, "annotations", f"captions_{split}2017.json"))
            )
            # Collect list of captions for each image.
            captions_per_image: Dict[int, List[str]] = defaultdict(list)
            for ann in captions["annotations"]:
                captions_per_image[ann["image_id"]].append(ann["caption"])

            # Collect image file for each image (by its ID).
            image_filepaths: Dict[int, str] = {
                im["id"]: os.path.join(image_dir, im["file_name"])
                for im in captions["images"]
            }
            # Keep all annotations in memory. Make a list of tuples, each tuple
            # is ``(image_id, file_path, list[captions])``.
            self.instances = [
                (im_id, image_filepaths[im_id], captions_per_image[im_id])
                for im_id in captions_per_image.keys()
            ]

        def __len__(self):
            return len(self.instances)

        def __getitem__(self, idx: int):
            image_id, image_path, captions = self.instances[idx]

            # shape: (height, width, channels), dtype: uint8
            try:
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            except Exception as e:
                print(e)

            return {"image_id": image_id, "image": image, "captions": captions}

In [None]:
from virtex.data.datasets.arch_captions import ArchCaptionsDatasetRaw

Check how the dataset object behaves

In [None]:
coco_dataset = CocoCaptionsDataset('../datasets/coco', split='train')

In [None]:
coco_dataset.instances

In [None]:
coco_dataset.__getitem__(0)

## PubMed Dataset

In [None]:
pubmed_set_dir = '../datasets/ARCH/pubmed_set'

In [None]:
os.listdir(pubmed_set_dir)

In [None]:
os.listdir(pubmed_set_dir +'/images')[:5]

In [None]:
## Old version

# class ArchPubmedCaptionsDataset(Dataset):
#     r"""
#     A PyTorch dataset to read ARCH Pubmed dataset and provide it completely
#     unprocessed. This dataset is used by various task-specific datasets
#     in :mod:`~virtex.data.datasets` module.

#     Args:
#         data_root: Path to the ARCH dataset root directory.
#         split: Name of ARCH split to read. One of ``{"train", "val"}``.
#     """

#     def __init__(self, data_root: str, split: str=''):

# #         TODO: change after splitting the caption files into train and validation
# #         
# #         # Get paths to image directory and annotation file.
# #         image_dir = os.path.join(data_root, "pubmed_set/images", f"{split}")
# #         captions = json.load(
# #             open(os.path.join(data_root, "pubmed_set", f"captions_{split}.json"))
# #         )
        
#         # Get paths to image directory and annotation file.
#         image_dir = os.path.join(data_root, "pubmed_set/images")
#         captions = json.load(
#             open(os.path.join(data_root, "pubmed_set", "captions.json"))
#         )
        
#         # Collect list of captions for each image.
#         captions_per_image: Dict[int, List[str]] = defaultdict(list)
#         for idx, ann in captions.items():
#             captions_per_image[ann['uuid']].append(ann['caption'])
#         #print(captions_per_image)

#         # Collect image file for each image (by its ID).
#         image_filepaths: Dict[int, str] = {
#             ann["uuid"]: os.path.join(image_dir, f"{ann['uuid']}.jpg")
#             for idx, ann in captions.items()
#         }
#         # Keep all annotations in memory. Make a list of tuples, each tuple
#         # is ``(image_id, file_path, list[captions])``.
#         self.instances = [
#             (im_id, image_filepaths[im_id], captions_per_image[im_id])
#             for im_id in captions_per_image.keys()
#         ]

#     def __len__(self):
#         return len(self.instances)

#     def __getitem__(self, idx: int):
#         image_id, image_path, captions = self.instances[idx]

#         # shape: (height, width, channels), dtype: uint8
#         image = cv2.imread(image_path)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # cv2.imread loads images in BGR (blue, green, red) order

#         return {"image_id": image_id, "image": image, "captions": captions}
    
    
# # test
# arch_pubmed_dataset = ArchPubmedCaptionsDataset('../datasets/ARCH')
# arch_pubmed_dataset.instances

# test_instance = arch_pubmed_dataset.__getitem__(0)

# print(test_instance['image_id'])
# print(test_instance['image'].shape)


# plt.imshow(test_instance['image'])

In [None]:
class ArchPubmedCaptionsDataset(Dataset):
    r"""
    A PyTorch dataset to read ARCH Pubmed dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        split: Name of ARCH split to read. One of ``{"train", "val", "all"}``.
    """

    def __init__(self, data_root: str, split: str=''):

        # Get path to image directory and record the extensions
        image_dir = os.path.join(data_root, "pubmed_set/images")
        uuids_to_extensions = {
            file_name.split('.')[0]: file_name.split('.')[1]
            for file_name in os.listdir(image_dir)
        }
        
        # Get path to the annotation file
        captions = json.load(
            open(os.path.join(data_root, "annotations", f"captions_{split}.json"))
        )
        
        # Collect list of uuids and file paths for each caption
        captions_to_uuids: Dict[str, List[str]] = defaultdict(list)
        captions_to_image_filepaths: Dict[str, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            if ann['uuid'] in uuids_to_extensions.keys():
                # uuids_to_extensions contains only image uuids from the image dir
                # this means that only uuids with exisitng images will be added
                captions_to_uuids[ann['caption']].append(ann['uuid'])
                captions_to_image_filepaths[ann['caption']].append(ann['path'])
        #print(captions_per_image)

        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(list[image_id], list[file_path], captions)``.
        self.instances = [
            (captions_to_uuids[caption], captions_to_image_filepaths[caption], caption)
            for caption in captions_to_image_filepaths.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_ids, image_paths, caption = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        images = [cv2.imread(image_path) for image_path in image_paths]
        # cv2.imread loads images in BGR (blue, green, red) order
        images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images]

        return {"image_ids": image_ids, "images": images, "caption": caption}

In [None]:
arch_pubmed_dataset = ArchPubmedCaptionsDataset(data_root='../datasets/ARCH', split="all")

print(len(arch_pubmed_dataset.instances))
arch_pubmed_dataset.instances

In [None]:
test_instance = arch_pubmed_dataset.__getitem__(0)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

In [None]:
test_instance = arch_pubmed_dataset.__getitem__(19)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

Check that all the images recorded in the instances exist.

In [None]:
exist_status_list = [[os.path.exists(img_path) for img_path in img_paths] for img_ids, img_paths, img_caption in arch_pubmed_dataset.instances]

# 3309 unique uuids, 3309 images, 3309 entries in the captions path
sum([sum(sublist) for sublist in exist_status_list])

In [None]:
exist_status_list_compressed = [all(sublist) for sublist in exist_status_list]
# All images exit where they should, Same as the number of unique captions
all(exist_status_list_compressed), sum(exist_status_list_compressed)

## Books Dataset

In [None]:
books_set_dir = '../datasets/ARCH/books_set'

In [None]:
os.listdir(books_set_dir)

In [None]:
os.listdir(books_set_dir +'/images')[:5]

In [None]:
## Old Version

# class ArchBooksCaptionsDataset(Dataset):
#     r"""
#     A PyTorch dataset to read ARCH Books dataset and provide it completely
#     unprocessed. This dataset is used by various task-specific datasets
#     in :mod:`~virtex.data.datasets` module.

#     Args:
#         data_root: Path to the ARCH dataset root directory.
#         split: Name of ARCH split to read. One of ``{"train", "val"}``.
#     """

#     def __init__(self, data_root: str, split: str=''):

# #         TODO: change after splitting the caption files into train and validation
# #         
# #         # Get paths to image directory and annotation file.
# #         image_dir = os.path.join(data_root, "pubmed_set/images", f"{split}")
# #         captions = json.load(
# #             open(os.path.join(data_root, "pubmed_set", f"captions_{split}.json"))
# #         )
        
#         # Get paths to image directory and annotation file.
#         image_dir = os.path.join(data_root, "books_set/images")
#         captions = json.load(
#             open(os.path.join(data_root, "books_set", "captions.json"))
#         )
                
#         # Collect list of captions for each figure.
#         captions_per_figure: Dict[int, List[str]] = defaultdict(list)
#         for idx, ann in captions.items():
#             captions_per_figure[ann['figure_id']].append(ann['caption'])
#         #print(captions_per_image)
        
#         # Collect image file for each image (by its ID).
#         image_filepaths: Dict[int, str] = {
#             ann["uuid"]: os.path.join(image_dir, f"{ann['uuid']}.png")
#             for idx, ann in captions.items()
#         }
            
#         # Collect list of images and image paths for each figure.
#         images_per_figure: Dict[int, List[str]] = defaultdict(list)
#         image_filepaths_per_figure: Dict[int, List[str]] = defaultdict(list)
#         for idx, ann in captions.items():
#             images_per_figure[ann['figure_id']].append(ann['uuid'])
#             image_filepaths_per_figure[ann['figure_id']].append(image_filepaths[ann["uuid"]])
#         #print(captions_per_image)
            
        
#         # Keep all annotations in memory. Make a list of tuples, each tuple
#         # is ``(figure_id, list[img_ids], list[img_file_paths], list[captions])``.
#         self.instances = [
#             (figure_id, images_per_figure[figure_id],
#              image_filepaths_per_figure[figure_id], captions_per_figure[figure_id])
#             for figure_id in captions_per_figure.keys()
#         ]

#     def __len__(self):
#         return len(self.instances)

#     def __getitem__(self, idx: int):
#         figure_id, image_ids, image_paths, captions = self.instances[idx]
        
#         images = []
#         for image_path in image_paths:
#             # shape: (height, width, channels), dtype: uint8
#             image = cv2.imread(image_path)
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#             images.append(image)
        
#         return {"figure_id": figure_id, "image_ids": image_ids, "images": images, "captions": captions}


## test

# arch_books_dataset = ArchBooksCaptionsDataset('../datasets/ARCH')
# arch_books_dataset.instances

# test_instance = arch_books_dataset.__getitem__(9)
# print(test_instance.keys())

# print('figure_id:', test_instance['figure_id'])
# print('image_ids:', test_instance['image_ids'])

# for image in test_instance['images']:
#     plt.imshow(image)
#     plt.show()
    
# print('captions:\n', '-'*80, '\n', '\n\n'.join(test_instance['captions']))

In [None]:
class ArchBooksCaptionsDataset(Dataset):
    r"""
    A PyTorch dataset to read ARCH Books dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the ARCH dataset root directory.
        split: Name of ARCH split to read. One of ``{"train", "val", "all"}``.
    """

    def __init__(self, data_root: str, split: str=''):

        # Get path to image directory and record the extensions
        image_dir = os.path.join(data_root, "books_set/images")
        uuids_to_extensions = {
            file_name.split('.')[0]: file_name.split('.')[1]
            for file_name in os.listdir(image_dir)
        }
        #print(uuids_to_extensions)
        
        # Get path to the annotation file
        captions = json.load(
            open(os.path.join(data_root, "annotations", f"captions_{split}.json"))
        )
        #print(captions)
        
        # Collect list of uuids and file paths for each caption
        captions_to_uuids: Dict[str, List[str]] = defaultdict(list)
        captions_to_image_filepaths: Dict[str, List[str]] = defaultdict(list)
        for idx, ann in captions.items():
            if ann['uuid'] in uuids_to_extensions.keys():
                
                # uuids_to_extensions contains only image uuids from the image dir
                # this means that only uuids with exisitng images will be added
                captions_to_uuids[ann['caption']].append(ann['uuid'])
                captions_to_image_filepaths[ann['caption']].append(ann['path'])
        #print(captions_to_uuids)

        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(list[image_id], list[file_path], captions)``.
        self.instances = [
            (captions_to_uuids[caption], captions_to_image_filepaths[caption], caption)
            for caption in captions_to_image_filepaths.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_ids, image_paths, caption = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        images = [cv2.imread(image_path) for image_path in image_paths]
        # cv2.imread loads images in BGR (blue, green, red) order
        images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images]

        return {"image_ids": image_ids, "images": images, "caption": caption}

In [None]:
arch_books_dataset = ArchBooksCaptionsDataset(data_root='../datasets/ARCH', split="all")
print(len(arch_books_dataset.instances))
arch_books_dataset.instances

In [None]:
test_instance = arch_books_dataset.__getitem__(0)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

In [None]:
test_instance = arch_books_dataset.__getitem__(5)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

In [None]:
test_instance = arch_books_dataset.__getitem__(18)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

## Unified Dataset Class for ARCH Dataset

Once happy, I moved the dataset class to `virtex/data/datasets/arch_captions.py`

In [None]:
from virtex.data.datasets.arch_captions import ArchCaptionsDatasetRaw

In [None]:
help(ArchCaptionsDatasetRaw)

### Books Subset

In [None]:
arch_books_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                            split="all",
                                            source="books")
len(arch_books_dataset.instances)

In [None]:
test_instance = arch_books_dataset.__getitem__(0)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

### PubMed Subset

In [None]:
arch_pubmed_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                             split="all",
                                             source="pubmed")
len(arch_pubmed_dataset.instances)

In [None]:
test_instance = arch_pubmed_dataset.__getitem__(19)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

### Both Sets Together

In [None]:
arch_dataset = ArchCaptionsDatasetRaw(data_root='../datasets/ARCH',
                                      split="all",
                                      source="both")
len(arch_dataset)

In [None]:
test_instance = arch_dataset.__getitem__(0) # same as the 0th example in the books dataset

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()

In [None]:
# same as the 19th example in the pubmed dataset 
# there are 3210 examples in the books set
# pubmed set is concatenated to it -> 3239 gives the 19th example in the pubmed set
test_instance = arch_dataset.__getitem__(3229)
# test_instance = arch_dataset.__getitem__(1298)

print(test_instance['caption'], '\n')
print("Total images:", len(test_instance['images']), '\n')

for i, img_id in enumerate(test_instance['image_ids']):
    print(img_id)
    img=test_instance['images'][i]
    plt.imshow(img)
    plt.show()