In [66]:
import io
import os
import itertools
import weakref
from typing import Any, Dict, List, Set
import logging
from collections import OrderedDict
import PIL

import torch
from fvcore.nn.precise_bn import get_bn_modules


from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
import detectron2.utils.comm as comm
from detectron2.utils.logger import setup_logger
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_train_loader
from detectron2.evaluation import COCOEvaluator, LVISEvaluator, verify_results
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.modeling import build_model

# from detectron2.data import transforms as T


from diffusioninst import (
    DiffusionInstDatasetMapper,
    add_diffusioninst_config,
    DiffusionInstWithTTA,
)
from diffusioninst.util.model_ema import (
    add_model_ema_configs,
    may_build_model_ema,
    may_get_ema_checkpointer,
    EMAHook,
    apply_model_ema_and_restore,
    EMADetectionCheckpointer,
)
from torch.utils.data import Dataset
from datadings.reader import MsgpackReader
from torchvision.transforms import Compose, Resize

In [72]:
from pathlib import Path
import numpy as np
from detectron2.structures import BoxMode


class TorchMsgpackDataset(Dataset):
    def __init__(
        self,
        root_path="/home/raushan/dataset/",
        split="train",
        transforms=None,
        image_size=256,
    ):
        self.root_path = Path(root_path)
        self.crop_gen = None
        self.tfm_gens = []
        self.data_reader = MsgpackReader(self.root_path / f"publaynet-{split}.msgpack")
        if transforms is not None:
            self.transforms = transforms
        else:
            self.transforms = Compose([Resize(image_size)])

    def __getitem__(self, index):
        sample = self.data_reader[index]
        # sample["objects"]["bbox_mode"] = BoxMode.XYXY_ABS
        sample["image"] = PIL.Image.open(io.BytesIO(sample["image"]["bytes"]))
        sample["image"] = self.transforms(sample["image"])
        ## convert pil image to numpy array
        sample["image"] = np.array(sample["image"])
        ## change shape
        image = sample["image"]
        image_shape = sample["image"].shape
        print(len(sample["objects"]))
        print("iamge shape", image_shape)
        annos = sample["objects"]
        for obj in annos:
            obj["bbox_mode"] = BoxMode.XYXY_ABS

        # if np.random.rand() > 0.5:
        #     image, transforms = T.apply_transform_gens(self.tfm_gens, image)
        # else:
        #     image, transforms = T.apply_transform_gens(
        #         self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
        #     )
        ## change this properly.
        # annos = [
        #     utils.transform_instance_annotations(obj, self.transforms, image_shape)
        #     for obj in sample.pop("objects")
        #     if obj.get("iscrowd", 0) == 0
        # ]
        instances = utils.annotations_to_instances(
            annos, (image_shape[0], image_shape[1]), mask_format="bitmask"
        )
        sample["instances"] = utils.filter_empty_instances(instances)
        return sample

    def __len__(self):
        return len(self.data_reader)


dataset = TorchMsgpackDataset(
    root_path="/home/raushan/dataset/",
    split="train",
    transforms=None,
    image_size=256,
)
# print(dataset[0])

11
iamge shape (256, 256, 3)
{'image_id': 428223, 'image_width': 601, 'image_height': 792, 'image_file_path': '/ds-sds//documents/publaynet/publaynet/test/PMC3382231_00001.jpg', 'image': array([[[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       ...,

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
  

In [27]:
from datetime import datetime
import random
import numpy as np


def seed_all_rng(seed=None):
    """
    Set the random seed for the RNG in torch, numpy and python.

    Args:
        seed (int): if None, will use a strong random seed.
    """
    if seed is None:
        seed = (
            os.getpid()
            + int(datetime.now().strftime("%S%f"))
            + int.from_bytes(os.urandom(2), "big")
        )
        logger = logging.getLogger(__name__)
        logger.info("Using a generated random seed {}".format(seed))
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def trivial_batch_collator(batch):
    """
    A batch collator that does nothing.
    """
    return batch


def worker_init_reset_seed(worker_id):
    initial_seed = torch.initial_seed() % 2**31
    seed_all_rng(initial_seed + worker_id)

In [28]:
collate_fn = None

from pathlib import Path


def build_train_loader(cls, cfg):
    from torch.utils.data import DataLoader

    # read msgpack
    dataset = TorchMsgpackDataset(
        root_path="/home/raushan/dataset/",
        split="train",
        transforms=None,
        image_size=256,
    )

    return DataLoader(
        dataset,
        batch_size=8,
        drop_last=True,
        num_workers=2,
        collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
        worker_init_fn=worker_init_reset_seed,
    )

In [30]:
def main():
    dataset = TorchMsgpackDataset(
        root_path="/home/raushan/dataset/",
        split="test",
        transforms=None,
        image_size=256,
    )
    dataset[0]


if __name__ == "__main__":
    main()

KeyError: 'images'