Skip to content

Commit 1d9924b

Browse files
authored
Merge pull request #162 from awu-labelbox/awu/mea-create-model
Add create model and create model run
2 parents 32da99f + e313b65 commit 1d9924b

File tree

5 files changed

+71
-6
lines changed

5 files changed

+71
-6
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from labelbox.schema.role import Role, ProjectRole
2020
from labelbox.schema.invite import Invite, InviteLimit
2121
from labelbox.schema.model_run import ModelRun
22+
from labelbox.schema.model import Model

labelbox/client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,19 @@ def get_models(self, where=None):
533533
An iterable of Models (typically a PaginatedCollection).
534534
"""
535535
return self._get_all(Model, where, filter_deleted=False)
536+
537+
def create_model(self, name, ontology_id):
538+
""" Creates a Model object on the server.
539+
540+
>>> model = client.create_model(<model_name>, <ontology_id>)
541+
542+
Args:
543+
name (string): Name of the model
544+
ontology_id (string): ID of the related ontology
545+
Returns:
546+
A new Model object.
547+
Raises:
548+
InvalidAttributeError: If the Model type does not contain
549+
any of the attribute names given in kwargs.
550+
"""
551+
return self._create(Model, {"name": name, "ontology_id": ontology_id})

labelbox/schema/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import labelbox.schema.benchmark
44
import labelbox.schema.data_row
55
import labelbox.schema.dataset
6+
import labelbox.schema.invite
67
import labelbox.schema.label
78
import labelbox.schema.labeling_frontend
9+
import labelbox.schema.model
10+
import labelbox.schema.model_run
11+
import labelbox.schema.ontology
812
import labelbox.schema.organization
13+
import labelbox.schema.prediction
914
import labelbox.schema.project
1015
import labelbox.schema.review
16+
import labelbox.schema.role
1117
import labelbox.schema.task
1218
import labelbox.schema.user
1319
import labelbox.schema.webhook
14-
import labelbox.schema.prediction
15-
import labelbox.schema.ontology
16-
import labelbox.schema.invite
17-
import labelbox.schema.role
18-
import labelbox.schema.model_run

labelbox/schema/model.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from labelbox.orm import query
12
from labelbox.orm.db_object import DbObject
2-
from labelbox.orm.model import Field, Relationship
3+
from labelbox.orm.model import Entity, Field, Relationship
34

45

56
class Model(DbObject):
@@ -16,3 +17,26 @@ class Model(DbObject):
1617

1718
model_runs = Relationship.ToMany("ModelRun", False)
1819
ontology = Relationship.ToOne("Ontology", False)
20+
21+
def create_model_run(self, name):
22+
""" Creates a model run belonging to this model.
23+
24+
Args:
25+
name (string): The name for the model run.
26+
Returns:
27+
ModelRun, the created model run.
28+
"""
29+
name_param = "name"
30+
model_id_param = "modelId"
31+
ModelRun = Entity.ModelRun
32+
query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: ID!) {
33+
createModelRun(data: {name: $%s, modelId: $%s}) {%s}}""" % (
34+
name_param, model_id_param,
35+
name_param, model_id_param,
36+
query.results_query_part(ModelRun)
37+
)
38+
res = self.client.execute(query_str, {
39+
name_param: name,
40+
model_id_param: self.uid
41+
})
42+
return ModelRun(self.client, res["createModelRun"])

tests/integration/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from labelbox import Model
2+
3+
def test_model(client, configured_project, rand_gen):
4+
before = list(client.get_models())
5+
for m in before:
6+
assert isinstance(m, Model)
7+
8+
ontology = configured_project.ontology
9+
10+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
11+
model = client.create_model(data["name"], data["ontology_id"])
12+
assert model.name == data["name"]
13+
assert model.ontology.id == data["ontology_id"]
14+
15+
after = list(client.get_models())
16+
assert len(after) == len(before) + 1
17+
assert model in after
18+
19+
model = client.get_model(model.uid)
20+
assert model.name == data["name"]
21+
assert model.ontology.id == data["ontology_id"]
22+
23+

0 commit comments

Comments
 (0)