In [31]:
import glob
import hashlib
import json
import math
import os
import random
import copy
import shutil
import time
from itertools import repeat
import seaborn as sn
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pathlib import Path
from scipy import ndimage
import cv2
import logging
import numpy as np
import torch
import torch.nn.functional as F
from PIL import ExifTags, Image, ImageOps, ImageDraw, ImageFont
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import xml.etree.ElementTree as ET
from typing import Dict, Any
import collections
import transforms as T
from tool.utils import *
from engine import train_one_epoch, evaluate
from metric import *

"""
    data util
"""
class load_voc():
    def __init__(self, root_dir, set_name='VOC2007_test'):
        self.root_dir = root_dir
        self.set_name = set_name
        
        self.classes = [
                        "aeroplane",
                        "bicycle",
                        "bird",
                        "boat",
                        "bottle",
                        "bus",
                        "car",
                        "cat",
                        "chair",
                        "cow",
                        "diningtable",
                        "dog",
                        "horse",
                        "motorbike",
                        "person",
                        "pottedplant",
                        "sheep",
                        "sofa",
                        "train",
                        "tvmonitor"
                    ]
        
        # preparing a data list
        #dataset = 'VOC2007' # or VOC2007_test, VOC2012, VOC2007+2012
        #path = '/nasdata2/khj/objectdetection/segmentation/VOCdevkit/'
        if self.set_name in ['VOC2007', 'VOC2007_test', 'VOC2012']:

            img_base_path = os.path.join(self.root_dir,self.set_name,'JPEGImages')

            self.img_files = glob.glob(os.path.join(img_base_path,'*.jpg'))
            self.label_files = [ i.replace('JPEGImages','Annotations').replace('jpg','xml') for i in self.img_files ]

        elif self.set_name == 'VOC2007+2012':

            img_base_path_2007, img_base_path_2012 = os.path.join(self.root_dir,'VOC2007','JPEGImages'), os.path.join(self.root_dir,'VOC2012','JPEGImages')

            img_files_2007, img_files_2012 = glob.glob(os.path.join(img_base_path_2007,'*.jpg')), glob.glob(os.path.join(img_base_path_2012,'*.jpg'))
            label_files_2007 = [ i.replace('JPEGImages','Annotations').replace('jpg','xml') for i in img_files_2007 ]
            label_files_2012 = [ i.replace('JPEGImages','Annotations').replace('jpg','xml') for i in img_files_2012 ]

            self.img_files, self.label_files = img_files_2007 + img_files_2012, label_files_2007 + label_files_2012

    def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]: # xml 파일을 dictionary로 반환
        voc_dict: Dict[str, Any] = {}
        children = list(node)
        if children:
            def_dic: Dict[str, Any] = collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == "annotation":
                def_dic["object"] = [def_dic["object"]]
            voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict
    
    def get_annotations(self,annot_file):
        target = self.parse_voc_xml(ET.parse(annot_file).getroot())
        annotations = np.zeros((0, 5))
        
        for t in target['annotation']['object']:
            annotation = np.zeros((1, 5))
            annotation[0, 1:] = np.array( [ t['bndbox']['xmin'], t['bndbox']['ymin'], t['bndbox']['xmax'], t['bndbox']['ymax'] ] )
            annotation[0, 0] = self.classes.index(t['name'])
            annotations = np.append(annotations, annotation, axis=0)
        
        return annotations

    def get_dataset(self):
        
        print('We found {} files... Read images/labels...'.format(len(self.img_files)))
        cnt = 0
        since = time.time()
        img_all, label_all = [], []
        for idx in range(len(self.img_files)):
            if cnt % 1000 == 0 and cnt != 0: 
                time_elapsed = time.time() - since
                print('\t Current idx {}... complete in {:.0f}m {:.0f}s'.format(cnt, time_elapsed // 60, time_elapsed % 60))
            img_all.append(np.array(Image.open(self.img_files[idx]).convert('RGB')))
            label_all.append(self.get_annotations(self.label_files[idx]))
            cnt += 1
            
        return img_all, label_all

In [32]:
voc_loader = load_voc(root_dir='/nasdata2/khj/objectdetection/segmentation/VOCdevkit/', set_name='VOC2007_test')

In [33]:
img_all, label_all = voc_loader.get_dataset()

We found 4952 files... Read images/labels...
	 Current idx 1000... complete in 0m 6s
	 Current idx 2000... complete in 0m 12s
	 Current idx 3000... complete in 0m 18s
	 Current idx 4000... complete in 0m 25s


In [34]:
print(len(img_all))
print(np.shape(img_all[0]))

print(label_all[0])

4952
(500, 353, 3)
[[ 11.  48. 240. 195. 371.]
 [ 14.   8.  12. 352. 498.]]
