Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #60 - Support for images with multiple objects in Dataset class #62

Merged
merged 3 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions detecto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, label_data, image_folder=None, transform=None):
the XML label files or a CSV file containing the label data.
If a CSV file, the file should have the following columns in
order: ``filename``, ``width``, ``height``, ``class``, ``xmin``,
``ymin``, ``xmax``, and ``ymax``. See
``ymin``, ``xmax``, ``ymax`` and ``image_id``. See
:func:`detecto.utils.xml_to_csv` to generate CSV files in this
format from XML label files.
:type label_data: str
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(self, label_data, image_folder=None, transform=None):

# Returns the length of this dataset
def __len__(self):
return len(self._csv)
# number of entries == number of unique image_ids in csv.
return len(self._csv['image_id'].unique().tolist())

# Is what allows you to index the dataset, e.g. dataset[0]
# dataset[index] returns a tuple containing the image and the targets dict
Expand All @@ -145,22 +146,29 @@ def __getitem__(self, idx):
idx = idx.tolist()

# Read in the image from the file name in the 0th column
img_name = os.path.join(self._root_dir, self._csv.iloc[idx, 0])
object_entries = self._csv.loc[self._csv['image_id'] == idx]

img_name = os.path.join(self._root_dir, object_entries.iloc[0, 0])
image = read_image(img_name)

# Read in xmin, ymin, xmax, and ymax
box = self._csv.iloc[idx, 4:]
box = torch.tensor(box).view(1, 4)
boxes = []
labels = []
for object_idx, row in object_entries.iterrows():
# Read in xmin, ymin, xmax, and ymax
box = self._csv.iloc[object_idx, 4:8]
boxes.append(box)
# Read in the labe
label = self._csv.iloc[object_idx, 3]
labels.append(label)

# Read in the label
label = self._csv.iloc[idx, 3]
boxes = torch.tensor(boxes).view(-1, 4)

targets = {'boxes': box, 'labels': label}
targets = {'boxes': boxes, 'labels': labels}

# Perform transformations
if self.transform:
width = self._csv.loc[idx, 'width']
height = self._csv.loc[idx, 'height']
width = object_entries.iloc[0, 1]
height = object_entries.iloc[0, 2]

# Apply the transforms manually to be able to deal with
# transforms like Resize or RandomHorizontalFlip
Expand Down Expand Up @@ -189,15 +197,20 @@ def __getitem__(self, idx):
if isinstance(t, transforms.RandomHorizontalFlip):
if random.random() < random_flip:
image = transforms.RandomHorizontalFlip(1)(image)
# Flip box's x-coordinates
box[0, 0] = width - box[0, 0]
box[0, 2] = width - box[0, 2]
box[0, 0], box[0, 2] = box[0, (2, 0)]
for idx, box in enumerate(targets['boxes']):
# Flip box's x-coordinates
box[0] = width - box[0]
box[2] = width - box[2]
box[[0,2]] = box[[2,0]]
targets['boxes'][idx] = box
else:
image = t(image)

# Scale down box if necessary
targets['boxes'] = (box / scale_factor).long()
if scale_factor != 1.0:
for idx, box in enumerate(targets['boxes']):
box = (box / scale_factor).long()
targets['boxes'][idx] = box

return image, targets

Expand Down Expand Up @@ -329,6 +342,7 @@ def predict(self, images):

return results[0] if is_single_image else results


