From 4e8bd338a8581097166d33d05c5630f2c4dd04d4 Mon Sep 17 00:00:00 2001 From: "xianzelin@huoxian.cn" Date: Wed, 29 Dec 2021 15:20:42 +0800 Subject: [PATCH] feature:#355:add api to batch modify hook rule type --- iast/views/engine_hook_rule_status.py | 42 +++++++++++++++++++++------ 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/iast/views/engine_hook_rule_status.py b/iast/views/engine_hook_rule_status.py index 5c68d13a..de4fdc3c 100644 --- a/iast/views/engine_hook_rule_status.py +++ b/iast/views/engine_hook_rule_status.py @@ -11,19 +11,28 @@ from django.utils.translation import gettext_lazy as _ from iast.utils import extend_schema_with_envcheck, get_response_serializer from rest_framework import serializers +from dongtai.models.hook_type import HookType logger = logging.getLogger('dongtai-webapi') OP_CHOICES = ('enable', 'disable', 'delete') -#SCOPE_CHOICES = ('all',) +SCOPE_CHOICES = ('all',) class EngineHookRuleStatusGetQuerySerializer(serializers.Serializer): - rule_id = serializers.IntegerField(help_text=_("The id of hook rule")) - # rule_type = serializers.IntegerField(help_text=_("The id of hook rule type")) + rule_id = serializers.IntegerField(required=False, + help_text=_("The id of hook rule")) + type = serializers.IntegerField( + required=False, help_text=_("The id of hook rule type")) op = serializers.ChoiceField(OP_CHOICES, + required=False, help_text=_("The state of the hook rule")) -# scope = serializers.ChoiceField(SCOPE_CHOICES, -# help_text=_("The scope of the hook rule")) + scope = serializers.ChoiceField(SCOPE_CHOICES, + required=False, + help_text=_("The scope of the hook rule")) + language_id = serializers.IntegerField(required=False, + help_text=_("The language_id")) + hook_rule_type = serializers.IntegerField( + required=False, help_text=_("The type of hook rule")) class EngineHookRuleStatusPostBodySerializer(serializers.Serializer): @@ -51,7 +60,8 @@ def parse_args(self, request): rule_type = request.query_params.get('type') scope = request.query_params.get('scope') op = request.query_params.get('op') - return rule_id, rule_type, scope, op + return rule_id, rule_type, scope, op, request.query_params.get( + 'language_id'), request.query_params.get('hook_rule_type') @staticmethod def set_strategy_status(strategy_id, strategy_ids, user_id, enable_status): @@ -86,9 +96,13 @@ def check_op(op): response_schema=_GetResponseSerializer, ) def get(self, request): - rule_id, rule_type, scope, op = self.parse_args(request) + rule_id, rule_type, scope, op, hook_rule_type, language_id = self.parse_args( + request) try: - rule_id = int(rule_id) + if rule_id: + rule_id = int(rule_id) + if rule_type: + rule_type = int(rule_type) except: return R.failure(_("Parameter error")) user_id = request.user.id @@ -97,11 +111,21 @@ def get(self, request): op = self.check_op(op) if op is None: return R.failure(msg=_('Operation type does not exist')) - if rule_type is not None and scope == 'all': count = HookStrategy.objects.filter(type__id=rule_type, created_by=user_id).update(enable=op) logger.info(_('Policy type {} operation success, total of {} Policy types').format(rule_type, count)) status = True + if hook_rule_type is not None and language_id is not None and scope == 'all': + users = self.get_auth_users(request.user) + user_ids = (user.id for user in users) + hook_type_ids = HookType.objects.filter( + language_id=language_id, + type=hook_rule_type).values_list('id', flat=True).all() + count = HookStrategy.objects.filter( + type__id__in=hook_type_ids, + created_by__in=user_ids).update(enable=op) + logger.info(_('total of {} Policy types').format(count)) + status = True elif rule_id is not None: status = self.set_strategy_status(strategy_id=rule_id, strategy_ids=None, user_id=user_id, enable_status=op)