Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions backend/src/chat/chat.controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import { Controller, Post, Body, Res, UseGuards } from '@nestjs/common';
import { Response } from 'express';
import { ChatProxyService, ChatService } from './chat.service';
import { ChatRestDto } from './dto/chat-rest.dto';
import { MessageRole } from './message.model';
import { JWTAuthGuard } from '../guard/jwt-auth.guard';
import { ChatGuard } from '../guard/chat.guard';
import { GetAuthToken } from '../decorator/get-auth-token.decorator';

@Controller('api/chat')
@UseGuards(JWTAuthGuard, ChatGuard) // Order matters: JWTAuthGuard sets user object, then ChatGuard uses it
export class ChatController {
constructor(
private readonly chatProxyService: ChatProxyService,
private readonly chatService: ChatService,
) {}

@Post()
async chat(
@Body() chatDto: ChatRestDto,
@Res() res: Response,
@GetAuthToken() userId: string,
) {
try {
// Save user's message first
await this.chatService.saveMessage(
chatDto.chatId,
chatDto.message,
MessageRole.User,
);

if (chatDto.stream) {
// Streaming response
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');

const stream = this.chatProxyService.streamChat({
chatId: chatDto.chatId,
message: chatDto.message,
model: chatDto.model,
});

let fullResponse = '';

for await (const chunk of stream) {
if (chunk.choices[0]?.delta?.content) {
const content = chunk.choices[0].delta.content;
fullResponse += content;
res.write(`data: ${JSON.stringify({ content })}\n\n`);
}
}

// Save the complete message
await this.chatService.saveMessage(
chatDto.chatId,
fullResponse,
MessageRole.Assistant,
);

res.write('data: [DONE]\n\n');
res.end();
} else {
// Non-streaming response using chatSync
const response = await this.chatProxyService.chatSync({
chatId: chatDto.chatId,
message: chatDto.message,
model: chatDto.model,
});

// Save the complete message
await this.chatService.saveMessage(
chatDto.chatId,
response,
MessageRole.Assistant,
);

res.json({ content: response });
}
} catch (error) {
console.error('Chat error:', error);
res.status(500).json({
error: 'An error occurred during chat processing',
});
}
}
}
2 changes: 2 additions & 0 deletions backend/src/chat/chat.module.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Module } from '@nestjs/common';
import { ChatResolver } from './chat.resolver';
import { ChatController } from './chat.controller';
import { ChatProxyService, ChatService } from './chat.service';
import { TypeOrmModule } from '@nestjs/typeorm';
import { User } from 'src/user/user.model';
Expand All @@ -18,6 +19,7 @@ import { UploadModule } from 'src/upload/upload.module';
JwtCacheModule,
UploadModule,
],
controllers: [ChatController],
providers: [
ChatResolver,
ChatProxyService,
Expand Down
13 changes: 10 additions & 3 deletions backend/src/chat/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ export class ChatProxyService {
);
}

async chatSync(input: ChatInput): Promise<string> {
return await this.models.chatSync({
messages: [{ role: MessageRole.User, content: input.message }],
model: input.model,
});
}

async fetchModelTags(): Promise<string[]> {
return await this.models.fetchModelsName();
}
Expand Down Expand Up @@ -173,14 +180,14 @@ export class ChatService {
}

async updateChatTitle(
upateChatTitleInput: UpdateChatTitleInput,
updateChatTitleInput: UpdateChatTitleInput,
): Promise<Chat> {
const chat = await this.chatRepository.findOne({
where: { id: upateChatTitleInput.chatId, isDeleted: false },
where: { id: updateChatTitleInput.chatId, isDeleted: false },
});
new Logger('chat').log('chat', chat);
if (chat) {
chat.title = upateChatTitleInput.title;
chat.title = updateChatTitleInput.title;
chat.updatedAt = new Date();
return await this.chatRepository.save(chat);
}
Expand Down
16 changes: 16 additions & 0 deletions backend/src/chat/dto/chat-rest.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { IsString, IsBoolean, IsOptional } from 'class-validator';

export class ChatRestDto {
@IsString()
chatId: string;

@IsString()
message: string;

@IsString()
model: string;

@IsBoolean()
@IsOptional()
stream?: boolean = false;
}
112 changes: 38 additions & 74 deletions backend/src/guard/chat.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
CanActivate,
ExecutionContext,
UnauthorizedException,
ContextType,
} from '@nestjs/common';
import { GqlExecutionContext } from '@nestjs/graphql';
import { JwtService } from '@nestjs/jwt';
Expand All @@ -16,29 +17,44 @@ export class ChatGuard implements CanActivate {
) {}

