Skip to content

Commit

Permalink
fix: DIA-1061: Start Training should send START_TRAINING event as web…
Browse files Browse the repository at this point in the history
…hook instead of PROJECT_UPDATE (#5761)

#### Change has impacts in these area(s)
_(check all that apply)_
- [ ] Product design
- [ ] Backend (Database)
- [x] Backend (API)
- [ ] Frontend



### Describe the reason for change
Start Training button on Model settings page sends PROJECT_UPDATE event
and it's not possible to distinguish when Start Training really was
called.



#### What does this fix?
ML backend model method fit(..., event, ...) will be able to distinguish
the right training event.


#### What is the current behavior?
ML backend model method fit(..., event, ...) receives PROJECT_UPDATE as
event and this can be triggered on any project change event.
  • Loading branch information
makseq committed Apr 29, 2024
1 parent 0f4a8aa commit 77eed02
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
21 changes: 21 additions & 0 deletions docs/source/guide/webhook_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,24 @@ You must [enable organization-level webhooks](webhooks.html#Enable-organization-
```


### Start Training

This webhook is triggered when a user clicks `Start Training` button on the ML Model card in the Project Settings page.
This event will be sent to the ML Backend and can be caught in the model.fit(event, ...) method:

```
class MyModel(LabelStudioMLBase):
def fit(self, event, *args, **kwargs):
if event == 'START_TRAINING':
...
```

### Webhook payload details

| Key | Type | Description |
| --- | --- |---------------------------------------|
| action | string | Name of the action: `START_TRAINING`. |
| id | integer | ID of the project where training is started. |
| project | JSON dictionary | All fields related to the project that was updated. See the [API documentation for updating a project](/api#operation/api_projects_partial_update). |


2 changes: 1 addition & 1 deletion label_studio/ml/api_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def train(self, project, use_ground_truth=False):
# Identify if feature flag is turned on
if flag_set('ff_back_dev_1417_start_training_mlbackend_webhooks_250122_long', user):
request = {
'action': 'PROJECT_UPDATED',
'action': 'START_TRAINING',
'project': load_func(settings.WEBHOOK_SERIALIZERS['project'])(instance=project).data,
}
return self._request('webhook', request, verbose=False, timeout=TIMEOUT_PREDICT)
Expand Down
50 changes: 50 additions & 0 deletions label_studio/tests/webhooks/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def project_webhook(configured_project):
)


@pytest.fixture
def ml_start_training_webhook(configured_project):
organization = configured_project.organization
uri = 'http://0.0.0.0:9090/webhook'
return Webhook.objects.create(
organization=organization,
project=configured_project,
url=uri,
)


@pytest.mark.django_db
def test_run_webhook(setup_project_dialog, organization_webhook):
webhook = organization_webhook
Expand Down Expand Up @@ -390,3 +401,42 @@ def test_webhooks_for_tasks_from_storages(configured_project, business_client, o
assert r.json()['action'] == WebhookAction.TASKS_CREATED
assert 'tasks' in r.json()
assert 'project' in r.json()


@pytest.mark.django_db
def test_start_training_webhook(setup_project_dialog, ml_start_training_webhook, business_client):
"""
1. Setup: The test uses the project_webhook fixture, which assumes that a webhook
is already configured for the project.
2. Mocking the POST Request: The requests_mock.Mocker is used to mock
the POST request to the webhook URL. This is where you expect the START_TRAINING action to be sent.
3. Making the Request: The test makes a POST request to the /api/ml/{id}/train endpoint.
Assertions:
- The response status code is checked to ensure the request was successful.
- It verifies that exactly one request was made to the webhook URL.
- It checks that the request method was POST.
- The request URL and the JSON payload are validated against expected values.
"""
from ml.models import MLBackend

webhook = ml_start_training_webhook
project = webhook.project
ml = MLBackend.objects.create(project=project, url='http://0.0.0.0:9090')

# Mock the POST request to the ML backend train endpoint
with requests_mock.Mocker(real_http=True) as m:
m.register_uri('POST', webhook.url)
response = business_client.post(
f'/api/ml/{ml.id}/train',
data=json.dumps({'action': 'START_TRAINING'}),
content_type='application/json',
)

assert response.status_code == 200
request_history = m.request_history
assert len(request_history) == 1
assert request_history[0].method == 'POST'
assert request_history[0].url == webhook.url
assert 'project' in request_history[0].json()
assert request_history[0].json()['action'] == 'START_TRAINING'

0 comments on commit 77eed02

Please sign in to comment.