In [1]:
import argparse
import os
import pickle

import yaml
from tqdm import tqdm

from data_pipeline import DatasetsGenerator

In [2]:
with open('settings/model_testing.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [7]:
from pycocotools.coco import COCO
from torch.utils.data import DataLoader, Dataset
from torchvision.io import ImageReadMode, read_image
from data_pipeline.transform import TransformTesting, TransformTraining

class DatasetFromCocoAnnotations(Dataset):

    def __init__(self, coco: COCO, images_dir: str, 
                 transform: TransformTesting | TransformTraining) -> None:
        super().__init__()
        self.coco = coco
        self.images_dir = images_dir
        self.transform = transform
        self.idx_to_img = {i: coco.loadImgs(ids=[img])[0]
                           for i, img in enumerate(self.coco.imgs)}

    def __len__(self):
        return len(self.coco.imgs)

    def __getitem__(self, idx):
        '''
        Returns a sample of the dataset. If no transform is passed, the sample is a dictionary with:
            - image
            - landmarks:
                - id
                - category_id
                - center_point
                - size
                - bbox
                - area
                - image_id
            - original_image_size
        '''
        current_image_id = self.idx_to_img[idx]['id']
        img_name = os.path.join(self.images_dir, self.idx_to_img[idx]['file_name'])

        image = read_image(img_name, mode=ImageReadMode.RGB)

        # Add center point and size to annotations
        annotations_for_image = self.coco.imgToAnns[current_image_id]
        for a in annotations_for_image:
            # a["bbox"] is top [left x position, top left y position, width, height]
            center_point = (a["bbox"][0] + a["bbox"][2]/2, 
                            a["bbox"][1] + a["bbox"][3]/2)
            size = (a["bbox"][2], 
                    a["bbox"][3])
            a['center_point'] = center_point
            a['size'] = size
            if 'iscrowd' in a: del a['iscrowd'] # Pointless to keep it

        sample = {'image': image, 
                  "landmarks": annotations_for_image,
                  "img_name": img_name}

        if isinstance(self.transform, (TransformTraining, TransformTesting)):
            return self.transform(sample)       # sample, transformed_landmarks, original_sample
        else:
            return sample

In [8]:
train_base = COCO(config['paths']['train_base_annotations_path'])
val_base = COCO(config['paths']['val_base_annotations_path'])
test_base = COCO(config['paths']['test_base_annotations_path'])
images_dir = config['paths']['images_dir']

dataset_base_train, dataset_base_val, dataset_base_test = (
    DatasetFromCocoAnnotations(train_base, images_dir, TransformTraining(
        config,
        base_classes=list(train_base.cats),
        novel_classes=[]
    )),
    DatasetFromCocoAnnotations(val_base, images_dir, TransformTraining(
        config,
        base_classes=list(val_base.cats),
        novel_classes=[]
    )),
    DatasetFromCocoAnnotations(test_base, images_dir, TransformTesting(
        config,
        base_classes=list(test_base.cats),
        novel_classes=[]
    ))
)

loading annotations into memory...
Done (t=0.77s)
creating index...
index created!
loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
loading annotations into memory...
Done (t=0.27s)
creating index...
index created!


In [12]:
for i in tqdm(range(len(dataset_base_train))):
    try:
        result = dataset_base_train[i]
    except Exception as e:
        print(f"Error at index {i}: {e}")

 10%|█         | 17435/170553 [04:54<58:35, 43.56it/s]  

Error at index 17427: Unsupported color conversion request


 89%|████████▉ | 152107/170553 [45:17<06:14, 49.26it/s] 

Error at index 152104: Unsupported color conversion request


100%|██████████| 170553/170553 [50:28<00:00, 56.32it/s] 


In [13]:
for i in tqdm(range(len(dataset_base_val))):
    try:
        result = dataset_base_val[i]
    except Exception as e:
        print(f"Error at index {i}: {e}")

100%|██████████| 21216/21216 [11:51<00:00, 29.82it/s]


In [14]:
for i in tqdm(range(len(dataset_base_test))):
    try:
        result = dataset_base_test[i]
    except Exception as e:
        print(f"Error at index {i}: {e}")

100%|██████████| 21543/21543 [10:34<00:00, 33.94it/s]


In [None]:
os.path.join(dataset_base_train.images_dir, dataset_base_train.idx_to_img[17427]['file_name'])

'../data/train_val_images/Aves/Pyrocephalus rubinus/225f3aadbddd26da2cf4cc87e74e8ed3.jpg'

In [None]:
os.path.join(dataset_base_train.images_dir, dataset_base_train.idx_to_img[152104]['file_name'])

'../data/train_val_images/Aves/Sturnus vulgaris/17ce5d50647b3217a2e24ec523f81378.jpg'