async canActivate(context: ExecutionContext): Promise<boolean> {
const gqlContext = GqlExecutionContext.create(context);
const request = gqlContext.getContext().req;
// Determine if this is a GraphQL or REST request
const contextType = context.getType();
let chatId: string;
let user: any;

// Extract the authorization header
const authHeader = request.headers.authorization;
if (!authHeader || !authHeader.startsWith('Bearer ')) {
throw new UnauthorizedException('Authorization token is missing');
if (contextType === 'http') {
// REST request (only for chat stream endpoint)
const request = context.switchToHttp().getRequest();
user = request.user;
chatId = request.body?.chatId;
} else if (contextType === ('graphql' as ContextType)) {
// GraphQL request (for all other chat operations)
const gqlContext = GqlExecutionContext.create(context);
const { req } = gqlContext.getContext();
user = req.user;

const args = gqlContext.getArgs();
chatId =
args.chatId || args.input?.chatId || args.updateChatTitleInput?.chatId;

// Allow chat creation mutation which doesn't require a chatId
const info = gqlContext.getInfo();
if (info.operation.name.value === 'createChat') {
return true;
}
}

// Decode the token to get user information
const token = authHeader.split(' ')[1];
let user: any;
try {
user = this.jwtService.verify(token);
} catch (error) {
throw new UnauthorizedException('Invalid token');
// Common validation for both REST and GraphQL
if (!user) {
throw new UnauthorizedException('User not found');
}

// Extract chatId from the request arguments
const args = gqlContext.getArgs();
const { chatId } = args;
// Skip chat validation for operations that don't require a chatId
if (!chatId) {
return true;
}

// check if the user is part of the chat
// Verify chat ownership for both types of requests
const chat = await this.chatService.getChatWithUser(chatId);
if (!chat) {
throw new UnauthorizedException('Chat not found');
Expand All @@ -54,54 +70,6 @@ export class ChatGuard implements CanActivate {
}
}

// @Injectable()
// export class MessageGuard implements CanActivate {
// constructor(
// private readonly chatService: ChatService, // Inject ChatService to fetch chat details
// private readonly jwtService: JwtService, // JWT Service to verify tokens
// ) {}

// async canActivate(context: ExecutionContext): Promise<boolean> {
// const gqlContext = GqlExecutionContext.create(context);
// const request = gqlContext.getContext().req;

// // Extract the authorization header
// const authHeader = request.headers.authorization;
// if (!authHeader || !authHeader.startsWith('Bearer ')) {
// throw new UnauthorizedException('Authorization token is missing');
// }

// // Decode the token to get user information
// const token = authHeader.split(' ')[1];
// let user: any;
// try {
// user = this.jwtService.verify(token);
// } catch (error) {
// throw new UnauthorizedException('Invalid token');
// }

// // Extract chatId from the request arguments
// const args = gqlContext.getArgs();
// const { messageId } = args;

// // Fetch the message and its associated chat
// const message = await this.chatService.getMessageById(messageId);
// if (!message) {
// throw new UnauthorizedException('Message not found');
// }

// // Ensure that the user is part of the chat the message belongs to
// const chat = message.chat;
// if (chat.user.id !== user.userId) {
// throw new UnauthorizedException(
// 'User is not authorized to access this message',
// );
// }

// return true;
// }
// }

@Injectable()
export class ChatSubscriptionGuard implements CanActivate {
constructor(
Expand All @@ -110,12 +78,9 @@ export class ChatSubscriptionGuard implements CanActivate {
) {}

async canActivate(context: ExecutionContext): Promise<boolean> {
const gqlContext = GqlExecutionContext.create(context);

// For WebSocket context: get token from connectionParams
const token = gqlContext
.getContext()
.connectionParams?.authorization?.split(' ')[1];
const wsContext = context.switchToWs();
const client = wsContext.getClient();
const token = client.handshake?.auth?.token?.split(' ')[1];

if (!token) {
throw new UnauthorizedException('Authorization token is missing');
Expand All @@ -128,9 +93,8 @@ export class ChatSubscriptionGuard implements CanActivate {
throw new UnauthorizedException('Invalid token');
}

// Extract chatId from the subscription arguments
const args = gqlContext.getArgs();
const { chatId } = args;
const data = wsContext.getData();
const { chatId } = data;

// Check if the user is part of the chat
const chat = await this.chatService.getChatWithUser(chatId);
Expand Down
Loading
Loading