Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DIA-1061: Start Training should send START_TRAINING event as webhook instead of PROJECT_UPDATE #5761

Merged
merged 16 commits into from
Apr 29, 2024
21 changes: 21 additions & 0 deletions docs/source/guide/webhook_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,24 @@ You must [enable organization-level webhooks](webhooks.html#Enable-organization-
```


### Start Training

This webhook is triggered when a user clicks `Start Training` button on the ML Model card in the Project Settings page.
This event will be sent to the ML Backend and can be caught in the model.fit(event, ...) method:

```
class MyModel(LabelStudioMLBase):
def fit(self, event, *args, **kwargs):
if event == 'START_TRAINING':
...
```

### Webhook payload details

| Key | Type | Description |
| --- | --- |---------------------------------------|
| action | string | Name of the action: `START_TRAINING`. |
| id | integer | ID of the project where training is started. |
| project | JSON dictionary | All fields related to the project that was updated. See the [API documentation for updating a project](/api#operation/api_projects_partial_update). |


2 changes: 1 addition & 1 deletion label_studio/ml/api_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def train(self, project, use_ground_truth=False):
# Identify if feature flag is turned on
if flag_set('ff_back_dev_1417_start_training_mlbackend_webhooks_250122_long', user):
request = {
'action': 'PROJECT_UPDATED',
'action': 'START_TRAINING',
'project': load_func(settings.WEBHOOK_SERIALIZERS['project'])(instance=project).data,
}
return self._request('webhook', request, verbose=False, timeout=TIMEOUT_PREDICT)
Expand Down
50 changes: 50 additions & 0 deletions label_studio/tests/webhooks/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def project_webhook(configured_project):
)


@pytest.fixture
def ml_start_training_webhook(configured_project):
organization = configured_project.organization
uri = 'http://0.0.0.0:9090/webhook'
return Webhook.objects.create(
organization=organization,
project=configured_project,
url=uri,
)


@pytest.mark.django_db
def test_run_webhook(setup_project_dialog, organization_webhook):
webhook = organization_webhook
Expand Down Expand Up @@ -390,3 +401,42 @@ def test_webhooks_for_tasks_from_storages(configured_project, business_client, o
assert r.json()['action'] == WebhookAction.TASKS_CREATED
assert 'tasks' in r.json()
assert 'project' in r.json()


@pytest.mark.django_db
def test_start_training_webhook(setup_project_dialog, ml_start_training_webhook, business_client):
"""
1. Setup: The test uses the project_webhook fixture, which assumes that a webhook
is already configured for the project.
2. Mocking the POST Request: The requests_mock.Mocker is used to mock
the POST request to the webhook URL. This is where you expect the START_TRAINING action to be sent.
3. Making the Request: The test makes a POST request to the /api/ml/{id}/train endpoint.

Assertions:
- The response status code is checked to ensure the request was successful.
- It verifies that exactly one request was made to the webhook URL.
- It checks that the request method was POST.
- The request URL and the JSON payload are validated against expected values.
"""
from ml.models import MLBackend

webhook = ml_start_training_webhook
project = webhook.project
ml = MLBackend.objects.create(project=project, url='http://0.0.0.0:9090')

# Mock the POST request to the ML backend train endpoint
with requests_mock.Mocker(real_http=True) as m:
m.register_uri('POST', webhook.url)
response = business_client.post(
f'/api/ml/{ml.id}/train',
data=json.dumps({'action': 'START_TRAINING'}),
content_type='application/json',
)

assert response.status_code == 200
request_history = m.request_history
assert len(request_history) == 1
assert request_history[0].method == 'POST'
assert request_history[0].url == webhook.url
assert 'project' in request_history[0].json()
assert request_history[0].json()['action'] == 'START_TRAINING'