Skip to content

Commit

Permalink
Merge pull request #329 from Ilhasoft/feature/update_nlp
Browse files Browse the repository at this point in the history
Feature/update nlp
  • Loading branch information
dyohan9 committed Dec 17, 2019
2 parents 14e5bef + cd5d3bf commit db24e72
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 137 deletions.
68 changes: 44 additions & 24 deletions bothub/api/v2/nlp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,20 @@ class RepositoryAuthorizationTrainViewSet(
def retrieve(self, request, *args, **kwargs):
check_auth(request)
repository_authorization = self.get_object()
current_version = repository_authorization.repository.current_version(
str(request.query_params.get("language"))
)
repository_version = request.query_params.get("repository_version")
if repository_version:
current_version = repository_authorization.repository.get_specific_version_id(
repository_version, str(request.query_params.get("language"))
)
else:
current_version = repository_authorization.repository.current_version(
str(request.query_params.get("language"))
)

return Response(
{
"ready_for_train": current_version.ready_for_train,
"current_update_id": current_version.id,
"current_version_id": current_version.id,
"repository_authorization_user_id": repository_authorization.user.id,
"language": current_version.language,
}
Expand All @@ -70,7 +76,7 @@ def retrieve(self, request, *args, **kwargs):
def get_examples(self, request, **kwargs):
check_auth(request)
queryset = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)

page = self.paginate_queryset(queryset.examples)
Expand All @@ -87,7 +93,7 @@ def get_examples(self, request, **kwargs):
def get_examples_labels(self, request, **kwargs):
check_auth(request)
queryset = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)

