Skip to content

Commit

Permalink
Updated NLP
Browse files Browse the repository at this point in the history
  • Loading branch information
dyohan9 committed Dec 16, 2019
1 parent 14e5bef commit 7036997
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 91 deletions.
51 changes: 31 additions & 20 deletions bothub/api/v2/nlp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,20 @@ class RepositoryAuthorizationTrainViewSet(
def retrieve(self, request, *args, **kwargs):
check_auth(request)
repository_authorization = self.get_object()
current_version = repository_authorization.repository.current_version(
str(request.query_params.get("language"))
)
repository_version = request.query_params.get("repository_version")
if repository_version:
current_version = repository_authorization.repository.get_specific_version_id(
repository_version, str(request.query_params.get("language"))
)
else:
current_version = repository_authorization.repository.current_version(
str(request.query_params.get("language"))
)

return Response(
{
"ready_for_train": current_version.ready_for_train,
"current_update_id": current_version.id,
"current_version_id": current_version.id,
"repository_authorization_user_id": repository_authorization.user.id,
"language": current_version.language,
}
Expand All @@ -70,7 +76,7 @@ def retrieve(self, request, *args, **kwargs):
def get_examples(self, request, **kwargs):
check_auth(request)
queryset = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)

page = self.paginate_queryset(queryset.examples)
Expand All @@ -87,7 +93,7 @@ def get_examples(self, request, **kwargs):
def get_examples_labels(self, request, **kwargs):
check_auth(request)
queryset = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)

