Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: LSDV-4695: Add Project relation to Prediction model to improve overall performance #4629

Merged
merged 6 commits into from Aug 22, 2023
7 changes: 6 additions & 1 deletion label_studio/core/old_ls_migration.py
Expand Up @@ -69,7 +69,12 @@ def _migrate_tasks(project_path, project):
# migrate predictions
predictions_data = task_data.get('predictions', [])
for prediction in predictions_data:
task_prediction = Prediction(result=prediction['result'], task=task, score=prediction.get('score'))
task_prediction = Prediction(
result=prediction['result'],
task=task,
score=prediction.get('score'),
project=task.project,
)
with suppress_autotime(task_prediction, ['created_at']):
task_prediction.created_at = datetime.datetime.fromtimestamp(
prediction['created_at'], tz=datetime.datetime.now().astimezone().tzinfo
Expand Down
3 changes: 3 additions & 0 deletions label_studio/data_import/api.py
Expand Up @@ -327,7 +327,9 @@ class ImportPredictionsAPI(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
# check project permissions
project = self.get_object()
dredivaris marked this conversation as resolved.
Show resolved Hide resolved

tasks_ids = set(Task.objects.filter(project=project).values_list('id', flat=True))

logger.debug(f'Importing {len(self.request.data)} predictions to project {project} with {len(tasks_ids)} tasks')
predictions = []
for item in self.request.data:
Expand All @@ -337,6 +339,7 @@ def create(self, request, *args, **kwargs):
f'from project {project} tasks')
predictions.append(Prediction(
task_id=item['task'],
project_id=project.id,
result=Prediction.prepare_prediction_result(item.get('result'), project),
score=item.get('score'),
model_version=item.get('model_version', 'undefined')
Expand Down
1 change: 1 addition & 0 deletions label_studio/ml/models.py
Expand Up @@ -262,6 +262,7 @@ def predict_one_task(self, task, check_state=True):
score=safe_float(score),
model_version=self.model_version,
task_id=task_id,
project=task.project,
cluster=prediction_response.get('cluster'),
neighbors=prediction_response.get('neighbors'),
mislabeling=safe_float(prediction_response.get('mislabeling', 0)),
Expand Down
15 changes: 14 additions & 1 deletion label_studio/tasks/functions.py
Expand Up @@ -12,7 +12,7 @@
from data_export.serializers import ExportDataSerializer
from organizations.models import Organization
from projects.models import Project
from tasks.models import Task, Annotation
from tasks.models import Task, Annotation, Prediction
from data_export.mixins import ExportMixin


Expand Down Expand Up @@ -137,3 +137,16 @@ def fill_annotations_project():

logger.info('Finished filling project field for Annotation model')


def _fill_predictions_project(project_id):
Prediction.objects.filter(task__project_id=project_id).update(project_id=project_id)


def fill_predictions_project():
logger.info('Start filling project field for Prediction model')
projects = Project.objects.all()
for project in projects:
start_job_async_or_sync(_fill_predictions_project, project.id)

logger.info('Finished filling project field for Prediction model')

20 changes: 20 additions & 0 deletions label_studio/tasks/migrations/0041_prediction_project.py
@@ -0,0 +1,20 @@
# Generated by Django 3.2.19 on 2023-08-10 22:43
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('projects', '0024_merge_0023_merge_20230512_1333_0023_projectreimport'),
('tasks', '0040_auto_20230628_1101'),
]

operations = [
migrations.AddField(
model_name='prediction',
name='project',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='predictions', to='projects.project'),
),

]
22 changes: 22 additions & 0 deletions label_studio/tasks/migrations/0042_auto_20230810_2304.py
@@ -0,0 +1,22 @@
# Generated by Django 3.2.19 on 2023-08-10 23:04

from django.db import migrations
from tasks.functions import fill_predictions_project

def forward(apps, schema_editor):
fill_predictions_project()


def backwards(apps, schema_editor):
pass


class Migration(migrations.Migration):
atomic = False
dependencies = [
('tasks', '0041_prediction_project'),
]

operations = [
migrations.RunPython(forward, backwards),
]
5 changes: 5 additions & 0 deletions label_studio/tasks/models.py
Expand Up @@ -865,6 +865,7 @@ class Prediction(models.Model):
task = models.ForeignKey(
"tasks.Task", on_delete=models.CASCADE, related_name="predictions"
)
project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='predictions', null=True)
created_at = models.DateTimeField(_("created at"), auto_now_add=True)
updated_at = models.DateTimeField(_("updated at"), auto_now=True)

Expand Down Expand Up @@ -941,6 +942,10 @@ def update_task(self):
self.task.save(update_fields=update_fields)

def save(self, *args, **kwargs):
if self.project_id is None and self.task_id:
logger.warning('project_id is not set for prediction, project_id being set in save method')
self.project_id = Task.objects.only('project_id').get(pk=self.task_id).project_id

# "result" data can come in different forms - normalize them to JSON
self.result = self.prepare_prediction_result(self.result, self.task.project)
# set updated_at field of task to now()
Expand Down
3 changes: 2 additions & 1 deletion label_studio/tests/test_export.py
Expand Up @@ -142,7 +142,8 @@ def test_export_with_predictions(
Annotation.objects.create(task=task, result=result, completed_by=business_client.admin)
if predictions:
for task in tasks:
Prediction.objects.create(task=task, result=predictions['result'], score=predictions['score'])
Prediction.objects.create(task=task, project=task.project, result=predictions['result'],
score=predictions['score'])

r = business_client.get(f'/api/projects/{configured_project.id}/results/', data={
'finished': finished,
Expand Down
4 changes: 3 additions & 1 deletion label_studio/tests/test_next_task.py
Expand Up @@ -449,7 +449,9 @@ def count(self):

for task, prediction, annotation in zip(tasks, predictions, annotations):
task = make_task(task, project)
Prediction.objects.create(task=task, model_version=project.model_version, **prediction)
Prediction.objects.create(
task=task, project=task.project, model_version=project.model_version, **prediction
)
if annotation is not None:
completed_by = any_client.annotator if num_annotators == 1 else annotator2_client.annotator
Annotation.objects.create(task=task, completed_by=completed_by, project=project, **annotation)
Expand Down
2 changes: 1 addition & 1 deletion label_studio/tests/test_predictions.py
Expand Up @@ -653,7 +653,7 @@ def test_predictions_with_partially_predicted_tasks(
if annotation is not None:
Annotation.objects.create(task=task_obj, **annotation)
if prediction is not None:
Prediction.objects.create(task=task_obj, **prediction)
Prediction.objects.create(task=task_obj, project=task_obj.project, **prediction)

# run prediction
with requests_mock.Mocker() as m:
Expand Down
6 changes: 3 additions & 3 deletions label_studio/tests/utils.py
Expand Up @@ -246,9 +246,9 @@ def make_annotation(config, task_id):


def make_prediction(config, task_id):
from tasks.models import Prediction

return Prediction.objects.create(task_id=task_id, **config)
from tasks.models import Prediction, Task
task = Task.objects.get(pk=task_id)
return Prediction.objects.create(task_id=task_id, project=task.project, **config)


def make_annotator(config, project, login=False, client=None):
Expand Down