In [14]:
import xml.etree.ElementTree as ET
import os
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split

In [15]:
class_mapping = {
    'person': 0,
    'bird': 1,
    'cat': 2,
    'cow': 3,
    'dog': 4,
    'horse': 5,
    'sheep': 6,
    'aeroplane': 7,
    'bicycle': 8,
    'boat': 9,
    'bus': 10,
    'car': 11,
    'motorbike': 12,
    'train': 13,
    'bottle': 14,
    'chair': 15,
    'diningtable': 16,
    'pottedplant': 17,
    'sofa': 18,
    'tvmonitor': 19,
    'background': 20,
    'void': 255
}

In [16]:
def read_scribble_xml(xml_file):
    #  Parse XML into Element Tree
    tree = ET.parse(xml_file)
    root = tree.getroot()

    #  Read meta data
    filename = root.find('filename').text
    width = int(root.find('size/width').text)
    height = int(root.find('size/height').text)
    depth = int(root.find('size/depth').text)

    scribbles = []
    #  Read all tags
    for polygon in root.findall('polygon'):
        tag = polygon.find('tag').text
        #  Read all points of this polygon(tag)
        points = [(int(point.find('X').text), int(point.find('Y').text)) for point in polygon.findall('point')]

        scribbles.append({'tag': tag, 'points': points})
    
    return filename, width, height, depth, scribbles

In [23]:
# Define dataset
class ScribbleDataset(Dataset):
    def __init__(self, image_dir, xml_dir, xml_files, transform=None):
        self.image_dir = image_dir
        self.xml_files = xml_files
        self.transform = transform
        self.xml_dir = xml_dir

    def __len__(self):
        return len(self.xml_files)

    def __getitem__(self, idx):
        xml_file = self.xml_files[idx]
        xml_path = os.path.join(self.xml_dir, xml_file)
        image_path = os.path.join(self.image_dir, xml_file.replace(".xml", ".jpg"))

        filename, width, height, scribbles = read_scribble_xml(xml_path)
        image = Image.open(image_path)

        sample = {'image': image, 'scribbles': scribbles}

        if self.transform:
            sample = self.transform(sample)

        return sample


In [26]:
class ToTensor(object):
    def __call__(self, sample):
        image, scribbles = sample['image'], sample['scribbles']

        # Convert image to tensor
        image = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])(image)

        return {'image': image, 'scribbles': scribbles}

In [27]:
# Set up dataset and dataloader
xml_dir = "scribble"
image_dir = "train_JPEGImages"
xml_list = os.listdir(xml_dir)
transform = transforms.Compose([ToTensor()])
train_data, val_data = train_test_split(xml_list, test_size=0.1, random_state=1)
train_dataset = ScribbleDataset(image_dir=image_dir, xml_dir=xml_dir, xml_files=train_data, transform=transform)
val_dataset = ScribbleDataset(image_dir=image_dir, xml_dir=xml_dir, xml_files=val_data, transform=transform)
train_image_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_image_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
