Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-25 11:28

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('application', '0003_application_icon'),
]

operations = [
migrations.AddField(
model_name='applicationaccesstoken',
name='show_source',
field=models.BooleanField(default=False, verbose_name='是否显示知识来源'),
),
]
1 change: 1 addition & 0 deletions apps/application/models/api_key_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ApplicationAccessToken(AppModelMixin):
white_list = ArrayField(verbose_name="白名单列表",
base_field=models.CharField(max_length=128, blank=True)
, default=list)
show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源")

class Meta:
db_table = "application_access_token"
Expand Down
17 changes: 13 additions & 4 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
Expand Down Expand Up @@ -170,7 +170,9 @@ class AccessTokenEditSerializer(serializers.Serializer):
white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True,
error_messages=ErrMessage.char(
"白名单")),
error_messages=ErrMessage.list("白名单列表"))
error_messages=ErrMessage.list("白名单列表")),
show_source = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("是否显示知识来源"))

def edit(self, instance: Dict, with_valid=True):
if with_valid:
Expand All @@ -190,6 +192,8 @@ def edit(self, instance: Dict, with_valid=True):
application_access_token.white_active = instance.get("white_active")
if 'white_list' in instance and instance.get('white_list') is not None:
application_access_token.white_list = instance.get('white_list')
if 'show_source' in instance and instance.get('show_source') is not None:
application_access_token.show_source = instance.get('show_source')
application_access_token.save()
return self.one(with_valid=False)

Expand All @@ -210,7 +214,8 @@ def one(self, with_valid=True):
"is_active": application_access_token.is_active,
'access_num': application_access_token.access_num,
'white_active': application_access_token.white_active,
'white_list': application_access_token.white_list
'white_list': application_access_token.white_list,
'show_source': application_access_token.show_source
}

class Authentication(serializers.Serializer):
Expand Down Expand Up @@ -474,8 +479,12 @@ def profile(self, with_valid=True):
self.is_valid()
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first()
if application_access_token is None:
raise AppUnauthorizedFailed(500, "非法用户")
return ApplicationSerializer.Query.reset_application(
ApplicationSerializer.ApplicationModel(application).data)
{**ApplicationSerializer.ApplicationModel(application).data,
'show_source': application_access_token.show_source})

def edit(self, instance: Dict, with_valid=True):
if with_valid:
Expand Down
21 changes: 16 additions & 5 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from rest_framework import serializers

from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
Expand Down Expand Up @@ -277,17 +278,27 @@ class Meta:
class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))

application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token is None:
raise AppApiException(500, '不存在的应用认证信息')
if not application_access_token.show_source:
raise AppApiException(500, '未开启显示知识来源')

def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()

def one(self, with_valid=True):
Expand Down
2 changes: 2 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def get_request_body_api():
'white_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表",
description="白名单列表"),
'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源",
description="是否显示知识来源"),
}
)

Expand Down
3 changes: 2 additions & 1 deletion apps/application/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ class Operate(APIView):
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
Expand Down
25 changes: 15 additions & 10 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<img src="@/assets/icon_robot.svg" style="width: 75%" alt="" />
</AppAvatar>
</div>

<div class="content">
<div class="flex" v-if="!item.answer_text">
<el-card
Expand All @@ -72,7 +73,9 @@

<el-card v-else shadow="always" class="dialog-card">
<MdRenderer :source="item.answer_text"></MdRenderer>
<div v-if="(id && !props.appId && item.write_ed) || log">
<div
v-if="(id && item.write_ed) || (props.data?.show_source && item.write_ed) || log"
>
<el-divider> <el-text type="info">知识来源</el-text> </el-divider>
<div class="mb-8">
<el-space wrap>
Expand Down Expand Up @@ -271,7 +274,7 @@ function openParagraph(row: any, id?: string) {
}

function quickProblemHandle(val: string) {
if (!props.log && !loading.value) {
if (!props.log && !loading.value && props.data?.name && props.data?.model_id) {
// inputValue.value = val
// nextTick(() => {
// quickInputRef.value?.focus()
Expand Down Expand Up @@ -488,7 +491,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
}
})
.then(() => {
return !props.appId && getSourceDetail(chat)
return (id || props.data?.show_source) && getSourceDetail(chat)
})
.finally(() => {
ChatManagement.close(chat.id)
Expand All @@ -505,14 +508,16 @@ function regenerationChart(item: chatType) {
}

function getSourceDetail(row: any) {
logApi.getRecordDetail(id, chartOpenId.value, row.record_id, loading).then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
logApi
.getRecordDetail(id || props.appId, chartOpenId.value, row.record_id, loading)
.then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
})
})
return true
}

Expand Down
7 changes: 7 additions & 0 deletions ui/src/views/application-overview/component/LimitDialog.vue
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
<template>
<el-dialog title="访问限制" v-model="dialogVisible">
<el-form label-position="top" ref="limitFormRef" :model="form">
<el-form-item label="显示知识来源" @click.prevent>
<el-switch size="small" v-model="form.show_source"></el-switch>
</el-form-item>
<el-form-item label="客户端提问限制">
<el-input-number
v-model="form.access_num"
Expand Down Expand Up @@ -51,6 +54,7 @@ const emit = defineEmits(['refresh'])

const limitFormRef = ref()
const form = ref<any>({
show_source: false,
access_num: 0,
white_active: true,
white_list: ''
Expand All @@ -62,6 +66,7 @@ const loading = ref(false)
watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
show_source: false,
access_num: 0,
white_active: true,
white_list: ''
Expand All @@ -70,6 +75,7 @@ watch(dialogVisible, (bool) => {
})

const open = (data: any) => {
form.value.show_source = data.show_source
form.value.access_num = data.access_num
form.value.white_active = data.white_active
form.value.white_list = data.white_list?.length ? data.white_list?.join('\n') : ''
Expand All @@ -81,6 +87,7 @@ const submit = async (formEl: FormInstance | undefined) => {
await formEl.validate((valid, fields) => {
if (valid) {
const obj = {
show_source: form.value.show_source,
white_list: form.value.white_list ? form.value.white_list.split('\n') : [],
white_active: form.value.white_active,
access_num: form.value.access_num
Expand Down