diff --git a/backend/src/chat/__tests__/chat-isolation.spec.ts b/backend/src/chat/__tests__/chat-isolation.spec.ts new file mode 100644 index 00000000..1fe760c7 --- /dev/null +++ b/backend/src/chat/__tests__/chat-isolation.spec.ts @@ -0,0 +1,114 @@ +// chat.service.spec.ts +import { Test, TestingModule } from '@nestjs/testing'; +import { ChatService } from '../chat.service'; +import { getRepositoryToken } from '@nestjs/typeorm'; +import { Chat } from '../chat.model'; +import { User } from 'src/user/user.model'; +import { Message, MessageRole } from 'src/chat/message.model'; +import { Repository } from 'typeorm'; +import { TypeOrmModule } from '@nestjs/typeorm'; +import { UserResolver } from 'src/user/user.resolver'; +import { AuthService } from 'src/auth/auth.service'; +import { UserService } from 'src/user/user.service'; +import { JwtService } from '@nestjs/jwt'; +import { JwtCacheService } from 'src/auth/jwt-cache.service'; +import { ConfigService } from '@nestjs/config'; +import { Menu } from 'src/auth/menu/menu.model'; +import { Role } from 'src/auth/role/role.model'; +import { RegisterUserInput } from 'src/user/dto/register-user.input'; +import { NewChatInput } from '../dto/chat.input'; +import { ModelProvider } from 'src/common/model-provider'; +import { HttpService } from '@nestjs/axios'; +import { MessageInterface } from 'src/common/model-provider/types'; + +describe('ChatService', () => { + let chatService: ChatService; + let userResolver: UserResolver; + let userService: UserService; + let mockedChatService: jest.Mocked>; + let modelProvider: ModelProvider; + let user: User; + let userid = '1'; + + beforeAll(async () => { + const module: TestingModule = await Test.createTestingModule({ + imports: [ + TypeOrmModule.forRoot({ + type: 'sqlite', + database: ':memory:', + synchronize: true, + entities: [Chat, User, Menu, Role], + }), + TypeOrmModule.forFeature([Chat, User, Menu, Role]), + ], + providers: [ + ChatService, + AuthService, + UserService, + UserResolver, + JwtService, + JwtCacheService, + ConfigService, + ], + }).compile(); + chatService = module.get(ChatService); + userService = module.get(UserService); + userResolver = module.get(UserResolver); + + modelProvider = ModelProvider.getInstance(); + mockedChatService = module.get(getRepositoryToken(Chat)); + }); + it('should excute curd in chat service', async () => { + try { + user = await userResolver.registerUser({ + username: 'testuser', + password: 'securepassword', + email: 'testuser@example.com', + } as RegisterUserInput); + userid = user.id; + } catch (error) {} + const chat = await chatService.createChat(userid, { + title: 'test', + } as NewChatInput); + const chatId = chat.id; + console.log(await chatService.getChatHistory(chatId)); + + console.log( + await chatService.saveMessage( + chatId, + 'Hello, this is a test message.', + MessageRole.User, + ), + ); + console.log( + await chatService.saveMessage( + chatId, + 'Hello, hello, im gpt.', + MessageRole.Model, + ), + ); + + console.log( + await chatService.saveMessage( + chatId, + 'write me the system prompt', + MessageRole.User, + ), + ); + + const history = await chatService.getChatHistory(chatId); + const messages = history.map((message) => { + return { + role: message.role, + content: message.content, + } as MessageInterface; + }); + console.log(history); + console.log( + await modelProvider.chatSync({ + model: 'gpt-4o', + messages, + }), + ); + }); +}); diff --git a/backend/src/chat/chat.model.ts b/backend/src/chat/chat.model.ts index 5fce9e47..a5e2fe82 100644 --- a/backend/src/chat/chat.model.ts +++ b/backend/src/chat/chat.model.ts @@ -38,8 +38,8 @@ export class Chat extends SystemBaseModel { @Column({ nullable: true }) title: string; - @Field(() => [Message], { nullable: true }) - @OneToMany(() => Message, (message) => message.chat, { cascade: true }) + @Field({ nullable: true }) + @Column('simple-json', { nullable: true, default: '[]' }) messages: Message[]; @ManyToOne(() => User, (user) => user.chats) diff --git a/backend/src/chat/chat.module.ts b/backend/src/chat/chat.module.ts index 149d3bfe..c2068bb3 100644 --- a/backend/src/chat/chat.module.ts +++ b/backend/src/chat/chat.module.ts @@ -12,6 +12,7 @@ import { ChatGuard } from '../guard/chat.guard'; import { AuthModule } from '../auth/auth.module'; import { UserService } from 'src/user/user.service'; import { PubSub } from 'graphql-subscriptions'; +import { ModelProvider } from 'src/common/model-provider'; @Module({ imports: [ @@ -30,6 +31,6 @@ import { PubSub } from 'graphql-subscriptions'; useValue: new PubSub(), }, ], - exports: [ChatService, ChatGuard], + exports: [ChatService, ChatGuard, ModelProvider], }) export class ChatModule {} diff --git a/backend/src/chat/chat.resolver.ts b/backend/src/chat/chat.resolver.ts index 57fc8bfd..e9e9a0fa 100644 --- a/backend/src/chat/chat.resolver.ts +++ b/backend/src/chat/chat.resolver.ts @@ -103,16 +103,6 @@ export class ChatResolver { const user = await this.userService.getUserChats(userId); return user ? user.chats : []; } - - @JWTAuth() - @Query(() => Message, { nullable: true }) - async getMessageDetail( - @GetUserIdFromToken() userId: string, - @Args('messageId') messageId: string, - ): Promise { - return this.chatService.getMessageById(messageId); - } - // To do: message need a update resolver @JWTAuth() diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index 8f5bc492..68f91c33 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -16,11 +16,11 @@ import { ModelProvider } from 'src/common/model-provider'; @Injectable() export class ChatProxyService { private readonly logger = new Logger('ChatProxyService'); - private models: ModelProvider; - constructor(private httpService: HttpService) { - this.models = ModelProvider.getInstance(); - } + constructor( + private httpService: HttpService, + private readonly models: ModelProvider, + ) {} streamChat( input: ChatInput, @@ -40,32 +40,32 @@ export class ChatService { private chatRepository: Repository, @InjectRepository(User) private userRepository: Repository, - @InjectRepository(Message) - private messageRepository: Repository, ) {} async getChatHistory(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, - relations: ['messages'], }); + console.log(chat); if (chat && chat.messages) { // Sort messages by createdAt in ascending order chat.messages = chat.messages .filter((message) => !message.isDeleted) - .sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime()); + .map((message) => { + if (!(message.createdAt instanceof Date)) { + message.createdAt = new Date(message.createdAt); + } + return message; + }) + .sort((a, b) => { + return a.createdAt.getTime() - b.createdAt.getTime(); + }); } return chat ? chat.messages : []; } - async getMessageById(messageId: string): Promise { - return await this.messageRepository.findOne({ - where: { id: messageId, isDeleted: false }, - }); - } - async getChatDetails(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, @@ -111,12 +111,6 @@ export class ChatService { chat.isActive = false; await this.chatRepository.save(chat); - // Soft delete all associated messages - await this.messageRepository.update( - { chat: { id: chatId }, isDeleted: false }, - { isDeleted: true, isActive: false }, - ); - return true; } return false; @@ -125,13 +119,8 @@ export class ChatService { async clearChatHistory(chatId: string): Promise { const chat = await this.chatRepository.findOne({ where: { id: chatId, isDeleted: false }, - relations: ['messages'], }); if (chat) { - await this.messageRepository.update( - { chat: { id: chatId }, isDeleted: false }, - { isDeleted: true, isActive: false }, - ); chat.updatedAt = new Date(); await this.chatRepository.save(chat); return true; @@ -161,21 +150,24 @@ export class ChatService { ): Promise { // Find the chat instance const chat = await this.chatRepository.findOne({ where: { id: chatId } }); - //if the chat id not exist, dont save this messages - if (!chat) { - return null; - } - // Create a new message associated with the chat - const message = this.messageRepository.create({ + const message = { + id: `${chat.id}/${chat.messages.length}`, content: messageContent, role: role, - chat, createdAt: new Date(), - }); - + updatedAt: new Date(), + isActive: true, + isDeleted: false, + }; + //if the chat id not exist, dont save this messages + if (!chat) { + return null; + } + chat.messages.push(message); + await this.chatRepository.save(chat); // Save the message to the database - return await this.messageRepository.save(message); + return message; } async getChatWithUser(chatId: string): Promise { diff --git a/backend/src/chat/message.model.ts b/backend/src/chat/message.model.ts index b4adf4e6..05efb478 100644 --- a/backend/src/chat/message.model.ts +++ b/backend/src/chat/message.model.ts @@ -17,8 +17,8 @@ import { Chat } from 'src/chat/chat.model'; import { SystemBaseModel } from 'src/system-base-model/system-base.model'; export enum MessageRole { - User = 'User', - Model = 'Model', + User = 'user', + Model = 'assistant', } registerEnumType(MessageRole, { @@ -43,8 +43,4 @@ export class Message extends SystemBaseModel { @Field({ nullable: true }) @Column({ nullable: true }) modelId?: string; - - @ManyToOne(() => Chat, (chat) => chat.messages) - @JoinColumn({ name: 'chatId' }) - chat: Chat; } diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 9991c96c..0e81e594 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -1,11 +1,8 @@ import { Logger } from '@nestjs/common'; import { HttpService } from '@nestjs/axios'; import { Subject, Subscription } from 'rxjs'; - -export interface ModelProviderConfig { - endpoint: string; - defaultModel?: string; -} +import { MessageRole } from 'src/chat/message.model'; +import { ChatInput, ModelProviderConfig } from './types'; export interface CustomAsyncIterableIterator extends AsyncIterator { [Symbol.asyncIterator](): AsyncIterableIterator; @@ -54,11 +51,7 @@ export class ModelProvider { /** * Synchronous chat method that returns a complete response */ - async chatSync( - input: ChatInput | string, - model: string, - chatId?: string, - ): Promise { + async chatSync(input: ChatInput): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); } @@ -70,8 +63,6 @@ export class ModelProvider { `Starting request ${requestId}. Active: ${this.currentRequests}/${this.concurrentLimit}`, ); - const normalizedInput = this.normalizeChatInput(input); - let resolvePromise: (value: string) => void; let rejectPromise: (error: any) => void; @@ -113,7 +104,7 @@ export class ModelProvider { promise, }); - this.processRequest(normalizedInput, model, chatId, requestId, stream); + this.processRequest(input, requestId, stream); return promise; } @@ -125,7 +116,7 @@ export class ModelProvider { model: string, chatId?: string, ): CustomAsyncIterableIterator { - const chatInput = this.normalizeChatInput(input); + const chatInput = this.normalizeChatInput(input, model); const selectedModel = model || this.config.defaultModel; if (!selectedModel) { @@ -156,8 +147,6 @@ export class ModelProvider { private async processRequest( input: ChatInput, - model: string, - chatId: string | undefined, requestId: string, stream: Subject, ) { @@ -165,14 +154,10 @@ export class ModelProvider { try { const response = await this.httpService - .post( - `${this.config.endpoint}/chat/completion`, - this.createRequestPayload(input, model, chatId), - { - responseType: 'stream', - headers: { 'Content-Type': 'application/json' }, - }, - ) + .post(`${this.config.endpoint}/chat/completion`, input, { + responseType: 'stream', + headers: { 'Content-Type': 'application/json' }, + }) .toPromise(); let buffer = ''; @@ -434,8 +419,21 @@ export class ModelProvider { ); } - private normalizeChatInput(input: ChatInput | string): ChatInput { - return typeof input === 'string' ? { content: input } : input; + private normalizeChatInput( + input: ChatInput | string, + model: string, + ): ChatInput { + return typeof input === 'string' + ? { + model, + messages: [ + { + content: input, + role: MessageRole.User, + }, + ], + } + : input; } public async fetchModelsName() { @@ -468,17 +466,6 @@ export class ChatCompletionChunk { status: StreamStatus; } -export interface ChatInput { - content: string; - attachments?: Array<{ - type: string; - content: string | Buffer; - name?: string; - }>; - contextLength?: number; - temperature?: number; -} - class ChatCompletionDelta { content?: string; } diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts index 8c649851..61984d3a 100644 --- a/backend/src/common/model-provider/types.ts +++ b/backend/src/common/model-provider/types.ts @@ -1,3 +1,5 @@ +import { MessageRole } from 'src/chat/message.model'; + export interface ModelChatStreamConfig { endpoint: string; model?: string; @@ -5,3 +7,18 @@ export interface ModelChatStreamConfig { export type CustomAsyncIterableIterator = AsyncIterator & { [Symbol.asyncIterator](): AsyncIterableIterator; }; + +export interface ModelProviderConfig { + endpoint: string; + defaultModel?: string; +} + +export interface MessageInterface { + content: string; + role: MessageRole; +} + +export interface ChatInput { + model: string; + messages: MessageInterface[]; +} diff --git a/llm-server/package.json b/llm-server/package.json index 6b67f9a4..dd5f7f28 100644 --- a/llm-server/package.json +++ b/llm-server/package.json @@ -5,7 +5,7 @@ "type": "module", "scripts": { "start": "node --experimental-specifier-resolution=node --loader ts-node/esm src/main.ts", - "dev": "nodemon --watch 'src/**/*.ts' --exec 'node --experimental-specifier-resolution=node --loader ts-node/esm' src/main.ts", + "dev": "nodemon --watch \"src/**/*.ts\" --exec \"node --experimental-specifier-resolution=node --loader ts-node/esm\" src/main.ts", "dev:backend": "pnpm dev", "build": "tsc", "serve": "node --experimental-specifier-resolution=node dist/main.js", diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index 0982ebdf..dc53a22c 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -11,10 +11,6 @@ import { import { ModelProvider } from './model/model-provider'; export interface ChatMessageInput { - content: string; -} - -export interface ChatMessage { role: string; content: string; } diff --git a/llm-server/src/main.ts b/llm-server/src/main.ts index 80ec476a..bb2c8e72 100644 --- a/llm-server/src/main.ts +++ b/llm-server/src/main.ts @@ -1,4 +1,4 @@ -import { Logger } from '@nestjs/common'; +import { Logger, Module } from '@nestjs/common'; import { ChatMessageInput, LLMProvider } from './llm-provider'; import express, { Express, Request, Response } from 'express'; import { GenerateMessageParams } from './types'; @@ -27,23 +27,16 @@ export class App { private async handleChatRequest(req: Request, res: Response): Promise { this.logger.log('Received chat request.'); try { - this.logger.debug(JSON.stringify(req.body)); - const { content, model } = req.body as ChatMessageInput & { - model: string; - }; + const input = req.body as GenerateMessageParams; + const model = input.model; this.logger.log(`Received chat request for model: ${model}`); - const params: GenerateMessageParams = { - model: model || 'gpt-3.5-turbo', - message: content, - role: 'user', - }; - this.logger.debug(`Request content: "${content}"`); + this.logger.debug(`Request messages: "${input.messages}"`); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); res.setHeader('Connection', 'keep-alive'); this.logger.debug('Response headers set for streaming.'); - await this.llmProvider.generateStreamingResponse(params, res); + await this.llmProvider.generateStreamingResponse(input, res); } catch (error) { this.logger.error('Error in chat endpoint:', error); res.status(500).json({ error: 'Internal server error' }); diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index 77fa89ae..6aafc787 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -36,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { } async generateStreamingResponse( - { model, message, role = 'user' }: GenerateMessageParams, + { model, messages }: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with Llama...'); @@ -50,13 +50,10 @@ export class LlamaModelProvider extends ModelProvider { // Get the system prompt based on the model const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; - const messages = [ - { role: 'system', content: systemPrompt }, - { role: role as 'user' | 'system' | 'assistant', content: message }, - ]; + const allMessage = [{ role: 'system', content: systemPrompt }, ...messages]; // Convert messages array to a single formatted string for Llama - const formattedPrompt = messages + const formattedPrompt = allMessage .map(({ role, content }) => `${role}: ${content}`) .join('\n'); diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index fc845a69..85d7bd5d 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -89,7 +89,10 @@ export class OpenAIModelProvider { private async processRequest(request: QueuedRequest): Promise { const { params, res, retries } = request; - const { model, message, role = 'user' } = params; + const { model, messages } = params as { + model: string; + messages: ChatCompletionMessageParam[]; + }; this.logger.log(`Processing request (attempt ${retries + 1})`); const startTime = Date.now(); @@ -103,14 +106,14 @@ export class OpenAIModelProvider { const systemPrompt = systemPrompts[this.options.systemPromptKey]?.systemPrompt || ''; - const messages: ChatCompletionMessageParam[] = [ + const allMessages: ChatCompletionMessageParam[] = [ { role: 'system', content: systemPrompt }, - { role: role as 'user' | 'system' | 'assistant', content: message }, + ...messages, ]; - + console.log(allMessages); const stream = await this.openai.chat.completions.create({ model, - messages, + messages: allMessages, stream: true, }); @@ -171,8 +174,7 @@ export class OpenAIModelProvider { error: errorResponse, params: { model, - messageLength: message.length, - role, + messageLength: messages.length, }, }); diff --git a/llm-server/src/types.ts b/llm-server/src/types.ts index 2f1bc1df..679996a9 100644 --- a/llm-server/src/types.ts +++ b/llm-server/src/types.ts @@ -1,7 +1,11 @@ +interface MessageInterface { + role: string; + content: string; +} + export interface GenerateMessageParams { model: string; // Model to use, e.g., 'gpt-3.5-turbo' - message: string; // User's message or query - role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role + messages: MessageInterface[]; // User's message or query } // types.ts