Skip to content

Commit

Permalink
Merge pull request #50 from heartexlabs/fix/gh1115-category-list
Browse files Browse the repository at this point in the history
[fix] Fix github issue 1115 - different category ids for same label c…
  • Loading branch information
makseq committed Sep 7, 2021
2 parents 39d308d + 4866d9a commit 1ea3768
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions label_studio_converter/converter.py
Expand Up @@ -418,7 +418,7 @@ def convert_to_coco(self, input_data, output_dir, output_image_dir=None, is_dir=
output_image_dir = os.path.join(output_dir, 'images')
os.makedirs(output_image_dir, exist_ok=True)
images, categories, annotations = [], [], []
category_name_to_id = {}
categories, category_name_to_id = self._get_labels()
data_key = self._data_keys[0]
item_iterator = self.iter_from_dir(input_data) if is_dir else self.iter_from_json_file(input_data)
for item_idx, item in enumerate(item_iterator):
Expand Down Expand Up @@ -470,14 +470,14 @@ def convert_to_coco(self, input_data, output_dir, output_image_dir=None, is_dir=
})
first = False

if category_name not in category_name_to_id:
'''if category_name not in category_name_to_id:
category_id = len(categories)
category_name_to_id[category_name] = category_id
categories.append({
'id': category_id,
'name': category_name,
'supercategory': category_name
})
})'''
category_id = category_name_to_id[category_name]

annotation_id = len(annotations)
Expand Down Expand Up @@ -730,3 +730,34 @@ def create_child_node(doc, tag, attr, parent_node):

with io.open(xml_filepath, mode='w', encoding='utf8') as fout:
doc.writexml(fout, addindent='' * 4, newl='\n', encoding='utf-8')

def _get_labels(self):
labels = set()
categories = list()
category_name_to_id = dict()

for name, info in self._schema.items():
labels |= set(info['labels'])
attrs = info['labels_attrs']
for label in attrs:
if attrs[label].get('category'):
categories.append({
'id': attrs[label].get('category'),
'name': label
})
category_name_to_id[label] = attrs[label].get('category')
labels_to_add = set(labels) - set(list(category_name_to_id.keys()))
labels_to_add = sorted(list(labels_to_add))
idx = 0
while idx in list(category_name_to_id.values()):
idx += 1
for label in labels_to_add:
categories.append({
'id': idx,
'name': label
})
category_name_to_id[label] = idx
idx += 1
while idx in list(category_name_to_id.values()):
idx += 1
return categories, category_name_to_id

0 comments on commit 1ea3768

Please sign in to comment.