Skip to content

Commit

Permalink
feat: openapi version 3 support non-auth changes (#932)
Browse files Browse the repository at this point in the history
  • Loading branch information
praneetap committed Jun 3, 2019
1 parent 12899b5 commit 7859f19
Show file tree
Hide file tree
Showing 36 changed files with 4,690 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/cloudformation_compatibility.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ BinaryMediaTypes All
MinimumCompressionSize All
Cors All
TracingEnabled All
OpenApiVersion None
================================== ======================== ========================


Expand Down
1 change: 1 addition & 0 deletions docs/globals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Currently, the following resources and properties are being supported:
AccessLogSetting:
CanarySetting:
TracingEnabled:
OpenApiVersion:
SimpleTable:
# Properties of AWS::Serverless::SimpleTable
Expand Down
17 changes: 13 additions & 4 deletions samtranslator/model/api/api_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
from six import string_types
import re

from samtranslator.model.intrinsics import ref
from samtranslator.model.apigateway import (ApiGatewayDeployment, ApiGatewayRestApi,
Expand Down Expand Up @@ -31,7 +32,8 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
definition_body, definition_uri, name, stage_name, endpoint_configuration=None,
method_settings=None, binary_media=None, minimum_compression_size=None, cors=None,
auth=None, gateway_responses=None, access_log_setting=None, canary_setting=None,
tracing_enabled=None, resource_attributes=None, passthrough_resource_attributes=None):
tracing_enabled=None, resource_attributes=None, passthrough_resource_attributes=None,
open_api_version=None):
"""Constructs an API Generator class that generates API Gateway resources
:param logical_id: Logical id of the SAM API Resource
Expand Down Expand Up @@ -70,6 +72,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
self.tracing_enabled = tracing_enabled
self.resource_attributes = resource_attributes
self.passthrough_resource_attributes = passthrough_resource_attributes
self.open_api_version = open_api_version

def _construct_rest_api(self):
"""Constructs and returns the ApiGateway RestApi.
Expand All @@ -93,6 +96,11 @@ def _construct_rest_api(self):
raise InvalidResourceException(self.logical_id,
"Specify either 'DefinitionUri' or 'DefinitionBody' property and not both")

if self.open_api_version:
if re.match(SwaggerEditor.get_openapi_versions_supported_regex(), self.open_api_version) is None:
raise InvalidResourceException(
self.logical_id, "The OpenApiVersion value must be of the format 3.0.0")

self._add_cors()
self._add_auth()
self._add_gateway_responses()
Expand Down Expand Up @@ -137,7 +145,7 @@ def _construct_body_s3_dict(self):
body_s3['Version'] = s3_pointer['Version']
return body_s3

def _construct_deployment(self, rest_api):
def _construct_deployment(self, rest_api, open_api_version):
"""Constructs and returns the ApiGateway Deployment.
:param model.apigateway.ApiGatewayRestApi rest_api: the RestApi for this Deployment
Expand All @@ -147,7 +155,8 @@ def _construct_deployment(self, rest_api):
deployment = ApiGatewayDeployment(self.logical_id + 'Deployment',
attributes=self.passthrough_resource_attributes)
deployment.RestApiId = rest_api.get_runtime_attr('rest_api_id')
deployment.StageName = 'Stage'
if not self.open_api_version:
deployment.StageName = 'Stage'

return deployment

Expand Down Expand Up @@ -189,7 +198,7 @@ def to_cloudformation(self):
"""

rest_api = self._construct_rest_api()
deployment = self._construct_deployment(rest_api)
deployment = self._construct_deployment(rest_api, self.open_api_version)

swagger = None
if rest_api.Body is not None:
Expand Down
2 changes: 1 addition & 1 deletion samtranslator/model/apigateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ApiGatewayDeployment(Resource):
'Description': PropertyType(False, is_str()),
'RestApiId': PropertyType(True, is_str()),
'StageDescription': PropertyType(False, is_type(dict)),
'StageName': PropertyType(True, is_str())
'StageName': PropertyType(False, is_str())
}

runtime_attrs = {
Expand Down
6 changes: 4 additions & 2 deletions samtranslator/model/sam_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ class SamApi(SamResourceMacro):
'GatewayResponses': PropertyType(False, is_type(dict)),
'AccessLogSetting': PropertyType(False, is_type(dict)),
'CanarySetting': PropertyType(False, is_type(dict)),
'TracingEnabled': PropertyType(False, is_type(bool))
'TracingEnabled': PropertyType(False, is_type(bool)),
'OpenApiVersion': PropertyType(False, is_str())
}

