Skip to content

Commit

Permalink
Make context more effecient (#3007)
Browse files Browse the repository at this point in the history
  • Loading branch information
kddejong committed Jan 23, 2024
1 parent cd31423 commit 8d780b0
Show file tree
Hide file tree
Showing 25 changed files with 344 additions and 249 deletions.
121 changes: 78 additions & 43 deletions src/cfnlint/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import InitVar, dataclass, field, fields
from typing import Any, Deque, Dict, Iterable, List, Mapping, Sequence
from typing import Any, Deque, Dict, Iterator, List, Mapping, Sequence, Tuple

from cfnlint.helpers import FUNCTIONS, PSEUDOPARAMS, REGION_PRIMARY
from cfnlint.schema import PROVIDER_SCHEMA_MANAGER, AttributeDict

_PSEUDOPARAMS_NON_REGION = ["AWS::AccountId", "AWS::NoValue", "AWS::StackName"]


@dataclass
class Transforms:
Expand Down Expand Up @@ -50,18 +52,19 @@ class Context:
The conditions being used and their current state
"""

# what region we are processing
region: str = field(init=True, default=REGION_PRIMARY)
# what regions we are processing
regions: Sequence[str] = field(
init=True, default_factory=lambda: list([REGION_PRIMARY])
)

# supported functions at this point in the template
functions: Sequence[str] = field(init=True, default_factory=list)
# As we move down the template this is used to keep track of the
# how the conditions affected the path we are on
# The key is the condition name and the value is if the condition
# was true or false

# path keeps track of the path as we move down the template
# Example: Resources, MyResource, Properties, Name, ...
path: Deque[str] = field(init=True, default_factory=deque)
# value_path is an override of the value if we got it from another place
# like a Parameter default value
value_path: Deque[str] = field(init=True, default_factory=deque)

# cfn-lint Template class
Expand All @@ -70,21 +73,18 @@ class Context:
conditions: Dict[str, "Condition"] = field(init=True, default_factory=dict)
mappings: Dict[str, "Map"] = field(init=True, default_factory=dict)

# other Refs from Fn::Sub
# Combiniation of storing any resolved ref
# and adds in any Refs available from things like Fn::Sub
ref_values: Dict[str, Any] = field(init=True, default_factory=dict)

# Resolved value
# Resolved conditions for reference
resolved_conditions: Mapping[str, bool] = field(init=True, default_factory=dict)

transforms: Transforms = field(init=True, default_factory=lambda: Transforms([]))

def __post_init__(self) -> None:
if self.path is None:
self.path = deque([])
for pseudo_parameter in PSEUDOPARAMS:
self.ref_values[pseudo_parameter] = _get_pseudo_value(
pseudo_parameter, self.region
)

def evolve(self, **kwargs) -> "Context":
"""
Expand All @@ -110,22 +110,60 @@ def evolve(self, **kwargs) -> "Context":

return cls(**kwargs)

def ref_value(self, instance: str) -> Iterator[Tuple[str | List[str], "Context"]]:
# Non regionalized items first
if instance in _PSEUDOPARAMS_NON_REGION:
pseudo_value = _get_pseudo_value(instance)
if pseudo_value is not None:
yield pseudo_value, self.evolve(ref_values={instance: pseudo_value})
return
if instance in self.parameters:
for v, path in self.parameters[instance].ref(self):
yield v, self.evolve(
value_path=deque(["Parameters", instance]) + path,
ref_values={instance: v},
)
return

if instance in self.ref_values:
yield self.ref_values[instance], self
return

# Regionalized values second
if instance in PSEUDOPARAMS:
for region in self.regions:
# We can resolve all region based pseudo values
# as we are now deciding on a region.
yield _get_pseudo_value_by_region(instance, region), self.evolve(
regions=[region],
ref_values={
p: _get_pseudo_value_by_region(p, region)
for p in PSEUDOPARAMS
if p not in _PSEUDOPARAMS_NON_REGION
},
)

@property
def refs(self):
return (
list(self.parameters.keys())
+ list(self.resources.keys())
+ PSEUDOPARAMS
+ list(self.ref_values.keys())
)


def _get_pseudo_value(parameter: str, region: str) -> str | List[str] | None:
def _get_pseudo_value(parameter: str) -> str | List[str] | None:
if parameter == "AWS::AccountId":
return "123456789012"
if parameter == "AWS::StackName":
return "teststack"
return None


def _get_pseudo_value_by_region(parameter: str, region: str) -> str | List[str]:
if parameter == "AWS::NotificationARNs":
return [f"arn:{_get_partition(region)}:sns:{region}:123456789012:notification"]
if parameter == "AWS::NoValue":
return None
if parameter == "AWS::Partition":
return _get_partition(region)
if parameter == "AWS::Region":
Expand All @@ -135,9 +173,6 @@ def _get_pseudo_value(parameter: str, region: str) -> str | List[str] | None:
f"arn:{_get_partition(region)}:cloudformation:{region}"
":123456789012:stack/teststack/51af3dc0-da77-11e4-872e-1234567db123"
)
if parameter == "AWS::StackName":
return "teststack"
# if parameter == "AWS::URLSuffix":
if region in ("cn-north-1", "cn-northwest-1"):
return "amazonaws.com.cn"
else:
Expand All @@ -159,7 +194,7 @@ class _Ref(ABC):
"""

@abstractmethod
def ref(self, context: Context) -> Iterable[Any]:
def ref(self, context: Context) -> Iterator[Any]:
pass


Expand Down Expand Up @@ -210,20 +245,20 @@ def __post_init__(self, parameter) -> None:
self.min_value = parameter.get("MinValue")
self.max_value = parameter.get("MaxValue")

def ref(self, context: Context) -> Iterable[Any]:
def ref(self, context: Context) -> Iterator[Tuple[Any, deque]]:
if self.allowed_values:
for allowed_value in self.allowed_values:
yield allowed_value
for i, allowed_value in enumerate(self.allowed_values):
yield allowed_value, deque(["AllowedValues", i])
return
# assume default is an allowed value so we skip it
if self.default:
yield self.default
if self.default is not None:
yield self.default, deque(["Default"])

if self.min_value:
yield self.min_value
if self.min_value is not None:
yield self.min_value, deque(["MinValue"])

if self.max_value:
yield self.max_value
if self.max_value is not None:
yield self.max_value, deque(["MaxValue"])


@dataclass
Expand All @@ -245,7 +280,7 @@ def __post_init__(self, resource) -> None:
def get_atts(self, region: str = "us-east-1") -> AttributeDict:
return PROVIDER_SCHEMA_MANAGER.get_type_getatts(self.type, region)

def ref(self, context: Context) -> Iterable[Any]:
def ref(self, context: Context) -> Iterator[Any]:
return
yield

Expand Down Expand Up @@ -289,7 +324,7 @@ def __post_init__(self, mapping) -> None:
for k, v in mapping.items():
self.keys[k] = _MappingSecondaryKey(v)

def find_in_map(self, top_key: str, secondary_key: str) -> Iterable[Any]:
def find_in_map(self, top_key: str, secondary_key: str) -> Iterator[Any]:
if top_key not in self.keys:
raise KeyError(top_key)
yield self.keys[top_key].value(secondary_key)
Expand Down Expand Up @@ -373,7 +408,7 @@ def _init_mappings(self, mappings: Any) -> None:
collection: Dict | List[str | Dict] = field(init=False)
output: Dict[str, Any] = field(init=False)

def create_context_for_resources(self, region: str) -> Context:
def create_context_for_resources(self, regions: Sequence[str]) -> Context:
"""
Create a context for a resources
"""
Expand All @@ -383,13 +418,13 @@ def create_context_for_resources(self, region: str) -> Context:
conditions=self.conditions,
transforms=self.transforms,
mappings=self.mappings,
region=region,
regions=regions,
path=deque(["Resources"]),
functions=[],
)

def create_context_for_resource_properties(
self, region: str, resource_name: str
self, regions: Sequence[str], resource_name: str
) -> Context:
"""
Create a context for a resource properties
Expand All @@ -401,12 +436,12 @@ def create_context_for_resource_properties(
conditions=self.conditions,
transforms=self.transforms,
mappings=self.mappings,
region=region,
regions=regions,
path=deque(["Resources", resource_name, "Properties"]),
functions=list(FUNCTIONS),
)