page = self.paginate_queryset(
Expand All @@ -107,7 +113,7 @@ def start_training(self, request, **kwargs):
check_auth(request)

repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

repository.start_training(
Expand All @@ -117,7 +123,7 @@ def start_training(self, request, **kwargs):
return Response(
{
"language": repository.language,
"update_id": repository.id,
"repository_version": repository.id,
"repository_uuid": str(repository.repository_version.repository.uuid),
"intent": repository.intents,
"algorithm": repository.algorithm,
Expand All @@ -141,7 +147,7 @@ def get_entities_and_labels(self, request, **kwargs):
try:
examples = request.data.get("examples")
label_examples_query = request.data.get("label_examples_query")
update_id = request.data.get("update_id")
update_id = request.data.get("repository_version")
except ValueError:
raise exceptions.NotFound()

Expand Down Expand Up @@ -206,7 +212,7 @@ def get_entities_and_labels(self, request, **kwargs):
def train_fail(self, request, **kwargs):
check_auth(request)
repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)
repository.train_fail()
return Response({})
Expand All @@ -215,7 +221,7 @@ def train_fail(self, request, **kwargs):
def training_log(self, request, **kwargs):
check_auth(request)
repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)
repository.training_log = request.data.get("training_log")
repository.save(update_fields=["training_log"])
Expand All @@ -233,16 +239,21 @@ def retrieve(self, request, *args, **kwargs):
repository = repository_authorization.repository

language = request.query_params.get("language")
repository_version = request.query_params.get("repository_version")

if language == "None" or language is None:
language = str(repository.language)

update = repository.last_trained_update(language)
if repository_version:
update = repository.get_specific_version_id(repository_version, language)
else:
update = repository.last_trained_update(language)

try:
return Response(
{
"update": False if update is None else True,
"update_id": update.id,
"version": False if update is None else True,
"repository_version": update.id,
"language": update.language,
}
)
Expand All @@ -253,7 +264,7 @@ def retrieve(self, request, *args, **kwargs):
def repository_entity(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)
repository_entity = get_object_or_404(
RepositoryEntity,
Expand Down Expand Up @@ -293,13 +304,22 @@ def retrieve(self, request, *args, **kwargs):
check_auth(request)
repository_authorization = self.get_object()
repository = repository_authorization.repository
update = repository.last_trained_update(
str(request.query_params.get("language"))
)

repository_version = request.query_params.get("repository_version")

if repository_version:
update = repository.get_specific_version_id(
repository_version, str(request.query_params.get("language"))
)
else:
update = repository.last_trained_update(
str(request.query_params.get("language"))
)

return Response(
{
"update": False if update is None else True,
"update_id": update.id,
"repository_version": update.id,
"language": update.language,
"user_id": repository_authorization.user.id,
}
Expand All @@ -309,7 +329,7 @@ def retrieve(self, request, *args, **kwargs):
def evaluations(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)
evaluations = repository_update.repository_version.repository.evaluations(
language=repository_update.language
Expand Down Expand Up @@ -345,7 +365,7 @@ def evaluations(self, request, **kwargs):
def evaluate_results(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

intents_score = RepositoryEvaluateResultScore.objects.create(
Expand Down Expand Up @@ -418,7 +438,7 @@ def evaluate_results_score(self, request, **kwargs):
)

repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

entity_score = RepositoryEvaluateResultScore.objects.create(
Expand Down Expand Up @@ -484,8 +504,8 @@ def retrieve(self, request, *args, **kwargs):

return Response(
{
"update_id": update.id,
"repository_uuid": update.repository.uuid,
"version_id": update.id,
"repository_uuid": update.repository_version.repository.uuid,
"bot_data": str(bot_data),
"from_aws": aws,
}
Expand Down
6 changes: 6 additions & 0 deletions bothub/api/v2/repository/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,16 @@ def create(self, validated_data):
class AnalyzeTextSerializer(serializers.Serializer):
language = serializers.ChoiceField(LANGUAGE_CHOICES, required=True)
text = serializers.CharField(allow_blank=False)
repository_version = serializers.IntegerField(required=False)


class TrainSerializer(serializers.Serializer):
repository_version = serializers.IntegerField(required=False)


class EvaluateSerializer(serializers.Serializer):
language = serializers.ChoiceField(LANGUAGE_CHOICES, required=True)
repository_version = serializers.IntegerField(required=False)


class RepositoryUpdateSerializer(serializers.ModelSerializer):
Expand Down
12 changes: 7 additions & 5 deletions bothub/api/v2/repository/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .permissions import RepositoryAdminManagerAuthorization
from .permissions import RepositoryExamplePermission
from .permissions import RepositoryPermission
from .serializers import AnalyzeTextSerializer
from .serializers import AnalyzeTextSerializer, TrainSerializer
from .serializers import EvaluateSerializer
from .serializers import RepositoryAuthorizationRoleSerializer
from .serializers import RepositoryAuthorizationSerializer
Expand Down Expand Up @@ -87,21 +87,23 @@ def languagesstatus(self, request, **kwargs):

@action(
detail=True,
methods=["GET"],
methods=["POST"],
url_name="repository-train",
lookup_fields=["uuid"],
)
def train(self, request, **kwargs):
"""
Train current update using Bothub NLP service
"""
if self.lookup_field not in kwargs:
return Response({}, status=403)
repository = self.get_object()
user_authorization = repository.get_user_authorization(request.user)
serializer = TrainSerializer(data=request.data) # pragma: no cover
serializer.is_valid(raise_exception=True) # pragma: no cover
if not user_authorization.can_write:
raise PermissionDenied()
request = repository.request_nlp_train(user_authorization) # pragma: no cover
request = repository.request_nlp_train(
user_authorization, serializer.data
) # pragma: no cover
if request.status_code != status.HTTP_200_OK: # pragma: no cover
raise APIException( # pragma: no cover
{"status_code": request.status_code}, code=request.status_code
Expand Down
4 changes: 2 additions & 2 deletions bothub/api/v2/tests/test_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def request(self, token):
"/v2/repository/nlp/authorization/train/start_training/",
json.dumps(
{
"update_id": self.repository_version_language.pk,
"repository_version": self.repository_version_language.pk,
"by_user": self.user.pk,
}
),
Expand Down Expand Up @@ -108,7 +108,7 @@ def request(self, token):
authorization_header = {"HTTP_AUTHORIZATION": "Bearer {}".format(token)}
request = self.factory.post(
"/v2/repository/nlp/authorization/train/train_fail/",
json.dumps({"update_id": self.repository_version_language.pk}),
json.dumps({"repository_version": self.repository_version_language.pk}),
content_type="application/json",
**authorization_header
)
Expand Down
11 changes: 7 additions & 4 deletions bothub/api/v2/tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,19 +1708,22 @@ def setUp(self):
language=languages.LANGUAGE_EN,
)

def request(self, repository, token):
def request(self, repository, token, data):
authorization_header = {"HTTP_AUTHORIZATION": "Token {}".format(token.key)}
request = self.factory.get(
request = self.factory.post(
"/v2/repository/repository-info/{}/train/".format(str(repository.uuid)),
data,
**authorization_header,
)
response = RepositoryViewSet.as_view({"get": "train"})(request)
response = RepositoryViewSet.as_view({"post": "train"})(
request, uuid=repository.uuid
)
response.render()
content_data = json.loads(response.content)
return (response, content_data)

def test_permission_denied(self):
response, content_data = self.request(self.repository, self.user_token)
response, content_data = self.request(self.repository, self.user_token, {})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)


Expand Down
5 changes: 5 additions & 0 deletions bothub/common/management/commands/fill_db_using_fake_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def handle(self, *args, **kwargs):
repository_1.categories.add(categories[1])
repository_1.categories.add(categories[3])

repository_1.current_version()

repository_2 = Repository.objects.create(
owner=user,
name="Repository 2",
Expand All @@ -67,6 +69,8 @@ def handle(self, *args, **kwargs):
repository_2.categories.add(categories[0])
repository_2.categories.add(categories[2])

repository_2.current_version()

for x in range(3, 46):
new_repository = Repository.objects.create(
owner=user,
Expand All @@ -75,6 +79,7 @@ def handle(self, *args, **kwargs):
language=languages.LANGUAGE_EN,
)
new_repository.categories.add(random.choice(categories))
new_repository.current_version()

# Examples

Expand Down
2 changes: 2 additions & 0 deletions bothub/common/migrations/0041_delete_examples_isdeleted.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ def noop(apps, schema_editor): # pragma: no cover

def delete_examples_already_deleted(apps, schema_editor): # pragma: no cover
RepositoryExample = apps.get_model("common", "RepositoryExample")
RepositoryEvaluate = apps.get_model("common", "RepositoryEvaluate")
RepositoryExample.objects.filter(deleted_in__isnull=False).delete()
RepositoryEvaluate.objects.filter(deleted_in__isnull=False).delete()


class Migration(migrations.Migration):
Expand Down
Loading

0 comments on commit db24e72

Please sign in to comment.