diff --git a/src/main/java/com/iemr/common/constant/Constants.java b/src/main/java/com/iemr/common/constant/Constants.java index c6f98e02..ebe7d772 100644 --- a/src/main/java/com/iemr/common/constant/Constants.java +++ b/src/main/java/com/iemr/common/constant/Constants.java @@ -11,5 +11,6 @@ public class Constants { public static final String HOLD = "Hold"; public static final String NOT_READY = "Not Ready"; public static final String AUX = "Aux"; + public static final String JWT_TOKEN = "Jwttoken"; } diff --git a/src/main/java/com/iemr/common/controller/users/IEMRAdminController.java b/src/main/java/com/iemr/common/controller/users/IEMRAdminController.java index 072bf88d..f46b9653 100644 --- a/src/main/java/com/iemr/common/controller/users/IEMRAdminController.java +++ b/src/main/java/com/iemr/common/controller/users/IEMRAdminController.java @@ -34,6 +34,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; @@ -44,6 +45,7 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.iemr.common.config.encryption.SecurePassword; +import com.iemr.common.constant.Constants; import com.iemr.common.data.users.LoginSecurityQuestions; import com.iemr.common.data.users.M_Role; import com.iemr.common.data.users.ServiceRoleScreenMapping; @@ -56,6 +58,7 @@ import com.iemr.common.service.users.IEMRAdminUserService; import com.iemr.common.utils.CookieUtil; import com.iemr.common.utils.JwtUtil; +import com.iemr.common.utils.TokenDenylist; import com.iemr.common.utils.encryption.AESUtil; import com.iemr.common.utils.exception.IEMRException; import com.iemr.common.utils.mapper.InputMapper; @@ -80,6 +83,8 @@ public class IEMRAdminController { @Autowired private JwtUtil jwtUtil; @Autowired + private TokenDenylist tokenDenylist; + @Autowired private CookieUtil cookieUtil; @Autowired private RedisTemplate redisTemplate; @@ -923,8 +928,6 @@ private void deleteSessionObject(String key) { } } - - @CrossOrigin() @Operation(summary = "Force log out") @RequestMapping(value = "/forceLogout", method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON, headers = "Authorization") @@ -934,8 +937,27 @@ public String forceLogout(@RequestBody ForceLogoutRequestModel request, HttpServ // Perform the force logout logic iemrAdminUserServiceImpl.forceLogout(request); - // Extract and invalidate JWT token cookie dynamically from the request - invalidateJwtCookie(httpRequest, response); + // Extract token from cookies or headers + String token = getJwtTokenFromCookies(httpRequest); + if (token == null) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + outputResponse.setError(new RuntimeException("No JWT token found in request")); + return outputResponse.toString(); + } + + // Validate the token: Check if it is expired or in the deny list + Claims claims = jwtUtil.validateToken(token); + if (claims.isEmpty() || claims.getExpiration() == null || claims.getId() == null) { // If token is either expired or in the deny list, return 401 Unauthorized + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + outputResponse.setError(new RuntimeException("Token is expired or has been logged out")); + return outputResponse.toString(); + } + + // Extract the jti (JWT ID) and expiration time from the validated claims + String jti = claims.getId(); // jti is in the 'id' field of claims + long expirationTime = claims.getExpiration().getTime(); // Use expiration from claims + long ttlMillis = expirationTime - System.currentTimeMillis(); + tokenDenylist.addTokenToDenylist(jti, ttlMillis); // Set the response message outputResponse.setResponse("Success"); @@ -944,31 +966,17 @@ public String forceLogout(@RequestBody ForceLogoutRequestModel request, HttpServ } return outputResponse.toString(); } - - private void invalidateJwtCookie(HttpServletRequest request, HttpServletResponse response) { - // Get the cookies from the incoming request - Cookie[] cookies = request.getCookies(); + private String getJwtTokenFromCookies(HttpServletRequest request) { + Cookie[] cookies = request.getCookies(); if (cookies != null) { for (Cookie cookie : cookies) { - // Check if the cookie name matches "Jwttoken" (case-sensitive) - if (cookie.getName().equalsIgnoreCase("Jwttoken")) { - // Invalidate the JWT token cookie by setting the value to null and max age to 0 - cookie.setValue(null); - cookie.setMaxAge(0); // Expire the cookie immediately - cookie.setPath(cookie.getPath()); // Ensure the path matches the cookie's original path - cookie.setHttpOnly(true); // Secure the cookie so it can't be accessed via JS - cookie.setSecure(true); // Only send over HTTPS if you're using secure connections - cookie.setAttribute("SameSite", "Strict"); - // Add the invalidated cookie back to the response - response.addCookie(cookie); - break; // If we found the JWT cookie, no need to continue looping + if (cookie.getName().equalsIgnoreCase(Constants.JWT_TOKEN)) { + return cookie.getValue(); } } - } else { - // Log or handle the case when no cookies are found in the request - logger.warn("No cookies found in the request."); } + return null; } diff --git a/src/main/java/com/iemr/common/utils/JwtUtil.java b/src/main/java/com/iemr/common/utils/JwtUtil.java index 56e49549..0a8829dc 100644 --- a/src/main/java/com/iemr/common/utils/JwtUtil.java +++ b/src/main/java/com/iemr/common/utils/JwtUtil.java @@ -1,105 +1,177 @@ package com.iemr.common.utils; -import java.util.Date; -import java.util.UUID; -import java.util.function.Function; +import io.jsonwebtoken.*; +import io.jsonwebtoken.security.Keys; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; -import io.jsonwebtoken.Claims; -import io.jsonwebtoken.ExpiredJwtException; -import io.jsonwebtoken.Jwts; -import io.jsonwebtoken.MalformedJwtException; -import io.jsonwebtoken.UnsupportedJwtException; -import io.jsonwebtoken.security.Keys; -import io.jsonwebtoken.security.SignatureException; - import javax.crypto.SecretKey; +import java.util.Date; +import java.util.UUID; +import java.util.function.Function; @Component public class JwtUtil { - @Value("${jwt.secret}") - private String SECRET_KEY; - - @Value("${jwt.access.expiration}") - private long ACCESS_EXPIRATION_TIME; - - @Value("${jwt.refresh.expiration}") - private long REFRESH_EXPIRATION_TIME; - - private SecretKey getSigningKey() { - if (SECRET_KEY == null || SECRET_KEY.isEmpty()) { - throw new IllegalStateException("JWT secret key is not set in application.properties"); - } - return Keys.hmacShaKeyFor(SECRET_KEY.getBytes()); - } - - public String generateToken(String username, String userId) { - return buildToken(username, userId, "access", ACCESS_EXPIRATION_TIME); - } - - public String generateRefreshToken(String username, String userId) { - return buildToken(username, userId, "refresh", REFRESH_EXPIRATION_TIME); - } - - private String buildToken(String username, String userId, String tokenType, long expiration) { - return Jwts.builder() - .subject(username) - .claim("userId", userId) - .claim("token_type", tokenType) - .id(UUID.randomUUID().toString()) - .issuedAt(new Date()) - .expiration(new Date(System.currentTimeMillis() + expiration)) - .signWith(getSigningKey()) - .compact(); - } - - public Claims validateToken(String token) { - try { - return Jwts.parser() - .verifyWith(getSigningKey()) - .build() - .parseSignedClaims(token) - .getPayload(); - - } catch (ExpiredJwtException ex) { - // Handle expired token specifically if needed - } catch (UnsupportedJwtException | MalformedJwtException | SignatureException | IllegalArgumentException ex) { - // Log specific error types - } - return null; - } - - public T getClaimFromToken(String token, Function claimsResolver) { - final Claims claims = getAllClaimsFromToken(token); - return claimsResolver.apply(claims); - } - - public Claims getAllClaimsFromToken(String token) { - return Jwts.parser() + @Value("${jwt.secret}") + private String SECRET_KEY; + + @Value("${jwt.access.expiration}") + private long ACCESS_EXPIRATION_TIME; + + @Value("${jwt.refresh.expiration}") + private long REFRESH_EXPIRATION_TIME; + + @Autowired + private TokenDenylist tokenDenylist; + + private SecretKey getSigningKey() { + if (SECRET_KEY == null || SECRET_KEY.isEmpty()) { + throw new IllegalStateException("JWT secret key is not set in application.properties"); + } + return Keys.hmacShaKeyFor(SECRET_KEY.getBytes()); + } + + /** + * Generate an access token. + * + * @param username the username of the user + * @param userId the user ID + * @return the generated JWT access token + */ + public String generateToken(String username, String userId) { + return buildToken(username, userId, "access", ACCESS_EXPIRATION_TIME); + } + + /** + * Generate a refresh token. + * + * @param username the username of the user + * @param userId the user ID + * @return the generated JWT refresh token + */ + public String generateRefreshToken(String username, String userId) { + return buildToken(username, userId, "refresh", REFRESH_EXPIRATION_TIME); + } + + /** + * Build a JWT token with the specified parameters. + * + * @param username the username of the user + * @param userId the user ID + * @param tokenType the type of the token (access or refresh) + * @param expiration the expiration time of the token in milliseconds + * @return the generated JWT token + */ + private String buildToken(String username, String userId, String tokenType, long expiration) { + return Jwts.builder() + .subject(username) + .claim("userId", userId) + .claim("token_type", tokenType) + .id(UUID.randomUUID().toString()) + .issuedAt(new Date()) + .expiration(new Date(System.currentTimeMillis() + expiration)) + .signWith(getSigningKey()) + .compact(); + } + + /** + * Validate the JWT token, checking if it is expired and if it's blacklisted + * @param token the JWT token + * @return Claims if valid, null if invalid (expired or denylisted) + */ + public Claims validateToken(String token) { + // Check if the token is blacklisted (invalidated by force logout) + if (tokenDenylist.isTokenDenylisted(getJtiFromToken(token))) { + return null; // Token is denylisted, so return null + } + + // Check if the token is expired + if (isTokenExpired(token)) { + return null; // Token is expired, so return null + } + + // If token is not blacklisted and not expired, verify the token signature and return claims + try { + return Jwts.parser().verifyWith(getSigningKey()).build().parseSignedClaims(token).getPayload(); + } catch (ExpiredJwtException ex) { + + return null; // Token is expired, so return null + } catch (UnsupportedJwtException | MalformedJwtException | SignatureException | IllegalArgumentException ex) { + return null; // Return null for any other JWT-related issue (invalid format, bad signature, etc.) + } + } + + /** + * Check if the JWT token is expired + * @param token the JWT token + * @return true if expired, false otherwise + */ + private boolean isTokenExpired(String token) { + Date expirationDate = getAllClaimsFromToken(token).getExpiration(); + return expirationDate.before(new Date()); + } + + /** + * Extract claims from the token + * @param token the JWT token + * @return all claims from the token + */ + public Claims getAllClaimsFromToken(String token) { + return Jwts.parser() .verifyWith(getSigningKey()) .build() .parseSignedClaims(token) .getPayload(); - } - - - public long getRefreshTokenExpiration() { - return REFRESH_EXPIRATION_TIME; - } - - public String getUserIdFromToken(String token) { - return getAllClaimsFromToken(token).get("userId", String.class); - } - - // Additional helper methods - public String getJtiFromToken(String token) { - return getAllClaimsFromToken(token).getId(); - } - public String getUsernameFromToken(String token) { - return getAllClaimsFromToken(token).getSubject(); - } -} \ No newline at end of file + } + + /** + * Extract a specific claim from the token using a function + * @param token the JWT token + * @param claimsResolver the function to extract the claim + * @param the type of the claim + * @return the extracted claim + */ + public T getClaimFromToken(String token, Function claimsResolver) { + final Claims claims = getAllClaimsFromToken(token); + return claimsResolver.apply(claims); + } + + /** + * Get the JWT ID (JTI) from the token + * @param token the JWT token + * @return the JWT ID + */ + public String getJtiFromToken(String token) { + return getAllClaimsFromToken(token).getId(); + } + + /** + * Get the username from the token + * @param token the JWT token + * @return the username + */ + public String getUsernameFromToken(String token) { + return getAllClaimsFromToken(token).getSubject(); + } + + /** + * Get the user ID from the token + * @param token the JWT token + * @return the user ID + */ + public String getUserIdFromToken(String token) { + return getAllClaimsFromToken(token).get("userId", String.class); + } + + /** + * Get the expiration time of the refresh token + * @return the expiration time in milliseconds + */ + public long getRefreshTokenExpiration() { + return REFRESH_EXPIRATION_TIME; + } +} diff --git a/src/main/java/com/iemr/common/utils/TokenDenylist.java b/src/main/java/com/iemr/common/utils/TokenDenylist.java new file mode 100644 index 00000000..660499af --- /dev/null +++ b/src/main/java/com/iemr/common/utils/TokenDenylist.java @@ -0,0 +1,63 @@ +package com.iemr.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.stereotype.Component; + +import java.util.concurrent.TimeUnit; + +@Component +public class TokenDenylist { + private final Logger logger = LoggerFactory.getLogger(this.getClass().getName()); + + private static final String PREFIX = "denied_"; + + @Autowired + private RedisTemplate redisTemplate; + + private String getKey(String jti) { + return PREFIX + jti; + } + + // Add a token's jti to the denylist with expiration time + public void addTokenToDenylist(String jti, Long expirationTime) { + if (jti == null || jti.trim().isEmpty()) { + return; + } + if (expirationTime == null || expirationTime <= 0) { + throw new IllegalArgumentException("Expiration time must be positive"); + } + + try { + String key = getKey(jti); // Use helper method to get the key + redisTemplate.opsForValue().set(key, " ", expirationTime, TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw new RuntimeException("Failed to denylist token", e); + } + } + + // Check if a token's jti is in the denylist + public boolean isTokenDenylisted(String jti) { + if (jti == null || jti.trim().isEmpty()) { + return false; + } + try { + String key = getKey(jti); // Use helper method to get the key + return Boolean.TRUE.equals(redisTemplate.hasKey(key)); + } catch (Exception e) { + logger.error("Failed to check denylist status for jti: " + jti, e); + // In case of Redis failure, consider the token as not denylisted to avoid blocking all requests + return false; + } + } + + // Remove a token's jti from the denylist (Redis) + public void removeTokenFromDenylist(String jti) { + if (jti != null && !jti.trim().isEmpty()) { + String key = getKey(jti); // Use helper method to get the key + redisTemplate.delete(key); + } + } +}