def create_context_for_mappings(self, region: str) -> Context:
def create_context_for_mappings(self, regions: Sequence[str]) -> Context:
"""
Create a context for a resource properties
"""
Expand All @@ -417,12 +452,12 @@ def create_context_for_mappings(self, region: str) -> Context:
conditions={},
transforms=self.transforms,
mappings={},
region=region,
regions=regions,
path=deque(["Mappings"]),
functions=["Fn::Transform"],
)

def create_context_for_outputs(self, region: str) -> Context:
def create_context_for_outputs(self, regions: Sequence[str]) -> Context:
"""
Create a context for a resource properties
"""
Expand All @@ -433,12 +468,12 @@ def create_context_for_outputs(self, region: str) -> Context:
conditions=self.conditions,
transforms=self.transforms,
mappings=self.mappings,
region=region,
regions=regions,
path=deque(["Outputs"]),
functions=[],
)

def create_context_for_conditions(self, region: str) -> Context:
def create_context_for_conditions(self, regions: Sequence[str]) -> Context:
"""
Create a context for a conditions
"""
Expand All @@ -449,12 +484,12 @@ def create_context_for_conditions(self, region: str) -> Context:
conditions=self.conditions,
mappings=self.mappings,
transforms=self.transforms,
region=region,
regions=regions,
path=deque(["Conditions"]),
functions=[],
)

def create_context_for_parameters(self, region: str) -> Context:
def create_context_for_parameters(self, regions: Sequence[str]) -> Context:
"""
Create a context for a conditions
"""
Expand All @@ -465,7 +500,7 @@ def create_context_for_parameters(self, region: str) -> Context:
conditions={},
mappings={},
transforms=self.transforms,
region=region,
regions=regions,
path=deque(["Parameters"]),
functions=[],
)
Loading

0 comments on commit 8d780b0

Please sign in to comment.