Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 135 additions & 44 deletions auth.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package main

import (
"errors"
"fmt"
"log"
"net/http"
"os"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -11,8 +15,16 @@ import (
"golang.org/x/crypto/bcrypt"
)

const (
// TokenExpiration defines how long a JWT token is valid
TokenExpiration = 24 * time.Hour
// MinPasswordLength defines minimum password length
MinPasswordLength = 8
)

// Authentication-related request and response structures
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Username string `json:"username" binding:"required,min=3,max=50"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
Expand All @@ -23,30 +35,63 @@ type LoginRequest struct {
}

type AuthResponse struct {
Token string `json:"token"`
User User `json:"user"`
Token string `json:"token"`
ExpiresAt time.Time `json:"expiresAt"`
User UserDTO `json:"user"`
}

var jwtSecret = os.Getenv("JWT_SECRET") // JWT secret from environment variable
// UserDTO is a Data Transfer Object for User information
type UserDTO struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
}

// UserClaims defines the claims in JWT token
type UserClaims struct {
UserID uuid.UUID `json:"userId"`
jwt.RegisteredClaims
}

// convertToUserDTO converts User model to UserDTO
func convertToUserDTO(user User) UserDTO {
return UserDTO{
ID: user.ID,
Username: user.Username,
Email: user.Email,
}
}

// registerHandler handles user registration
func registerHandler(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input", "details": err.Error()})
return
}

// Validate password strength
if err := validatePassword(req.Password); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Normalize input
req.Username = strings.TrimSpace(req.Username)
req.Email = strings.TrimSpace(strings.ToLower(req.Email))

// Check if user exists
var existingUser User
if result := db.Where("username = ? OR email = ?", req.Username, req.Email).First(&existingUser); result.Error == nil {
c.JSON(http.StatusConflict, gin.H{"error": "Username or email already exists"})
return
}

// Hash password
// Hash password with appropriate cost
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to hash password"})
log.Printf("Failed to hash password: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process registration"})
return
}

Expand All @@ -58,66 +103,118 @@ func registerHandler(c *gin.Context) {
}

if result := db.Create(&user); result.Error != nil {
log.Printf("Failed to create user: %v", result.Error)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}

// Generate JWT token
token, err := generateJWT(user)
token, expiresAt, err := generateJWT(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
log.Printf("Failed to generate token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate authentication token"})
return
}

c.JSON(http.StatusCreated, AuthResponse{
Token: token,
User: user,
Token: token,
ExpiresAt: expiresAt,
User: convertToUserDTO(user),
})
}

// loginHandler handles user login
func loginHandler(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input", "details": err.Error()})
return
}

// Find user
var user User
if result := db.Where("username = ?", req.Username).First(&user); result.Error != nil {
// Use same error message for username not found and password mismatch
// to avoid leaking information about existing usernames
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
return
}

// Check password
// Check password with constant-time comparison
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
// Log failed login attempts for security monitoring
log.Printf("Failed login attempt for user %s from IP %s", req.Username, c.ClientIP())
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
return
}

// Generate JWT token
token, err := generateJWT(user)
token, expiresAt, err := generateJWT(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
log.Printf("Failed to generate token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate authentication token"})
return
}

c.JSON(http.StatusOK, AuthResponse{
Token: token,
User: user,
Token: token,
ExpiresAt: expiresAt,
User: convertToUserDTO(user),
})
}

func generateJWT(user User) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": user.ID.String(),
"exp": time.Now().Add(time.Hour * 24).Unix(),
})
// validatePassword checks password strength
func validatePassword(password string) error {
if len(password) < MinPasswordLength {
return fmt.Errorf("password must be at least %d characters long", MinPasswordLength)
}

// Check for at least one number
hasNumber := false
for _, char := range password {
if char >= '0' && char <= '9' {
hasNumber = true
break
}
}

if !hasNumber {
return errors.New("password must contain at least one number")
}

return token.SignedString([]byte(jwtSecret))
return nil
}

// Middleware to authenticate requests
// generateJWT creates a new JWT token for the given user
func generateJWT(user User) (string, time.Time, error) {
jwtSecret := []byte(os.Getenv("JWT_SECRET"))
if len(jwtSecret) == 0 {
return "", time.Time{}, errors.New("JWT_SECRET environment variable not set")
}

expiresAt := time.Now().Add(TokenExpiration)

claims := UserClaims{
UserID: user.ID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "hop-backend",
Subject: user.ID.String(),
},
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString(jwtSecret)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err)
}

return signedToken, expiresAt, nil
}

// authMiddleware verifies JWT token and sets user ID in context
func authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader("Authorization")
Expand All @@ -132,38 +229,32 @@ func authMiddleware() gin.HandlerFunc {
tokenString = tokenString[7:]
}

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
// Parse the JWT token
claims := &UserClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// Validate the signing algorithm
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(os.Getenv("JWT_SECRET")), nil
})

if err != nil || !token.Valid {
// Handle token parsing errors
if err != nil {
log.Printf("Token validation error: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
c.Abort()
return
}

userIDStr, ok := claims["sub"].(string)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"})
c.Abort()
return
}

userID, err := uuid.Parse(userIDStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"})
if !token.Valid {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}

c.Set("userID", userID)
// Set the user ID in the context
c.Set("userID", claims.UserID)
c.Next()
}
}
Loading