In [1]:
import torch
import cv2
import numpy as np
import os
import glob as glob
from xml.etree import ElementTree as et
import import_ipynb
from config import classes, resize, train_dir, valid_dir, batch_size
from torch.utils.data import Dataset, DataLoader
from custom_utils import collate_fn, get_train_transform, get_valid_transform

importing Jupyter notebook from config.ipynb
importing Jupyter notebook from custom_utils.ipynb


In [2]:
class CustomDataset(Dataset):
    
    def __init__(self, dir_path, width, height, classes, transforms=None):
        self.transforms = transforms
        self.dir_path = dir_path
        self.height = height
        self.width = width
        self.classes = classes
        
        # 獲得圖片位址
        self.image_paths = glob.glob(f"{self.dir_path}/*.jpg")
        self.all_images = [image_path.split(os.path.sep)[-1] for image_path in self.image_paths]
        self.all_images = sorted(self.all_images)
        
    def __getitem__(self, idx):
        image_name = self.all_images[idx]
        image_path = os.path.join(self.dir_path, image_name)
        image = cv2.imread(image_path)
        # BGR2RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0
        
        # 取得相對應的xml
        annot_filename = image_name[:-4] + '.xml'
        annot_file_path = os.path.join(self.dir_path, annot_filename)
        
        boxes = []
        labels = []
        tree = et.parse(annot_file_path)
        root = tree.getroot()
        
        image_width = image.shape[1]
        image_height = image.shape[0]
        # 提取xml中的座標並調整圖像大小
        for member in root.findall('object'):
            # 獲取label
            labels.append(self.classes.index(member.find('name').text))
            # 左角x座標
            xmin = int(member.find('bndbox').find('xmin').text)
            # 右角x座標
            xmax = int(member.find('bndbox').find('xmax').text)
            # 左角y座標
            ymin = int(member.find('bndbox').find('ymin').text)
            # 右角y座標
            ymax = int(member.find('bndbox').find('ymax').text)

            
            # 圖像會resize，所以box也要跟著resize
            xmin_final = (xmin/image_width)*self.width
            xmax_final = (xmax/image_width)*self.width
            ymin_final = (ymin/image_height)*self.height
            yamx_final = (ymax/image_height)*self.height
            
            boxes.append([xmin_final, ymin_final, xmax_final, yamx_final])
        
            
        
        # bounding box 2 tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # area of the bounding boxes
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # 非「人群」實例
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        # labels 2 tensor
        labels = torch.as_tensor(labels, dtype=torch.int64)
        # 建立target字典
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        image_id = torch.tensor([idx])
        target["image_id"] = image_id
                    
        if self.transforms:
            sample = self.transforms(image = image_resized,
                                     bboxes = target['boxes'],
                                     labels = labels)
            image_resized = sample['image']
            target['boxes'] = torch.Tensor(sample['bboxes'])
            
        return image_resized, target
                    
    def __len__(self):
        return len(self.all_images)

In [3]:
def create_train_dataset():
    train_dataset = CustomDataset(train_dir, resize, resize, classes, get_train_transform())
    return train_dataset

In [4]:
def create_valid_dataset():
    valid_dataset = CustomDataset(valid_dir, resize, resize, classes, get_valid_transform())
    return valid_dataset

In [5]:
def create_train_loader(train_dataset, num_workers=0):
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return train_loader

In [6]:
def create_valid_loader(valid_dataset, num_workers=0):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return valid_loader

In [7]:
def visualize_sample(image, target):
    for box_num in range(len(target['boxes'])):
        box = target['boxes'][box_num]
        label = classes[target['labels'][box_num]]
        cv2.rectangle(
            image, 
            (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
            (0, 255, 0), 2
        )
        cv2.putText(
            image, label, (int(box[0]), int(box[1]-5)), 
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2
        )
    cv2.imshow('Image', image)
    cv2.waitKey(1)
    cv2.destroyAllWindows()

In [8]:
if __name__ == '__main__':
    # sanity check of the Dataset pipeline with sample visualization
    dataset = CustomDataset(train_dir, resize, resize, classes)
    print(f"Number of training images: {len(dataset)}")
    
        
    NUM_SAMPLES_TO_VISUALIZE = 5
    for i in range(NUM_SAMPLES_TO_VISUALIZE):
        image, target= dataset[i]
        visualize_sample(image, target)

Number of training images: 13299
