In [1]:
import torch
from torch import nn, optim
from torch.nn import functional as F
import os

In [9]:
from torchvision import datasets, transforms

In [30]:
import torch.utils.data as data
import xml.etree.ElementTree as ET
from PIL import Image
import collections

In [38]:
VOC_CLASSES = (  # always index 0
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')

In [83]:
class VOCDetection(datasets.VisionDataset):
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
    Args:
        root (string): Root directory of the VOC Dataset.
        year (string, optional): The dataset year, supports years 2007 to 2012.
        image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
            (default: alphabetic indexing of VOC's 20 classes).
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, required): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
    """

    def __init__(
            self,
            root: str,
            year: str = "2012",
            image_set: str = "train",
            transform= None,
            target_transform= None,
            transforms = None,
    ):
        super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
        self.year = year
        voc_root = os.path.join(self.root, "VOC2012")
        image_dir = os.path.join(voc_root, 'JPEGImages')
        annotation_dir = os.path.join(voc_root, 'Annotations')
        
        if image_set == "train":
            self.train = True
        else:
            self.train = False

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.')

        splits_dir = os.path.join(voc_root, 'ImageSets/Main')

        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))
    
    def transform_annotation_to_bbox(self, target):
        class_to_ind = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        
        res = []
        classes = []
        for obj in target["annotation"]["object"]:
            difficult = int(obj['difficult']) == 1
            name = obj["name"].lower().strip()
            bbox = obj['bndbox']
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox[pt]) - 1
#                 cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = class_to_ind[name]
            classes.append(label_idx)
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]

        return res, classes

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
        ])
        img = Image.open(self.images[index]).convert('RGB')
        target = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())

        img = trans(img)
            
        bndboxes, clses = self.transform_annotation_to_bbox(target)

        return img, bndboxes, clses

    def __len__(self) -> int:
        return len(self.images)

    def parse_voc_xml(self, node):
        voc_dict = {}
        children = list(node)
        if children:
            def_dic= collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == 'annotation':
                def_dic['object'] = [def_dic['object']]
            voc_dict = {
                node.tag:
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

In [84]:
train_dataset = VOCDetection("VOCdevkit")
# train_dataloader = data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_w)
val_dataset = VOCDetection("VOCdevkit",image_set="val")

In [86]:
train_dataset[150]

(tensor([[[-1.1247, -1.3644, -1.3130,  ..., -1.2959, -1.2103, -1.3302],
          [-1.4672, -1.5185, -1.1589,  ..., -1.3644, -1.2445, -1.1247],
          [-1.4158, -1.2103, -0.7993,  ..., -1.2274, -1.3815, -1.5528],
          ...,
          [-1.1932, -1.6555, -1.6555,  ..., -0.3883,  0.0741, -0.3369],
          [-1.5699, -1.6384, -1.6042,  ..., -0.1314,  0.7248,  0.4337],
          [-1.5699, -1.6213, -1.7069,  ...,  0.6221, -0.0801, -0.3027]],
 
         [[-0.4951, -0.7752, -0.8102,  ..., -0.7227, -0.5826, -0.7052],
          [-0.8803, -0.9678, -0.7052,  ..., -0.8102, -0.6877, -0.5651],
          [-0.8102, -0.6352, -0.2850,  ..., -0.6702, -0.8277, -1.0028],
          ...,
          [-0.8452, -1.2304, -1.2479,  ..., -0.4601,  0.0476, -0.3375],
          [-1.2479, -1.1779, -1.1954,  ..., -0.2500,  0.6954,  0.4328],
          [-1.1954, -1.1253, -1.2479,  ...,  0.6429, -0.0224, -0.1975]],
 
         [[-0.9330, -1.2119, -1.0898,  ..., -1.0376, -0.9504, -1.0724],
          [-1.3513, -1.3513,