page = self.paginate_queryset(
Expand All @@ -107,7 +113,7 @@ def start_training(self, request, **kwargs):
check_auth(request)

repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

repository.start_training(
Expand All @@ -117,7 +123,7 @@ def start_training(self, request, **kwargs):
return Response(
{
"language": repository.language,
"update_id": repository.id,
"repository_version": repository.id,
"repository_uuid": str(repository.repository_version.repository.uuid),
"intent": repository.intents,
"algorithm": repository.algorithm,
Expand All @@ -141,7 +147,7 @@ def get_entities_and_labels(self, request, **kwargs):
try:
examples = request.data.get("examples")
label_examples_query = request.data.get("label_examples_query")
update_id = request.data.get("update_id")
update_id = request.data.get("repository_version")
except ValueError:
raise exceptions.NotFound()

Expand Down Expand Up @@ -206,7 +212,7 @@ def get_entities_and_labels(self, request, **kwargs):
def train_fail(self, request, **kwargs):
check_auth(request)
repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)
repository.train_fail()
return Response({})
Expand All @@ -215,7 +221,7 @@ def train_fail(self, request, **kwargs):
def training_log(self, request, **kwargs):
check_auth(request)
repository = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)
repository.training_log = request.data.get("training_log")
repository.save(update_fields=["training_log"])
Expand All @@ -233,16 +239,21 @@ def retrieve(self, request, *args, **kwargs):
repository = repository_authorization.repository

language = request.query_params.get("language")
repository_version = request.query_params.get("repository_version")

if language == "None" or language is None:
language = str(repository.language)

update = repository.last_trained_update(language)
if repository_version:
update = repository.get_specific_version_id(repository_version, language)
else:
update = repository.last_trained_update(language)

try:
return Response(
{
"update": False if update is None else True,
"update_id": update.id,
"version": False if update is None else True,
"repository_version": update.id,
"language": update.language,
}
)
Expand All @@ -253,7 +264,7 @@ def retrieve(self, request, *args, **kwargs):
def repository_entity(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)
repository_entity = get_object_or_404(
RepositoryEntity,
Expand Down Expand Up @@ -309,7 +320,7 @@ def retrieve(self, request, *args, **kwargs):
def evaluations(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.query_params.get("update_id")
RepositoryVersionLanguage, pk=request.query_params.get("repository_version")
)
evaluations = repository_update.repository_version.repository.evaluations(
language=repository_update.language
Expand Down Expand Up @@ -345,7 +356,7 @@ def evaluations(self, request, **kwargs):
def evaluate_results(self, request, **kwargs):
check_auth(request)
repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

intents_score = RepositoryEvaluateResultScore.objects.create(
Expand Down Expand Up @@ -418,7 +429,7 @@ def evaluate_results_score(self, request, **kwargs):
)

repository_update = get_object_or_404(
RepositoryVersionLanguage, pk=request.data.get("update_id")
RepositoryVersionLanguage, pk=request.data.get("repository_version")
)

entity_score = RepositoryEvaluateResultScore.objects.create(
Expand Down Expand Up @@ -484,8 +495,8 @@ def retrieve(self, request, *args, **kwargs):

return Response(
{
"update_id": update.id,
"repository_uuid": update.repository.uuid,
"version_id": update.id,
"repository_uuid": update.repository_version.repository.uuid,
"bot_data": str(bot_data),
"from_aws": aws,
}
Expand Down
5 changes: 5 additions & 0 deletions bothub/api/v2/repository/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ def create(self, validated_data):
class AnalyzeTextSerializer(serializers.Serializer):
language = serializers.ChoiceField(LANGUAGE_CHOICES, required=True)
text = serializers.CharField(allow_blank=False)
repository_version = serializers.IntegerField(required=False)


class TrainSerializer(serializers.Serializer):
repository_version = serializers.IntegerField(required=False)


class EvaluateSerializer(serializers.Serializer):
Expand Down
12 changes: 7 additions & 5 deletions bothub/api/v2/repository/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .permissions import RepositoryAdminManagerAuthorization
from .permissions import RepositoryExamplePermission
from .permissions import RepositoryPermission
from .serializers import AnalyzeTextSerializer
from .serializers import AnalyzeTextSerializer, TrainSerializer
from .serializers import EvaluateSerializer
from .serializers import RepositoryAuthorizationRoleSerializer
from .serializers import RepositoryAuthorizationSerializer
Expand Down Expand Up @@ -87,21 +87,23 @@ def languagesstatus(self, request, **kwargs):

@action(
detail=True,
methods=["GET"],
methods=["POST"],
url_name="repository-train",
lookup_fields=["uuid"],
)
def train(self, request, **kwargs):
"""
Train current update using Bothub NLP service
"""
if self.lookup_field not in kwargs:
return Response({}, status=403)
repository = self.get_object()
user_authorization = repository.get_user_authorization(request.user)
serializer = TrainSerializer(data=request.data) # pragma: no cover
serializer.is_valid(raise_exception=True) # pragma: no cover
if not user_authorization.can_write:
raise PermissionDenied()
request = repository.request_nlp_train(user_authorization) # pragma: no cover
request = repository.request_nlp_train(
user_authorization, serializer.data
) # pragma: no cover
if request.status_code != status.HTTP_200_OK: # pragma: no cover
raise APIException( # pragma: no cover
{"status_code": request.status_code}, code=request.status_code
Expand Down
4 changes: 2 additions & 2 deletions bothub/api/v2/tests/test_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def request(self, token):
"/v2/repository/nlp/authorization/train/start_training/",
json.dumps(
{
"update_id": self.repository_version_language.pk,
"repository_version": self.repository_version_language.pk,
"by_user": self.user.pk,
}
),
Expand Down Expand Up @@ -108,7 +108,7 @@ def request(self, token):
authorization_header = {"HTTP_AUTHORIZATION": "Bearer {}".format(token)}
request = self.factory.post(
"/v2/repository/nlp/authorization/train/train_fail/",
json.dumps({"update_id": self.repository_version_language.pk}),
json.dumps({"repository_version": self.repository_version_language.pk}),
content_type="application/json",
**authorization_header
)
Expand Down
11 changes: 7 additions & 4 deletions bothub/api/v2/tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,19 +1708,22 @@ def setUp(self):
language=languages.LANGUAGE_EN,
)

def request(self, repository, token):
def request(self, repository, token, data):
authorization_header = {"HTTP_AUTHORIZATION": "Token {}".format(token.key)}
request = self.factory.get(
request = self.factory.post(
"/v2/repository/repository-info/{}/train/".format(str(repository.uuid)),
data,
**authorization_header,
)
response = RepositoryViewSet.as_view({"get": "train"})(request)
response = RepositoryViewSet.as_view({"post": "train"})(
request, uuid=repository.uuid
)
response.render()
content_data = json.loads(response.content)
return (response, content_data)

def test_permission_denied(self):
response, content_data = self.request(self.repository, self.user_token)
response, content_data = self.request(self.repository, self.user_token, {})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)


Expand Down
5 changes: 5 additions & 0 deletions bothub/common/management/commands/fill_db_using_fake_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def handle(self, *args, **kwargs):
repository_1.categories.add(categories[1])
repository_1.categories.add(categories[3])

repository_1.current_version()

repository_2 = Repository.objects.create(
owner=user,
name="Repository 2",
Expand All @@ -67,6 +69,8 @@ def handle(self, *args, **kwargs):
repository_2.categories.add(categories[0])
repository_2.categories.add(categories[2])

repository_2.current_version()

for x in range(3, 46):
new_repository = Repository.objects.create(
owner=user,
Expand All @@ -75,6 +79,7 @@ def handle(self, *args, **kwargs):
language=languages.LANGUAGE_EN,
)
new_repository.categories.add(random.choice(categories))
new_repository.current_version()

# Examples

Expand Down
Loading

0 comments on commit 7036997

Please sign in to comment.