Skip to content

Commit

Permalink
fixing formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Malyuk committed Apr 9, 2024
1 parent a4256b1 commit c5b3bce
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 28 deletions.
7 changes: 3 additions & 4 deletions label_studio/ml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ def perform_create(self, serializer):

project = ml_backend.project

# In case we are adding the model, let's set it as the default
# to obtain predictions. This approach is consistent with uploading
# In case we are adding the model, let's set it as the default
# to obtain predictions. This approach is consistent with uploading
# offline predictions, which would be set automatically.
if project.show_collab_predictions and \
(project.model_version is None or project.model_version == ""):
if project.show_collab_predictions and (project.model_version is None or project.model_version == ''):
project.model_version = ml_backend.title
project.save(update_fields=['model_version'])

Expand Down
6 changes: 3 additions & 3 deletions label_studio/projects/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def to_internal_value(self, data):
data['expert_instruction'] = bleach.clean(
initial_data['expert_instruction'], tags=SAFE_HTML_TAGS, attributes=SAFE_HTML_ATTRIBUTES
)

return data

class Meta:
Expand Down Expand Up @@ -187,8 +187,8 @@ def validate_model_version(self, value):
return value

def update(self, instance, validated_data):
if validated_data.get("show_collab_predictions") is False:
instance.model_version = ""
if validated_data.get('show_collab_predictions') is False:
instance.model_version = ''

return super().update(instance, validated_data)

Expand Down
15 changes: 7 additions & 8 deletions label_studio/tests/ml/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from label_studio.tests.utils import make_project, register_ml_backend_mock

ORIG_MODEL_NAME="basic_ml_backend"
PROJECT_CONFIG="""<View><Image name="image" value="$image_url"/><Choices name="label"
ORIG_MODEL_NAME = 'basic_ml_backend'
PROJECT_CONFIG = """<View><Image name="image" value="$image_url"/><Choices name="label"
toName="image"><Choice value="pos"/><Choice value="neg"/></Choices></View>"""


@pytest.fixture
def ml_backend_for_test_api(ml_backend):
register_ml_backend_mock(
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_ml_backend_set_for_prelabeling(business_client, ml_backend_for_test_api
},
)
assert response.status_code == 201

project.refresh_from_db()
assert project.model_version == 'ml_backend_title'

Expand All @@ -69,7 +70,7 @@ def test_ml_backend_not_set_for_prelabeling(business_client, ml_backend_for_test

project.model_version = ORIG_MODEL_NAME
project.save()

# create ML backend
response = business_client.post(
'/api/ml/',
Expand All @@ -80,10 +81,10 @@ def test_ml_backend_not_set_for_prelabeling(business_client, ml_backend_for_test
},
)
assert response.status_code == 201

project.refresh_from_db()
assert project.model_version == ORIG_MODEL_NAME


@pytest.mark.django_db
def test_model_version_on_save(business_client, ml_backend_for_test_api, mock_gethostbyname):
Expand Down Expand Up @@ -304,5 +305,3 @@ def test_ml_backend_predict_test_api_post_random_true(business_client):
r = response.json()
assert r['url'] == 'http://localhost:8999/predict'
assert r['status'] == 200


19 changes: 6 additions & 13 deletions scripts/update_ml_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

import os
import re
import yaml
from pathlib import Path
from typing import List

import yaml

ML_REPO_PATH = os.getenv('ML_REPO_PATH', 'label-studio-ml-backend/')

Expand All @@ -45,10 +45,7 @@ def parse_readme_file(file_path: str) -> dict:
header = re.findall(r'---\n(.*?)\n---', content, re.DOTALL)
body = re.sub(r'---\n(.*?)\n---', '', content, flags=re.DOTALL)

return {
'header': header[0].strip() if header else '',
'body': body.strip()
}
return {'header': header[0].strip() if header else '', 'body': body.strip()}


def create_tutorial_files():
Expand All @@ -68,10 +65,9 @@ def create_tutorial_files():
f.write(parsed_content['header'])
f.write('\n---\n\n')
f.write(parsed_content['body'])
files_and_headers.append({
'model_name': model_name,
'header': yaml.load(parsed_content['header'], Loader=yaml.FullLoader)
})
files_and_headers.append(
{'model_name': model_name, 'header': yaml.load(parsed_content['header'], Loader=yaml.FullLoader)}
)

update_ml_tutorials_index(files_and_headers)

Expand All @@ -91,10 +87,7 @@ def update_ml_tutorials_index(files_and_headers: List):
for f in files_and_headers:
print('Processing', f['model_name'])
h = f['header'] or {}
card = {
'title': h.get('title') or f['model_name'],
'url': f'/tutorials/{f["model_name"]}.html'
}
card = {'title': h.get('title') or f['model_name'], 'url': f'/tutorials/{f["model_name"]}.html'}
card.update(f['header'] or {})
data['cards'].append(card)

Expand Down

0 comments on commit c5b3bce

Please sign in to comment.