Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions labelbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from labelbox.schema.asset_metadata import AssetMetadata
from labelbox.schema.webhook import Webhook
from labelbox.schema.prediction import Prediction, PredictionModel
from labelbox.schema.ontology import Ontology
5 changes: 5 additions & 0 deletions labelbox/orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Type(Enum):
Boolean = auto()
ID = auto()
DateTime = auto()
Json = auto()

class EnumType:

Expand Down Expand Up @@ -85,6 +86,10 @@ def DateTime(*args):
def Enum(enum_cls: type, *args):
return Field(Field.EnumType(enum_cls), *args)

@staticmethod
def Json(*args):
return Field(Field.Type.Json, *args)

def __init__(self,
field_type: Union[Type, EnumType],
name,
Expand Down
1 change: 1 addition & 0 deletions labelbox/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
import labelbox.schema.user
import labelbox.schema.webhook
import labelbox.schema.prediction
import labelbox.schema.ontology
120 changes: 120 additions & 0 deletions labelbox/schema/ontology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Client side object for interacting with the ontology."""
import abc
from dataclasses import dataclass

from typing import Any, Callable, Dict, List, Optional, Union

from labelbox.orm import query
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.utils import snake_case, camel_case


@dataclass
class OntologyEntity:
required: bool
name: str


@dataclass
class Option:
label: str
value: str
feature_schema_id: Optional[str] = None
schema_node_id: Optional[str] = None

@classmethod
def from_json(cls, json_dict):
_dict = convert_keys(json_dict, snake_case)
return cls(**_dict)


@dataclass
class Classification(OntologyEntity):
type: str
instructions: str
options: List[Option]
feature_schema_id: Optional[str] = None
schema_node_id: Optional[str] = None

@classmethod
def from_json(cls, json_dict):
_dict = convert_keys(json_dict, snake_case)
_dict['options'] = [
Option.from_json(option) for option in _dict['options']
]
return cls(**_dict)


@dataclass
class Tool(OntologyEntity):
tool: str
color: str
classifications: List[Classification]
feature_schema_id: Optional[str] = None
schema_node_id: Optional[str] = None

@classmethod
def from_json(cls, json_dict):
_dict = convert_keys(json_dict, snake_case)
_dict['classifications'] = [
Classification.from_json(classification)
for classification in _dict['classifications']
]
return cls(**_dict)


class Ontology(DbObject):
""" A ontology specifies which tools and classifications are available
to a project.

NOTE: This is read only for now.

>>> project = client.get_project(name="<project_name>")
>>> ontology = project.ontology()
>>> ontology.normalized

"""

name = Field.String("name")
description = Field.String("description")
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
normalized = Field.Json("normalized")
object_schema_count = Field.Int("object_schema_count")
classification_schema_count = Field.Int("classification_schema_count")

projects = Relationship.ToMany("Project", True)
created_by = Relationship.ToOne("User", False, "created_by")

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._tools: Optional[List[Tool]] = None
self._classifications: Optional[List[Classification]] = None

def tools(self) -> List[Tool]:
if self._tools is None:
self._tools = [
Tool.from_json(tool) for tool in self.normalized['tools']
]
return self._tools # type: ignore

def classifications(self) -> List[Classification]:
if self._classifications is None:
self._classifications = [
Classification.from_json(classification)
for classification in self.normalized['classifications']
]
return self._classifications # type: ignore


def convert_keys(json_dict: Dict[str, Any],
converter: Callable) -> Dict[str, Any]:
if isinstance(json_dict, dict):
return {
converter(key): convert_keys(value, converter)
for key, value in json_dict.items()
}
if isinstance(json_dict, list):
return [convert_keys(ele, converter) for ele in json_dict]
return json_dict
1 change: 1 addition & 0 deletions labelbox/schema/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Project(DbObject, Updateable, Deletable):
active_prediction_model = Relationship.ToOne("PredictionModel", False,
"active_prediction_model")
predictions = Relationship.ToMany("Prediction", False)
ontology = Relationship.ToOne("Ontology", True)

def create_label(self, **kwargs):
""" Creates a label on this Project.
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/test_ontology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
from typing import Any, Dict, List, Union


def sample_ontology() -> Dict[str, Any]:
return {
"tools": [{
"required": False,
"name": "Dog",
"color": "#FF0000",
"tool": "rectangle",
"classifications": []
}],
"classifications": [{
"required":
True,
"instructions":
"This is a question.",
"name":
"this_is_a_question.",
"type":
"radio",
"options": [{
"label": "Yes",
"value": "yes"
}, {
"label": "No",
"value": "no"
}]
}]
}


def test_create_ontology(client, project) -> None:
""" Tests that the ontology that a project was set up with can be grabbed."""
frontend = list(client.get_labeling_frontends())[0]
project.setup(frontend, sample_ontology())
normalized_ontology = project.ontology().normalized

def _remove_schema_ids(
ontology_part: Union[List, Dict[str, Any]]) -> Dict[str, Any]:
""" Recursively scrub the normalized ontology of any schema information."""
removals = {'featureSchemaId', 'schemaNodeId'}

if isinstance(ontology_part, list):
return [_remove_schema_ids(part) for part in ontology_part]
if isinstance(ontology_part, dict):
return {
key: _remove_schema_ids(value)
for key, value in ontology_part.items()
if key not in removals
}
return ontology_part

removed = _remove_schema_ids(normalized_ontology)
assert removed == sample_ontology()

ontology = project.ontology()

tools = ontology.tools()
assert tools
for tool in tools:
assert tool.feature_schema_id
assert tool.schema_node_id

classifications = ontology.classifications()
assert classifications
for classification in classifications:
assert classification.feature_schema_id
assert classification.schema_node_id
for option in classification.options:
assert option.feature_schema_id
assert option.schema_node_id