Skip to content

Commit

Permalink
Stac 17136 (#48) SessionID is only created for Indirect clients. Impr…
Browse files Browse the repository at this point in the history
…ove test coverage.
  • Loading branch information
CMGRocha committed Jul 28, 2022
1 parent 5562116 commit a92a583
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 107 deletions.
97 changes: 53 additions & 44 deletions src/main/scala/com/stackstate/pac4j/AkkaHttpSecurity.scala
@@ -1,7 +1,6 @@
package com.stackstate.pac4j

import java.util

import akka.http.scaladsl.common.StrictForm
import akka.http.scaladsl.model.{HttpEntity, HttpHeader, HttpResponse}
import akka.http.scaladsl.server.Directives.{authorize => akkaHttpAuthorize}
Expand All @@ -19,6 +18,7 @@ import org.pac4j.core.profile.UserProfile
import akka.http.scaladsl.util.FastFuture._
import akka.stream.Materializer
import com.stackstate.pac4j.store.SessionStorage
import org.pac4j.core.matching.matcher.DefaultMatchers

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -48,11 +48,11 @@ object AkkaHttpSecurity {
* credentials.
*/
private def getFormFields(
entity: HttpEntity,
enforceFormEncoding: Boolean,
)(implicit materializer: Materializer,
executionContext: ExecutionContext,
): Future[Seq[(String, String)]] = {
entity: HttpEntity,
enforceFormEncoding: Boolean,
)(implicit materializer: Materializer,
executionContext: ExecutionContext,
): Future[Seq[(String, String)]] = {
Unmarshal(entity)
.to[StrictForm]
.fast
Expand Down Expand Up @@ -113,48 +113,55 @@ class AkkaHttpSecurity(config: Config, sessionStorage: SessionStorage, val sessi
* an AkkaHttpWebContext and any changes to this context are applied when the route returns (e.g. headers/cookies).
*/
def withContext(existingContext: Option[AkkaHttpWebContext] = None, formParams: Map[String, String] = Map.empty): Directive1[AkkaHttpWebContext] =
Directive[Tuple1[AkkaHttpWebContext]] { inner => ctx =>
val akkaWebContext = existingContext.getOrElse(new AkkaHttpWebContext(
request = ctx.request,
formFields = formParams.toSeq,
sessionStorage = sessionStorage,
sessionCookieName = sessionCookieName,
))

inner(Tuple1(akkaWebContext))(ctx).map[RouteResult] {
case Complete(response) => Complete(applyHeadersAndCookiesToResponse(akkaWebContext.getChanges)(response))
case rejection => rejection
}
Directive[Tuple1[AkkaHttpWebContext]] { inner =>
ctx =>
val akkaWebContext = existingContext.getOrElse(new AkkaHttpWebContext(
request = ctx.request,
formFields = formParams.toSeq,
sessionStorage = sessionStorage,
sessionCookieName = sessionCookieName,
))

inner(Tuple1(akkaWebContext))(ctx).map[RouteResult] {
case Complete(response) => Complete(applyHeadersAndCookiesToResponse(akkaWebContext.getChanges)(response))
case rejection => rejection
}
}

def withFormParameters(enforceFormEncoding: Boolean): Directive1[Map[String, String]] =
Directive[Tuple1[Map[String, String]]] { inner => ctx =>
import ctx.materializer
getFormFields(ctx.request.entity, enforceFormEncoding).flatMap { params =>
inner(Tuple1(params.toMap))(ctx)
}
Directive[Tuple1[Map[String, String]]] { inner =>
ctx =>
import ctx.materializer
getFormFields(ctx.request.entity, enforceFormEncoding).flatMap { params =>
inner(Tuple1(params.toMap))(ctx)
}
}

/**
* Authenticate using the provided pac4j configuration. Delivers an AuthenticationRequest which can be used for further authorization
* this does not apply any authorization ofr filtering.
*/
@SuppressWarnings(Array("NullAssignment"))
def withAuthentication(clients: String = null /* Default null, meaning all defined clients */,
def withAuthentication(clients: String = null /* Default null, meaning all defined clients */ ,
multiProfile: Boolean = true,
authorizers: String = ""): Directive1[AuthenticatedRequest] =
withContext().flatMap { akkaWebContext =>
Directive[Tuple1[AuthenticatedRequest]] { inner => ctx =>
// TODO This is a hack to ensure that any underlying Futures are scheduled (and handled in case of errors) from here
// TODO Fix this properly
Future.successful({}).flatMap { _ =>
securityLogic.perform(akkaWebContext, config, (context: AkkaHttpWebContext, profiles: util.Collection[UserProfile], _: AnyRef) => {
val authenticatedRequest =
AuthenticatedRequest(context, profiles.asScala.toList)

inner(Tuple1(authenticatedRequest))(ctx)
}, actionAdapter, clients, authorizers, "", multiProfile)
}
Directive[Tuple1[AuthenticatedRequest]] { inner =>
ctx =>
// TODO This is a hack to ensure that any underlying Futures are scheduled (and handled in case of errors) from here
// TODO Fix this properly
Future.successful({}).flatMap { _ =>
val securityAccessAdapter: SecurityGrantedAccessAdapter[Future[RouteResult], AkkaHttpWebContext] =
(context: AkkaHttpWebContext, profiles: util.Collection[UserProfile], _: AnyRef) => {
val authenticatedRequest = AuthenticatedRequest(context, profiles.asScala.toList)
inner(Tuple1(authenticatedRequest))(ctx)
}

securityLogic.perform(akkaWebContext, config, securityAccessAdapter, actionAdapter, clients, authorizers, DefaultMatchers.SECURITYHEADERS, multiProfile).map {
result =>
result
}
}
}
}

Expand All @@ -169,13 +176,14 @@ class AkkaHttpSecurity(config: Config, sessionStorage: SessionStorage, val sessi
existingContext: Option[AkkaHttpWebContext] = None,
setCsrfCookie: Boolean = true): Route =
withFormParameters(enforceFormEncoding) { formParams =>
withContext(existingContext, formParams) { akkaWebContext => _ =>
callbackLogic.perform(akkaWebContext, config, actionAdapter, defaultUrl, saveInSession, multiProfile, true, defaultClient.orNull).map {
result =>
akkaWebContext.addResponseSessionCookie()
if (setCsrfCookie) akkaWebContext.addResponseCsrfCookie()
result
}
withContext(existingContext, formParams) { akkaWebContext =>
_ =>
callbackLogic.perform(akkaWebContext, config, actionAdapter, defaultUrl, saveInSession, multiProfile, true, defaultClient.orNull).map {
result =>
akkaWebContext.addResponseSessionCookie()
if (setCsrfCookie) akkaWebContext.addResponseCsrfCookie()
result
}
}
}

Expand All @@ -184,8 +192,9 @@ class AkkaHttpSecurity(config: Config, sessionStorage: SessionStorage, val sessi
localLogout: Boolean = true,
destroySession: Boolean = true,
centralLogout: Boolean = false): Route = {
withContext() { akkaWebContext => _ =>
logoutLogic.perform(akkaWebContext, config, actionAdapter, defaultUrl, logoutPatternUrl, localLogout, destroySession, centralLogout)
withContext() { akkaWebContext =>
_ =>
logoutLogic.perform(akkaWebContext, config, actionAdapter, defaultUrl, logoutPatternUrl, localLogout, destroySession, centralLogout)
}
}

Expand Down
71 changes: 39 additions & 32 deletions src/main/scala/com/stackstate/pac4j/AkkaHttpWebContext.scala
@@ -1,15 +1,12 @@
package com.stackstate.pac4j

import java.util.{Optional, UUID}

import akka.http.scaladsl.model.HttpHeader.ParsingResult.{Error, Ok}
import akka.http.scaladsl.model.headers.HttpCookie
import akka.http.scaladsl.model.{ContentType, HttpHeader, HttpRequest}

import com.stackstate.pac4j.authorizer.CsrfCookieAuthorizer
import com.stackstate.pac4j.http.AkkaHttpSessionStore
import com.stackstate.pac4j.store.SessionStorage

import org.pac4j.core.context.{Cookie, WebContext}

import compat.java8.OptionConverters._
Expand All @@ -24,7 +21,8 @@ class AkkaHttpWebContext(val request: HttpRequest,
val formFields: Seq[(String, String)],
private[pac4j] val sessionStorage: SessionStorage,
val sessionCookieName: String)
extends WebContext {
extends WebContext {

import com.stackstate.pac4j.AkkaHttpWebContext._

private var changes = ResponseChanges.empty
Expand All @@ -38,28 +36,37 @@ class AkkaHttpWebContext(val request: HttpRequest,
private lazy val requestParameters =
formFields.toMap ++ request.getUri().query().toMap.asScala

private def newSession() = {
val sessionId = UUID.randomUUID().toString
sessionStorage.createSessionIfNeeded(sessionId)
sessionId
private var sessionId: Option[String] = request.cookies
.find(_.name == sessionCookieName)
.map(_.value)

def getOrCreateSessionId(): String = {
val newSession =
sessionId
.find(session => sessionStorage.sessionExists(session))
.getOrElse {
val sessionId = UUID.randomUUID().toString
sessionStorage.createSessionIfNeeded(sessionId)
sessionId
}

sessionId = Some(newSession)
newSession
}

private[pac4j] var sessionId: String =
request.cookies
.filter(_.name == sessionCookieName) // TODO: collect
.map(_.value)
.find(session => sessionStorage.sessionExists(session))
.getOrElse(newSession())
def getSessionId: Option[String] = {
sessionId
}

private[pac4j] def destroySession() = {
sessionStorage.destroySession(sessionId)
sessionId = newSession()
private[pac4j] def destroySession(): Boolean = {
sessionId.foreach(sessionStorage.destroySession)
sessionId = None
true
}

private[pac4j] def trackSession(session: String) = {
sessionStorage.createSessionIfNeeded(session)
sessionId = session
sessionId = Some(session)
true
}

Expand Down Expand Up @@ -118,7 +125,7 @@ class AkkaHttpWebContext(val request: HttpRequest,
changes = changes.copy(contentType = Some(ct))

case Left(_) =>
throw new IllegalArgumentException(s"Invalid response content type ${contentType}")
throw new IllegalArgumentException(s"Invalid response content type $contentType")
}
}

Expand Down Expand Up @@ -153,14 +160,14 @@ class AkkaHttpWebContext(val request: HttpRequest,
def getChanges: ResponseChanges = changes

def addResponseSessionCookie(): Unit = {
val cookie = new Cookie(sessionCookieName, sessionId)

cookie.setSecure(isSecure)
cookie.setMaxAge(sessionStorage.sessionLifetime.toSeconds.toInt)
cookie.setHttpOnly(true)
cookie.setPath("/")

addResponseCookie(cookie)
getSessionId.foreach { sessionId =>
val cookie = new Cookie(sessionCookieName, sessionId)
cookie.setSecure(isSecure)
cookie.setMaxAge(sessionStorage.sessionLifetime.toSeconds.toInt)
cookie.setHttpOnly(true)
cookie.setPath("/")
addResponseCookie(cookie)
}
}

def sessionCookieIsValid(): Boolean = {
Expand All @@ -179,11 +186,11 @@ object AkkaHttpWebContext {
new AkkaHttpWebContext(request, formFields, sessionStorage, sessionCookieName)

//This class is where all the HTTP response changes are stored so that they can later be applied to an HTTP Request
case class ResponseChanges private (val headers: List[HttpHeader],
val contentType: Option[ContentType],
val content: String,
val cookies: List[HttpCookie],
val attributes: Map[String, AnyRef])
case class ResponseChanges private(headers: List[HttpHeader],
contentType: Option[ContentType],
content: String,
cookies: List[HttpCookie],
attributes: Map[String, AnyRef])

object ResponseChanges {
def empty: ResponseChanges = {
Expand Down
Expand Up @@ -2,24 +2,27 @@ package com.stackstate.pac4j.http

import java.util.Optional
import compat.java8.OptionConverters._

import com.stackstate.pac4j.AkkaHttpWebContext
import org.pac4j.core.context.session.SessionStore

class AkkaHttpSessionStore() extends SessionStore[AkkaHttpWebContext] {
override def getOrCreateSessionId(context: AkkaHttpWebContext): String = context.sessionId
override def getOrCreateSessionId(context: AkkaHttpWebContext): String = context.getOrCreateSessionId()

override def get(context: AkkaHttpWebContext, key: String): Optional[Object] =
context.sessionStorage.getSessionValue(context.sessionId, key).asJava
context.getSessionId match {
case Some(value) => context.sessionStorage.getSessionValue(value, key).asJava
case None => Optional.empty()
}

override def set(context: AkkaHttpWebContext, key: String, value: scala.AnyRef): Unit = {
context.sessionStorage.setSessionValue(context.sessionId, key, value)
context.sessionStorage.setSessionValue(context.getOrCreateSessionId(), key, value)
()
}

override def destroySession(context: AkkaHttpWebContext): Boolean = context.destroySession()

override def getTrackableSession(context: AkkaHttpWebContext): Optional[AnyRef] = Optional.ofNullable(context.sessionId)
override def getTrackableSession(context: AkkaHttpWebContext): Optional[AnyRef] =
context.getSessionId.asInstanceOf[Option[AnyRef]].asJava

override def buildFromTrackableSession(context: AkkaHttpWebContext, trackableSession: scala.Any): Optional[SessionStore[AkkaHttpWebContext]] = {
trackableSession match {
Expand Down
Expand Up @@ -31,11 +31,23 @@ class AkkaHttpActionAdapterTest extends AnyWordSpecLike with Matchers with Scala
HttpEntity(ContentTypes.`application/octet-stream`, ByteString(""))
)
}
"convert 401 to Unauthorized" in withContext { context =>
"convert 401 to Unauthorized for direct clients" in withContext { context =>
AkkaHttpActionAdapter.adapt(UnauthorizedAction.INSTANCE, context).futureValue.response shouldEqual HttpResponse(Unauthorized)
context.getChanges.cookies.map(_.name) shouldBe List()
}
"convert 401 to Unauthorized for indirect clients" in withContext { context =>
context.getSessionStore.set(context, "state", "foo")
AkkaHttpActionAdapter.adapt(UnauthorizedAction.INSTANCE, context).futureValue.response shouldEqual HttpResponse(Unauthorized)
context.getChanges.cookies.map(_.name) shouldBe List(AkkaHttpWebContext.DEFAULT_COOKIE_NAME)
}
"convert 302 to SeeOther (to support login flow)" in withContext { context =>
"convert 302 to SeeOther (to support login flow) (direct client)" in withContext { context =>
val r = AkkaHttpActionAdapter.adapt(new FoundAction("/login"), context).futureValue.response
r.status shouldEqual SeeOther
r.headers.head.value() shouldEqual "/login"
context.getChanges.cookies.map(_.name) shouldBe List()
}
"convert 302 to SeeOther (to support login flow) (indirect client)" in withContext { context =>
context.getSessionStore.set(context, "state", "foo")
val r = AkkaHttpActionAdapter.adapt(new FoundAction("/login"), context).futureValue.response
r.status shouldEqual SeeOther
r.headers.head.value() shouldEqual "/login"
Expand Down

0 comments on commit a92a583

Please sign in to comment.