Skip to content

Commit

Permalink
Update GH config
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Mar 18, 2021
1 parent 0be0bbb commit 9adab7d
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/scala.yml
Expand Up @@ -16,4 +16,4 @@ jobs:
with:
java-version: 1.8
- name: Run tests
run: sbt test
run: sbt +testOnly
2 changes: 1 addition & 1 deletion build.sbt
Expand Up @@ -4,7 +4,7 @@ inThisBuild(List(
organization := "com.stackstate",
scalaVersion := "2.12.13",
crossScalaVersions := Seq(scalaVersion.value, "2.13.5"),
version := "0.6.1",
version := "0.6.2-SNAPSHOT",
libraryDependencies ++= {
// Silencer
val silencerVersion = "1.7.3"
Expand Down
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Expand Up @@ -2,8 +2,8 @@ import sbt._

object Dependencies {
val scalacheckVersion = "1.14.3"
val akkaHttpVersion = "10.1.9"
val akkaStreamsVersion = "2.5.25"
val akkaHttpVersion = "10.2.0"
val akkaStreamsVersion = "2.6.9"
val pac4jVersion = "4.4.0"
val scalaTestVersion = "3.2.2"

Expand Down
25 changes: 18 additions & 7 deletions src/main/scala/com/stackstate/pac4j/AkkaHttpSecurity.scala
Expand Up @@ -47,15 +47,19 @@ object AkkaHttpSecurity {
* If the request proceeds, other ways (e.g. basic auth) are assumed to be configured in pac4j in order to pass
* credentials.
*/
private def getFormFields(entity: HttpEntity, enforceFormEncoding: Boolean)(implicit materializer: Materializer,
executionContext: ExecutionContext): Future[Seq[(String, String)]] = {
private def getFormFields(
entity: HttpEntity,
enforceFormEncoding: Boolean,
)(implicit materializer: Materializer,
executionContext: ExecutionContext,
): Future[Seq[(String, String)]] = {
Unmarshal(entity)
.to[StrictForm]
.fast
.flatMap { form =>
val fields = form.fields.collect {
case (name, field) if name.nonEmpty =>
Unmarshal(field).to[String].map(fieldString => (name, fieldString))
Unmarshal(field).to[String].map(fieldString => name -> fieldString)
}
Future.sequence(fields)
}
Expand Down Expand Up @@ -110,8 +114,13 @@ class AkkaHttpSecurity(config: Config, sessionStorage: SessionStorage, val sessi
*/
def withContext(existingContext: Option[AkkaHttpWebContext] = None, formParams: Map[String, String] = Map.empty): Directive1[AkkaHttpWebContext] =
Directive[Tuple1[AkkaHttpWebContext]] { inner => ctx =>
val akkaWebContext =
existingContext.getOrElse(AkkaHttpWebContext(ctx.request, formParams.toSeq, sessionStorage, sessionCookieName = sessionCookieName))
val akkaWebContext = existingContext.getOrElse(new AkkaHttpWebContext(
request = ctx.request,
formFields = formParams.toSeq,
sessionStorage = sessionStorage,
sessionCookieName = sessionCookieName,
))

inner(Tuple1(akkaWebContext))(ctx).map[RouteResult] {
case Complete(response) => Complete(applyHeadersAndCookiesToResponse(akkaWebContext.getChanges)(response))
case rejection => rejection
Expand All @@ -138,9 +147,11 @@ class AkkaHttpSecurity(config: Config, sessionStorage: SessionStorage, val sessi
Directive[Tuple1[AuthenticatedRequest]] { inner => ctx =>
// TODO This is a hack to ensure that any underlying Futures are scheduled (and handled in case of errors) from here
// TODO Fix this properly
Future.successful(()).flatMap { _ =>
Future.successful({}).flatMap { _ =>
securityLogic.perform(akkaWebContext, config, (context: AkkaHttpWebContext, profiles: util.Collection[UserProfile], _: AnyRef) => {
val authenticatedRequest = AuthenticatedRequest(context, profiles.asScala.toList)
val authenticatedRequest =
AuthenticatedRequest(context, profiles.asScala.toList)

inner(Tuple1(authenticatedRequest))(ctx)
}, actionAdapter, clients, authorizers, "", multiProfile)
}
Expand Down
87 changes: 46 additions & 41 deletions src/main/scala/com/stackstate/pac4j/AkkaHttpWebContext.scala
Expand Up @@ -5,24 +5,27 @@ import java.util.{Optional, UUID}
import akka.http.scaladsl.model.HttpHeader.ParsingResult.{Error, Ok}
import akka.http.scaladsl.model.headers.HttpCookie
import akka.http.scaladsl.model.{ContentType, HttpHeader, HttpRequest}

import com.stackstate.pac4j.authorizer.CsrfCookieAuthorizer
import com.stackstate.pac4j.http.AkkaHttpSessionStore
import com.stackstate.pac4j.store.SessionStorage

import org.pac4j.core.context.{Cookie, WebContext}

import compat.java8.OptionConverters._
import scala.collection.JavaConverters._

/**
*
* The AkkaHttpWebContext is responsible for wrapping an HTTP request and stores changes that are produced by pac4j and
* need to be applied to an HTTP response.
*/
case class AkkaHttpWebContext(request: HttpRequest,
formFields: Seq[(String, String)],
private[pac4j] val sessionStorage: SessionStorage,
sessionCookieName: String)
extends WebContext {
* The AkkaHttpWebContext is responsible for wrapping an HTTP request
* and stores changes that are produced by pac4j
* and need to be applied to an HTTP response.
*/
final class AkkaHttpWebContext(
val request: HttpRequest,
val formFields: Seq[(String, String)],
private[pac4j] val sessionStorage: SessionStorage,
val sessionCookieName: String
) extends WebContext {
import com.stackstate.pac4j.AkkaHttpWebContext._

private var changes = ResponseChanges.empty
Expand All @@ -33,7 +36,8 @@ case class AkkaHttpWebContext(request: HttpRequest,
}.asJavaCollection

//Request parameters are composed of form fields and the query part of the uri. Stored in a lazy val in order to only compute it once
private lazy val requestParameters = formFields.toMap ++ request.getUri().query().toMap.asScala
private lazy val requestParameters =
formFields.toMap ++ request.getUri().query().toMap.asScala

private def newSession() = {
val sessionId = UUID.randomUUID().toString
Expand All @@ -43,7 +47,7 @@ case class AkkaHttpWebContext(request: HttpRequest,

private[pac4j] var sessionId: String =
request.cookies
.filter(_.name == sessionCookieName)
.filter(_.name == sessionCookieName) // TODO: collect
.map(_.value)
.find(session => sessionStorage.sessionExists(session))
.getOrElse(newSession())
Expand Down Expand Up @@ -94,7 +98,8 @@ case class AkkaHttpWebContext(request: HttpRequest,
}

// Avoid adding duplicate headers, Pac4J expects to overwrite headers like `Location`
changes = changes.copy(headers = header :: changes.headers.filter(_.name != name))
changes = changes.copy(
headers = header :: changes.headers.filter(_.name != name))
}

@com.github.ghik.silencer.silent("mapValues")
Expand All @@ -113,8 +118,10 @@ case class AkkaHttpWebContext(request: HttpRequest,
ContentType.parse(contentType) match {
case Right(ct) =>
changes = changes.copy(contentType = Some(ct))

case Left(_) =>
throw new IllegalArgumentException("Invalid response content type " + contentType)
throw new IllegalArgumentException(
s"Invalid response content type ${contentType}")
}
}

Expand All @@ -127,47 +134,37 @@ case class AkkaHttpWebContext(request: HttpRequest,
}

override def getRequestHeader(name: String): Optional[String] = {
request.headers.find(_.name().toLowerCase() == name.toLowerCase).map(_.value).asJava
request.headers.find(_.name().
toLowerCase() == name.toLowerCase).map(_.value).asJava
}

override def getScheme: String = {
request.getUri().getScheme
}
lazy val getScheme: String = request.getUri().getScheme

override def isSecure: Boolean = {
val scheme = request.getUri().getScheme.toLowerCase
def isSecure: Boolean = getScheme.toLowerCase == "https"

scheme == "https"
}
override def getRequestMethod: String = request.method.value

override def getRequestMethod: String = {
request.method.value
}
override def getServerPort: Int = request.getUri().getPort

override def getServerPort: Int = {
request.getUri().getPort
}
override def setRequestAttribute(name: String, value: scala.AnyRef): Unit =
changes = changes.copy(
attributes = changes.attributes ++ Map[String, AnyRef](name -> value))

override def setRequestAttribute(name: String, value: scala.AnyRef): Unit = {
changes = changes.copy(attributes = changes.attributes ++ Map[String, AnyRef](name -> value))
}

override def getRequestAttribute(name: String): Optional[AnyRef] = {
override def getRequestAttribute(name: String): Optional[AnyRef] =
changes.attributes.get(name).asJava
}

def getContentType: Option[ContentType] = {
changes.contentType
}
def getContentType: Option[ContentType] = changes.contentType

def getChanges: ResponseChanges = changes

def addResponseSessionCookie(): Unit = {
val cookie = new Cookie(sessionCookieName, sessionId)

cookie.setSecure(isSecure)
cookie.setMaxAge(sessionStorage.sessionLifetime.toSeconds.toInt)
cookie.setHttpOnly(true)
cookie.setPath("/")

addResponseCookie(cookie)
}

Expand All @@ -178,13 +175,21 @@ case class AkkaHttpWebContext(request: HttpRequest,
}

object AkkaHttpWebContext {
def apply(
request: HttpRequest,
formFields: Seq[(String, String)],
sessionStorage: SessionStorage,
sessionCookieName: String
): AkkaHttpWebContext = new AkkaHttpWebContext(
request, formFields, sessionStorage, sessionCookieName)

//This class is where all the HTTP response changes are stored so that they can later be applied to an HTTP Request
case class ResponseChanges private (headers: List[HttpHeader],
contentType: Option[ContentType],
content: String,
cookies: List[HttpCookie],
attributes: Map[String, AnyRef])
case class ResponseChanges private (
val headers: List[HttpHeader],
val contentType: Option[ContentType],
val content: String,
val cookies: List[HttpCookie],
val attributes: Map[String, AnyRef])

object ResponseChanges {
def empty: ResponseChanges = {
Expand Down
Expand Up @@ -156,7 +156,7 @@ class AkkaHttpSecurityTest extends AnyWordSpecLike with Matchers with ScalatestR
} ~> check {
status shouldEqual StatusCodes.OK
responseAs[String] shouldBe "called!"
headers.size shouldBe 1
headers.size shouldBe 2
header("Set-Cookie").get.value() shouldBe "MyCookie=MyValue; Max-Age=100; Path=/; Secure; HttpOnly"
}
}
Expand Down

0 comments on commit 9adab7d

Please sign in to comment.