diff --git a/auth.go b/auth.go index 6a775b3..fbf9b16 100644 --- a/auth.go +++ b/auth.go @@ -1,8 +1,12 @@ package main import ( + "errors" + "fmt" + "log" "net/http" "os" + "strings" "time" "github.com/gin-gonic/gin" @@ -11,8 +15,16 @@ import ( "golang.org/x/crypto/bcrypt" ) +const ( + // TokenExpiration defines how long a JWT token is valid + TokenExpiration = 24 * time.Hour + // MinPasswordLength defines minimum password length + MinPasswordLength = 8 +) + +// Authentication-related request and response structures type RegisterRequest struct { - Username string `json:"username" binding:"required"` + Username string `json:"username" binding:"required,min=3,max=50"` Email string `json:"email" binding:"required,email"` Password string `json:"password" binding:"required,min=6"` } @@ -23,19 +35,51 @@ type LoginRequest struct { } type AuthResponse struct { - Token string `json:"token"` - User User `json:"user"` + Token string `json:"token"` + ExpiresAt time.Time `json:"expiresAt"` + User UserDTO `json:"user"` } -var jwtSecret = os.Getenv("JWT_SECRET") // JWT secret from environment variable +// UserDTO is a Data Transfer Object for User information +type UserDTO struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + Email string `json:"email"` +} + +// UserClaims defines the claims in JWT token +type UserClaims struct { + UserID uuid.UUID `json:"userId"` + jwt.RegisteredClaims +} +// convertToUserDTO converts User model to UserDTO +func convertToUserDTO(user User) UserDTO { + return UserDTO{ + ID: user.ID, + Username: user.Username, + Email: user.Email, + } +} + +// registerHandler handles user registration func registerHandler(c *gin.Context) { var req RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input", "details": err.Error()}) + return + } + + // Validate password strength + if err := validatePassword(req.Password); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + // Normalize input + req.Username = strings.TrimSpace(req.Username) + req.Email = strings.TrimSpace(strings.ToLower(req.Email)) + // Check if user exists var existingUser User if result := db.Where("username = ? OR email = ?", req.Username, req.Email).First(&existingUser); result.Error == nil { @@ -43,10 +87,11 @@ func registerHandler(c *gin.Context) { return } - // Hash password + // Hash password with appropriate cost hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to hash password"}) + log.Printf("Failed to hash password: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process registration"}) return } @@ -58,66 +103,118 @@ func registerHandler(c *gin.Context) { } if result := db.Create(&user); result.Error != nil { + log.Printf("Failed to create user: %v", result.Error) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) return } // Generate JWT token - token, err := generateJWT(user) + token, expiresAt, err := generateJWT(user) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + log.Printf("Failed to generate token: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate authentication token"}) return } c.JSON(http.StatusCreated, AuthResponse{ - Token: token, - User: user, + Token: token, + ExpiresAt: expiresAt, + User: convertToUserDTO(user), }) } +// loginHandler handles user login func loginHandler(c *gin.Context) { var req LoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input", "details": err.Error()}) return } // Find user var user User if result := db.Where("username = ?", req.Username).First(&user); result.Error != nil { + // Use same error message for username not found and password mismatch + // to avoid leaking information about existing usernames c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) return } - // Check password + // Check password with constant-time comparison if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { + // Log failed login attempts for security monitoring + log.Printf("Failed login attempt for user %s from IP %s", req.Username, c.ClientIP()) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) return } // Generate JWT token - token, err := generateJWT(user) + token, expiresAt, err := generateJWT(user) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + log.Printf("Failed to generate token: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate authentication token"}) return } c.JSON(http.StatusOK, AuthResponse{ - Token: token, - User: user, + Token: token, + ExpiresAt: expiresAt, + User: convertToUserDTO(user), }) } -func generateJWT(user User) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": user.ID.String(), - "exp": time.Now().Add(time.Hour * 24).Unix(), - }) +// validatePassword checks password strength +func validatePassword(password string) error { + if len(password) < MinPasswordLength { + return fmt.Errorf("password must be at least %d characters long", MinPasswordLength) + } + + // Check for at least one number + hasNumber := false + for _, char := range password { + if char >= '0' && char <= '9' { + hasNumber = true + break + } + } + + if !hasNumber { + return errors.New("password must contain at least one number") + } - return token.SignedString([]byte(jwtSecret)) + return nil } -// Middleware to authenticate requests +// generateJWT creates a new JWT token for the given user +func generateJWT(user User) (string, time.Time, error) { + jwtSecret := []byte(os.Getenv("JWT_SECRET")) + if len(jwtSecret) == 0 { + return "", time.Time{}, errors.New("JWT_SECRET environment variable not set") + } + + expiresAt := time.Now().Add(TokenExpiration) + + claims := UserClaims{ + UserID: user.ID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expiresAt), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "hop-backend", + Subject: user.ID.String(), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString(jwtSecret) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err) + } + + return signedToken, expiresAt, nil +} + +// authMiddleware verifies JWT token and sets user ID in context func authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { tokenString := c.GetHeader("Authorization") @@ -132,38 +229,32 @@ func authMiddleware() gin.HandlerFunc { tokenString = tokenString[7:] } - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte(jwtSecret), nil + // Parse the JWT token + claims := &UserClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + // Validate the signing algorithm + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(os.Getenv("JWT_SECRET")), nil }) - if err != nil || !token.Valid { + // Handle token parsing errors + if err != nil { + log.Printf("Token validation error: %v", err) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) - c.Abort() - return - } - - userIDStr, ok := claims["sub"].(string) - if !ok { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"}) - c.Abort() - return - } - - userID, err := uuid.Parse(userIDStr) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"}) + if !token.Valid { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } - c.Set("userID", userID) + // Set the user ID in the context + c.Set("userID", claims.UserID) c.Next() } } diff --git a/chat.go b/chat.go index f5b2d8a..3c27b66 100644 --- a/chat.go +++ b/chat.go @@ -7,20 +7,28 @@ import ( "encoding/json" "fmt" "io" - "strings" "log" "net/http" "os" + "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" ) +// Constants for Redis +const ( + RedisCacheTTL = time.Hour * 24 + ChatCacheKeyFormat = "chat:%s" + MsgCacheKeyFormat = "chat:%s:messages" +) + +// Request and response types type CreateChatRequest struct { - ID uuid.UUID `json:"id"` - Name string `json:"name" binding:"required"` - IsPrivate bool `json:"isPrivate"` + ID uuid.UUID `json:"id"` + Name string `json:"name" binding:"required"` + IsPrivate bool `json:"isPrivate"` } type ChatResponse struct { @@ -35,6 +43,15 @@ type MessageRequest struct { Model string `json:"model" binding:"omitempty"` // Optional, will use OLLAMA_DEFAULT_MODEL if not provided } +type MessageResponse struct { + ID uuid.UUID `json:"id"` + ChatID uuid.UUID `json:"chatId"` + Role string `json:"role"` + Content string `json:"content"` + Timestamp time.Time `json:"timestamp"` +} + +// Ollama API types type OllamaMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -72,182 +89,157 @@ type APIError struct { Type string `json:"type"` } -type MessageResponse struct { - ID uuid.UUID `json:"id"` - ChatID uuid.UUID `json:"chatId"` - Role string `json:"role"` - Content string `json:"content"` - Timestamp time.Time `json:"timestamp"` +// Redis cache helper functions +func getChatCacheKey(chatID uuid.UUID) string { + return fmt.Sprintf(ChatCacheKeyFormat, chatID) } -func listChatsHandler(c *gin.Context) { - userID := c.MustGet("userID").(uuid.UUID) - var chats []Chat - - if err := db.Where("user_id = ?", userID).Find(&chats).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch chats"}) - return - } - - var response = make([]ChatResponse, 0) - for _, chat := range chats { - response = append(response, ChatResponse{ - ID: chat.ID, - Name: chat.Name, - IsPrivate: chat.IsPrivate, - CreatedAt: chat.CreatedAt, - }) - } - - c.JSON(http.StatusOK, response) +func getMessagesCacheKey(chatID uuid.UUID) string { + return fmt.Sprintf(MsgCacheKeyFormat, chatID.String()) } -func createChatHandler(c *gin.Context) { - var req CreateChatRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - userID := c.MustGet("userID").(uuid.UUID) - chat := Chat{ - ID: req.ID, - Name: req.Name, - IsPrivate: req.IsPrivate, - UserID: userID, +func cacheObject(ctx context.Context, key string, obj interface{}) error { + data, err := json.Marshal(obj) + if err != nil { + return fmt.Errorf("failed to marshal object: %v", err) } - if err := db.Create(&chat).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create chat"}) - return - } + return rdb.Set(ctx, key, data, RedisCacheTTL).Err() +} - // Cache the chat data in Redis - ctx := context.Background() - chatResponse := ChatResponse{ - ID: chat.ID, - Name: chat.Name, - IsPrivate: chat.IsPrivate, - CreatedAt: chat.CreatedAt, - } - chatJSON, _ := json.Marshal(chatResponse) - chatKey := fmt.Sprintf("chat:%s", chat.ID) - if err := rdb.Set(ctx, chatKey, chatJSON, time.Hour*24).Err(); err != nil { - log.Printf("Failed to cache chat data: %v", err) +func getCachedObject(ctx context.Context, key string, obj interface{}) error { + data, err := rdb.Get(ctx, key).Bytes() + if err != nil { + return err } - c.JSON(http.StatusCreated, chatResponse) + return json.Unmarshal(data, obj) } -func getChatHandler(c *gin.Context) { - chatID, err := uuid.Parse(c.Param("chatId")) +// Message handling functions +func cacheMessage(ctx context.Context, chatID uuid.UUID, msg MessageResponse) error { + msgJSON, err := json.Marshal(msg) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid chat ID"}) - return + return fmt.Errorf("failed to marshal message: %v", err) } - ctx := context.Background() - chatKey := fmt.Sprintf("chat:%s", chatID) + cacheKey := getMessagesCacheKey(chatID) + pipe := rdb.Pipeline() + pipe.RPush(ctx, cacheKey, msgJSON) + pipe.Expire(ctx, cacheKey, RedisCacheTTL) - // Try to get chat from Redis first - chatData, err := rdb.Get(ctx, chatKey).Bytes() - if err == nil { - var chatResponse ChatResponse - if err := json.Unmarshal(chatData, &chatResponse); err == nil { - c.JSON(http.StatusOK, chatResponse) - return - } - } + _, err = pipe.Exec(ctx) + return err +} - // If not in cache, get from database - var chat Chat - if err := db.First(&chat, "id = ?", chatID).Error; err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Chat not found"}) - return +func createAndSaveMessage(ctx context.Context, chatID uuid.UUID, role, content string) (*Message, error) { + message := Message{ + ChatID: chatID, + Role: role, + Content: content, + Timestamp: time.Now(), } - // Cache the chat data - chatResponse := ChatResponse{ - ID: chat.ID, - Name: chat.Name, - IsPrivate: chat.IsPrivate, - CreatedAt: chat.CreatedAt, + // Save to database + if err := db.Create(&message).Error; err != nil { + return nil, fmt.Errorf("failed to save message: %v", err) } - chatJSON, _ := json.Marshal(chatResponse) - if err := rdb.Set(ctx, chatKey, chatJSON, time.Hour*24).Err(); err != nil { - log.Printf("Failed to cache chat data: %v", err) + + // Cache the message + msgResponse := messageToResponse(message) + if err := cacheMessage(ctx, chatID, msgResponse); err != nil { + log.Printf("Failed to cache message: %v", err) } - c.JSON(http.StatusOK, chatResponse) + return &message, nil } -func cacheMessage(ctx context.Context, cacheKey string, msg MessageResponse) error { - msgJSON, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal message: %v", err) +func getChatMessages(ctx context.Context, chatID uuid.UUID) ([]Message, error) { + cacheKey := getMessagesCacheKey(chatID) + + // Try cache first + messages, err := getCachedMessages(ctx, cacheKey) + if err == nil && len(messages) > 0 { + return messages, nil } - msgPipe := rdb.Pipeline() - msgPipe.RPush(ctx, cacheKey, msgJSON) - msgPipe.Expire(ctx, cacheKey, time.Hour*24) + // Fallback to database + var dbMessages []Message + if err := db.Where("chat_id = ?", chatID).Order("created_at ASC").Find(&dbMessages).Error; err != nil { + return nil, err + } - if _, err := msgPipe.Exec(ctx); err != nil { - return fmt.Errorf("failed to cache message: %v", err) + // Cache results for next time + if err := cacheMessagesFromDB(ctx, cacheKey, dbMessages); err != nil { + log.Printf("Failed to cache messages: %v", err) } - return nil + + return dbMessages, nil } func getCachedMessages(ctx context.Context, cacheKey string) ([]Message, error) { var messages []Message - // Try to get messages from Redis cachedMsgs, err := rdb.LRange(ctx, cacheKey, 0, -1).Result() if err != nil { - return nil, fmt.Errorf("failed to get messages from cache: %v", err) + return nil, err } for _, msgStr := range cachedMsgs { var msgResponse MessageResponse if err := json.Unmarshal([]byte(msgStr), &msgResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal message: %v", err) + return nil, err } - messages = append(messages, Message{ - ID: msgResponse.ID, - ChatID: msgResponse.ChatID, - Role: msgResponse.Role, - Content: msgResponse.Content, - Timestamp: msgResponse.Timestamp, - }) + messages = append(messages, responseToMessage(msgResponse)) } return messages, nil } func cacheMessagesFromDB(ctx context.Context, cacheKey string, messages []Message) error { - msgPipe := rdb.Pipeline() - msgPipe.Del(ctx, cacheKey) + pipe := rdb.Pipeline() + pipe.Del(ctx, cacheKey) for _, msg := range messages { - msgResponse := MessageResponse{ - ID: msg.ID, - ChatID: msg.ChatID, - Role: msg.Role, - Content: msg.Content, - Timestamp: msg.Timestamp, - } - msgJSON, err := json.Marshal(msgResponse) + msgJSON, err := json.Marshal(messageToResponse(msg)) if err != nil { - log.Printf("Failed to marshal message: %v", err) continue } - msgPipe.RPush(ctx, cacheKey, msgJSON) + pipe.RPush(ctx, cacheKey, msgJSON) + } + + pipe.Expire(ctx, cacheKey, RedisCacheTTL) + _, err := pipe.Exec(ctx) + return err +} + +// Conversion helpers +func messageToResponse(msg Message) MessageResponse { + return MessageResponse{ + ID: msg.ID, + ChatID: msg.ChatID, + Role: msg.Role, + Content: msg.Content, + Timestamp: msg.Timestamp, } +} - msgPipe.Expire(ctx, cacheKey, time.Hour*24) - if _, err := msgPipe.Exec(ctx); err != nil { - return fmt.Errorf("failed to cache messages: %v", err) +func responseToMessage(resp MessageResponse) Message { + return Message{ + ID: resp.ID, + ChatID: resp.ChatID, + Role: resp.Role, + Content: resp.Content, + Timestamp: resp.Timestamp, + } +} + +func messagesToResponses(messages []Message) []MessageResponse { + responses := make([]MessageResponse, len(messages)) + for i, msg := range messages { + responses[i] = messageToResponse(msg) } - return nil + return responses } func convertToOllamaMessages(messages []Message) []OllamaMessage { @@ -261,116 +253,157 @@ func convertToOllamaMessages(messages []Message) []OllamaMessage { return ollamaMessages } -func convertMessagesToResponse(messages []Message) []MessageResponse { - response := make([]MessageResponse, len(messages)) - for i, msg := range messages { - response[i] = MessageResponse{ - ID: msg.ID, - ChatID: msg.ChatID, - Role: msg.Role, - Content: msg.Content, - Timestamp: msg.Timestamp, - } +// Ollama API helpers +func prepareOllamaRequest(messages []OllamaMessage, model string) OllamaChatRequest { + if model == "" { + model = os.Getenv("OLLAMA_DEFAULT_MODEL") + } + + return OllamaChatRequest{ + Model: model, + Messages: messages, + Stream: true, + Temperature: 0.7, } - return response } -func getChatMessagesHandler(c *gin.Context) { - chatID, err := uuid.Parse(c.Param("chatId")) +func getOllamaURL() string { + ollamaHost := os.Getenv("OLLAMA_HOST") + return fmt.Sprintf("https://%s/chat/completions", ollamaHost) +} + +func setupOllamaRequest(messages []OllamaMessage, model string) (*http.Request, error) { + ollamaReq := prepareOllamaRequest(messages, model) + + reqBody, err := json.Marshal(ollamaReq) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid chat ID"}) - return + return nil, fmt.Errorf("failed to marshal request: %v", err) } - ctx := context.Background() - cacheKey := fmt.Sprintf("chat:%s:messages", chatID.String()) + // Log the request for debugging + log.Printf("🌐 Sending request to Ollama: %s", string(reqBody)) - // Try to get messages from cache - messages, err := getCachedMessages(ctx, cacheKey) - if err == nil && len(messages) > 0 { - c.JSON(http.StatusOK, convertMessagesToResponse(messages)) - return + // Create HTTP request + req, err := http.NewRequest("POST", getOllamaURL(), bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err } - // If not in cache, get from database - var dbMessages []Message - if err := db.Where("chat_id = ?", chatID).Order("created_at DESC").Find(&dbMessages).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch messages"}) + req.Header.Set("Content-Type", "application/json") + return req, nil +} + +// Chat handlers +func listChatsHandler(c *gin.Context) { + userID := c.MustGet("userID").(uuid.UUID) + var chats []Chat + + if err := db.Where("user_id = ?", userID).Find(&chats).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch chats"}) return } - // Cache the messages from database - if err := cacheMessagesFromDB(ctx, cacheKey, dbMessages); err != nil { - log.Printf("Failed to cache messages: %v", err) + var response = make([]ChatResponse, 0) + for _, chat := range chats { + response = append(response, ChatResponse{ + ID: chat.ID, + Name: chat.Name, + IsPrivate: chat.IsPrivate, + CreatedAt: chat.CreatedAt, + }) } - c.JSON(http.StatusOK, convertMessagesToResponse(dbMessages)) + c.JSON(http.StatusOK, response) } -func prepareOllamaRequest(messages []OllamaMessage, model string) OllamaChatRequest { - if model == "" { - model = os.Getenv("OLLAMA_DEFAULT_MODEL") +func createChatHandler(c *gin.Context) { + var req CreateChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } - return OllamaChatRequest{ - Model: model, - Messages: messages, - Stream: true, - Temperature: 0.7, + userID := c.MustGet("userID").(uuid.UUID) + chat := Chat{ + ID: req.ID, + Name: req.Name, + IsPrivate: req.IsPrivate, + UserID: userID, } -} -func setupStreamHeaders(c *gin.Context) { - c.Header("Content-Type", "text/plain") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Transfer-Encoding", "chunked") + if err := db.Create(&chat).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create chat"}) + return + } + + // Cache the chat data + chatResponse := ChatResponse{ + ID: chat.ID, + Name: chat.Name, + IsPrivate: chat.IsPrivate, + CreatedAt: chat.CreatedAt, + } + + if err := cacheObject(context.Background(), getChatCacheKey(chat.ID), chatResponse); err != nil { + log.Printf("Failed to cache chat data: %v", err) + } + + c.JSON(http.StatusCreated, chatResponse) } -func logOllamaRequest(ollamaReq OllamaChatRequest) error { - reqBody, err := json.Marshal(ollamaReq) +func getChatHandler(c *gin.Context) { + chatID, err := uuid.Parse(c.Param("chatId")) if err != nil { - return fmt.Errorf("failed to marshal request: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid chat ID"}) + return } - log.Printf("🌐 Sending request to Ollama: %s", string(reqBody)) - return nil -} -func getOllamaURL() string { - ollamaHost := os.Getenv("OLLAMA_HOST") - ollamaURL := fmt.Sprintf("https://%s/chat/completions", ollamaHost) - log.Printf("🔗 Connecting to Ollama at: %s", ollamaURL) - return ollamaURL -} + ctx := context.Background() + cacheKey := getChatCacheKey(chatID) -func createAndCacheMessage(ctx context.Context, chatID uuid.UUID, role, content string) (*Message, error) { - message := Message{ - ChatID: chatID, - Role: role, - Content: content, - Timestamp: time.Now(), + // Try to get chat from Redis first + var chatResponse ChatResponse + if err := getCachedObject(ctx, cacheKey, &chatResponse); err == nil { + c.JSON(http.StatusOK, chatResponse) + return } - // Save to database - if err := db.Create(&message).Error; err != nil { - return nil, fmt.Errorf("failed to save message: %v", err) + // If not in cache, get from database + var chat Chat + if err := db.First(&chat, "id = ?", chatID).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Chat not found"}) + return } - // Cache in Redis - cacheKey := fmt.Sprintf("chat:%s:messages", chatID.String()) - msgResponse := MessageResponse{ - ID: message.ID, - ChatID: message.ChatID, - Role: message.Role, - Content: message.Content, - Timestamp: message.Timestamp, + // Cache the chat data + chatResponse = ChatResponse{ + ID: chat.ID, + Name: chat.Name, + IsPrivate: chat.IsPrivate, + CreatedAt: chat.CreatedAt, } - if err := cacheMessage(ctx, cacheKey, msgResponse); err != nil { - log.Printf("Failed to cache message: %v", err) + if err := cacheObject(ctx, cacheKey, chatResponse); err != nil { + log.Printf("Failed to cache chat data: %v", err) } - return &message, nil + c.JSON(http.StatusOK, chatResponse) +} + +func getChatMessagesHandler(c *gin.Context) { + chatID, err := uuid.Parse(c.Param("chatId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid chat ID"}) + return + } + + messages, err := getChatMessages(context.Background(), chatID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch messages"}) + return + } + + c.JSON(http.StatusOK, messagesToResponses(messages)) } func streamAIResponseHandler(c *gin.Context) { @@ -388,70 +421,61 @@ func streamAIResponseHandler(c *gin.Context) { ctx := context.Background() - cacheKey := fmt.Sprintf("chat:%s:messages", chatID.String()) - var previousMessages []Message - previousMessages, fetchErr := getCachedMessages(ctx, cacheKey) - if fetchErr != nil || len(previousMessages) == 0 { - // Fallback to database - if dbErr := db.Where("chat_id = ?", chatID).Order("created_at ASC").Find(&previousMessages).Error; dbErr != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch chat history"}) - return - } - - // Cache the messages we got from database - if cacheErr := cacheMessagesFromDB(ctx, cacheKey, previousMessages); cacheErr != nil { - log.Printf("Failed to cache messages from DB: %v", cacheErr) - } + // Get previous messages + previousMessages, err := getChatMessages(ctx, chatID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch chat history"}) + return } - // Convert previous messages to Ollama format + // Convert previous messages to Ollama format and add the current message ollamaMessages := convertToOllamaMessages(previousMessages) - - // Add the current message ollamaMessages = append(ollamaMessages, OllamaMessage{ Role: "user", Content: req.Content, }) - // Prepare and send Ollama request - ollamaReq := prepareOllamaRequest(ollamaMessages, req.Model) - setupStreamHeaders(c) + // Save the user message + if _, err := createAndSaveMessage(ctx, chatID, "user", req.Content); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save user message"}) + return + } + + // Setup HTTP response for streaming + c.Header("Content-Type", "text/plain") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") - // Create channels for response handling + // Setup channels for communication responseChan := make(chan string) errorChan := make(chan error) // Make request to Ollama API go func() { - // Log request - if err := logOllamaRequest(ollamaReq); err != nil { - errorChan <- err - return - } - - // Get Ollama URL and prepare request - ollamaURL := getOllamaURL() - reqBody, err := json.Marshal(ollamaReq) + // Create and send request + request, err := setupOllamaRequest(ollamaMessages, req.Model) if err != nil { - errorChan <- fmt.Errorf("failed to marshal request body: %v", err) + errorChan <- err return } - resp, err := http.Post(ollamaURL, "application/json", bytes.NewBuffer(reqBody)) + client := &http.Client{} + resp, err := client.Do(request) if err != nil { log.Printf("❌ Error connecting to Ollama: %v", err) errorChan <- err return } defer resp.Body.Close() + log.Printf("✅ Connected to Ollama (Status: %s)", resp.Status) - // Create a scanner to read the streaming response line by line + // Process streaming response scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { // Get raw response rawResp := scanner.Text() - log.Printf("📥 Raw response: %s", rawResp) // Skip empty lines if rawResp == "" { @@ -468,6 +492,7 @@ func streamAIResponseHandler(c *gin.Context) { // Skip [DONE] message if jsonData == "[DONE]" { log.Println("✅ Received [DONE] message") + close(responseChan) return } @@ -503,12 +528,6 @@ func streamAIResponseHandler(c *gin.Context) { } }() - // Save and cache the user's message - if _, saveErr := createAndCacheMessage(ctx, chatID, "user", req.Content); saveErr != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save user message"}) - return - } - // Stream the response var fullResponse string c.Stream(func(w io.Writer) bool { @@ -516,23 +535,20 @@ func streamAIResponseHandler(c *gin.Context) { case token, ok := <-responseChan: if !ok { // Save and cache the complete AI response - _, saveErr := createAndCacheMessage(ctx, chatID, "assistant", fullResponse) - if saveErr != nil { - errorChan <- err - return false + if _, err := createAndSaveMessage(ctx, chatID, "assistant", fullResponse); err != nil { + log.Printf("Failed to save AI response: %v", err) } - return false } - + fullResponse += token fmt.Fprint(w, token) return true - + case err := <-errorChan: fmt.Fprint(w, "Error: "+err.Error()) return false - + case <-c.Request.Context().Done(): return false } diff --git a/main.go b/main.go index 8a2e8c3..7cd826c 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,9 @@ import ( "log" "net/http" "os" + "os/signal" + "strings" + "syscall" "time" "github.com/gin-contrib/cors" @@ -15,13 +18,49 @@ import ( "github.com/joho/godotenv" "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" ) var ( - db *gorm.DB - rdb *redis.Client + db *gorm.DB + rdb *redis.Client ) +// Configuration holds all application configuration +type Config struct { + Port string + Environment string + FrontendURL string + JWTSecret string + DB DBConfig + Redis RedisConfig + Ollama OllamaConfig +} + +// DBConfig holds database configuration +type DBConfig struct { + Host string + Port string + User string + Password string + Name string + SSLMode string +} + +// RedisConfig holds Redis configuration +type RedisConfig struct { + Host string + Port string + Username string + Password string +} + +// OllamaConfig holds Ollama API configuration +type OllamaConfig struct { + Host string + DefaultModel string +} + // User represents the user model type User struct { ID uuid.UUID `gorm:"type:uuid;primary_key;default:gen_random_uuid()"` @@ -54,79 +93,138 @@ type Message struct { Timestamp time.Time } -func initDB() { - var err error - dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s", - os.Getenv("DB_HOST"), - os.Getenv("DB_USER"), - os.Getenv("DB_PASSWORD"), - os.Getenv("DB_NAME"), - os.Getenv("DB_PORT"), - os.Getenv("DB_SSLMODE"), +// loadConfig loads configuration from environment variables +func loadConfig() (*Config, error) { + config := &Config{ + Environment: getEnvWithDefault("APP_ENV", "development"), + Port: getEnvWithDefault("PORT", "8080"), + FrontendURL: getEnvWithDefault("FRONTEND_URL", "http://localhost:3000"), + JWTSecret: os.Getenv("JWT_SECRET"), + DB: DBConfig{ + Host: getEnvWithDefault("DB_HOST", "localhost"), + Port: getEnvWithDefault("DB_PORT", "5432"), + User: getEnvWithDefault("DB_USER", "postgres"), + Password: os.Getenv("DB_PASSWORD"), + Name: getEnvWithDefault("DB_NAME", "hop"), + SSLMode: getEnvWithDefault("DB_SSLMODE", "disable"), + }, + Redis: RedisConfig{ + Host: getEnvWithDefault("REDIS_HOST", "localhost"), + Port: getEnvWithDefault("REDIS_PORT", "6379"), + Username: os.Getenv("REDIS_USERNAME"), + Password: os.Getenv("REDIS_PASSWORD"), + }, + Ollama: OllamaConfig{ + Host: getEnvWithDefault("OLLAMA_HOST", "http://localhost:11434"), + DefaultModel: getEnvWithDefault("OLLAMA_DEFAULT_MODEL", "llama3.3"), + }, + } + + // Validate critical configuration + if config.JWTSecret == "" { + return nil, fmt.Errorf("JWT_SECRET environment variable is required") + } + + return config, nil +} + +// getEnvWithDefault returns environment variable value or default if not set +func getEnvWithDefault(key, defaultValue string) string { + value := os.Getenv(key) + if value == "" { + return defaultValue + } + return value +} + +// initDB initializes database connection +func initDB(config *Config) error { + dsn := fmt.Sprintf( + "host=%s user=%s password=%s dbname=%s port=%s sslmode=%s", + config.DB.Host, config.DB.User, config.DB.Password, + config.DB.Name, config.DB.Port, config.DB.SSLMode, ) - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + + logLevel := logger.Silent + if config.Environment == "development" { + logLevel = logger.Info + } + + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logLevel), + } + + var err error + db, err = gorm.Open(postgres.Open(dsn), gormConfig) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + // Configure connection pool + sqlDB, err := db.DB() if err != nil { - log.Fatal("Failed to connect to database:", err) + return fmt.Errorf("failed to get database connection: %w", err) } - // Auto migrate the schema - db.AutoMigrate(&User{}, &Chat{}, &Message{}) + // Set reasonable connection pool settings + sqlDB.SetMaxIdleConns(10) + sqlDB.SetMaxOpenConns(100) + sqlDB.SetConnMaxLifetime(time.Hour) + + log.Println("✅ Successfully connected to database") + return nil } -func initRedis() { - // Create a context with timeout for Redis operations +// initRedis initializes Redis connection +func initRedis(config *Config) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Get Redis configuration from environment - redisHost := os.Getenv("REDIS_HOST") - redisPort := os.Getenv("REDIS_PORT") - redisAddr := fmt.Sprintf("%s:%s", redisHost, redisPort) - + redisAddr := fmt.Sprintf("%s:%s", config.Redis.Host, config.Redis.Port) log.Printf("Connecting to Redis at %s...", redisAddr) rdb = redis.NewClient(&redis.Options{ Addr: redisAddr, - Username: os.Getenv("REDIS_USERNAME"), - Password: os.Getenv("REDIS_PASSWORD"), + Username: config.Redis.Username, + Password: config.Redis.Password, DB: 0, - DialTimeout: 5 * time.Second, // Connection timeout - ReadTimeout: 3 * time.Second, // Read timeout - WriteTimeout: 3 * time.Second, // Write timeout - PoolTimeout: 4 * time.Second, // Pool timeout + DialTimeout: 5 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + PoolTimeout: 4 * time.Second, + PoolSize: 50, // Maximum number of connections + MinIdleConns: 10, // Minimum number of idle connections }) // Try to ping Redis - _, err := rdb.Ping(ctx).Result() - if err != nil { + if _, err := rdb.Ping(ctx).Result(); err != nil { log.Printf("⚠️ Warning: Failed to connect to Redis: %v", err) log.Println("⚠️ Application will continue without Redis caching") - return + return nil // Non-fatal error } log.Println("✅ Successfully connected to Redis") + return nil } -func setupRouter() *gin.Engine { - r := gin.Default() +// setupRouter configures and returns the Gin router +func setupRouter(config *Config) *gin.Engine { + // Set Gin mode based on environment + if config.Environment != "development" { + gin.SetMode(gin.ReleaseMode) + } - // Configure CORS middleware - config := cors.DefaultConfig() - config.AllowOrigins = []string{os.Getenv("FRONTEND_URL")} - config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"} - config.AllowHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization"} - config.ExposeHeaders = []string{"Content-Length"} - config.AllowCredentials = true - r.Use(cors.New(config)) + r := gin.New() // Use New() instead of Default() for custom middleware + + // Add middleware + r.Use(gin.Recovery()) + r.Use(requestLoggerMiddleware()) + r.Use(setupCORS(config.FrontendURL)) // Health check endpoint - r.GET("/health", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "status": "ok", - }) - }) + r.GET("/health", healthCheckHandler) - // API routes will be added here + // API routes api := r.Group("/api") { // Auth routes @@ -150,25 +248,134 @@ func setupRouter() *gin.Engine { return r } +// requestLoggerMiddleware logs HTTP requests +func requestLoggerMiddleware() gin.HandlerFunc { + return gin.LoggerWithConfig(gin.LoggerConfig{ + Formatter: func(param gin.LogFormatterParams) string { + return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n", + param.ClientIP, + param.TimeStamp.Format(time.RFC1123), + param.Method, + param.Path, + param.Request.Proto, + param.StatusCode, + param.Latency, + param.Request.UserAgent(), + param.ErrorMessage, + ) + }, + }) +} + +// setupCORS configures CORS middleware +func setupCORS(frontendURL string) gin.HandlerFunc { + config := cors.DefaultConfig() + + // Handle multiple origins if provided + origins := strings.Split(frontendURL, ",") + for i := range origins { + origins[i] = strings.TrimSpace(origins[i]) + } + + config.AllowOrigins = origins + config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"} + config.AllowHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization"} + config.ExposeHeaders = []string{"Content-Length"} + config.AllowCredentials = true + config.MaxAge = 12 * 60 * 60 // 12 hours + + return cors.New(config) +} + +// healthCheckHandler handles health check requests +func healthCheckHandler(c *gin.Context) { + // Check database connection + sqlDB, err := db.DB() + dbStatus := "ok" + if err != nil || sqlDB.Ping() != nil { + dbStatus = "error" + } + + // Check Redis connection + redisStatus := "ok" + if _, err := rdb.Ping(context.Background()).Result(); err != nil { + redisStatus = "error" + } + + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "db": dbStatus, + "redis": redisStatus, + "version": "1.0.0", + }) +} + func main() { - // Load .env file + // Load .env file in development if err := godotenv.Load(); err != nil { log.Println("Warning: .env file not found, using environment variables") } - // Initialize database connections - initDB() - initRedis() + // Load configuration + config, err := loadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } - // Setup and run the server - r := setupRouter() - port := os.Getenv("PORT") - if port == "" { - port = "8080" + // Initialize database connection + if err := initDB(config); err != nil { + log.Fatalf("Database initialization failed: %v", err) } - - log.Printf("Server starting on port %s", port) - if err := r.Run(":" + port); err != nil { - log.Fatal("Failed to start server:", err) + + // Initialize Redis connection + if err := initRedis(config); err != nil { + log.Fatalf("Redis initialization failed: %v", err) + } + + // Setup router + router := setupRouter(config) + + // Create server with timeouts + server := &http.Server{ + Addr: ":" + config.Port, + Handler: router, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, } + + // Start server in a goroutine + go func() { + log.Printf("Server starting on port %s in %s mode", config.Port, config.Environment) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Server shutting down...") + + // Create shutdown context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Gracefully shutdown the server + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("Server forced to shutdown: %v", err) + } + + // Close database connection + if sqlDB, err := db.DB(); err == nil { + sqlDB.Close() + } + + // Close Redis connection + if rdb != nil { + rdb.Close() + } + + log.Println("Server exited properly") }