/
HeaderDirectives.scala
219 lines (191 loc) · 8.52 KB
/
HeaderDirectives.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
/*
* Copyright (C) 2009-2019 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.http.scaladsl.server
package directives
import akka.http.impl.util._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import scala.reflect.ClassTag
import scala.util.control.NonFatal
/**
* @groupname header Header directives
* @groupprio header 110
*/
trait HeaderDirectives {
import BasicDirectives._
import RouteDirectives._
/**
* Checks that request comes from the same origin. Extracts the [[Origin]] header value and verifies that
* allowed range contains the obtained value. In the case of absent of the [[Origin]] header rejects
* with [[MissingHeaderRejection]]. If the origin value is not in the allowed range
* rejects with an [[InvalidOriginRejection]] and [[StatusCodes.Forbidden]] status.
*
* @group header
*/
def checkSameOrigin(allowed: HttpOriginRange.Default): Directive0 = {
headerValueByType[Origin](()).flatMap { origin =>
if (origin.origins.exists(allowed.matches)) pass
else reject(InvalidOriginRejection(allowed.origins))
}
}
/**
* Extracts an HTTP header value using the given function. If the function result is undefined for all headers the
* request is rejected with an empty rejection set. If the given function throws an exception the request is rejected
* with a [[akka.http.scaladsl.server.MalformedHeaderRejection]].
*
* @group header
*/
def headerValue[T](f: HttpHeader => Option[T]): Directive1[T] = {
val protectedF: HttpHeader => Option[Either[Rejection, T]] = header =>
try f(header).map(Right.apply)
catch {
case NonFatal(e) => Some(Left(MalformedHeaderRejection(header.name, e.getMessage.nullAsEmpty, Some(e))))
}
extract(_.request.headers.collectFirst(Function.unlift(protectedF))).flatMap {
case Some(Right(a)) => provide(a)
case Some(Left(rejection)) => reject(rejection)
case None => reject
}
}
/**
* Extracts an HTTP header value using the given partial function. If the function is undefined for all headers the
* request is rejected with an empty rejection set.
*
* @group header
*/
def headerValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[T] = headerValue(pf.lift)
/**
* Extracts the value of the first HTTP request header with the given name.
* If no header with a matching name is found the request is rejected with a [[akka.http.scaladsl.server.MissingHeaderRejection]].
*
* @group header
*/
def headerValueByName(headerName: Symbol): Directive1[String] = headerValueByName(headerName.name)
/**
* Extracts the value of the HTTP request header with the given name.
* If no header with a matching name is found the request is rejected with a [[akka.http.scaladsl.server.MissingHeaderRejection]].
*
* @group header
*/
def headerValueByName(headerName: String): Directive1[String] =
headerValue(optionalValue(headerName.toLowerCase)) | reject(MissingHeaderRejection(headerName))
/**
* Extracts the first HTTP request header of the given type.
* If no header with a matching type is found the request is rejected with a [[akka.http.scaladsl.server.MissingHeaderRejection]].
*
* Custom headers will only be matched by this directive if they extend [[ModeledCustomHeader]]
* and provide a companion extending [[ModeledCustomHeaderCompanion]].
*
* @group header
*/
def headerValueByType[T](magnet: HeaderMagnet[T]): Directive1[T] =
headerValuePF(magnet.extractPF) | reject(MissingHeaderRejection(magnet.headerName))
/**
* Extracts an optional HTTP header value using the given function.
* If the given function throws an exception the request is rejected
* with a [[akka.http.scaladsl.server.MalformedHeaderRejection]].
*
* @group header
*/
//#optionalHeaderValue
def optionalHeaderValue[T](f: HttpHeader => Option[T]): Directive1[Option[T]] =
headerValue(f).map(Some(_): Option[T]).recoverPF {
case Nil => provide(None)
}
//#optionalHeaderValue
/**
* Extracts an optional HTTP header value using the given partial function.
* If the given function throws an exception the request is rejected
* with a [[akka.http.scaladsl.server.MalformedHeaderRejection]].
*
* @group header
*/
def optionalHeaderValuePF[T](pf: PartialFunction[HttpHeader, T]): Directive1[Option[T]] =
optionalHeaderValue(pf.lift)
/**
* Extracts the value of the optional HTTP request header with the given name.
*
* @group header
*/
def optionalHeaderValueByName(headerName: Symbol): Directive1[Option[String]] =
optionalHeaderValueByName(headerName.name)
/**
* Extracts the value of the optional HTTP request header with the given name.
*
* @group header
*/
def optionalHeaderValueByName(headerName: String): Directive1[Option[String]] = {
val lowerCaseName = headerName.toLowerCase
extract(_.request.headers.collectFirst {
case HttpHeader(`lowerCaseName`, value) => value
})
}
/**
* Extract the header value of the optional HTTP request header with the given type.
*
* Custom headers will only be matched by this directive if they extend [[ModeledCustomHeader]]
* and provide a companion extending [[ModeledCustomHeaderCompanion]].
*
* @group header
*/
def optionalHeaderValueByType[T <: HttpHeader](magnet: HeaderMagnet[T]): Directive1[Option[T]] =
optionalHeaderValuePF(magnet.extractPF)
private def optionalValue(lowerCaseName: String): HttpHeader => Option[String] = {
case HttpHeader(`lowerCaseName`, value) => Some(value)
case _ => None
}
}
object HeaderDirectives extends HeaderDirectives
trait HeaderMagnet[T] {
def classTag: ClassTag[T]
def runtimeClass: Class[T]
def headerName: String = ModeledCompanion.nameFromClass(runtimeClass)
/**
* Returns a partial function that checks if the input value is of runtime type
* T and returns the value if it does. Doesn't take erased information into account.
*/
def extractPF: PartialFunction[HttpHeader, T]
}
object HeaderMagnet extends LowPriorityHeaderMagnetImplicits {
/**
* If possible we want to apply the special logic for [[ModeledCustomHeader]] to extract custom headers by type,
* otherwise the default `fromUnit` is good enough (for headers that the parser emits in the right type already).
*/
implicit def fromUnitForModeledCustomHeader[T <: ModeledCustomHeader[T], H <: ModeledCustomHeaderCompanion[T]](u: Unit)(implicit tag: ClassTag[T], companion: ModeledCustomHeaderCompanion[T]): HeaderMagnet[T] =
fromClassTagForModeledCustomHeader[T, H](tag, companion)
implicit def fromClassForModeledCustomHeader[T <: ModeledCustomHeader[T], H <: ModeledCustomHeaderCompanion[T]](clazz: Class[T], companion: ModeledCustomHeaderCompanion[T]): HeaderMagnet[T] =
fromClassTagForModeledCustomHeader(ClassTag(clazz), companion)
implicit def fromClassTagForModeledCustomHeader[T <: ModeledCustomHeader[T], H <: ModeledCustomHeaderCompanion[T]](tag: ClassTag[T], companion: ModeledCustomHeaderCompanion[T]): HeaderMagnet[T] =
new HeaderMagnet[T] {
override def classTag: ClassTag[T] = tag
override def runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]]
override def extractPF: PartialFunction[HttpHeader, T] = {
case h if h.is(companion.lowercaseName) => companion.apply(h.value)
}
override def headerName: String = companion.name
}
}
trait LowPriorityHeaderMagnetImplicits {
implicit def fromClassNormalHeader[T <: HttpHeader](clazz: Class[T]): HeaderMagnet[T] =
fromClassTagNormalHeader(ClassTag(clazz))
// TODO DRY?
implicit def fromClassNormalJavaHeader[T <: akka.http.javadsl.model.HttpHeader](clazz: Class[T]): HeaderMagnet[T] =
new HeaderMagnet[T] {
override def classTag: ClassTag[T] = ClassTag(clazz)
override def runtimeClass: Class[T] = clazz
override def extractPF: PartialFunction[HttpHeader, T] = {
case x if runtimeClass.isAssignableFrom(x.getClass) => x.asInstanceOf[T]
}
}
implicit def fromUnitNormalHeader[T <: HttpHeader](u: Unit)(implicit tag: ClassTag[T]): HeaderMagnet[T] =
fromClassTagNormalHeader(tag)
implicit def fromClassTagNormalHeader[T <: HttpHeader](tag: ClassTag[T]): HeaderMagnet[T] =
new HeaderMagnet[T] {
val classTag: ClassTag[T] = tag
val runtimeClass: Class[T] = tag.runtimeClass.asInstanceOf[Class[T]]
val extractPF: PartialFunction[Any, T] = {
case x if runtimeClass.isAssignableFrom(x.getClass) => x.asInstanceOf[T]
}
}
}