In [1]:
import torch
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np

BATCH_SIZE = 1

In [2]:
VOC_CLASSES = (  # always index 0
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')


In [136]:
class VOCAnnotationAnalyzer():
    """
    deal with annotation data (dict)
    
    Arguments:
        cls_to_idx (dict, optional): dictionary lookup of classnames -> indexes
            (default: alphabetic indexing of VOC's 20 classes)
        keep_difficult (bool, optional): keep difficult instances or not
            (default: False)
        height (int): height
        width (int): width
    """
    def __init__(self, cls_to_idx=None, keep_difficult=False):
        self.cls_to_idx = cls_to_idx or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        self.keep_difficult = keep_difficult
        
    def __call__(self, annotation: dict):
        w = int(annotation['size']['width'])
        h = int(annotation['size']['height'])
        # if img only contains one gt that annotation['object'] is just a dict, not a list
        objects = [annotation['object']] if type(annotation['object']) != list else annotation['object']
        res = [] # [xmin, ymin, xmax, ymax, label]
        for box in objects:
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            difficult = int(box['difficult'])
            if not self.keep_difficult and difficult:
                continue
            name = box['name']
            bnd = []
            for pt in pts:
                bnd.append(int(box['bndbox'][pt]))
            bnd.append(self.cls_to_idx[name])
            res.append(bnd)
            
        return res
            
        

In [137]:
class VOCDectection(data.Dataset):
    def __init__(self, root, year, image_set,
                 transform=None, 
                 target_transform=VOCAnnotationAnalyzer(),
                 dataset_name='VOC07_12'):
        super(VOCDectection, self).__init__()
        self.datas = datasets.VOCDetection(root, str(year), image_set, download=False)
        self.image_set = image_set
        self.transform = transform
        self.name = dataset_name
        self.target_transform = target_transform # use for annotation
        self.longer_sides = [480, 576, 688, 864, 1200]
        
    def __getitem__(self, index):
        pic, anno = self.datas[index]
        img = pic
        gt = anno
        
        if self.target_transform:
            gt = self.target_transform(gt["annotation"])
        
        h, w = img.size
        if "train" in self.image_set:
            if self.transform is None:
                # follow by paper: randomly horiztontal flip and randomly resize
                if np.random.random() > 0.5: # then flip

                    fliper = transforms.RandomHorizontalFlip(1)
                    img = fliper(img)
                    for box in gt:
                        box[0] = w - box[2]
                        box[2] = w - box[0]
                # then resize
                max_side = self.longer_sides[np.random.randint(5)]
                if (w > h):
                    resizer = transforms.Resize((int(max_side*h/w), max_side))
                    ratio = max_side/w
                else: # h >= w
                    resizer = transforms.Resize((max_side, int(max_side*w/h)))
                    ratio = max_side/h
                img = resizer(img)
                for box in gt:
                    box[0] = int(ratio * box[0])
                    box[1] = int(ratio * box[1])
                    box[2] = int(ratio * box[2])
                    box[3] = int(ratio * box[3])
            else:
                raise NotImplementedError("This dataset can only be compatible with the paper's implementation")
            totensor = transforms.ToTensor()
            img = totensor(img)
            gt = np.array(gt)
        
        if "test" in self.image_set:
            pass
        return img, gt

    def __len__(self):
        return len(self.datas)

In [138]:
voc = VOCDectection("~/data/", 2007, 'trainval')

In [139]:
train_loader = data.DataLoader(voc, 1, shuffle=False)
for img, gt in train_loader:
    print(img.size())
    print(gt)

torch.Size([1, 3, 576, 432])
tensor([[[302, 243, 373, 390,   8],
         [190, 304, 291, 428,   8],
         [277, 223, 339, 344,   8]]])
torch.Size([1, 3, 576, 383])
tensor([[[-192,   57,  576,  380,    6]]])
torch.Size([1, 3, 688, 516])
tensor([[[ 94, 236, 371, 454,  12],
         [206, 194, 315, 390,  14],
         [392, 276, 449, 455,  14],
         [355, 272, 408, 452,  14]]])
torch.Size([1, 3, 480, 319])
tensor([[[149,  93, 336, 259,   6]]])
torch.Size([1, 3, 320, 480])
tensor([[[ 88,  69, 292, 454,   1]]])
torch.Size([1, 3, 576, 436])
tensor([[[102,  74, 334, 238,  14],
         [-46,  93, 483, 403,  12]]])
torch.Size([1, 3, 480, 360])
tensor([[[221,  84, 463, 245,   7],
         [ 10, 108, 255, 248,   7]]])
torch.Size([1, 3, 432, 576])
tensor([[[ 38, 170, 427, 479,   6]]])
torch.Size([1, 3, 322, 480])
tensor([[[  0, 225, 174, 372,  11],
         [201,  34, 322, 462,  14],
         [ 44,  78, 163, 350,  14],
         [ 10, 173, 136, 402,  14]]])
torch.Size([1, 3, 577, 864])
ten

torch.Size([1, 3, 480, 319])
tensor([[[ 56,  72, 352, 255,   3]]])
torch.Size([1, 3, 688, 516])
tensor([[[ 56, 100, 524, 484,  17],
         [594, 141, 682, 198,   6]]])
torch.Size([1, 3, 576, 432])
tensor([[[124,  95, 307, 238,   2]]])
torch.Size([1, 3, 1200, 799])
tensor([[[ 112,  115,  686,  297,   14],
         [-396,   38, 1195,  799,    6]]])
torch.Size([1, 3, 1200, 936])
tensor([[[ 249,   81, 1070,  936,    6],
         [ 163,  468,  290,  691,    6]]])
torch.Size([1, 3, 576, 432])
tensor([[[352, 261, 437, 344,  19],
         [225, 164, 355, 425,  14]]])
torch.Size([1, 3, 688, 594])
tensor([[[ 71,  30, 423, 451,  14],
         [ 35, 148, 627, 544,  13]]])
torch.Size([1, 3, 1200, 900])
tensor([[[  14,  369,  885,  835,   13],
         [-170,  117, 1070,  888,   14]]])
torch.Size([1, 3, 1200, 926])
tensor([[[ 213,  129, 1200,  926,   14]]])
torch.Size([1, 3, 1200, 799])
tensor([[[  55,   67, 1152,  681,    6],
         [ 244,    7,  386,  120,   14],
         [1053,   55, 1200,  4

tensor([[[ 67, 136, 463, 317,   6]]])
torch.Size([1, 3, 576, 432])
tensor([[[380,  12, 523, 423,   2],
         [ 63, 107, 192, 427,   2]]])
torch.Size([1, 3, 864, 601])
tensor([[[-262,  216,  864,  312,    6],
         [-191,  207,  793,  383,    3]]])
torch.Size([1, 3, 576, 446])
tensor([[[207,  36, 390, 413,  14]]])
torch.Size([1, 3, 480, 360])
tensor([[[-98, 216, 458, 326,   6],
         [ 56, 220, 303, 268,   6],
         [276, 231,  83, 360,   6],
         [183, 234, 176, 352,  14],
         [155, 262, 204, 352,  13]]])
torch.Size([1, 3, 576, 434])
tensor([[[ 17,  44, 576, 322,   6]]])
torch.Size([1, 3, 432, 576])
tensor([[[ 94, 110, 278, 274,  19]]])
torch.Size([1, 3, 799, 1200])
tensor([[[ 489,  556,  710, 1200,    1],
         [ 424,  189,  775, 1137,   14]]])
torch.Size([1, 3, 688, 242])
tensor([[[  -8,   49,  250,  118,    9],
         [-173,   63,  415,  119,    9],
         [-280,   57,  522,  119,    9]]])
torch.Size([1, 3, 864, 648])
tensor([[[  3,   5, 864, 648,  15],
 

torch.Size([1, 3, 864, 513])
tensor([[[  1,  25, 794, 513,  11]]])
torch.Size([1, 3, 576, 384])
tensor([[[-191,   99,  576,  382,   10]]])
torch.Size([1, 3, 576, 432])
tensor([[[298,  77, 542, 428,   8]]])
torch.Size([1, 3, 576, 383])
tensor([[[ 10, 115, 372, 266,   6]]])
torch.Size([1, 3, 688, 516])
tensor([[[-123,    4,  639,  516,    9]]])
torch.Size([1, 3, 864, 302])
tensor([[[338, 133, 520, 285,  13],
         [381,  81, 489, 278,  14]]])
torch.Size([1, 3, 360, 480])
tensor([[[272, 241, 207, 301,   9],
         [248, 193, 231, 283,   9]]])
torch.Size([1, 3, 575, 864])
tensor([[[489, 247, 374, 864,  14],
         [293, 330, 570, 864,  14]]])
torch.Size([1, 3, 576, 432])
tensor([[[ 18,  32, 523, 362,   6]]])
torch.Size([1, 3, 688, 516])
tensor([[[-55, 122, 571, 460,  12],
         [277, 181, 238, 316,  14]]])
torch.Size([1, 3, 480, 360])
tensor([[[-70, 150, 430, 344,   6],
         [264,  58,  95, 149,   3]]])
torch.Size([1, 3, 864, 648])
tensor([[[ 29, 190, 618, 644,  14]]])
torch.

KeyboardInterrupt: 

In [51]:
next(enumerate(train_loader))

0
0
1
0
1


(0, [tensor([[[[0.0431, 0.0471, 0.0510,  ..., 0.6392, 0.6353, 0.6353],
            [0.0431, 0.0471, 0.0510,  ..., 0.6392, 0.6353, 0.6353],
            [0.0353, 0.0392, 0.0471,  ..., 0.6353, 0.6353, 0.6353],
            ...,
            [0.1804, 0.1922, 0.2078,  ..., 0.2667, 0.2706, 0.2706],
            [0.1922, 0.2078, 0.2353,  ..., 0.2667, 0.2706, 0.2706],
            [0.1922, 0.2078, 0.2353,  ..., 0.2667, 0.2706, 0.2706]],
  
           [[0.0353, 0.0392, 0.0510,  ..., 0.7294, 0.7255, 0.7255],
            [0.0353, 0.0392, 0.0510,  ..., 0.7294, 0.7255, 0.7255],
            [0.0314, 0.0353, 0.0471,  ..., 0.7333, 0.7294, 0.7294],
            ...,
            [0.0549, 0.0588, 0.0667,  ..., 0.3255, 0.3294, 0.3294],
            [0.0667, 0.0706, 0.0824,  ..., 0.3255, 0.3294, 0.3294],
            [0.0667, 0.0706, 0.0824,  ..., 0.3255, 0.3294, 0.3294]],
  
           [[0.0392, 0.0431, 0.0510,  ..., 0.7529, 0.7490, 0.7490],
            [0.0392, 0.0431, 0.0510,  ..., 0.7529, 0.7490, 0.7490],
   