Skip to content

Commit 90140cc

Browse files
committed
[KYUUBI #2550] Fix swagger does not show the request/response schema issue
### _Why are the changes needed?_ To close #2550 Refer the comments: #1658 (comment) ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [x] Add screenshots for manual tests if appropriate <img width="1253" alt="image" src="https://user-images.githubusercontent.com/6757692/166884297-a1d3654e-e229-4650-b656-8e38a1036af8.png"> <img width="1458" alt="image" src="https://user-images.githubusercontent.com/6757692/166808692-4fed9977-affd-4847-9ebb-dcc430405b0f.png"> - [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #2551 from turboFei/swagger_schema. Closes #2550 9da1e25 [Fei Wang] use arrary schema 032eae6 [Fei Wang] set schema a698f12 [Fei Wang] add commment ed4c52b [Fei Wang] fix scala style 418984f [Fei Wang] add service loader meta inf d2338f8 [Fei Wang] reformat b3a96e4 [Fei Wang] skip scala check ff10703 [Fei Wang] copy from SwaggerScalaModelConverter Authored-by: Fei Wang <fwang12@ebay.com> Signed-off-by: Fei Wang <fwang12@ebay.com>
1 parent 1932ad7 commit 90140cc

File tree

5 files changed

+391
-27
lines changed

5 files changed

+391
-27
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
org.apache.kyuubi.server.api.SwaggerScalaModelConverter
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.server.api
19+
20+
// scalastyle:off
21+
import java.lang.annotation.Annotation
22+
import java.lang.reflect.ParameterizedType
23+
import java.util.Iterator
24+
25+
import scala.language.existentials
26+
import scala.util.Try
27+
import scala.util.control.NonFatal
28+
29+
import com.fasterxml.jackson.databind.`type`.ReferenceType
30+
import com.fasterxml.jackson.databind.JavaType
31+
import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration}
32+
import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, PropertyDescriptor}
33+
import io.swagger.v3.core.converter._
34+
import io.swagger.v3.core.jackson.ModelResolver
35+
import io.swagger.v3.core.util.{Json, PrimitiveType}
36+
import io.swagger.v3.oas.annotations.Parameter
37+
import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotation}
38+
import io.swagger.v3.oas.models.media.Schema
39+
import org.slf4j.LoggerFactory
40+
41+
/**
42+
* Copied from https://github.com/swagger-akka-http/swagger-scala-module
43+
*/
44+
class AnnotatedTypeForOption extends AnnotatedType
45+
46+
object SwaggerScalaModelConverter {
47+
val objectMapper = Json.mapper().registerModule(DefaultScalaModule)
48+
}
49+
50+
class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverter.objectMapper) {
51+
SwaggerScalaModelConverter
52+
53+
private val logger = LoggerFactory.getLogger(classOf[SwaggerScalaModelConverter])
54+
private val EnumClass = classOf[scala.Enumeration]
55+
private val OptionClass = classOf[scala.Option[_]]
56+
private val IterableClass = classOf[scala.collection.Iterable[_]]
57+
private val SetClass = classOf[scala.collection.Set[_]]
58+
private val BigDecimalClass = classOf[BigDecimal]
59+
private val BigIntClass = classOf[BigInt]
60+
private val ProductClass = classOf[Product]
61+
private val AnyClass = classOf[Any]
62+
63+
override def resolve(
64+
`type`: AnnotatedType,
65+
context: ModelConverterContext,
66+
chain: Iterator[ModelConverter]): Schema[_] = {
67+
val javaType = _mapper.constructType(`type`.getType)
68+
val cls = javaType.getRawClass
69+
70+
matchScalaPrimitives(`type`, cls).getOrElse {
71+
// Unbox scala options
72+
val annotatedOverrides = getRequiredSettings(`type`)
73+
if (_isOptional(`type`, cls)) {
74+
val baseType =
75+
if (annotatedOverrides.headOption.getOrElse(false)) new AnnotatedType()
76+
else new AnnotatedTypeForOption()
77+
resolve(nextType(baseType, `type`, javaType), context, chain)
78+
} else if (!annotatedOverrides.headOption.getOrElse(true)) {
79+
resolve(nextType(new AnnotatedTypeForOption(), `type`, javaType), context, chain)
80+
} else if (isCaseClass(cls)) {
81+
caseClassSchema(cls, `type`, context, chain).getOrElse(None.orNull)
82+
} else if (chain.hasNext) {
83+
val nextResolved = Option(chain.next().resolve(`type`, context, chain))
84+
nextResolved match {
85+
case Some(property) => {
86+
if (isIterable(cls)) {
87+
property.setRequired(null)
88+
property.setProperties(null)
89+
}
90+
setRequired(`type`)
91+
property
92+
}
93+
case None => None.orNull
94+
}
95+
} else {
96+
None.orNull
97+
}
98+
}
99+
}
100+
101+
private def caseClassSchema(
102+
cls: Class[_],
103+
`type`: AnnotatedType,
104+
context: ModelConverterContext,
105+
chain: Iterator[ModelConverter]): Option[Schema[_]] = {
106+
if (chain.hasNext) {
107+
Option(chain.next().resolve(`type`, context, chain)).map { schema =>
108+
val introspector = BeanIntrospector(cls)
109+
introspector.properties.foreach { property =>
110+
getPropertyAnnotations(property) match {
111+
case Seq() => {
112+
val propertyClass = getPropertyClass(property)
113+
val optionalFlag = isOption(propertyClass)
114+
if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(
115+
property.name)) {
116+
schema.getRequired.remove(property.name)
117+
} else if (!optionalFlag) {
118+
addRequiredItem(schema, property.name)
119+
}
120+
}
121+
case annotations => {
122+
val required = getRequiredSettings(annotations).headOption
123+
.getOrElse(!isOption(getPropertyClass(property)))
124+
if (required) addRequiredItem(schema, property.name)
125+
}
126+
}
127+
}
128+
schema
129+
}
130+
} else {
131+
None
132+
}
133+
}
134+
135+
private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] =
136+
annotatedType match {
137+
case _: AnnotatedTypeForOption => Seq.empty
138+
case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations))
139+
}
140+
141+
private def getRequiredSettings(annotations: Seq[Annotation]): Seq[Boolean] = {
142+
annotations.collect {
143+
case p: Parameter => p.required()
144+
case s: SchemaAnnotation => s.required()
145+
case a: ArraySchema => a.arraySchema().required()
146+
}
147+
}
148+
149+
private def matchScalaPrimitives(
150+
`type`: AnnotatedType,
151+
nullableClass: Class[_]): Option[Schema[_]] = {
152+
val annotations = Option(`type`.getCtxAnnotations).map(_.toSeq).getOrElse(Seq.empty)
153+
annotations.collectFirst { case ann: SchemaAnnotation => ann } match {
154+
case Some(_) => None
155+
case _ => {
156+
annotations.collectFirst { case ann: JsonScalaEnumeration => ann } match {
157+
case Some(enumAnnotation: JsonScalaEnumeration) => {
158+
val pt = enumAnnotation.value().getGenericSuperclass.asInstanceOf[ParameterizedType]
159+
val args = pt.getActualTypeArguments
160+
val cls = args(0).asInstanceOf[Class[_]]
161+
val sp: Schema[String] =
162+
PrimitiveType.STRING.createProperty().asInstanceOf[Schema[String]]
163+
setRequired(`type`)
164+
try {
165+
val mainClass = getMainClass(cls)
166+
val valueMethods = mainClass.getMethods.toSeq.filter { m =>
167+
m.getDeclaringClass != EnumClass &&
168+
m.getReturnType.getName == "scala.Enumeration$Value" && m.getParameterCount == 0
169+
}
170+
val enumValues = valueMethods.map(_.invoke(None.orNull))
171+
enumValues.foreach { v =>
172+
sp.addEnumItemObject(v.toString)
173+
}
174+
} catch {
175+
case NonFatal(t) => logger.warn(s"Failed to get values for enum ${cls.getName}", t)
176+
}
177+
Some(sp)
178+
}
179+
case _ => {
180+
Option(nullableClass).flatMap { cls =>
181+
if (cls == BigDecimalClass) {
182+
val dp = PrimitiveType.DECIMAL.createProperty()
183+
setRequired(`type`)
184+
Some(dp)
185+
} else if (cls == BigIntClass) {
186+
val ip = PrimitiveType.INT.createProperty()
187+
setRequired(`type`)
188+
Some(ip)
189+
} else {
190+
None
191+
}
192+
}
193+
}
194+
}
195+
}
196+
}
197+
}
198+
199+
private def getMainClass(clazz: Class[_]): Class[_] = {
200+
val cname = clazz.getName
201+
if (cname.endsWith("$")) {
202+
Try(Class.forName(cname.substring(0, cname.length - 1))).getOrElse(clazz)
203+
} else {
204+
clazz
205+
}
206+
}
207+
208+
private def _isOptional(annotatedType: AnnotatedType, cls: Class[_]): Boolean = {
209+
annotatedType.getType match {
210+
case _: ReferenceType if isOption(cls) => true
211+
case _ => false
212+
}
213+
}
214+
215+
private def underlyingJavaType(annotatedType: AnnotatedType, javaType: JavaType): JavaType = {
216+
annotatedType.getType match {
217+
case rt: ReferenceType => rt.getContentType
218+
case _ => javaType
219+
}
220+
}
221+
222+
private def nextType(
223+
baseType: AnnotatedType,
224+
`type`: AnnotatedType,
225+
javaType: JavaType): AnnotatedType = {
226+
baseType.`type`(underlyingJavaType(`type`, javaType))
227+
.ctxAnnotations(`type`.getCtxAnnotations)
228+
.parent(`type`.getParent)
229+
.schemaProperty(`type`.isSchemaProperty)
230+
.name(`type`.getName)
231+
.propertyName(`type`.getPropertyName)
232+
.resolveAsRef(`type`.isResolveAsRef)
233+
.jsonViewAnnotation(`type`.getJsonViewAnnotation)
234+
.skipOverride(`type`.isSkipOverride)
235+
}
236+
237+
override def _isOptionalType(propType: JavaType): Boolean = {
238+
isOption(propType.getRawClass) || super._isOptionalType(propType)
239+
}
240+
241+
override def _isSetType(cls: Class[_]): Boolean = {
242+
val setInterfaces = cls.getInterfaces.find { interface =>
243+
interface == SetClass
244+
}
245+
setInterfaces.isDefined || super._isSetType(cls)
246+
}
247+
248+
private def setRequired(annotatedType: AnnotatedType): Unit = annotatedType match {
249+
case _: AnnotatedTypeForOption => // not required
250+
case _ => {
251+
val required = getRequiredSettings(annotatedType).headOption.getOrElse(true)
252+
if (required) {
253+
Option(annotatedType.getParent).foreach { parent =>
254+
Option(annotatedType.getPropertyName).foreach { n =>
255+
addRequiredItem(parent, n)
256+
}
257+
}
258+
}
259+
}
260+
}
261+
262+
private def getPropertyClass(property: PropertyDescriptor): Class[_] = {
263+
property.param match {
264+
case Some(constructorParameter) => {
265+
val types = constructorParameter.constructor.getParameterTypes
266+
if (constructorParameter.index > types.size) {
267+
AnyClass
268+
} else {
269+
types(constructorParameter.index)
270+
}
271+
}
272+
case _ => property.field match {
273+
case Some(field) => field.getType
274+
case _ => property.setter match {
275+
case Some(setter) if setter.getParameterCount == 1 => {
276+
setter.getParameterTypes()(0)
277+
}
278+
case _ => property.beanSetter match {
279+
case Some(setter) if setter.getParameterCount == 1 => {
280+
setter.getParameterTypes()(0)
281+
}
282+
case _ => AnyClass
283+
}
284+
}
285+
}
286+
}
287+
}
288+
289+
private def getPropertyAnnotations(property: PropertyDescriptor): Seq[Annotation] = {
290+
property.param match {
291+
case Some(constructorParameter) => {
292+
val types = constructorParameter.constructor.getParameterAnnotations
293+
if (constructorParameter.index > types.size) {
294+
Seq.empty
295+
} else {
296+
types(constructorParameter.index).toSeq
297+
}
298+
}
299+
case _ => property.field match {
300+
case Some(field) => field.getAnnotations.toSeq
301+
case _ => property.setter match {
302+
case Some(setter) if setter.getParameterCount == 1 => {
303+
setter.getAnnotations().toSeq
304+
}
305+
case _ => property.beanSetter match {
306+
case Some(setter) if setter.getParameterCount == 1 => {
307+
setter.getAnnotations().toSeq
308+
}
309+
case _ => Seq.empty
310+
}
311+
}
312+
}
313+
}
314+
}
315+
316+
private def isOption(cls: Class[_]): Boolean = cls == OptionClass
317+
private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls)
318+
private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls)
319+
320+
private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match {
321+
case None => List.empty[T]
322+
case Some(arr) => arr.toList
323+
}
324+
}

kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/BatchesResource.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import javax.ws.rs.core.{MediaType, Response}
2222

2323
import scala.util.control.NonFatal
2424

25-
import io.swagger.v3.oas.annotations.media.Content
25+
import io.swagger.v3.oas.annotations.media.{Content, Schema}
2626
import io.swagger.v3.oas.annotations.responses.ApiResponse
2727
import io.swagger.v3.oas.annotations.tags.Tag
2828
import org.apache.hive.service.rpc.thrift.TProtocolVersion
@@ -43,7 +43,8 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
4343
@ApiResponse(
4444
responseCode = "200",
4545
content = Array(new Content(
46-
mediaType = MediaType.APPLICATION_JSON)),
46+
mediaType = MediaType.APPLICATION_JSON,
47+
schema = new Schema(implementation = classOf[Batch]))),
4748
description = "create and open a batch session")
4849
@POST
4950
@Consumes(Array(MediaType.APPLICATION_JSON))
@@ -64,7 +65,8 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
6465
@ApiResponse(
6566
responseCode = "200",
6667
content = Array(new Content(
67-
mediaType = MediaType.APPLICATION_JSON)),
68+
mediaType = MediaType.APPLICATION_JSON,
69+
schema = new Schema(implementation = classOf[Batch]))),
6870
description = "get the batch info via batch id")
6971
@GET
7072
@Path("{batchId}")

0 commit comments

Comments
 (0)