referable_properties = {
Expand Down Expand Up @@ -483,7 +484,8 @@ def to_cloudformation(self, **kwargs):
canary_setting=self.CanarySetting,
tracing_enabled=self.TracingEnabled,
resource_attributes=self.resource_attributes,
passthrough_resource_attributes=self.get_passthrough_resource_attributes())
passthrough_resource_attributes=self.get_passthrough_resource_attributes(),
open_api_version=self.OpenApiVersion)

rest_api, deployment, stage, permissions = api_generator.to_cloudformation()

Expand Down
41 changes: 39 additions & 2 deletions samtranslator/plugins/globals/globals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from samtranslator.public.sdk.resource import SamResourceType
from samtranslator.public.intrinsics import is_intrinsics
from samtranslator.swagger.swagger import SwaggerEditor
import re


class Globals(object):
Expand All @@ -11,6 +13,9 @@ class Globals(object):
# Key of the dictionary containing Globals section in SAM template
_KEYWORD = "Globals"
_RESOURCE_PREFIX = "AWS::Serverless::"
_OPENAPIVERSION = "OpenApiVersion"
_API_TYPE = "AWS::Serverless::Api"
_MANAGE_SWAGGER = "__MANAGE_SWAGGER"

supported_properties = {
# Everything on Serverless::Function except Role, Policies, FunctionName, Events
Expand Down Expand Up @@ -53,7 +58,8 @@ class Globals(object):
"GatewayResponses",
"AccessLogSetting",
"CanarySetting",
"TracingEnabled"
"TracingEnabled",
"OpenApiVersion"
],

