diff --git a/aws/stack.py b/aws/stack.py index cb04b4dc..e3e896b4 100644 --- a/aws/stack.py +++ b/aws/stack.py @@ -11,7 +11,6 @@ from aws_cdk import Duration, Stack from aws_cdk import aws_ec2 as ec2 from aws_cdk import aws_ecs as ecs -from aws_cdk import aws_ecs_patterns as ecs_patterns from aws_cdk import aws_efs as efs from aws_cdk import aws_elasticloadbalancingv2 as elbv2 from aws_cdk import aws_iam as iam @@ -188,47 +187,89 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: ) ) - # Create fargate service - ecs_service = ecs_patterns.ApplicationMultipleTargetGroupsFargateService( + # Create ALB Fargate service + ecs_service = ecs.FargateService( self, "FargateService", cluster=cluster, task_definition=task_definition, - load_balancers=[ - ecs_patterns.ApplicationLoadBalancerProps( - name="Alb", - domain_name=AWS_ROUTE53__HOSTED_ZONE_NAME, - domain_zone=hosted_zone, - public_load_balancer=True, - listeners=[ - ecs_patterns.ApplicationListenerProps( - name="Listener", port=443, certificate=cert - ) - ], + ) + + service_health_check = elbv2.HealthCheck( + enabled=True, + interval=Duration.seconds(120), + timeout=Duration.seconds(10), + healthy_threshold_count=5, + unhealthy_threshold_count=2, + path="/health", + ) + + # Load balancer + alb = elbv2.ApplicationLoadBalancer( + self, "Alb", vpc=cluster.vpc, internet_facing=True, load_balancer_name="Alb" + ) + + # Main HTTPS listener + listener = alb.add_listener("Listener", port=443, certificates=[cert]) + listener.add_action( + "DefaultAction", action=elbv2.ListenerAction.fixed_response(status_code=200) + ) + + # Redirect HTTP to HTTPS + alb.add_listener( + "HttpListener", + port=80, + default_action=elbv2.ListenerAction.redirect( + port="443", + protocol="HTTPS", + host="#{host}", + path="/#{path}", + query="#{query}", + permanent=True, + ), + ) + + listener.add_targets( + "TargetGroup-ApiContainer", + port=8000, + protocol=elbv2.ApplicationProtocol.HTTP, + priority=10, + health_check=service_health_check, + conditions=[elbv2.ListenerCondition.path_patterns(["/api", "/api/*"])], + targets=[ + ecs_service.load_balancer_target( + container_name="ApiContainer", container_port=8000 ) ], - target_groups=[ - ecs_patterns.ApplicationTargetProps( - container_port=8000, - priority=10, - path_pattern="/api/*", - listener="Listener", - ), - ecs_patterns.ApplicationTargetProps( - container_port=8001, - priority=20, - path_pattern="/runner/*", - listener="Listener", - ), + ) + listener.add_targets( + "TargetGroup-RunnerContainer", + port=8001, + protocol=elbv2.ApplicationProtocol.HTTP, + priority=20, + health_check=service_health_check, + conditions=[ + elbv2.ListenerCondition.path_patterns(["/runner", "/runner/*"]) + ], + targets=[ + ecs_service.load_balancer_target( + container_name="RunnerContainer", container_port=8001 + ) ], ) - listener = ecs_service.load_balancers[0].listeners[0] - listener.add_action( - "DefaultAction", action=elbv2.ListenerAction.fixed_response(status_code=200) + + # Add WAFv2 WebACL to the ALB + + # Define the IP set for VPC's IP range + private_cidr_blocks = [subnet.ipv4_cidr_block for subnet in vpc.private_subnets] + vpc_ip_set = wafv2.CfnIPSet( + self, + "VpcIpSet", + addresses=private_cidr_blocks, + scope="REGIONAL", + ip_address_version="IPV4", ) - # Add WAF to block all traffic not from platform.tracecat.com - # NOTE: Please change this to the domain you deployed Tracecat frontend to web_acl = wafv2.CfnWebACL( self, "WebAcl", @@ -241,6 +282,23 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: sampled_requests_enabled=True, ), rules=[ + # New rule for allowing health checks from within VPC + wafv2.CfnWebACL.RuleProperty( + name="AllowHealthChecks", + priority=5, # Set priority appropriately + action=wafv2.CfnWebACL.RuleActionProperty(allow={}), + statement=wafv2.CfnWebACL.StatementProperty( + ip_set_reference_statement=wafv2.CfnWebACL.IPSetReferenceStatementProperty( + arn=vpc_ip_set.attr_arn + ) + ), + visibility_config=wafv2.CfnWebACL.VisibilityConfigProperty( + cloud_watch_metrics_enabled=True, + metric_name="AllowHealthChecksMetric", + sampled_requests_enabled=True, + ), + ), + # Block all traffic by default except for specific domain over HTTPS wafv2.CfnWebACL.RuleProperty( name="AllowSpecificDomainOverHttps", priority=10, @@ -327,7 +385,7 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: metric_name="allowSpecificDomainOverHttpsMetric", sampled_requests_enabled=True, ), - ) + ), ], ) @@ -335,7 +393,7 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: wafv2.CfnWebACLAssociation( self, "WebAclAssociation", - resource_arn=ecs_service.load_balancer.load_balancer_arn, + resource_arn=alb.load_balancer_arn, web_acl_arn=web_acl.attr_arn, ) @@ -344,8 +402,6 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: self, "AliasRecord", record_name=AWS_ROUTE53__HOSTED_ZONE_NAME, - target=route53.RecordTarget.from_alias( - LoadBalancerTarget(ecs_service.load_balancer) - ), + target=route53.RecordTarget.from_alias(LoadBalancerTarget(alb)), zone=hosted_zone, )