In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

from taco_dataset import CoCoDatasetForYOLO
from config import Config

cfg = Config()

train_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=int(cfg.IMAGE_SIZE * cfg.scale)),
        A.PadIfNeeded(
            min_height=int(cfg.IMAGE_SIZE * cfg.scale),
            min_width=int(cfg.IMAGE_SIZE * cfg.scale),
            border_mode=cv2.BORDER_CONSTANT,
        ),
        A.RandomCrop(width=cfg.IMAGE_SIZE, height=cfg.IMAGE_SIZE),
        A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
        A.OneOf(
            [
                A.ShiftScaleRotate(
                    rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
                ),
                A.Affine(shear=15, p=0.5, mode=cv2.BORDER_CONSTANT),
            ],
            p=1.0,
        ),
        A.HorizontalFlip(p=0.5),
        A.Blur(p=0.1),
        A.CLAHE(p=0.1),
        A.Posterize(p=0.1),
        A.ToGray(p=0.1),
        A.ChannelShuffle(p=0.05),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],

    # got this error
    # Expected y_min for bbox (0.5355392156862745, -0.00015318627450980338, 0.6523692810457516, 0.1803002450980392, 0) to be in the range [0.0, 1.0], got -0.00015318627450980338.
    # rounding issue :/
    # the insane solution to this problem (modifying the library code) https://github.com/albumentations-team/albumentations/issues/459#issuecomment-734454278
    bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
)

train_dataset = CoCoDatasetForYOLO(
    root=cfg.DATASET_PATH,
    annFile=cfg.anns_file_path,
    transform=train_transforms,
    S=cfg.SPLIT_SIZE, B=cfg.NUM_BOXES, C=cfg.NUM_CLASSES
)

train_percentage = 0.8




loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


In [3]:
indices = torch.randperm(len(train_dataset))
test_size = round(len(train_dataset) * (1 - train_percentage))

print(indices[:-test_size])
print(len(indices[:-test_size]))
print()
print(indices[-test_size:])
print(len(indices[-test_size:]))

tensor([ 885,  757,  734,  ..., 1001, 1481,  657])
1200

tensor([1110,  843,  459,  859, 1441,  610,  367, 1158,  914,  753, 1274, 1100,
        1136, 1004,  190,  678,  195,  428,  477,  575,  191,  426,  158,  518,
          50,  670, 1146, 1102,  910, 1485,  235,  541, 1337,  975,  967,  996,
        1360,   22, 1458, 1488,  144,   36,  916,  635,  258, 1256,   69,  789,
         504,   76, 1223,    5,  475, 1478,  441, 1265, 1436,  708,  418,  581,
         314, 1065,  177, 1350,  765,  693,  147, 1096,  481, 1277, 1407,  570,
         291,  207, 1371,  983,  345,  107, 1036, 1127,  860,  677,  183, 1116,
         528, 1206, 1200,  126,  936,   90, 1069,  883,  278,  168,  811,  222,
        1428, 1203,  440,  768,  869, 1252,  170,  468,   10,  707, 1020,  982,
        1406, 1182,  795, 1193, 1258, 1013, 1320,  301, 1139,  752,  306,  928,
         239,  663, 1168,  648, 1291,  866,  961, 1064,  870, 1317,   92,  938,
        1474, 1394,  265, 1166,  445,  288,  321,  508,  727,  

In [5]:
from utils import get_stratified_indices

train_indices, test_indices = get_stratified_indices(cfg.anns_file_path, len(train_dataset), train_percentage)

print(train_indices)
print(len(train_indices))
print()
print(test_indices)
print(len(test_indices))

tensor([1120,  839,  819,  ...,  546,  443,  480])
1199

tensor([1119,  817,  818, 1142,  870, 1096, 1103,  608, 1153,   84,  795,  408,
          69, 1253,  792,  793, 1199, 1307,  308,  709, 1073,   24, 1039,  712,
        1088, 1365, 1427,  156,  166,  862, 1191, 1247, 1135,  713,  410,  223,
        1190, 1430,  469,  224,   65,  135,  312,  428,  600,  483,  132,   50,
         614,  197,  331,  383, 1036,  262,  669, 1186, 1038,  432,  263, 1368,
        1050, 1465, 1242, 1330,  666,  365,  535,   36,  220,  566,  875,  228,
         730,  122,  523,  161, 1463,  849,  461,  957,  389,  830, 1321,  708,
         610,  260,  863, 1206,  169,  214,  729, 1363,  683,  121,  366,  314,
        1320,  107, 1294,  892, 1127,  941,  943,  991,  964,  939,  981,  778,
         788,  336, 1010,  886,  973,  907,  947, 1332,  915,  972,  133, 1003,
         636,  894,  978,  921, 1335,  893,  970,  984, 1196,  932,  899,  237,
          98,   99,  271, 1383,  477,  745,  268,  602,  173, 1