SamResourceType.SimpleTable.value: [
Expand All @@ -68,7 +74,7 @@ def __init__(self, template):
:param dict template: SAM template to be parsed
"""
self.supported_resource_section_names = ([x.replace(self._RESOURCE_PREFIX, "")
for x in self.supported_properties.keys()])
for x in self.supported_properties.keys()])
# Sort the names for stability in list ordering
self.supported_resource_section_names.sort()

Expand Down Expand Up @@ -107,6 +113,36 @@ def del_section(cls, template):
if cls._KEYWORD in template:
del template[cls._KEYWORD]

@classmethod
def fix_openapi_definitions(cls, template):
"""
Helper method to postprocess the resources to make sure the swagger doc version matches
the one specified on the resource with flag OpenApiVersion.
This is done postprocess in globals because, the implicit api plugin runs before globals, \
and at that point the global flags aren't applied on each resource, so we do not know \
whether OpenApiVersion flag is specified. Running the globals plugin before implicit api \
was a risky change, so we decided to postprocess the openapi version here.
To make sure we don't modify customer defined swagger, we also check for __MANAGE_SWAGGER flag.
:param dict template: SAM template
:return: Modified SAM template with corrected swagger doc matching the OpenApiVersion.
"""
resources = template["Resources"]

for _, resource in resources.items():
if ("Type" in resource) and (resource["Type"] == cls._API_TYPE):
properties = resource["Properties"]
if (cls._OPENAPIVERSION in properties) and (cls._MANAGE_SWAGGER in properties) and \
(re.match(SwaggerEditor.get_openapi_version_3_regex(),
properties[cls._OPENAPIVERSION]) is not None):
if "DefinitionBody" in properties:
definition_body = properties['DefinitionBody']
definition_body['openapi'] = properties[cls._OPENAPIVERSION]
if definition_body.get('swagger'):
del definition_body['swagger']

def _parse(self, globals_dict):
"""
Takes a SAM template as input and parses the Globals section
Expand Down Expand Up @@ -398,6 +434,7 @@ class InvalidGlobalsSectionException(Exception):
Attributes:
message -- explanation of the error
"""

def __init__(self, logical_id, message):
self._logical_id = logical_id
self._message = message
Expand Down
6 changes: 6 additions & 0 deletions samtranslator/plugins/globals/globals_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from samtranslator.plugins.globals.globals import Globals, InvalidGlobalsSectionException

_API_RESOURCE = "AWS::Serverless::Api"


class GlobalsPlugin(BasePlugin):
"""
Expand Down Expand Up @@ -38,3 +40,7 @@ def on_before_transform_template(self, template_dict):

# Remove the Globals section from template if necessary
Globals.del_section(template_dict)

# If there was a global openApiVersion flag, check and convert swagger
# to the right version
Globals.fix_openapi_definitions(template_dict)
23 changes: 19 additions & 4 deletions samtranslator/swagger/swagger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import re
from six import string_types

from samtranslator.model.intrinsics import ref
Expand Down Expand Up @@ -554,10 +555,14 @@ def is_valid(data):
:param dict data: Data to be validated
:return: True, if data is a Swagger
"""
return bool(data) and \
isinstance(data, dict) and \
bool(data.get("swagger")) and \
isinstance(data.get('paths'), dict)

if bool(data) and isinstance(data, dict) and isinstance(data.get('paths'), dict):
if bool(data.get("swagger")):
return True
elif bool(data.get("openapi")):
return re.search(SwaggerEditor.get_openapi_version_3_regex(), data["openapi"]) is not None
return False
return False

@staticmethod
def gen_skeleton():
Expand Down Expand Up @@ -595,3 +600,13 @@ def _normalize_method_name(method):
return SwaggerEditor._X_ANY_METHOD
else:
return method

@staticmethod
def get_openapi_versions_supported_regex():
openapi_version_supported_regex = r"\A[2-3](\.\d)(\.\d)?$"
return openapi_version_supported_regex

@staticmethod
def get_openapi_version_3_regex():
openapi_version_3_regex = r"\A3(\.\d)(\.\d)?$"
return openapi_version_3_regex
51 changes: 51 additions & 0 deletions tests/model/test_sam_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from samtranslator.intrinsics.resolver import IntrinsicsResolver
from samtranslator.model import InvalidResourceException
from samtranslator.model.lambda_ import LambdaFunction, LambdaVersion
from samtranslator.model.apigateway import ApiGatewayRestApi
from samtranslator.model.apigateway import ApiGatewayDeployment
from samtranslator.model.sam_resources import SamFunction
from samtranslator.model.sam_resources import SamApi


class TestCodeUri(TestCase):
Expand Down Expand Up @@ -70,3 +73,51 @@ def test_with_version_description(self):
cfnResources = function.to_cloudformation(**self.kwargs)
generateFunctionVersion = [x for x in cfnResources if isinstance(x, LambdaVersion)]
self.assertEqual(generateFunctionVersion[0].Description, test_description)

class TestOpenApi(TestCase):
kwargs = {
'intrinsics_resolver': IntrinsicsResolver({}),
'event_resources': [],
'managed_policy_map': {
"foo": "bar"
}
}

@patch('boto3.session.Session.region_name', 'ap-southeast-1')
def test_with_open_api_3_no_stage(self):
api = SamApi("foo")
api.OpenApiVersion = "3.0"

resources = api.to_cloudformation(**self.kwargs)
deployment = [x for x in resources if isinstance(x, ApiGatewayDeployment)]

self.assertEqual(deployment.__len__(), 1)
self.assertEqual(deployment[0].StageName, None)

@patch('boto3.session.Session.region_name', 'ap-southeast-1')
def test_with_open_api_2_no_stage(self):
api = SamApi("foo")
api.OpenApiVersion = "3.0"

resources = api.to_cloudformation(**self.kwargs)
deployment = [x for x in resources if isinstance(x, ApiGatewayDeployment)]

self.assertEqual(deployment.__len__(), 1)
self.assertEqual(deployment[0].StageName, None)

@patch('boto3.session.Session.region_name', 'ap-southeast-1')
def test_with_open_api_bad_value(self):
api = SamApi("foo")
api.OpenApiVersion = "5.0"
with pytest.raises(InvalidResourceException):
api.to_cloudformation(**self.kwargs)

@patch('boto3.session.Session.region_name', 'ap-southeast-1')
def test_with_swagger_no_stage(self):
api = SamApi("foo")

resources = api.to_cloudformation(**self.kwargs)
deployment = [x for x in resources if isinstance(x, ApiGatewayDeployment)]

self.assertEqual(deployment.__len__(), 1)
self.assertEqual(deployment[0].StageName, "Stage")
Loading

0 comments on commit 7859f19

Please sign in to comment.