def predict_top(self, images):
"""Takes in an image or list of images and returns the top
scoring predictions for each detected label in each image.
Expand Down Expand Up @@ -568,9 +582,12 @@ def load(file, classes):
# Converts all string labels in a list of target dicts to
# their corresponding int mappings
def _convert_to_int_labels(self, targets):
for target in targets:
# Convert string labels to integer mapping
target['labels'] = torch.tensor(self._int_mapping[target['labels']]).view(1)
for idx, target in enumerate(targets):
# get all string labels for objects in a single image
labels_array = target['labels']
# convert string labels into one hot encoding
labels_int_array = [self._int_mapping[class_name] for class_name in labels_array]
target['labels'] = torch.tensor(labels_int_array)

# Sends all images and targets to the same device as the model
def _to_device(self, images, targets):
Expand Down
22 changes: 11 additions & 11 deletions detecto/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
def test_dataset():
# Test the format of the values returned by indexing the dataset
dataset = get_dataset()
assert len(dataset) == 2
assert len(dataset) == 1 # there is only one image in the dataset (label.xml)
assert isinstance(dataset[0][0], torch.Tensor)
assert isinstance(dataset[0][1], dict)
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]
assert dataset[0][1]['boxes'].shape == (1, 4)
assert dataset[0][1]['labels'] == 'start_tick'
assert dataset[0][1]['boxes'].shape == (2, 4)
assert dataset[0][1]['labels'] == ['start_tick', 'start_gate']

transform = transforms.Compose([
transforms.ToPILImage(),
Expand All @@ -29,20 +29,20 @@ def test_dataset():

# Test that the transforms are properly applied
dataset = get_dataset(transform=transform)
assert dataset[1][0].shape == (3, 108, 172)
assert torch.all(dataset[1][1]['boxes'][0] == torch.tensor([6, 41, 171, 107]))
assert dataset[0][0].shape == (3, 108, 172)
assert torch.all(dataset[0][1]['boxes'][1] == torch.tensor([6, 41, 171, 107]))

# Test works when given an XML folder
path = os.path.dirname(__file__)
input_folder = os.path.join(path, 'static')

dataset = Dataset(input_folder, input_folder)
assert len(dataset) == 2
assert len(dataset) == 1
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]

dataset = Dataset(input_folder)
assert len(dataset) == 2
assert len(dataset) == 1
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]

Expand Down Expand Up @@ -75,14 +75,14 @@ def test_dataloader():
iterations += 1

assert isinstance(data, tuple)
assert len(data) == 2
assert len(data) == 2 # data[0] = image tensor, data[1] = targets dictionary
assert isinstance(data[0], list)
assert len(data[0]) == 2
assert len(data[0]) == 1 # only one image in data[0] since label.xml contains one image only.

assert isinstance(data[0][0], torch.Tensor)
assert isinstance(data[0][1], torch.Tensor)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[1][1]
assert 'boxes' in data[1][0] and 'labels' in data[1][0]

assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]
assert iterations == 1


Expand Down
9 changes: 6 additions & 3 deletions detecto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def xml_to_csv(xml_folder, output_file=None):
"""Converts a folder of XML label files into a pandas DataFrame and/or
CSV file, which can then be used to create a :class:`detecto.core.Dataset`
object. Each XML file should correspond to an image and contain the image
name, image size, and the names and bounding boxes of the objects in the
name, image size, image_id and the names and bounding boxes of the objects in the
image, if any. Extraneous data in the XML files will simply be ignored.
See :download:`here <../_static/example.xml>` for an example XML file.
For an image labeling tool that produces XML files in this format,
Expand All @@ -249,6 +249,7 @@ def xml_to_csv(xml_folder, output_file=None):
"""

xml_list = []
image_id = 0
# Loop through every XML file
for xml_file in glob(xml_folder + '/*.xml'):
tree = ET.parse(xml_file)
Expand All @@ -266,11 +267,13 @@ def xml_to_csv(xml_folder, output_file=None):

# Add image file name, image size, label, and box coordinates to CSV file
row = (filename, width, height, label, int(float(box[0].text)),
int(float(box[1].text)), int(float(box[2].text)), int(float(box[3].text)))
int(float(box[1].text)), int(float(box[2].text)), int(float(box[3].text)), image_id)
xml_list.append(row)

image_id += 1

# Save as a CSV file
column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax', 'image_id']
xml_df = pd.DataFrame(xml_list, columns=column_names)

if output_file is not None:
Expand Down