In [None]:
## import modules

import os

import torch
from torchvision.transforms import v2

import matplotlib.pyplot as plt
import importlib
import numpy as np

In [None]:
# import required modules

import Modules.COCODataset as COCODataset
import Modules.ImageStackTransform as ImageStackTransform

In [None]:
## input configurations

src_annotation_file_path = r"./data/annotations_trainval2014/annotations/instances_train2014.json"
category_names = [r"cat"]

In [None]:
## test data set and transforms

importlib.reload(COCODataset)
importlib.reload(ImageStackTransform)

# create common transforms
common_transforms = v2.Compose([
    ImageStackTransform.RandomCrop(
        size = (256,256), 
        pad_if_needed = True, 
        padding_mode = "reflect",
    ),
    ImageStackTransform.RandomHorizontalFlip(p = 0.5),
    ImageStackTransform.RandomVerticalFlip(p = 0.5),
])

# create coco segmentation dataset
coco_dataset = COCODataset.COCOSegDataset(
    annotation_file_path = src_annotation_file_path,
    category_names = category_names,
    common_transform = common_transforms,
    color_categories = False,
    split_segmentations = False,
)

In [None]:
## check data and label
check_idx = 0

check_data, check_label = coco_dataset[check_idx]
check_data = check_data.numpy()
check_label = check_label.numpy()

plt.figure()

plt.subplot(1,2,1)
plt.imshow(np.rollaxis(check_data,0,3))
plt.xticks([])
plt.yticks([])
plt.title("Data")

plt.subplot(1,2,2)
plt.imshow(np.rollaxis(check_label,0,3), cmap = "tab20c")
# plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Target")

plt.tight_layout()
plt.show()