Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: LEAP-930: Autoselect model as prelabeler when its added #5697

Merged
merged 4 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions label_studio/ml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def perform_create(self, serializer):
ml_backend = serializer.save()
ml_backend.update_state()

project = ml_backend.project

# In case we are adding the model, let's set it as the default
# to obtain predictions. This approach is consistent with uploading
# offline predictions, which would be set automatically.
if project.show_collab_predictions and (project.model_version is None or project.model_version == ''):
project.model_version = ml_backend.title
project.save(update_fields=['model_version'])


@method_decorator(
name='patch',
Expand Down
3 changes: 2 additions & 1 deletion label_studio/projects/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def to_internal_value(self, data):
# FIXME: remake this logic with start_training_on_annotation_update
initial_data = data
data = super().to_internal_value(data)

if 'start_training_on_annotation_update' in initial_data:
data['min_annotations_to_start_training'] = int(initial_data['start_training_on_annotation_update'])

Expand Down Expand Up @@ -186,7 +187,7 @@ def validate_model_version(self, value):
return value

def update(self, instance, validated_data):
if not validated_data.get('show_collab_predictions'):
if validated_data.get('show_collab_predictions') is False:
instance.model_version = ''

return super().update(instance, validated_data)
Expand Down
76 changes: 68 additions & 8 deletions label_studio/tests/ml/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from label_studio.tests.utils import make_project, register_ml_backend_mock

ORIG_MODEL_NAME = 'basic_ml_backend'
PROJECT_CONFIG = """<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>"""


@pytest.fixture
def ml_backend_for_test_api(ml_backend):
Expand All @@ -22,13 +26,72 @@ def mock_gethostbyname(mocker):
mocker.patch('socket.gethostbyname', return_value='321.21.21.21')


@pytest.mark.django_db
def test_ml_backend_set_for_prelabeling(business_client, ml_backend_for_test_api, mock_gethostbyname):
project = make_project(
config=dict(
is_published=True,
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
)

assert project.model_version == ''

# create ML backend
response = business_client.post(
'/api/ml/',
data={
'project': project.id,
'title': 'ml_backend_title',
'url': 'https://ml_backend_for_test_api',
},
)
assert response.status_code == 201

project.refresh_from_db()
assert project.model_version == 'ml_backend_title'


@pytest.mark.django_db
def test_ml_backend_not_set_for_prelabeling(business_client, ml_backend_for_test_api, mock_gethostbyname):
"""We are not setting it when its already set for another name,
for example when predictions were uploaded before"""

project = make_project(
config=dict(
is_published=True,
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
)

project.model_version = ORIG_MODEL_NAME
project.save()

# create ML backend
response = business_client.post(
'/api/ml/',
data={
'project': project.id,
'title': 'ml_backend_title',
'url': 'https://ml_backend_for_test_api',
},
)
assert response.status_code == 201

project.refresh_from_db()
assert project.model_version == ORIG_MODEL_NAME


@pytest.mark.django_db
def test_model_version_on_save(business_client, ml_backend_for_test_api, mock_gethostbyname):
project = make_project(
config=dict(
is_published=True,
label_config="""<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>""",
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
Expand Down Expand Up @@ -86,8 +149,7 @@ def test_model_version_on_delete(business_client, ml_backend_for_test_api, mock_
project = make_project(
config=dict(
is_published=True,
label_config="""<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>""",
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
Expand Down Expand Up @@ -135,8 +197,7 @@ def test_security_write_only_payload(business_client, ml_backend_for_test_api, m
project = make_project(
config=dict(
is_published=True,
label_config="""<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>""",
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
Expand Down Expand Up @@ -226,8 +287,7 @@ def test_ml_backend_predict_test_api_post_random_true(business_client):
project = make_project(
config=dict(
is_published=True,
label_config="""<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>""",
label_config=PROJECT_CONFIG,
title='test_ml_backend_creation',
),
user=business_client.user,
Expand Down
28 changes: 5 additions & 23 deletions label_studio/tests/predictions.model.tavern.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ stages:
request:
data:
project: '{project_pk}'
title: My Testing ML backend
title: "ml_backend"
url: https://test.heartex.mlbackend.com:9090
method: POST
url: '{django_live_url}/api/ml'
Expand All @@ -147,25 +147,7 @@ stages:
response:
status_code: 200
json:
model_version: ''
- name: change_project_model_to_ml_backend
request:
method: PATCH
url: '{django_live_url}/api/projects/{project_pk}'
data:
model_version: "My Testing ML backend"
response:
status_code: 200
json:
model_version: "My Testing ML backend"
- name: check_project_model_after_project_change
request:
method: GET
url: '{django_live_url}/api/projects/{project_pk}'
response:
status_code: 200
json:
model_version: "My Testing ML backend"
model_version: "ml_backend"

---
test_name: model_change_before_ML_added
Expand Down Expand Up @@ -305,15 +287,15 @@ stages:
request:
data:
project: '{project_pk}'
title: My Testing ML backend
title: "ml_backend"
url: https://test.heartex.mlbackend.com:9090
method: POST
url: '{django_live_url}/api/ml'
- name: check_project_model_not_change_after_ml_added_to_empty
- name: check_project_model_change_after_ml_added_to_empty
request:
method: GET
url: '{django_live_url}/api/projects/{project_pk}'
response:
status_code: 200
json:
model_version: ""
model_version: "ml_backend"