diff --git a/backend/src/chat/chat.controller.ts b/backend/src/chat/chat.controller.ts new file mode 100644 index 00000000..c1a5e6ef --- /dev/null +++ b/backend/src/chat/chat.controller.ts @@ -0,0 +1,87 @@ +import { Controller, Post, Body, Res, UseGuards } from '@nestjs/common'; +import { Response } from 'express'; +import { ChatProxyService, ChatService } from './chat.service'; +import { ChatRestDto } from './dto/chat-rest.dto'; +import { MessageRole } from './message.model'; +import { JWTAuthGuard } from '../guard/jwt-auth.guard'; +import { ChatGuard } from '../guard/chat.guard'; +import { GetAuthToken } from '../decorator/get-auth-token.decorator'; + +@Controller('api/chat') +@UseGuards(JWTAuthGuard, ChatGuard) // Order matters: JWTAuthGuard sets user object, then ChatGuard uses it +export class ChatController { + constructor( + private readonly chatProxyService: ChatProxyService, + private readonly chatService: ChatService, + ) {} + + @Post() + async chat( + @Body() chatDto: ChatRestDto, + @Res() res: Response, + @GetAuthToken() userId: string, + ) { + try { + // Save user's message first + await this.chatService.saveMessage( + chatDto.chatId, + chatDto.message, + MessageRole.User, + ); + + if (chatDto.stream) { + // Streaming response + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + + const stream = this.chatProxyService.streamChat({ + chatId: chatDto.chatId, + message: chatDto.message, + model: chatDto.model, + }); + + let fullResponse = ''; + + for await (const chunk of stream) { + if (chunk.choices[0]?.delta?.content) { + const content = chunk.choices[0].delta.content; + fullResponse += content; + res.write(`data: ${JSON.stringify({ content })}\n\n`); + } + } + + // Save the complete message + await this.chatService.saveMessage( + chatDto.chatId, + fullResponse, + MessageRole.Assistant, + ); + + res.write('data: [DONE]\n\n'); + res.end(); + } else { + // Non-streaming response using chatSync + const response = await this.chatProxyService.chatSync({ + chatId: chatDto.chatId, + message: chatDto.message, + model: chatDto.model, + }); + + // Save the complete message + await this.chatService.saveMessage( + chatDto.chatId, + response, + MessageRole.Assistant, + ); + + res.json({ content: response }); + } + } catch (error) { + console.error('Chat error:', error); + res.status(500).json({ + error: 'An error occurred during chat processing', + }); + } + } +} diff --git a/backend/src/chat/chat.module.ts b/backend/src/chat/chat.module.ts index e8c120e5..df69a066 100644 --- a/backend/src/chat/chat.module.ts +++ b/backend/src/chat/chat.module.ts @@ -1,5 +1,6 @@ import { Module } from '@nestjs/common'; import { ChatResolver } from './chat.resolver'; +import { ChatController } from './chat.controller'; import { ChatProxyService, ChatService } from './chat.service'; import { TypeOrmModule } from '@nestjs/typeorm'; import { User } from 'src/user/user.model'; @@ -18,6 +19,7 @@ import { UploadModule } from 'src/upload/upload.module'; JwtCacheModule, UploadModule, ], + controllers: [ChatController], providers: [ ChatResolver, ChatProxyService, diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index a1575437..366b4e20 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -32,6 +32,13 @@ export class ChatProxyService { ); } + async chatSync(input: ChatInput): Promise { + return await this.models.chatSync({ + messages: [{ role: MessageRole.User, content: input.message }], + model: input.model, + }); + } + async fetchModelTags(): Promise { return await this.models.fetchModelsName(); } @@ -173,14 +180,14 @@ export class ChatService { } async updateChatTitle( - upateChatTitleInput: UpdateChatTitleInput, + updateChatTitleInput: UpdateChatTitleInput, ): Promise { const chat = await this.chatRepository.findOne({ - where: { id: upateChatTitleInput.chatId, isDeleted: false }, + where: { id: updateChatTitleInput.chatId, isDeleted: false }, }); new Logger('chat').log('chat', chat); if (chat) { - chat.title = upateChatTitleInput.title; + chat.title = updateChatTitleInput.title; chat.updatedAt = new Date(); return await this.chatRepository.save(chat); } diff --git a/backend/src/chat/dto/chat-rest.dto.ts b/backend/src/chat/dto/chat-rest.dto.ts new file mode 100644 index 00000000..05f798a9 --- /dev/null +++ b/backend/src/chat/dto/chat-rest.dto.ts @@ -0,0 +1,16 @@ +import { IsString, IsBoolean, IsOptional } from 'class-validator'; + +export class ChatRestDto { + @IsString() + chatId: string; + + @IsString() + message: string; + + @IsString() + model: string; + + @IsBoolean() + @IsOptional() + stream?: boolean = false; +} diff --git a/backend/src/guard/chat.guard.ts b/backend/src/guard/chat.guard.ts index 5a33dd32..69654e64 100644 --- a/backend/src/guard/chat.guard.ts +++ b/backend/src/guard/chat.guard.ts @@ -3,6 +3,7 @@ import { CanActivate, ExecutionContext, UnauthorizedException, + ContextType, } from '@nestjs/common'; import { GqlExecutionContext } from '@nestjs/graphql'; import { JwtService } from '@nestjs/jwt'; @@ -16,29 +17,44 @@ export class ChatGuard implements CanActivate { ) {} async canActivate(context: ExecutionContext): Promise { - const gqlContext = GqlExecutionContext.create(context); - const request = gqlContext.getContext().req; + // Determine if this is a GraphQL or REST request + const contextType = context.getType(); + let chatId: string; + let user: any; - // Extract the authorization header - const authHeader = request.headers.authorization; - if (!authHeader || !authHeader.startsWith('Bearer ')) { - throw new UnauthorizedException('Authorization token is missing'); + if (contextType === 'http') { + // REST request (only for chat stream endpoint) + const request = context.switchToHttp().getRequest(); + user = request.user; + chatId = request.body?.chatId; + } else if (contextType === ('graphql' as ContextType)) { + // GraphQL request (for all other chat operations) + const gqlContext = GqlExecutionContext.create(context); + const { req } = gqlContext.getContext(); + user = req.user; + + const args = gqlContext.getArgs(); + chatId = + args.chatId || args.input?.chatId || args.updateChatTitleInput?.chatId; + + // Allow chat creation mutation which doesn't require a chatId + const info = gqlContext.getInfo(); + if (info.operation.name.value === 'createChat') { + return true; + } } - // Decode the token to get user information - const token = authHeader.split(' ')[1]; - let user: any; - try { - user = this.jwtService.verify(token); - } catch (error) { - throw new UnauthorizedException('Invalid token'); + // Common validation for both REST and GraphQL + if (!user) { + throw new UnauthorizedException('User not found'); } - // Extract chatId from the request arguments - const args = gqlContext.getArgs(); - const { chatId } = args; + // Skip chat validation for operations that don't require a chatId + if (!chatId) { + return true; + } - // check if the user is part of the chat + // Verify chat ownership for both types of requests const chat = await this.chatService.getChatWithUser(chatId); if (!chat) { throw new UnauthorizedException('Chat not found'); @@ -54,54 +70,6 @@ export class ChatGuard implements CanActivate { } } -// @Injectable() -// export class MessageGuard implements CanActivate { -// constructor( -// private readonly chatService: ChatService, // Inject ChatService to fetch chat details -// private readonly jwtService: JwtService, // JWT Service to verify tokens -// ) {} - -// async canActivate(context: ExecutionContext): Promise { -// const gqlContext = GqlExecutionContext.create(context); -// const request = gqlContext.getContext().req; - -// // Extract the authorization header -// const authHeader = request.headers.authorization; -// if (!authHeader || !authHeader.startsWith('Bearer ')) { -// throw new UnauthorizedException('Authorization token is missing'); -// } - -// // Decode the token to get user information -// const token = authHeader.split(' ')[1]; -// let user: any; -// try { -// user = this.jwtService.verify(token); -// } catch (error) { -// throw new UnauthorizedException('Invalid token'); -// } - -// // Extract chatId from the request arguments -// const args = gqlContext.getArgs(); -// const { messageId } = args; - -// // Fetch the message and its associated chat -// const message = await this.chatService.getMessageById(messageId); -// if (!message) { -// throw new UnauthorizedException('Message not found'); -// } - -// // Ensure that the user is part of the chat the message belongs to -// const chat = message.chat; -// if (chat.user.id !== user.userId) { -// throw new UnauthorizedException( -// 'User is not authorized to access this message', -// ); -// } - -// return true; -// } -// } - @Injectable() export class ChatSubscriptionGuard implements CanActivate { constructor( @@ -110,12 +78,9 @@ export class ChatSubscriptionGuard implements CanActivate { ) {} async canActivate(context: ExecutionContext): Promise { - const gqlContext = GqlExecutionContext.create(context); - - // For WebSocket context: get token from connectionParams - const token = gqlContext - .getContext() - .connectionParams?.authorization?.split(' ')[1]; + const wsContext = context.switchToWs(); + const client = wsContext.getClient(); + const token = client.handshake?.auth?.token?.split(' ')[1]; if (!token) { throw new UnauthorizedException('Authorization token is missing'); @@ -128,9 +93,8 @@ export class ChatSubscriptionGuard implements CanActivate { throw new UnauthorizedException('Invalid token'); } - // Extract chatId from the subscription arguments - const args = gqlContext.getArgs(); - const { chatId } = args; + const data = wsContext.getData(); + const { chatId } = data; // Check if the user is part of the chat const chat = await this.chatService.getChatWithUser(chatId); diff --git a/backend/src/guard/jwt-auth.guard.ts b/backend/src/guard/jwt-auth.guard.ts index 5739d0b9..40324515 100644 --- a/backend/src/guard/jwt-auth.guard.ts +++ b/backend/src/guard/jwt-auth.guard.ts @@ -4,6 +4,7 @@ import { ExecutionContext, UnauthorizedException, Logger, + ContextType, } from '@nestjs/common'; import { GqlExecutionContext } from '@nestjs/graphql'; import { JwtService } from '@nestjs/jwt'; @@ -21,45 +22,71 @@ export class JWTAuthGuard implements CanActivate { async canActivate(context: ExecutionContext): Promise { this.logger.debug('Starting JWT authentication process'); - const gqlContext = GqlExecutionContext.create(context); - const { req } = gqlContext.getContext(); + let request; + const contextType = context.getType(); + this.logger.debug(`Context Type: ${contextType}`); + + if (contextType === 'http') { + request = context.switchToHttp().getRequest(); + this.logger.debug( + `HTTP Request Headers: ${JSON.stringify(request.headers)}`, + ); + } else if (contextType === ('graphql' as ContextType)) { + // GraphQL API + const gqlContext = GqlExecutionContext.create(context); + const { req } = gqlContext.getContext(); + request = req; + this.logger.debug('GraphQL request detected'); + } try { - const token = this.extractTokenFromHeader(req); + const token = this.extractTokenFromHeader(request); + this.logger.debug(`Extracted Token: ${token}`); const payload = await this.verifyToken(token); + this.logger.debug(`Token Verified. Payload: ${JSON.stringify(payload)}`); const isTokenValid = await this.jwtCacheService.isTokenStored(token); + this.logger.debug(`Token stored in cache: ${isTokenValid}`); + if (!isTokenValid) { + this.logger.warn('Token has been invalidated'); throw new UnauthorizedException('Token has been invalidated'); } - req.user = payload; + request.user = payload; + this.logger.debug('User successfully authenticated'); return true; } catch (error) { + this.logger.error(`Authentication failed: ${error.message}`); + if (error instanceof UnauthorizedException) { throw error; } - this.logger.error('Authentication failed:', error); + throw new UnauthorizedException('Invalid authentication token'); } } private extractTokenFromHeader(req: any): string { const authHeader = req.headers.authorization; + this.logger.debug(`Authorization Header: ${authHeader}`); if (!authHeader) { + this.logger.warn('Authorization header is missing'); throw new UnauthorizedException('Authorization header is missing'); } const [type, token] = authHeader.split(' '); if (type !== 'Bearer') { + this.logger.warn('Invalid authorization header format'); throw new UnauthorizedException('Invalid authorization header format'); } if (!token) { + this.logger.warn('Token is missing'); throw new UnauthorizedException('Token is missing'); } @@ -68,14 +95,18 @@ export class JWTAuthGuard implements CanActivate { private async verifyToken(token: string): Promise { try { + this.logger.debug(`Verifying Token: ${token}`); return await this.jwtService.verifyAsync(token); } catch (error) { if (error.name === 'TokenExpiredError') { + this.logger.warn('Token has expired'); throw new UnauthorizedException('Token has expired'); } if (error.name === 'JsonWebTokenError') { + this.logger.warn('Invalid token'); throw new UnauthorizedException('Invalid token'); } + this.logger.error(`Token verification failed: ${error.message}`); throw error; } } diff --git a/backend/src/interceptor/LoggingInterceptor.ts b/backend/src/interceptor/LoggingInterceptor.ts index e05e9a55..bbd55dd9 100644 --- a/backend/src/interceptor/LoggingInterceptor.ts +++ b/backend/src/interceptor/LoggingInterceptor.ts @@ -4,30 +4,73 @@ import { ExecutionContext, CallHandler, Logger, + ContextType, } from '@nestjs/common'; import { Observable } from 'rxjs'; import { GqlExecutionContext } from '@nestjs/graphql'; @Injectable() export class LoggingInterceptor implements NestInterceptor { - private readonly logger = new Logger('GraphQLRequest'); + private readonly logger = new Logger('RequestLogger'); intercept(context: ExecutionContext, next: CallHandler): Observable { + const contextType = context.getType(); + this.logger.debug(`Intercepting request, Context Type: ${contextType}`); + + if (contextType === ('graphql' as ContextType)) { + return this.handleGraphQLRequest(context, next); + } else if (contextType === 'http') { + return this.handleRestRequest(context, next); + } else { + this.logger.warn('Unknown request type, skipping logging.'); + return next.handle(); + } + } + + private handleGraphQLRequest( + context: ExecutionContext, + next: CallHandler, + ): Observable { const ctx = GqlExecutionContext.create(context); - const { operation, fieldName } = ctx.getInfo(); + const info = ctx.getInfo(); + if (!info) { + this.logger.warn( + 'GraphQL request detected, but ctx.getInfo() is undefined.', + ); + return next.handle(); + } + + const { operation, fieldName } = info; let variables = ''; + try { - variables = ctx.getContext().req.body.variables; + variables = JSON.stringify(ctx.getContext()?.req?.body?.variables ?? {}); } catch (error) { - variables = ''; + variables = '{}'; } this.logger.log( - `${operation.operation.toUpperCase()} \x1B[33m${fieldName}\x1B[39m${ - variables ? ` Variables: ${JSON.stringify(variables)}` : '' + `[GraphQL] ${operation.operation.toUpperCase()} \x1B[33m${fieldName}\x1B[39m${ + variables ? ` Variables: ${variables}` : '' }`, ); return next.handle(); } + + private handleRestRequest( + context: ExecutionContext, + next: CallHandler, + ): Observable { + const httpContext = context.switchToHttp(); + const request = httpContext.getRequest(); + + const { method, url, body } = request; + + this.logger.log( + `[REST] ${method.toUpperCase()} ${url} Body: ${JSON.stringify(body)}`, + ); + + return next.handle(); + } } diff --git a/frontend/next.config.mjs b/frontend/next.config.mjs index b2810022..9617083d 100644 --- a/frontend/next.config.mjs +++ b/frontend/next.config.mjs @@ -9,9 +9,8 @@ const nextConfig = { // Fixes npm packages that depend on `fs` module if (!isServer) { config.resolve.fallback = { - ...config.resolve.fallback, // if you miss it, all the other options in fallback, specified - // by next.js will be dropped. Doesn't make much sense, but how it is - fs: false, // the solution + ...config.resolve.fallback, + fs: false, module: false, perf_hooks: false, }; @@ -36,6 +35,16 @@ const nextConfig = { }, ], }, + + // Add proxy configuration for API + async rewrites() { + return [ + { + source: '/api/:path*', + destination: 'http://localhost:8080/api/:path*', + }, + ]; + }, }; export default nextConfig; diff --git a/frontend/src/hooks/useChatStream.ts b/frontend/src/hooks/useChatStream.ts index 30c41680..eb67ad3e 100644 --- a/frontend/src/hooks/useChatStream.ts +++ b/frontend/src/hooks/useChatStream.ts @@ -1,28 +1,10 @@ -import { useState, useCallback } from 'react'; -import { useMutation, useSubscription } from '@apollo/client'; -import { CHAT_STREAM, CREATE_CHAT, TRIGGER_CHAT } from '@/graphql/request'; +import { useState, useCallback, useEffect } from 'react'; +import { useMutation } from '@apollo/client'; +import { CREATE_CHAT } from '@/graphql/request'; import { Message } from '@/const/MessageType'; import { toast } from 'sonner'; -import { useRouter } from 'next/navigation'; import { logger } from '@/app/log/logger'; -enum StreamStatus { - IDLE = 'IDLE', - STREAMING = 'STREAMING', - DONE = 'DONE', -} - -interface ChatInput { - chatId: string; - message: string; - model: string; -} - -interface SubscriptionState { - enabled: boolean; - variables: { - input: ChatInput; - } | null; -} +import { useAuthContext } from '@/providers/AuthProvider'; interface UseChatStreamProps { chatId: string; @@ -32,131 +14,148 @@ interface UseChatStreamProps { selectedModel: string; } -export function useChatStream({ +export const useChatStream = ({ chatId, input, setInput, setMessages, selectedModel, -}: UseChatStreamProps) { +}: UseChatStreamProps) => { const [loadingSubmit, setLoadingSubmit] = useState(false); - const [streamStatus, setStreamStatus] = useState( - StreamStatus.IDLE - ); const [currentChatId, setCurrentChatId] = useState(chatId); + const { token } = useAuthContext(); - const [subscription, setSubscription] = useState({ - enabled: false, - variables: null, - }); + // Use useEffect to handle new chat event and cleanup + useEffect(() => { + const updateChatId = () => { + setCurrentChatId(''); + setMessages([]); // Clear messages for new chat + }; - const updateChatId = () => { - setCurrentChatId(''); - }; + // Only add event listener when we want to create a new chat + if (!chatId) { + window.addEventListener('newchat', updateChatId); + } - window.addEventListener('newchat', updateChatId); + // Cleanup + return () => { + window.removeEventListener('newchat', updateChatId); + }; + }, [chatId, setMessages]); - const [triggerChat] = useMutation(TRIGGER_CHAT, { - onCompleted: () => { - setStreamStatus(StreamStatus.STREAMING); - }, - onError: () => { - setStreamStatus(StreamStatus.IDLE); - finishChatResponse(); - }, - }); + // Update currentChatId when chatId prop changes + useEffect(() => { + setCurrentChatId(chatId); + }, [chatId]); const [createChat] = useMutation(CREATE_CHAT, { onCompleted: async (data) => { const newChatId = data.createChat.id; setCurrentChatId(newChatId); - await startChatStream(newChatId, input); + await handleChatResponse(newChatId, input); window.history.pushState({}, '', `/chat?id=${newChatId}`); logger.info(`new chat: ${newChatId}`); }, onError: () => { toast.error('Failed to create chat'); - setStreamStatus(StreamStatus.IDLE); setLoadingSubmit(false); }, }); - useSubscription(CHAT_STREAM, { - skip: !subscription.enabled || !subscription.variables, - variables: subscription.variables, - onSubscriptionData: ({ subscriptionData }) => { - const chatStream = subscriptionData?.data?.chatStream; - if (!chatStream) return; - - if (streamStatus === StreamStatus.STREAMING && loadingSubmit) { - setLoadingSubmit(false); - } - - if (chatStream.status === StreamStatus.DONE) { - setStreamStatus(StreamStatus.DONE); - finishChatResponse(); - return; - } + const startChatStream = async ( + targetChatId: string, + message: string, + model: string, + stream: boolean = false // Default to non-streaming for better performance + ): Promise => { + if (!token) { + throw new Error('Not authenticated'); + } - const content = chatStream.choices?.[0]?.delta?.content; - - if (content) { - setMessages((prev) => { - const lastMsg = prev[prev.length - 1]; - if (lastMsg?.role === 'assistant') { - return [ - ...prev.slice(0, -1), - { ...lastMsg, content: lastMsg.content + content }, - ]; - } else { - return [ - ...prev, - { - id: chatStream.id, - role: 'assistant', - content, - createdAt: new Date(chatStream.created * 1000).toISOString(), - }, - ]; - } - }); - } + const response = await fetch('/api/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chatId: targetChatId, + message, + model, + stream, + }), + }); - if (chatStream.choices?.[0]?.finishReason === 'stop') { - setStreamStatus(StreamStatus.DONE); - finishChatResponse(); - } - }, - onError: (error) => { - logger.info(error); - toast.error('Connection error. Please try again.'); - setStreamStatus(StreamStatus.IDLE); - finishChatResponse(); - }, - }); + if (!response.ok) { + throw new Error( + `Network response was not ok: ${response.status} ${response.statusText}` + ); + } + // TODO: Handle streaming responses properly + // if (stream) { + // // For streaming responses, aggregate the streamed content + // let fullContent = ''; + // const reader = response.body?.getReader(); + // if (!reader) { + // throw new Error('No reader available'); + // } + + // while (true) { + // const { done, value } = await reader.read(); + // if (done) break; + + // const text = new TextDecoder().decode(value); + // const lines = text.split('\n\n'); + + // for (const line of lines) { + // if (line.startsWith('data: ')) { + // const data = line.slice(5); + // if (data === '[DONE]') break; + // try { + // const { content } = JSON.parse(data); + // if (content) { + // fullContent += content; + // } + // } catch (e) { + // console.error('Error parsing SSE data:', e); + // } + // } + // } + // } + // return fullContent; + // } else { + // // For non-streaming responses, return the content directly + // const data = await response.json(); + // return data.content; + // } + + const data = await response.json(); + return data.content; + }; - const startChatStream = async (targetChatId: string, message: string) => { + const handleChatResponse = async (targetChatId: string, message: string) => { try { - const input: ChatInput = { - chatId: targetChatId, + setInput(''); + const response = await startChatStream( + targetChatId, message, - model: selectedModel, - }; - logger.info(input); + selectedModel + ); + + setMessages((prev) => [ + ...prev, + { + id: `${targetChatId}/${prev.length}`, + role: 'assistant', + content: response, + createdAt: new Date().toISOString(), + }, + ]); - setInput(''); - setStreamStatus(StreamStatus.STREAMING); - setSubscription({ - enabled: true, - variables: { input }, - }); - - await new Promise((resolve) => setTimeout(resolve, 100)); - await triggerChat({ variables: { input } }); + setLoadingSubmit(false); } catch (err) { - toast.error('Failed to start chat'); - setStreamStatus(StreamStatus.IDLE); - finishChatResponse(); + toast.error('Failed to get chat response' + err); + setLoadingSubmit(false); } }; @@ -192,21 +191,10 @@ export function useChatStream({ return; } } else { - await startChatStream(currentChatId, content); + await handleChatResponse(currentChatId, content); } }; - const finishChatResponse = useCallback(() => { - setLoadingSubmit(false); - setSubscription({ - enabled: false, - variables: null, - }); - if (streamStatus === StreamStatus.DONE) { - setStreamStatus(StreamStatus.IDLE); - } - }, [streamStatus]); - const handleInputChange = useCallback( (e: React.ChangeEvent) => { setInput(e.target.value); @@ -215,23 +203,19 @@ export function useChatStream({ ); const stop = useCallback(() => { - if (streamStatus === StreamStatus.STREAMING) { - setSubscription({ - enabled: false, - variables: null, - }); - setStreamStatus(StreamStatus.IDLE); + if (loadingSubmit) { setLoadingSubmit(false); toast.info('Message generation stopped'); } - }, [streamStatus]); + }, [loadingSubmit]); return { loadingSubmit, handleSubmit, handleInputChange, stop, - isStreaming: streamStatus === StreamStatus.STREAMING, + isStreaming: loadingSubmit, currentChatId, + startChatStream, }; -} +};