From 487485fbee2c9403b4f7df5713360aeb69f70a0d Mon Sep 17 00:00:00 2001 From: NarwhalChen Date: Sun, 5 Jan 2025 18:53:45 -0800 Subject: [PATCH 1/5] feat: adding chat isolation --- .../__tests__/test.chat-isolation.spec.ts | 97 +++++++++++++++++++ backend/src/chat/chat.model.ts | 4 +- backend/src/chat/chat.module.ts | 3 +- backend/src/chat/chat.resolver.ts | 10 -- backend/src/chat/chat.service.ts | 62 ++++++------ backend/src/chat/message.model.ts | 8 +- backend/src/common/model-provider/index.ts | 19 ++-- backend/src/common/model-provider/types.ts | 18 ++++ backend/src/main.ts | 2 +- llm-server/package.json | 2 +- llm-server/src/llm-provider.ts | 5 +- llm-server/src/main.ts | 17 +--- llm-server/src/model/llama-model-provider.ts | 8 +- llm-server/src/model/openai-model-provider.ts | 13 ++- llm-server/src/types.ts | 8 +- 15 files changed, 178 insertions(+), 98 deletions(-) create mode 100644 backend/src/chat/__tests__/test.chat-isolation.spec.ts diff --git a/backend/src/chat/__tests__/test.chat-isolation.spec.ts b/backend/src/chat/__tests__/test.chat-isolation.spec.ts new file mode 100644 index 00000000..fd4edafd --- /dev/null +++ b/backend/src/chat/__tests__/test.chat-isolation.spec.ts @@ -0,0 +1,97 @@ +// chat.service.spec.ts +import { Test, TestingModule } from '@nestjs/testing'; +import { ChatService } from '../chat.service'; +import { getRepositoryToken } from '@nestjs/typeorm'; +import { Chat } from '../chat.model'; +import { User } from 'src/user/user.model'; +import { Message, MessageRole } from 'src/chat/message.model'; +import { Repository } from 'typeorm'; +import { TypeOrmModule } from '@nestjs/typeorm'; +import { UserResolver } from 'src/user/user.resolver'; +import { AuthService } from 'src/auth/auth.service'; +import { UserService } from 'src/user/user.service'; +import { JwtService } from '@nestjs/jwt'; +import { JwtCacheService } from 'src/auth/jwt-cache.service'; +import { ConfigService } from '@nestjs/config'; +import { Menu } from 'src/auth/menu/menu.model'; +import { Role } from 'src/auth/role/role.model'; +import { RegisterUserInput } from 'src/user/dto/register-user.input'; +import { NewChatInput } from '../dto/chat.input'; +import { ModelProvider} from 'src/common/model-provider'; +import { HttpService } from '@nestjs/axios'; +import { MessageInterface } from 'src/common/model-provider/types'; + +describe('ChatService', () => { + let chatService: ChatService; + let userResolver: UserResolver; + let userService: UserService; + let mockedChatService: jest.Mocked>; + let modelProvider: ModelProvider; + let user: User; + let userid='1'; + + beforeAll(async()=>{ + const module: TestingModule = await Test.createTestingModule({ + imports:[ + TypeOrmModule.forRoot({ + type: 'sqlite', + database: '../../database.sqlite', + synchronize: true, + entities: ['../../' + '/**/*.model{.ts,.js}'], + }), + TypeOrmModule.forFeature([Chat, User, Menu, Role]), + ], + providers: [ + Repository, + ChatService, + AuthService, + UserService, + UserResolver, + JwtService, + JwtCacheService, + ConfigService, + ] + }).compile(); + chatService = module.get(ChatService); + userService = module.get(UserService); + userResolver = module.get(UserResolver); + + modelProvider = ModelProvider.getInstance(); + mockedChatService = module.get(getRepositoryToken(Chat)); + }) + it('should excute curd in chat service', async() => { + + try{ + user = await userResolver.registerUser({ + username: 'testuser', + password: 'securepassword', + email: 'testuser@example.com', + } as RegisterUserInput); + userid = user.id; + }catch(error){ + + } + const chat= await chatService.createChat(userid, {title: 'test'} as NewChatInput); + let chatId = chat.id; + console.log(await chatService.getChatHistory(chatId)); + + console.log(await chatService.saveMessage(chatId, 'Hello, this is a test message.', MessageRole.User)); + console.log(await chatService.saveMessage(chatId, 'Hello, hello, im gpt.', MessageRole.Model)); + + console.log(await chatService.saveMessage(chatId, 'write me the system prompt', MessageRole.User)); + + let history = await chatService.getChatHistory(chatId); + let messages = history.map((message) => { + return { + role: message.role, + content: message.content + } as MessageInterface; + }) + console.log(history); + console.log( + await modelProvider.chatSync({ + model: 'gpt-4o', + messages + })); + }) +}); \ No newline at end of file diff --git a/backend/src/chat/chat.model.ts b/backend/src/chat/chat.model.ts index 5fce9e47..a5e2fe82 100644 --- a/backend/src/chat/chat.model.ts +++ b/backend/src/chat/chat.model.ts @@ -38,8 +38,8 @@ export class Chat extends SystemBaseModel { @Column({ nullable: true }) title: string; - @Field(() => [Message], { nullable: true }) - @OneToMany(() => Message, (message) => message.chat, { cascade: true }) + @Field({ nullable: true }) + @Column('simple-json', { nullable: true, default: '[]' }) messages: Message[]; @ManyToOne(() => User, (user) => user.chats) diff --git a/backend/src/chat/chat.module.ts b/backend/src/chat/chat.module.ts index 149d3bfe..c2068bb3 100644 --- a/backend/src/chat/chat.module.ts +++ b/backend/src/chat/chat.module.ts @@ -12,6 +12,7 @@ import { ChatGuard } from '../guard/chat.guard'; import { AuthModule } from '../auth/auth.module'; import { UserService } from 'src/user/user.service'; import { PubSub } from 'graphql-subscriptions'; +import { ModelProvider } from 'src/common/model-provider'; @Module({ imports: [ @@ -30,6 +31,6 @@ import { PubSub } from 'graphql-subscriptions'; useValue: new PubSub(), }, ], - exports: [ChatService, ChatGuard], + exports: [ChatService, ChatGuard, ModelProvider], }) export class ChatModule {} diff --git a/backend/src/chat/chat.resolver.ts b/backend/src/chat/chat.resolver.ts index 57fc8bfd..e9e9a0fa 100644 --- a/backend/src/chat/chat.resolver.ts +++ b/backend/src/chat/chat.resolver.ts @@ -103,16 +103,6 @@ export class ChatResolver { const user = await this.userService.getUserChats(userId); return user ? user.chats : []; } - - @JWTAuth() - @Query(() => Message, { nullable: true }) - async getMessageDetail( - @GetUserIdFromToken() userId: string, - @Args('messageId') messageId: string, - ): Promise { - return this.chatService.getMessageById(messageId); - } - // To do: message need a update resolver @JWTAuth() diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index 8f5bc492..d5282724 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -16,10 +16,9 @@ import { ModelProvider } from 'src/common/model-provider'; @Injectable() export class ChatProxyService { private readonly logger = new Logger('ChatProxyService'); - private models: ModelProvider; - constructor(private httpService: HttpService) { - this.models = ModelProvider.getInstance(); + constructor(private httpService: HttpService, private readonly models: ModelProvider) { + } streamChat( @@ -39,33 +38,34 @@ export class ChatService { @InjectRepository(Chat) private chatRepository: Repository, @InjectRepository(User) - private userRepository: Repository, - @InjectRepository(Message) - private messageRepository: Repository, + private userRepository: Repository ) {} async getChatHistory(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, - relations: ['messages'], }); + console.log(chat); + if (chat && chat.messages) { // Sort messages by createdAt in ascending order chat.messages = chat.messages .filter((message) => !message.isDeleted) - .sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime()); + .map((message) => { + if (!(message.createdAt instanceof Date)) { + message.createdAt = new Date(message.createdAt); + } + return message; + }) + .sort((a, b) => { + return a.createdAt.getTime() - b.createdAt.getTime(); + }); } return chat ? chat.messages : []; } - async getMessageById(messageId: string): Promise { - return await this.messageRepository.findOne({ - where: { id: messageId, isDeleted: false }, - }); - } - async getChatDetails(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, @@ -111,12 +111,6 @@ export class ChatService { chat.isActive = false; await this.chatRepository.save(chat); - // Soft delete all associated messages - await this.messageRepository.update( - { chat: { id: chatId }, isDeleted: false }, - { isDeleted: true, isActive: false }, - ); - return true; } return false; @@ -125,13 +119,8 @@ export class ChatService { async clearChatHistory(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, - relations: ['messages'], }); if (chat) { - await this.messageRepository.update( - { chat: { id: chatId }, isDeleted: false }, - { isDeleted: true, isActive: false }, - ); chat.updatedAt = new Date(); await this.chatRepository.save(chat); return true; @@ -161,21 +150,24 @@ export class ChatService { ): Promise { // Find the chat instance const chat = await this.chatRepository.findOne({ where: { id: chatId } }); + + const message = { + id: `${chat.id}/${chat.messages.length}`, + content: messageContent, + role: role, + createdAt: new Date(), + updatedAt: new Date(), + isActive: true, + isDeleted: false, + }; //if the chat id not exist, dont save this messages if (!chat) { return null; } - - // Create a new message associated with the chat - const message = this.messageRepository.create({ - content: messageContent, - role: role, - chat, - createdAt: new Date(), - }); - + chat.messages.push(message); + await this.chatRepository.save(chat); // Save the message to the database - return await this.messageRepository.save(message); + return message; } async getChatWithUser(chatId: string): Promise { diff --git a/backend/src/chat/message.model.ts b/backend/src/chat/message.model.ts index b4adf4e6..05efb478 100644 --- a/backend/src/chat/message.model.ts +++ b/backend/src/chat/message.model.ts @@ -17,8 +17,8 @@ import { Chat } from 'src/chat/chat.model'; import { SystemBaseModel } from 'src/system-base-model/system-base.model'; export enum MessageRole { - User = 'User', - Model = 'Model', + User = 'user', + Model = 'assistant', } registerEnumType(MessageRole, { @@ -43,8 +43,4 @@ export class Message extends SystemBaseModel { @Field({ nullable: true }) @Column({ nullable: true }) modelId?: string; - - @ManyToOne(() => Chat, (chat) => chat.messages) - @JoinColumn({ name: 'chatId' }) - chat: Chat; } diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 9991c96c..78843f13 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -1,11 +1,9 @@ import { Logger } from '@nestjs/common'; import { HttpService } from '@nestjs/axios'; import { Subject, Subscription } from 'rxjs'; +import { MessageRole } from 'src/chat/message.model'; +import { LLMInterface, ModelProviderConfig } from './types'; -export interface ModelProviderConfig { - endpoint: string; - defaultModel?: string; -} export interface CustomAsyncIterableIterator extends AsyncIterator { [Symbol.asyncIterator](): AsyncIterableIterator; @@ -55,9 +53,7 @@ export class ModelProvider { * Synchronous chat method that returns a complete response */ async chatSync( - input: ChatInput | string, - model: string, - chatId?: string, + input: LLMInterface, ): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); @@ -70,7 +66,6 @@ export class ModelProvider { `Starting request ${requestId}. Active: ${this.currentRequests}/${this.concurrentLimit}`, ); - const normalizedInput = this.normalizeChatInput(input); let resolvePromise: (value: string) => void; let rejectPromise: (error: any) => void; @@ -113,7 +108,7 @@ export class ModelProvider { promise, }); - this.processRequest(normalizedInput, model, chatId, requestId, stream); + this.processRequest(input, requestId, stream); return promise; } @@ -155,9 +150,7 @@ export class ModelProvider { } private async processRequest( - input: ChatInput, - model: string, - chatId: string | undefined, + input: LLMInterface, requestId: string, stream: Subject, ) { @@ -167,7 +160,7 @@ export class ModelProvider { const response = await this.httpService .post( `${this.config.endpoint}/chat/completion`, - this.createRequestPayload(input, model, chatId), + input, { responseType: 'stream', headers: { 'Content-Type': 'application/json' }, diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts index 8c649851..ebe69d14 100644 --- a/backend/src/common/model-provider/types.ts +++ b/backend/src/common/model-provider/types.ts @@ -1,3 +1,5 @@ +import { MessageRole } from "src/chat/message.model"; + export interface ModelChatStreamConfig { endpoint: string; model?: string; @@ -5,3 +7,19 @@ export interface ModelChatStreamConfig { export type CustomAsyncIterableIterator = AsyncIterator & { [Symbol.asyncIterator](): AsyncIterableIterator; }; + +export interface ModelProviderConfig { + endpoint: string; + defaultModel?: string; +} + +export interface MessageInterface { + content: string; + role: MessageRole; +} + +export interface LLMInterface { + model: string; + messages: MessageInterface[]; +} + diff --git a/backend/src/main.ts b/backend/src/main.ts index ea049c9e..6893788e 100644 --- a/backend/src/main.ts +++ b/backend/src/main.ts @@ -17,7 +17,7 @@ async function bootstrap() { 'Access-Control-Allow-Credentials', ], }); - await downloadAllModels(); + // await downloadAllModels(); await app.listen(process.env.PORT ?? 3000); } diff --git a/llm-server/package.json b/llm-server/package.json index 6b67f9a4..dd5f7f28 100644 --- a/llm-server/package.json +++ b/llm-server/package.json @@ -5,7 +5,7 @@ "type": "module", "scripts": { "start": "node --experimental-specifier-resolution=node --loader ts-node/esm src/main.ts", - "dev": "nodemon --watch 'src/**/*.ts' --exec 'node --experimental-specifier-resolution=node --loader ts-node/esm' src/main.ts", + "dev": "nodemon --watch \"src/**/*.ts\" --exec \"node --experimental-specifier-resolution=node --loader ts-node/esm\" src/main.ts", "dev:backend": "pnpm dev", "build": "tsc", "serve": "node --experimental-specifier-resolution=node dist/main.js", diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index 0982ebdf..b4a8d142 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -10,11 +10,8 @@ import { } from './types'; import { ModelProvider } from './model/model-provider'; -export interface ChatMessageInput { - content: string; -} -export interface ChatMessage { +export interface ChatMessageInput { role: string; content: string; } diff --git a/llm-server/src/main.ts b/llm-server/src/main.ts index 80ec476a..bb2c8e72 100644 --- a/llm-server/src/main.ts +++ b/llm-server/src/main.ts @@ -1,4 +1,4 @@ -import { Logger } from '@nestjs/common'; +import { Logger, Module } from '@nestjs/common'; import { ChatMessageInput, LLMProvider } from './llm-provider'; import express, { Express, Request, Response } from 'express'; import { GenerateMessageParams } from './types'; @@ -27,23 +27,16 @@ export class App { private async handleChatRequest(req: Request, res: Response): Promise { this.logger.log('Received chat request.'); try { - this.logger.debug(JSON.stringify(req.body)); - const { content, model } = req.body as ChatMessageInput & { - model: string; - }; + const input = req.body as GenerateMessageParams; + const model = input.model; this.logger.log(`Received chat request for model: ${model}`); - const params: GenerateMessageParams = { - model: model || 'gpt-3.5-turbo', - message: content, - role: 'user', - }; - this.logger.debug(`Request content: "${content}"`); + this.logger.debug(`Request messages: "${input.messages}"`); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); res.setHeader('Connection', 'keep-alive'); this.logger.debug('Response headers set for streaming.'); - await this.llmProvider.generateStreamingResponse(params, res); + await this.llmProvider.generateStreamingResponse(input, res); } catch (error) { this.logger.error('Error in chat endpoint:', error); res.status(500).json({ error: 'Internal server error' }); diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index 77fa89ae..ce09ec78 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -36,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { } async generateStreamingResponse( - { model, message, role = 'user' }: GenerateMessageParams, + { model, messages}: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with Llama...'); @@ -50,13 +50,13 @@ export class LlamaModelProvider extends ModelProvider { // Get the system prompt based on the model const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; - const messages = [ + const allMessage = [ { role: 'system', content: systemPrompt }, - { role: role as 'user' | 'system' | 'assistant', content: message }, + ...messages, ]; // Convert messages array to a single formatted string for Llama - const formattedPrompt = messages + const formattedPrompt = allMessage .map(({ role, content }) => `${role}: ${content}`) .join('\n'); diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index fc845a69..f7e2d54f 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -89,7 +89,7 @@ export class OpenAIModelProvider { private async processRequest(request: QueuedRequest): Promise { const { params, res, retries } = request; - const { model, message, role = 'user' } = params; + const { model, messages} = params as {model:string, messages:ChatCompletionMessageParam[]}; this.logger.log(`Processing request (attempt ${retries + 1})`); const startTime = Date.now(); @@ -103,14 +103,14 @@ export class OpenAIModelProvider { const systemPrompt = systemPrompts[this.options.systemPromptKey]?.systemPrompt || ''; - const messages: ChatCompletionMessageParam[] = [ + const allMessages: ChatCompletionMessageParam[] = [ { role: 'system', content: systemPrompt }, - { role: role as 'user' | 'system' | 'assistant', content: message }, + ...messages, ]; - + console.log(allMessages); const stream = await this.openai.chat.completions.create({ model, - messages, + messages: allMessages, stream: true, }); @@ -171,8 +171,7 @@ export class OpenAIModelProvider { error: errorResponse, params: { model, - messageLength: message.length, - role, + messageLength: messages.length, }, }); diff --git a/llm-server/src/types.ts b/llm-server/src/types.ts index 2f1bc1df..679996a9 100644 --- a/llm-server/src/types.ts +++ b/llm-server/src/types.ts @@ -1,7 +1,11 @@ +interface MessageInterface { + role: string; + content: string; +} + export interface GenerateMessageParams { model: string; // Model to use, e.g., 'gpt-3.5-turbo' - message: string; // User's message or query - role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role + messages: MessageInterface[]; // User's message or query } // types.ts From c2df443c456f4ed68130dda99885cca5b55f0542 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 03:02:31 +0000 Subject: [PATCH 2/5] [autofix.ci] apply automated fixes --- .../__tests__/test.chat-isolation.spec.ts | 74 ++++++++++++------- backend/src/chat/chat.service.ts | 14 ++-- backend/src/common/model-provider/index.ts | 18 ++--- backend/src/common/model-provider/types.ts | 3 +- llm-server/src/llm-provider.ts | 1 - llm-server/src/model/llama-model-provider.ts | 7 +- llm-server/src/model/openai-model-provider.ts | 7 +- 7 files changed, 66 insertions(+), 58 deletions(-) diff --git a/backend/src/chat/__tests__/test.chat-isolation.spec.ts b/backend/src/chat/__tests__/test.chat-isolation.spec.ts index fd4edafd..7495ec58 100644 --- a/backend/src/chat/__tests__/test.chat-isolation.spec.ts +++ b/backend/src/chat/__tests__/test.chat-isolation.spec.ts @@ -17,7 +17,7 @@ import { Menu } from 'src/auth/menu/menu.model'; import { Role } from 'src/auth/role/role.model'; import { RegisterUserInput } from 'src/user/dto/register-user.input'; import { NewChatInput } from '../dto/chat.input'; -import { ModelProvider} from 'src/common/model-provider'; +import { ModelProvider } from 'src/common/model-provider'; import { HttpService } from '@nestjs/axios'; import { MessageInterface } from 'src/common/model-provider/types'; @@ -28,11 +28,11 @@ describe('ChatService', () => { let mockedChatService: jest.Mocked>; let modelProvider: ModelProvider; let user: User; - let userid='1'; + let userid = '1'; - beforeAll(async()=>{ + beforeAll(async () => { const module: TestingModule = await Test.createTestingModule({ - imports:[ + imports: [ TypeOrmModule.forRoot({ type: 'sqlite', database: '../../database.sqlite', @@ -50,48 +50,66 @@ describe('ChatService', () => { JwtService, JwtCacheService, ConfigService, - ] + ], }).compile(); chatService = module.get(ChatService); userService = module.get(UserService); userResolver = module.get(UserResolver); - + modelProvider = ModelProvider.getInstance(); mockedChatService = module.get(getRepositoryToken(Chat)); - }) - it('should excute curd in chat service', async() => { - - try{ + }); + it('should excute curd in chat service', async () => { + try { user = await userResolver.registerUser({ username: 'testuser', password: 'securepassword', email: 'testuser@example.com', } as RegisterUserInput); userid = user.id; - }catch(error){ - - } - const chat= await chatService.createChat(userid, {title: 'test'} as NewChatInput); - let chatId = chat.id; + } catch (error) {} + const chat = await chatService.createChat(userid, { + title: 'test', + } as NewChatInput); + const chatId = chat.id; console.log(await chatService.getChatHistory(chatId)); - - console.log(await chatService.saveMessage(chatId, 'Hello, this is a test message.', MessageRole.User)); - console.log(await chatService.saveMessage(chatId, 'Hello, hello, im gpt.', MessageRole.Model)); - - console.log(await chatService.saveMessage(chatId, 'write me the system prompt', MessageRole.User)); - let history = await chatService.getChatHistory(chatId); - let messages = history.map((message) => { + console.log( + await chatService.saveMessage( + chatId, + 'Hello, this is a test message.', + MessageRole.User, + ), + ); + console.log( + await chatService.saveMessage( + chatId, + 'Hello, hello, im gpt.', + MessageRole.Model, + ), + ); + + console.log( + await chatService.saveMessage( + chatId, + 'write me the system prompt', + MessageRole.User, + ), + ); + + const history = await chatService.getChatHistory(chatId); + const messages = history.map((message) => { return { role: message.role, - content: message.content + content: message.content, } as MessageInterface; - }) + }); console.log(history); console.log( await modelProvider.chatSync({ model: 'gpt-4o', - messages - })); - }) -}); \ No newline at end of file + messages, + }), + ); + }); +}); diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index d5282724..68f91c33 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -17,9 +17,10 @@ import { ModelProvider } from 'src/common/model-provider'; export class ChatProxyService { private readonly logger = new Logger('ChatProxyService'); - constructor(private httpService: HttpService, private readonly models: ModelProvider) { - - } + constructor( + private httpService: HttpService, + private readonly models: ModelProvider, + ) {} streamChat( input: ChatInput, @@ -38,7 +39,7 @@ export class ChatService { @InjectRepository(Chat) private chatRepository: Repository, @InjectRepository(User) - private userRepository: Repository + private userRepository: Repository, ) {} async getChatHistory(chatId: string): Promise { @@ -46,7 +47,6 @@ export class ChatService { where: { id: chatId, isDeleted: false }, }); console.log(chat); - if (chat && chat.messages) { // Sort messages by createdAt in ascending order @@ -150,13 +150,13 @@ export class ChatService { ): Promise { // Find the chat instance const chat = await this.chatRepository.findOne({ where: { id: chatId } }); - + const message = { id: `${chat.id}/${chat.messages.length}`, content: messageContent, role: role, createdAt: new Date(), - updatedAt: new Date(), + updatedAt: new Date(), isActive: true, isDeleted: false, }; diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 78843f13..d0f03705 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -4,7 +4,6 @@ import { Subject, Subscription } from 'rxjs'; import { MessageRole } from 'src/chat/message.model'; import { LLMInterface, ModelProviderConfig } from './types'; - export interface CustomAsyncIterableIterator extends AsyncIterator { [Symbol.asyncIterator](): AsyncIterableIterator; } @@ -52,9 +51,7 @@ export class ModelProvider { /** * Synchronous chat method that returns a complete response */ - async chatSync( - input: LLMInterface, - ): Promise { + async chatSync(input: LLMInterface): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); } @@ -66,7 +63,6 @@ export class ModelProvider { `Starting request ${requestId}. Active: ${this.currentRequests}/${this.concurrentLimit}`, ); - let resolvePromise: (value: string) => void; let rejectPromise: (error: any) => void; @@ -158,14 +154,10 @@ export class ModelProvider { try { const response = await this.httpService - .post( - `${this.config.endpoint}/chat/completion`, - input, - { - responseType: 'stream', - headers: { 'Content-Type': 'application/json' }, - }, - ) + .post(`${this.config.endpoint}/chat/completion`, input, { + responseType: 'stream', + headers: { 'Content-Type': 'application/json' }, + }) .toPromise(); let buffer = ''; diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts index ebe69d14..bdfb66ec 100644 --- a/backend/src/common/model-provider/types.ts +++ b/backend/src/common/model-provider/types.ts @@ -1,4 +1,4 @@ -import { MessageRole } from "src/chat/message.model"; +import { MessageRole } from 'src/chat/message.model'; export interface ModelChatStreamConfig { endpoint: string; @@ -22,4 +22,3 @@ export interface LLMInterface { model: string; messages: MessageInterface[]; } - diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index b4a8d142..dc53a22c 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -10,7 +10,6 @@ import { } from './types'; import { ModelProvider } from './model/model-provider'; - export interface ChatMessageInput { role: string; content: string; diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index ce09ec78..6aafc787 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -36,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { } async generateStreamingResponse( - { model, messages}: GenerateMessageParams, + { model, messages }: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with Llama...'); @@ -50,10 +50,7 @@ export class LlamaModelProvider extends ModelProvider { // Get the system prompt based on the model const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; - const allMessage = [ - { role: 'system', content: systemPrompt }, - ...messages, - ]; + const allMessage = [{ role: 'system', content: systemPrompt }, ...messages]; // Convert messages array to a single formatted string for Llama const formattedPrompt = allMessage diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index f7e2d54f..85d7bd5d 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -89,7 +89,10 @@ export class OpenAIModelProvider { private async processRequest(request: QueuedRequest): Promise { const { params, res, retries } = request; - const { model, messages} = params as {model:string, messages:ChatCompletionMessageParam[]}; + const { model, messages } = params as { + model: string; + messages: ChatCompletionMessageParam[]; + }; this.logger.log(`Processing request (attempt ${retries + 1})`); const startTime = Date.now(); @@ -105,7 +108,7 @@ export class OpenAIModelProvider { systemPrompts[this.options.systemPromptKey]?.systemPrompt || ''; const allMessages: ChatCompletionMessageParam[] = [ { role: 'system', content: systemPrompt }, - ...messages, + ...messages, ]; console.log(allMessages); const stream = await this.openai.chat.completions.create({ From 3e12625bc8e5ac1d6330d83f6ed79cff4336695b Mon Sep 17 00:00:00 2001 From: NarwhalChen Date: Sun, 5 Jan 2025 21:58:07 -0800 Subject: [PATCH 3/5] resolve conversation for commit 487485f --- ...olation.spec.ts => chat-isolation.spec.ts} | 0 backend/src/common/model-provider/index.ts | 26 +++++++------------ backend/src/common/model-provider/types.ts | 2 +- backend/src/main.ts | 2 +- 4 files changed, 11 insertions(+), 19 deletions(-) rename backend/src/chat/__tests__/{test.chat-isolation.spec.ts => chat-isolation.spec.ts} (100%) diff --git a/backend/src/chat/__tests__/test.chat-isolation.spec.ts b/backend/src/chat/__tests__/chat-isolation.spec.ts similarity index 100% rename from backend/src/chat/__tests__/test.chat-isolation.spec.ts rename to backend/src/chat/__tests__/chat-isolation.spec.ts diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 78843f13..35a6c9ed 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -2,7 +2,7 @@ import { Logger } from '@nestjs/common'; import { HttpService } from '@nestjs/axios'; import { Subject, Subscription } from 'rxjs'; import { MessageRole } from 'src/chat/message.model'; -import { LLMInterface, ModelProviderConfig } from './types'; +import { ChatInput, ModelProviderConfig } from './types'; export interface CustomAsyncIterableIterator extends AsyncIterator { @@ -53,7 +53,7 @@ export class ModelProvider { * Synchronous chat method that returns a complete response */ async chatSync( - input: LLMInterface, + input: ChatInput, ): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); @@ -120,7 +120,7 @@ export class ModelProvider { model: string, chatId?: string, ): CustomAsyncIterableIterator { - const chatInput = this.normalizeChatInput(input); + const chatInput = this.normalizeChatInput(input, model); const selectedModel = model || this.config.defaultModel; if (!selectedModel) { @@ -150,7 +150,7 @@ export class ModelProvider { } private async processRequest( - input: LLMInterface, + input: ChatInput, requestId: string, stream: Subject, ) { @@ -427,8 +427,11 @@ export class ModelProvider { ); } - private normalizeChatInput(input: ChatInput | string): ChatInput { - return typeof input === 'string' ? { content: input } : input; + private normalizeChatInput(input: ChatInput | string, model: string): ChatInput { + return typeof input === 'string' ? { model, messages:[{ + content: input, + role: MessageRole.User, + }] } : input; } public async fetchModelsName() { @@ -461,17 +464,6 @@ export class ChatCompletionChunk { status: StreamStatus; } -export interface ChatInput { - content: string; - attachments?: Array<{ - type: string; - content: string | Buffer; - name?: string; - }>; - contextLength?: number; - temperature?: number; -} - class ChatCompletionDelta { content?: string; } diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts index ebe69d14..fadb6f77 100644 --- a/backend/src/common/model-provider/types.ts +++ b/backend/src/common/model-provider/types.ts @@ -18,7 +18,7 @@ export interface MessageInterface { role: MessageRole; } -export interface LLMInterface { +export interface ChatInput { model: string; messages: MessageInterface[]; } diff --git a/backend/src/main.ts b/backend/src/main.ts index 6893788e..ea049c9e 100644 --- a/backend/src/main.ts +++ b/backend/src/main.ts @@ -17,7 +17,7 @@ async function bootstrap() { 'Access-Control-Allow-Credentials', ], }); - // await downloadAllModels(); + await downloadAllModels(); await app.listen(process.env.PORT ?? 3000); } From 8b1a1c9bf888b3cf6763b6838b579006b88cb70f Mon Sep 17 00:00:00 2001 From: NarwhalChen Date: Sun, 5 Jan 2025 22:59:43 -0800 Subject: [PATCH 4/5] feat: changing test to in memory mode --- backend/src/chat/__tests__/chat-isolation.spec.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/src/chat/__tests__/chat-isolation.spec.ts b/backend/src/chat/__tests__/chat-isolation.spec.ts index fd4edafd..7ff9d024 100644 --- a/backend/src/chat/__tests__/chat-isolation.spec.ts +++ b/backend/src/chat/__tests__/chat-isolation.spec.ts @@ -35,14 +35,13 @@ describe('ChatService', () => { imports:[ TypeOrmModule.forRoot({ type: 'sqlite', - database: '../../database.sqlite', + database: ':memory:', synchronize: true, - entities: ['../../' + '/**/*.model{.ts,.js}'], + entities: [Chat, User, Menu, Role], }), TypeOrmModule.forFeature([Chat, User, Menu, Role]), ], providers: [ - Repository, ChatService, AuthService, UserService, From 5b6bd1715a4be57f512a050fdd888bfdbb7dd7bb Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 07:02:33 +0000 Subject: [PATCH 5/5] [autofix.ci] apply automated fixes --- backend/src/common/model-provider/index.ts | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 3cefee76..0e81e594 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -51,9 +51,7 @@ export class ModelProvider { /** * Synchronous chat method that returns a complete response */ - async chatSync( - input: ChatInput, - ): Promise { + async chatSync(input: ChatInput): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); } @@ -421,11 +419,21 @@ export class ModelProvider { ); } - private normalizeChatInput(input: ChatInput | string, model: string): ChatInput { - return typeof input === 'string' ? { model, messages:[{ - content: input, - role: MessageRole.User, - }] } : input; + private normalizeChatInput( + input: ChatInput | string, + model: string, + ): ChatInput { + return typeof input === 'string' + ? { + model, + messages: [ + { + content: input, + role: MessageRole.User, + }, + ], + } + : input; } public async fetchModelsName() {