diff --git a/docs/guides/additional-tags.md b/docs/guides/additional-tags.md new file mode 100644 index 00000000..bf5501ab --- /dev/null +++ b/docs/guides/additional-tags.md @@ -0,0 +1,38 @@ +# Additional Tags + +The AWS Gateway API Controller automatically applies some tags to resources it creates. In addition, you can use annotations to specify additional tags. + +The `application-networking.k8s.aws/tags` annotation specifies additional tags that will be applied to AWS resources created. + +## Supported Resources + +- **HTTPRoute** - Tags applied to VPC Lattice Services, Listeners, Rules, Target Groups, and Service Network Service Associations +- **ServiceExport** - Tags applied to VPC Lattice Target Groups +- **AccessLogPolicy** - Tags applied to VPC Lattice Access Log Subscriptions +- **VpcAssociationPolicy** - Tags applied to VPC Lattice Service Network VPC Associations + +## Usage + +Add comma separated key=value pairs to the annotation: + +```yaml +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: inventory-route + annotations: + application-networking.k8s.aws/tags: "Environment=Production,Team=Backend" +spec: + # ... rest of spec +``` + +```yaml +apiVersion: application-networking.k8s.aws/v1alpha1 +kind: ServiceExport +metadata: + name: payment-service + annotations: + application-networking.k8s.aws/tags: "Environment=Production,Service=Payment" +spec: + # ... rest of spec +``` diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index b2837453..9562f58a 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -48,6 +48,10 @@ type Cloud interface { // check ownership and acquire if it is not owned by anyone. TryOwn(ctx context.Context, arn string) (bool, error) TryOwnFromTags(ctx context.Context, arn string, tags services.Tags) (bool, error) + + // MergeTags creates a new tag map by merging baseTags and additionalTags. + // BaseTags will override additionalTags for any duplicate keys. + MergeTags(baseTags services.Tags, additionalTags services.Tags) services.Tags } // NewCloud constructs new Cloud implementation. @@ -144,6 +148,17 @@ func (c *defaultCloud) DefaultTagsMergedWith(tags services.Tags) services.Tags { return newTags } +func (c *defaultCloud) MergeTags(baseTags services.Tags, additionalTags services.Tags) services.Tags { + result := make(services.Tags) + if additionalTags != nil { + maps.Copy(result, additionalTags) + } + if baseTags != nil { + maps.Copy(result, baseTags) + } + return result +} + func (c *defaultCloud) getTags(ctx context.Context, arn string) (services.Tags, error) { tagsReq := &vpclattice.ListTagsForResourceInput{ResourceArn: &arn} resp, err := c.lattice.ListTagsForResourceWithContext(ctx, tagsReq) diff --git a/pkg/aws/cloud_mocks.go b/pkg/aws/cloud_mocks.go index bd8fe67b..26293d05 100644 --- a/pkg/aws/cloud_mocks.go +++ b/pkg/aws/cloud_mocks.go @@ -106,6 +106,20 @@ func (mr *MockCloudMockRecorder) Lattice() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lattice", reflect.TypeOf((*MockCloud)(nil).Lattice)) } +// MergeTags mocks base method. +func (m *MockCloud) MergeTags(arg0, arg1 map[string]*string) map[string]*string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MergeTags", arg0, arg1) + ret0, _ := ret[0].(map[string]*string) + return ret0 +} + +// MergeTags indicates an expected call of MergeTags. +func (mr *MockCloudMockRecorder) MergeTags(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MergeTags", reflect.TypeOf((*MockCloud)(nil).MergeTags), arg0, arg1) +} + // Tagging mocks base method. func (m *MockCloud) Tagging() services.Tagging { m.ctrl.T.Helper() diff --git a/pkg/aws/services/tagging.go b/pkg/aws/services/tagging.go index 22c59991..518326cb 100644 --- a/pkg/aws/services/tagging.go +++ b/pkg/aws/services/tagging.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/utils" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -34,6 +35,9 @@ type Tagging interface { // Finds one resource that matches the given set of tags. FindResourcesByTags(ctx context.Context, resourceType ResourceType, tags Tags) ([]string, error) + + // Updates tags for a given resource ARN + UpdateTags(ctx context.Context, resourceArn string, newTags Tags) error } type defaultTagging struct { @@ -165,3 +169,72 @@ func convertTagsToFilter(tags Tags) []*taggingapi.TagFilter { } return filters } + +func (t *defaultTagging) UpdateTags(ctx context.Context, resourceArn string, newTags Tags) error { + existingTags, err := t.GetTagsForArns(ctx, []string{resourceArn}) + if err != nil { + return fmt.Errorf("failed to get existing tags: %w", err) + } + + currentTags := k8s.GetNonAWSManagedTags(existingTags[resourceArn]) + filteredNewTags := k8s.GetNonAWSManagedTags(newTags) + + tagsToAdd, tagsToRemove := k8s.CalculateTagDifference(currentTags, filteredNewTags) + + if len(tagsToRemove) > 0 { + _, err := t.UntagResourcesWithContext(ctx, &taggingapi.UntagResourcesInput{ + ResourceARNList: []*string{aws.String(resourceArn)}, + TagKeys: tagsToRemove, + }) + if err != nil { + return fmt.Errorf("failed to remove tags: %w", err) + } + } + + if len(tagsToAdd) > 0 { + _, err := t.TagResourcesWithContext(ctx, &taggingapi.TagResourcesInput{ + ResourceARNList: []*string{aws.String(resourceArn)}, + Tags: tagsToAdd, + }) + if err != nil { + return fmt.Errorf("failed to add/update tags: %w", err) + } + } + + return nil +} + +func (t *latticeTagging) UpdateTags(ctx context.Context, resourceArn string, newTags Tags) error { + existingTags, err := t.ListTagsForResourceWithContext(ctx, &vpclattice.ListTagsForResourceInput{ + ResourceArn: aws.String(resourceArn), + }) + if err != nil { + return fmt.Errorf("failed to get existing tags: %w", err) + } + + currentTags := k8s.GetNonAWSManagedTags(existingTags.Tags) + filteredNewTags := k8s.GetNonAWSManagedTags(newTags) + + tagsToAdd, tagsToRemove := k8s.CalculateTagDifference(currentTags, filteredNewTags) + + if len(tagsToRemove) > 0 { + _, err := t.UntagResourceWithContext(ctx, &vpclattice.UntagResourceInput{ + ResourceArn: aws.String(resourceArn), + TagKeys: tagsToRemove, + }) + if err != nil { + return fmt.Errorf("failed to remove tags: %w", err) + } + } + + if len(tagsToAdd) > 0 { + _, err := t.TagResourceWithContext(ctx, &vpclattice.TagResourceInput{ + ResourceArn: aws.String(resourceArn), + Tags: tagsToAdd, + }) + if err != nil { + return fmt.Errorf("failed to add/update tags: %w", err) + } + } + return nil +} diff --git a/pkg/aws/services/tagging_mocks.go b/pkg/aws/services/tagging_mocks.go index 0e85d2e3..1f452a19 100644 --- a/pkg/aws/services/tagging_mocks.go +++ b/pkg/aws/services/tagging_mocks.go @@ -63,3 +63,17 @@ func (mr *MockTaggingMockRecorder) GetTagsForArns(arg0, arg1 interface{}) *gomoc mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTagsForArns", reflect.TypeOf((*MockTagging)(nil).GetTagsForArns), arg0, arg1) } + +// UpdateTags mocks base method. +func (m *MockTagging) UpdateTags(arg0 context.Context, arg1 string, arg2 map[string]*string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateTags", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateTags indicates an expected call of UpdateTags. +func (mr *MockTaggingMockRecorder) UpdateTags(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTags", reflect.TypeOf((*MockTagging)(nil).UpdateTags), arg0, arg1, arg2) +} diff --git a/pkg/aws/services/tagging_test.go b/pkg/aws/services/tagging_test.go index 47c88617..2ec02123 100644 --- a/pkg/aws/services/tagging_test.go +++ b/pkg/aws/services/tagging_test.go @@ -429,3 +429,126 @@ func Test_latticeTagging_FindResourcesByTags(t *testing.T) { }) } } + +func TestLatticeTagging_UpdateTags(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + resourceArn string + existingTags Tags + newTags Tags + expectedTagCalls int + expectedUntagCalls int + expectError bool + description string + }{ + { + name: "nil new tags removes all existing additional tags", + resourceArn: "arn:aws:vpc-lattice:us-west-2:123456789:service/svc-123", + existingTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + }, + newTags: nil, + expectedTagCalls: 0, + expectedUntagCalls: 1, + expectError: false, + description: "should remove all additional tags when newTags is nil", + }, + { + name: "add new tags when no existing additional tags", + resourceArn: "arn:aws:vpc-lattice:us-west-2:123456789:service/svc-123", + existingTags: Tags{ + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + }, + newTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expectedTagCalls: 1, + expectedUntagCalls: 0, + expectError: false, + description: "should add new tags when no existing additional tags", + }, + { + name: "update existing additional tags", + resourceArn: "arn:aws:vpc-lattice:us-west-2:123456789:service/svc-123", + existingTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("OldApp"), + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + }, + newTags: Tags{ + "Environment": aws.String("Prod"), + "Project": aws.String("NewApp"), + }, + expectedTagCalls: 1, + expectedUntagCalls: 0, + expectError: false, + description: "should update changed additional tag values", + }, + { + name: "no changes needed", + resourceArn: "arn:aws:vpc-lattice:us-west-2:123456789:service/svc-123", + existingTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + }, + newTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expectedTagCalls: 0, + expectedUntagCalls: 0, + expectError: false, + description: "should not make API calls when no changes needed", + }, + { + name: "filters out AWS managed tags from new tags", + resourceArn: "arn:aws:vpc-lattice:us-west-2:123456789:service/svc-123", + existingTags: Tags{}, + newTags: Tags{ + "application-networking.k8s.aws/ManagedBy": aws.String("test-override"), + "application-networking.k8s.aws/RouteType": aws.String("http"), + }, + expectedTagCalls: 0, + expectedUntagCalls: 0, + expectError: false, + description: "should filter out AWS managed tags from new tags, resulting in no API calls", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := gomock.NewController(t) + mockLattice := NewMockLattice(c) + + lt := &latticeTagging{ + Lattice: mockLattice, + } + + mockLattice.EXPECT().ListTagsForResourceWithContext(ctx, gomock.Any()). + Return(&vpclattice.ListTagsForResourceOutput{Tags: tt.existingTags}, nil).Times(1) + + if tt.expectedUntagCalls > 0 { + mockLattice.EXPECT().UntagResourceWithContext(ctx, gomock.Any()). + Return(nil, nil).Times(tt.expectedUntagCalls) + } + + if tt.expectedTagCalls > 0 { + mockLattice.EXPECT().TagResourceWithContext(ctx, gomock.Any()). + Return(nil, nil).Times(tt.expectedTagCalls) + } + + err := lt.UpdateTags(ctx, tt.resourceArn, tt.newTags) + + if tt.expectError { + assert.Error(t, err, tt.description) + } else { + assert.NoError(t, err, tt.description) + } + }) + } +} diff --git a/pkg/controllers/accesslogpolicy_controller.go b/pkg/controllers/accesslogpolicy_controller.go index deea9442..0dc650c7 100644 --- a/pkg/controllers/accesslogpolicy_controller.go +++ b/pkg/controllers/accesslogpolicy_controller.go @@ -42,6 +42,7 @@ import ( "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/config" + "github.com/aws/aws-application-networking-k8s/pkg/controllers/predicates" "github.com/aws/aws-application-networking-k8s/pkg/deploy" "github.com/aws/aws-application-networking-k8s/pkg/gateway" "github.com/aws/aws-application-networking-k8s/pkg/k8s" @@ -95,7 +96,7 @@ func RegisterAccessLogPolicyController( } builder := ctrl.NewControllerManagedBy(mgr). - For(&anv1alpha1.AccessLogPolicy{}, pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + For(&anv1alpha1.AccessLogPolicy{}, pkg_builder.WithPredicates(predicate.Or(predicate.GenerationChangedPredicate{}, predicates.AdditionalTagsAnnotationChangedPredicate))). Watches(&gwv1.Gateway{}, handler.EnqueueRequestsFromMapFunc(r.findImpactedAccessLogPolicies), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). Watches(&gwv1.HTTPRoute{}, handler.EnqueueRequestsFromMapFunc(r.findImpactedAccessLogPolicies), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). Watches(&gwv1.GRPCRoute{}, handler.EnqueueRequestsFromMapFunc(r.findImpactedAccessLogPolicies), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). diff --git a/pkg/controllers/predicates/additionaltags_predicate.go b/pkg/controllers/predicates/additionaltags_predicate.go new file mode 100644 index 00000000..0c4ecb23 --- /dev/null +++ b/pkg/controllers/predicates/additionaltags_predicate.go @@ -0,0 +1,31 @@ +package predicates + +import ( + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/predicate" + + "github.com/aws/aws-application-networking-k8s/pkg/k8s" +) + +var AdditionalTagsAnnotationChangedPredicate = predicate.Funcs{ + UpdateFunc: func(e event.UpdateEvent) bool { + oldAnnotations := e.ObjectOld.GetAnnotations() + newAnnotations := e.ObjectNew.GetAnnotations() + + oldAdditionalTags := getAdditionalTagsAnnotation(oldAnnotations) + newAdditionalTags := getAdditionalTagsAnnotation(newAnnotations) + + return oldAdditionalTags != newAdditionalTags + }, + CreateFunc: func(e event.CreateEvent) bool { + annotations := e.Object.GetAnnotations() + return getAdditionalTagsAnnotation(annotations) != "" + }, +} + +func getAdditionalTagsAnnotation(annotations map[string]string) string { + if annotations == nil { + return "" + } + return annotations[k8s.TagsAnnotationKey] +} diff --git a/pkg/controllers/predicates/additionaltags_predicate_test.go b/pkg/controllers/predicates/additionaltags_predicate_test.go new file mode 100644 index 00000000..d9bc29ac --- /dev/null +++ b/pkg/controllers/predicates/additionaltags_predicate_test.go @@ -0,0 +1,248 @@ +package predicates + +import ( + "testing" + + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/event" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" + + "github.com/aws/aws-application-networking-k8s/pkg/k8s" +) + +func TestAdditionalTagsAnnotationChangedPredicate_Update(t *testing.T) { + predicate := AdditionalTagsAnnotationChangedPredicate + + tests := []struct { + name string + oldRoute *gwv1.HTTPRoute + newRoute *gwv1.HTTPRoute + expected bool + }{ + { + name: "additional tags annotation added should trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + expected: true, + }, + { + name: "additional tags annotation removed should trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + }, + }, + expected: true, + }, + { + name: "additional tags annotation value changed should trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Prod,Project=MyApp", + }, + }, + }, + expected: true, + }, + { + name: "additional tags annotation unchanged should not trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + expected: false, + }, + { + name: "other annotation changes should not trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + "other-annotation": "value1", + k8s.TagsAnnotationKey: "Environment=Dev", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + "other-annotation": "value2", + k8s.TagsAnnotationKey: "Environment=Dev", + }, + }, + }, + expected: false, + }, + { + name: "no annotations to no annotations should not trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + }, + }, + expected: false, + }, + { + name: "empty additional tags to empty additional tags should not trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "", + }, + }, + }, + expected: false, + }, + { + name: "empty additional tags to non-empty should trigger reconcile", + oldRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "", + }, + }, + }, + newRoute: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Generation: 1, + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev", + }, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + updateEvent := event.UpdateEvent{ + ObjectOld: tt.oldRoute, + ObjectNew: tt.newRoute, + } + + result := predicate.Update(updateEvent) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdditionalTagsAnnotationChangedPredicate_Create(t *testing.T) { + predicate := AdditionalTagsAnnotationChangedPredicate + + tests := []struct { + name string + route *gwv1.HTTPRoute + expected bool + }{ + { + name: "create with additional tags annotation should trigger reconcile", + route: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + expected: true, + }, + { + name: "create without additional tags annotation should not trigger reconcile", + route: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{}, + }, + expected: false, + }, + { + name: "create with empty additional tags annotation should not trigger reconcile", + route: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "", + }, + }, + }, + expected: false, + }, + { + name: "create with other annotations but no additional tags should not trigger reconcile", + route: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "other-annotation": "value", + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createEvent := event.CreateEvent{ + Object: tt.route, + } + + result := predicate.Create(createEvent) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/controllers/route_controller.go b/pkg/controllers/route_controller.go index c0a78658..326d44d0 100644 --- a/pkg/controllers/route_controller.go +++ b/pkg/controllers/route_controller.go @@ -21,6 +21,7 @@ import ( "fmt" "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/predicate" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" @@ -115,7 +116,7 @@ func RegisterAllRouteControllers( svcImportEventHandler := eventhandlers.NewServiceImportEventHandler(log, mgrClient) builder := ctrl.NewControllerManagedBy(mgr). - For(routeInfo.gatewayApiType, builder.WithPredicates(predicates.NewRouteChangedPredicate())). + For(routeInfo.gatewayApiType, builder.WithPredicates(predicate.Or(predicates.NewRouteChangedPredicate(), predicates.AdditionalTagsAnnotationChangedPredicate))). Watches(&gwv1.Gateway{}, gwEventHandler). Watches(&corev1.Service{}, svcEventHandler.MapToRoute(routeInfo.routeType)). Watches(&anv1alpha1.ServiceImport{}, svcImportEventHandler.MapToRoute(routeInfo.routeType)). diff --git a/pkg/controllers/route_controller_test.go b/pkg/controllers/route_controller_test.go index f509c9fd..56f44c1a 100644 --- a/pkg/controllers/route_controller_test.go +++ b/pkg/controllers/route_controller_test.go @@ -194,6 +194,7 @@ func TestRouteReconciler_ReconcileCreates(t *testing.T) { }).AnyTimes() mockCloud.EXPECT().DefaultTags().Return(mocks.Tags{}).AnyTimes() mockCloud.EXPECT().DefaultTagsMergedWith(gomock.Any()).Return(mocks.Tags{}).AnyTimes() + mockCloud.EXPECT().MergeTags(gomock.Any(), gomock.Any()).Return(mocks.Tags{}).AnyTimes() // we expect a fair number of lattice calls mockLattice.EXPECT().ListTargetsAsList(gomock.Any(), gomock.Any()).Return( diff --git a/pkg/controllers/vpcassociationpolicy_controller.go b/pkg/controllers/vpcassociationpolicy_controller.go index f6cac4a9..97bbc38b 100644 --- a/pkg/controllers/vpcassociationpolicy_controller.go +++ b/pkg/controllers/vpcassociationpolicy_controller.go @@ -13,6 +13,7 @@ import ( anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" + "github.com/aws/aws-application-networking-k8s/pkg/controllers/predicates" deploy "github.com/aws/aws-application-networking-k8s/pkg/deploy/lattice" "github.com/aws/aws-application-networking-k8s/pkg/k8s" policy "github.com/aws/aws-application-networking-k8s/pkg/k8s/policyhelper" @@ -49,7 +50,7 @@ func RegisterVpcAssociationPolicyController(log gwlog.Logger, cloud pkg_aws.Clou } b := ctrl.NewControllerManagedBy(mgr). - For(&anv1alpha1.VpcAssociationPolicy{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})) + For(&anv1alpha1.VpcAssociationPolicy{}, builder.WithPredicates(predicate.Or(predicate.GenerationChangedPredicate{}, predicates.AdditionalTagsAnnotationChangedPredicate))) ph.AddWatchers(b, &gwv1.Gateway{}) return b.Complete(controller) } @@ -106,7 +107,10 @@ func (c *vpcAssociationPolicyReconciler) upsert(ctx context.Context, k8sPolicy * str := string(sg) return &str }) - snva, err := c.manager.UpsertVpcAssociation(ctx, snName, sgIds) + + additionalTags := k8s.GetAdditionalTagsFromAnnotations(ctx, k8sPolicy) + + snva, err := c.manager.UpsertVpcAssociation(ctx, snName, sgIds, additionalTags) if err != nil { return err } diff --git a/pkg/deploy/lattice/access_log_subscription_manager.go b/pkg/deploy/lattice/access_log_subscription_manager.go index e107a2f1..fb6fd144 100644 --- a/pkg/deploy/lattice/access_log_subscription_manager.go +++ b/pkg/deploy/lattice/access_log_subscription_manager.go @@ -51,6 +51,8 @@ func (m *defaultAccessLogSubscriptionManager) Create( lattice.AccessLogPolicyTagKey: aws.String(accessLogSubscription.Spec.ALPNamespacedName.String()), }) + tags = m.cloud.MergeTags(tags, accessLogSubscription.Spec.AdditionalTags) + createALSInput := &vpclattice.CreateAccessLogSubscriptionInput{ ResourceIdentifier: sourceArn, DestinationArn: &accessLogSubscription.Spec.DestinationArn, @@ -148,6 +150,11 @@ func (m *defaultAccessLogSubscriptionManager) Update( } updateALSOutput, err := vpcLatticeSess.UpdateAccessLogSubscriptionWithContext(ctx, updateALSInput) if err == nil { + err = m.cloud.Tagging().UpdateTags(ctx, *updateALSOutput.Arn, accessLogSubscription.Spec.AdditionalTags) + if err != nil { + return nil, fmt.Errorf("failed to update tags for access log subscription %s: %w", *updateALSOutput.Arn, err) + } + return &lattice.AccessLogSubscriptionStatus{ Arn: *updateALSOutput.Arn, }, nil diff --git a/pkg/deploy/lattice/access_log_subscription_manager_test.go b/pkg/deploy/lattice/access_log_subscription_manager_test.go index 4bca73aa..4fb2f64e 100644 --- a/pkg/deploy/lattice/access_log_subscription_manager_test.go +++ b/pkg/deploy/lattice/access_log_subscription_manager_test.go @@ -49,7 +49,8 @@ func TestAccessLogSubscriptionManager(t *testing.T) { defer c.Finish() ctx := context.TODO() mockLattice := services.NewMockLattice(c) - cloud := an_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := services.NewMockTagging(c) + cloud := an_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) expectedTags := cloud.DefaultTagsMergedWith(services.Tags{ lattice.AccessLogPolicyTagKey: aws.String(accessLogPolicyNamespacedName.String()), }) @@ -325,6 +326,8 @@ func TestAccessLogSubscriptionManager(t *testing.T) { mockLattice.EXPECT().FindServiceNetwork(ctx, sourceName).Return(serviceNetworkInfo, nil) mockLattice.EXPECT().UpdateAccessLogSubscriptionWithContext(ctx, updateALSInput).Return(updateALSOutput, nil) + mockTagging.EXPECT().UpdateTags(ctx, accessLogSubscriptionArn, gomock.Any()).Return(nil) + mgr := NewAccessLogSubscriptionManager(gwlog.FallbackLogger, cloud) resp, err := mgr.Update(ctx, accessLogSubscription) assert.Nil(t, err) @@ -527,3 +530,115 @@ func TestAccessLogSubscriptionManager(t *testing.T) { assert.Nil(t, err) }) } + +func Test_AccessLogSubscriptionManager_WithAdditionalTags_Create(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := services.NewMockLattice(c) + mockTagging := services.NewMockTagging(c) + cloud := an_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + accessLogSubscription := &lattice.AccessLogSubscription{ + Spec: lattice.AccessLogSubscriptionSpec{ + SourceType: lattice.ServiceNetworkSourceType, + SourceName: sourceName, + DestinationArn: s3DestinationArn, + ALPNamespacedName: accessLogPolicyNamespacedName, + EventType: core.CreateEvent, + AdditionalTags: services.Tags{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"ALSManager"}[0], + }, + }, + } + + serviceNetworkInfo := &services.ServiceNetworkInfo{ + SvcNetwork: vpclattice.ServiceNetworkSummary{ + Arn: aws.String(serviceNetworkArn), + Name: aws.String(sourceName), + }, + } + + baseTags := cloud.DefaultTagsMergedWith(services.Tags{ + lattice.AccessLogPolicyTagKey: aws.String(accessLogPolicyNamespacedName.String()), + }) + expectedTags := cloud.MergeTags(baseTags, accessLogSubscription.Spec.AdditionalTags) + + mockLattice.EXPECT().FindServiceNetwork(ctx, sourceName).Return(serviceNetworkInfo, nil) + mockLattice.EXPECT().CreateAccessLogSubscriptionWithContext(ctx, gomock.Any()).DoAndReturn( + func(ctx context.Context, input *vpclattice.CreateAccessLogSubscriptionInput, opts ...interface{}) (*vpclattice.CreateAccessLogSubscriptionOutput, error) { + assert.Equal(t, expectedTags, input.Tags, "ALS tags should include additional tags") + assert.Equal(t, serviceNetworkArn, *input.ResourceIdentifier) + assert.Equal(t, s3DestinationArn, *input.DestinationArn) + + return &vpclattice.CreateAccessLogSubscriptionOutput{ + Arn: aws.String(accessLogSubscriptionArn), + }, nil + }) + + mgr := NewAccessLogSubscriptionManager(gwlog.FallbackLogger, cloud) + resp, err := mgr.Create(ctx, accessLogSubscription) + assert.Nil(t, err) + assert.Equal(t, accessLogSubscriptionArn, resp.Arn) +} + +func Test_AccessLogSubscriptionManager_WithAdditionalTags_Update(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := services.NewMockLattice(c) + mockTagging := services.NewMockTagging(c) + cloud := an_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + accessLogSubscription := &lattice.AccessLogSubscription{ + Spec: lattice.AccessLogSubscriptionSpec{ + SourceType: lattice.ServiceNetworkSourceType, + SourceName: sourceName, + DestinationArn: s3DestinationArn, + ALPNamespacedName: accessLogPolicyNamespacedName, + EventType: core.UpdateEvent, + AdditionalTags: services.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"ALSUpdate"}[0], + }, + }, + Status: &lattice.AccessLogSubscriptionStatus{ + Arn: accessLogSubscriptionArn, + }, + } + + serviceNetworkInfo := &services.ServiceNetworkInfo{ + SvcNetwork: vpclattice.ServiceNetworkSummary{ + Arn: aws.String(serviceNetworkArn), + Name: aws.String(sourceName), + }, + } + + getALSInput := &vpclattice.GetAccessLogSubscriptionInput{ + AccessLogSubscriptionIdentifier: aws.String(accessLogSubscriptionArn), + } + getALSOutput := &vpclattice.GetAccessLogSubscriptionOutput{ + Arn: aws.String(accessLogSubscriptionArn), + ResourceArn: aws.String(serviceNetworkArn), + DestinationArn: aws.String(s3DestinationArn), + } + updateALSInput := &vpclattice.UpdateAccessLogSubscriptionInput{ + AccessLogSubscriptionIdentifier: aws.String(accessLogSubscriptionArn), + DestinationArn: aws.String(s3DestinationArn), + } + updateALSOutput := &vpclattice.UpdateAccessLogSubscriptionOutput{ + Arn: aws.String(accessLogSubscriptionArn), + } + + mockLattice.EXPECT().GetAccessLogSubscriptionWithContext(ctx, getALSInput).Return(getALSOutput, nil) + mockLattice.EXPECT().FindServiceNetwork(ctx, sourceName).Return(serviceNetworkInfo, nil) + mockLattice.EXPECT().UpdateAccessLogSubscriptionWithContext(ctx, updateALSInput).Return(updateALSOutput, nil) + + mockTagging.EXPECT().UpdateTags(ctx, accessLogSubscriptionArn, accessLogSubscription.Spec.AdditionalTags).Return(nil) + + mgr := NewAccessLogSubscriptionManager(gwlog.FallbackLogger, cloud) + resp, err := mgr.Update(ctx, accessLogSubscription) + assert.Nil(t, err) + assert.Equal(t, accessLogSubscriptionArn, resp.Arn) +} diff --git a/pkg/deploy/lattice/listener_manager.go b/pkg/deploy/lattice/listener_manager.go index 1688d95d..0776f3f3 100644 --- a/pkg/deploy/lattice/listener_manager.go +++ b/pkg/deploy/lattice/listener_manager.go @@ -76,6 +76,12 @@ func (d *defaultListenerManager) Upsert( Id: aws.StringValue(latticeListenerSummary.Id), ServiceId: latticeSvcId, } + + err = d.cloud.Tagging().UpdateTags(ctx, aws.StringValue(latticeListenerSummary.Arn), modelListener.Spec.AdditionalTags) + if err != nil { + return model.ListenerStatus{}, fmt.Errorf("failed to update tags for listener %s due to %s", aws.StringValue(latticeListenerSummary.Id), err) + } + if modelListener.Spec.Protocol != vpclattice.ListenerProtocolTlsPassthrough { // The only mutable field for lattice listener is defaultAction, for non-TLS_PASSTHROUGH listener, the defaultAction is always the FixedResponse 404. Don't need to update. return existingListenerStatus, nil @@ -96,6 +102,8 @@ func (d *defaultListenerManager) Upsert( func (d *defaultListenerManager) create(ctx context.Context, latticeSvcId string, modelListener *model.Listener, defaultAction *vpclattice.RuleAction) ( model.ListenerStatus, error) { + tags := d.cloud.MergeTags(d.cloud.DefaultTags(), modelListener.Spec.AdditionalTags) + listenerInput := vpclattice.CreateListenerInput{ ClientToken: nil, DefaultAction: defaultAction, @@ -103,7 +111,7 @@ func (d *defaultListenerManager) create(ctx context.Context, latticeSvcId string Port: aws.Int64(modelListener.Spec.Port), Protocol: aws.String(modelListener.Spec.Protocol), ServiceIdentifier: aws.String(latticeSvcId), - Tags: d.cloud.DefaultTags(), + Tags: tags, } resp, err := d.cloud.Lattice().CreateListenerWithContext(ctx, &listenerInput) diff --git a/pkg/deploy/lattice/listener_manager_test.go b/pkg/deploy/lattice/listener_manager_test.go index 99b4e6d1..25ce656a 100644 --- a/pkg/deploy/lattice/listener_manager_test.go +++ b/pkg/deploy/lattice/listener_manager_test.go @@ -132,7 +132,8 @@ func Test_UpsertListener_DoNotNeedToUpdateExistingHTTPAndHTTPSListener(t *testin } mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) mockLattice.EXPECT().ListListenersWithContext(ctx, gomock.Any()).Return( &vpclattice.ListListenersOutput{Items: []*vpclattice.ListenerSummary{ { @@ -143,6 +144,8 @@ func Test_UpsertListener_DoNotNeedToUpdateExistingHTTPAndHTTPSListener(t *testin }, }}, nil) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", gomock.Any()).Return(nil) + mockLattice.EXPECT().GetListenerWithContext(ctx, gomock.Any()).Times(0) mockLattice.EXPECT().UpdateListenerWithContext(ctx, gomock.Any()).Times(0) @@ -251,7 +254,8 @@ func Test_UpsertListener_Update_TLS_PASSTHROUGHListener(t *testing.T) { defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ml := &model.Listener{ @@ -275,6 +279,8 @@ func Test_UpsertListener_Update_TLS_PASSTHROUGHListener(t *testing.T) { }, }}, nil) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", gomock.Any()).Return(nil) + mockLattice.EXPECT().GetListenerWithContext(ctx, &vpclattice.GetListenerInput{ ServiceIdentifier: aws.String("svc-id"), ListenerIdentifier: aws.String("existing-listener-id"), @@ -653,3 +659,179 @@ func Test_ListenerManager_getLatticeListenerDefaultAction_TLS_PASSTHROUGH_Listen }) } } + +func Test_ListenerManager_WithAdditionalTags_Create(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + ml := &model.Listener{ + Spec: model.ListenerSpec{ + Protocol: vpclattice.ListenerProtocolHttp, + Port: 8080, + DefaultAction: &model.DefaultAction{ + FixedResponseStatusCode: aws.Int64(404), + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"ListenerManager"}[0], + }, + }, + } + + ms := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + mockLattice.EXPECT().ListListenersWithContext(ctx, gomock.Any()).Return( + &vpclattice.ListListenersOutput{Items: []*vpclattice.ListenerSummary{}}, nil) + + expectedTags := cloud.MergeTags(cloud.DefaultTags(), ml.Spec.AdditionalTags) + + mockLattice.EXPECT().CreateListenerWithContext(ctx, gomock.Any()).DoAndReturn( + func(ctx context.Context, input *vpclattice.CreateListenerInput, opts ...interface{}) (*vpclattice.CreateListenerOutput, error) { + assert.Equal(t, expectedTags, input.Tags, "Listener tags should include additional tags") + assert.Equal(t, int64(8080), *input.Port) + assert.Equal(t, "HTTP", *input.Protocol) + + return &vpclattice.CreateListenerOutput{ + Arn: aws.String("listener-arn"), + Id: aws.String("listener-id"), + Name: aws.String("listener-name"), + }, nil + }) + + lm := NewListenerManager(gwlog.FallbackLogger, cloud) + status, err := lm.Upsert(ctx, ml, ms) + assert.Nil(t, err) + assert.Equal(t, "listener-arn", status.ListenerArn) + assert.Equal(t, "listener-id", status.Id) +} + +func Test_ListenerManager_WithAdditionalTags_UpdateHTTP(t *testing.T) { + // Test case: update existing HTTP listener with additional tags (no action update needed) + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + ml := &model.Listener{ + Spec: model.ListenerSpec{ + Protocol: vpclattice.ListenerProtocolHttp, + Port: 8080, + DefaultAction: &model.DefaultAction{ + FixedResponseStatusCode: aws.Int64(404), + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"ListenerUpdate"}[0], + }, + }, + } + + ms := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + mockLattice.EXPECT().ListListenersWithContext(ctx, gomock.Any()).Return( + &vpclattice.ListListenersOutput{Items: []*vpclattice.ListenerSummary{ + { + Arn: aws.String("existing-arn"), + Id: aws.String("existing-id"), + Name: aws.String("existing-name"), + Port: aws.Int64(8080), + }, + }}, nil) + + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", ml.Spec.AdditionalTags).Return(nil) + + // No UpdateListener call expected for HTTP listeners (only tags are updated) + mockLattice.EXPECT().UpdateListenerWithContext(ctx, gomock.Any()).Times(0) + mockLattice.EXPECT().GetListenerWithContext(ctx, gomock.Any()).Times(0) + + lm := NewListenerManager(gwlog.FallbackLogger, cloud) + status, err := lm.Upsert(ctx, ml, ms) + assert.Nil(t, err) + assert.Equal(t, "existing-arn", status.ListenerArn) + assert.Equal(t, "existing-id", status.Id) +} + +func Test_ListenerManager_WithAdditionalTags_UpdateTLSPassthrough(t *testing.T) { + // Test case: update existing TLS_PASSTHROUGH listener with additional tags and action update + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + ml := &model.Listener{ + Spec: model.ListenerSpec{ + Protocol: vpclattice.ListenerProtocolTlsPassthrough, + Port: 443, + DefaultAction: &model.DefaultAction{ + Forward: &model.RuleAction{ + TargetGroups: []*model.RuleTargetGroup{ + { + LatticeTgId: "tg-id-1", + Weight: 100, + }, + }, + }, + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Staging"}[0], + "Project": &[]string{"TLSListener"}[0], + }, + }, + } + + ms := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + // Existing TLS_PASSTHROUGH listener found + mockLattice.EXPECT().ListListenersWithContext(ctx, gomock.Any()).Return( + &vpclattice.ListListenersOutput{Items: []*vpclattice.ListenerSummary{ + { + Arn: aws.String("existing-tls-arn"), + Id: aws.String("existing-tls-id"), + Name: aws.String("existing-tls-name"), + Port: aws.Int64(443), + Protocol: aws.String(vpclattice.ListenerProtocolTlsPassthrough), + }, + }}, nil) + + mockTagging.EXPECT().UpdateTags(ctx, "existing-tls-arn", ml.Spec.AdditionalTags).Return(nil) + + mockLattice.EXPECT().GetListenerWithContext(ctx, gomock.Any()).Return( + &vpclattice.GetListenerOutput{ + DefaultAction: &vpclattice.RuleAction{ + Forward: &vpclattice.ForwardAction{ + TargetGroups: []*vpclattice.WeightedTargetGroup{ + { + TargetGroupIdentifier: aws.String("old-tg-id"), + Weight: aws.Int64(100), + }, + }, + }, + }, + }, nil) + + // Mock UpdateListener call (action is different, so update needed) + mockLattice.EXPECT().UpdateListenerWithContext(ctx, gomock.Any()).Return( + &vpclattice.UpdateListenerOutput{ + Id: aws.String("existing-tls-id"), + }, nil) + + lm := NewListenerManager(gwlog.FallbackLogger, cloud) + status, err := lm.Upsert(ctx, ml, ms) + assert.Nil(t, err) + assert.Equal(t, "existing-tls-arn", status.ListenerArn) + assert.Equal(t, "existing-tls-id", status.Id) +} diff --git a/pkg/deploy/lattice/rule_manager.go b/pkg/deploy/lattice/rule_manager.go index 4232b82e..63cbcc7a 100644 --- a/pkg/deploy/lattice/rule_manager.go +++ b/pkg/deploy/lattice/rule_manager.go @@ -192,9 +192,9 @@ func (r *defaultRuleManager) Upsert( } if matchingRule == nil { - return r.create(ctx, currentLatticeRules, latticeRuleFromModel, latticeServiceId, latticeListenerId) + return r.create(ctx, currentLatticeRules, latticeRuleFromModel, latticeServiceId, latticeListenerId, modelRule) } else { - return r.updateIfNeeded(ctx, latticeRuleFromModel, matchingRule, latticeServiceId, latticeListenerId) + return r.updateIfNeeded(ctx, latticeRuleFromModel, matchingRule, latticeServiceId, latticeListenerId, modelRule) } } @@ -204,6 +204,7 @@ func (r *defaultRuleManager) updateIfNeeded( matchingRule *vpclattice.GetRuleOutput, latticeSvcId string, latticeListenerId string, + modelRule *model.Rule, ) (model.RuleStatus, error) { updatedRuleStatus := model.RuleStatus{ Name: aws.StringValue(matchingRule.Name), @@ -214,6 +215,11 @@ func (r *defaultRuleManager) updateIfNeeded( Priority: aws.Int64Value(matchingRule.Priority), } + err := r.cloud.Tagging().UpdateTags(ctx, aws.StringValue(matchingRule.Arn), modelRule.Spec.AdditionalTags) + if err != nil { + return model.RuleStatus{}, fmt.Errorf("failed to update tags for rule %s: %w", aws.StringValue(matchingRule.Id), err) + } + // we already validated Match, if Action is also the same then no updates required updateNeeded := !reflect.DeepEqual(ruleToUpdate.Action, matchingRule.Action) if !updateNeeded { @@ -234,7 +240,7 @@ func (r *defaultRuleManager) updateIfNeeded( Priority: ruleToUpdate.Priority, } - _, err := r.cloud.Lattice().UpdateRuleWithContext(ctx, &uri) + _, err = r.cloud.Lattice().UpdateRuleWithContext(ctx, &uri) if err != nil { return model.RuleStatus{}, fmt.Errorf("failed UpdateRule %d for %s, %s due to %s", ruleToUpdate.Priority, latticeListenerId, latticeSvcId, err) @@ -250,6 +256,7 @@ func (r *defaultRuleManager) create( ruleToCreate *vpclattice.GetRuleOutput, latticeSvcId string, latticeListenerId string, + modelRule *model.Rule, ) (model.RuleStatus, error) { // when we create a rule, we just pick an available priority so we can // successfully create the rule. After all rules are created, we update @@ -261,6 +268,8 @@ func (r *defaultRuleManager) create( } ruleToCreate.Priority = aws.Int64(priority) + tags := r.cloud.MergeTags(r.cloud.DefaultTags(), modelRule.Spec.AdditionalTags) + cri := vpclattice.CreateRuleInput{ Action: ruleToCreate.Action, ServiceIdentifier: aws.String(latticeSvcId), @@ -268,7 +277,7 @@ func (r *defaultRuleManager) create( Match: ruleToCreate.Match, Name: ruleToCreate.Name, Priority: ruleToCreate.Priority, - Tags: r.cloud.DefaultTags(), + Tags: tags, } res, err := r.cloud.Lattice().CreateRuleWithContext(ctx, &cri) diff --git a/pkg/deploy/lattice/rule_manager_test.go b/pkg/deploy/lattice/rule_manager_test.go index 3d210bbd..44196846 100644 --- a/pkg/deploy/lattice/rule_manager_test.go +++ b/pkg/deploy/lattice/rule_manager_test.go @@ -2,6 +2,8 @@ package lattice import ( "context" + "testing" + pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" mocks "github.com/aws/aws-application-networking-k8s/pkg/aws/services" model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice" @@ -10,7 +12,6 @@ import ( "github.com/aws/aws-sdk-go/service/vpclattice" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "testing" ) func Test_Create(t *testing.T) { @@ -144,6 +145,9 @@ func Test_Create(t *testing.T) { }) t.Run("test update method match", func(t *testing.T) { + mockTagging := mocks.NewMockTagging(c) + cloudWithTagging := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return( []*vpclattice.GetRuleOutput{ { @@ -162,6 +166,8 @@ func Test_Create(t *testing.T) { }, }, nil) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", gomock.Any()).Return(nil) + mockLattice.EXPECT().UpdateRuleWithContext(ctx, gomock.Any()).Return( &vpclattice.UpdateRuleOutput{ Arn: aws.String("existing-arn"), @@ -169,13 +175,16 @@ func Test_Create(t *testing.T) { Name: aws.String("existing-name"), }, nil) - rm := NewRuleManager(gwlog.FallbackLogger, cloud) + rm := NewRuleManager(gwlog.FallbackLogger, cloudWithTagging) ruleStatus, err := rm.Upsert(ctx, r, l, svc) assert.Nil(t, err) assert.Equal(t, "existing-arn", ruleStatus.Arn) }) t.Run("test update path match", func(t *testing.T) { + mockTagging := mocks.NewMockTagging(c) + cloudWithTagging := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return( []*vpclattice.GetRuleOutput{ { @@ -200,6 +209,8 @@ func Test_Create(t *testing.T) { }, }, nil) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", gomock.Any()).Return(nil) + mockLattice.EXPECT().UpdateRuleWithContext(ctx, gomock.Any()).Return( &vpclattice.UpdateRuleOutput{ Arn: aws.String("existing-arn"), @@ -207,13 +218,16 @@ func Test_Create(t *testing.T) { Name: aws.String("existing-name"), }, nil) - rm := NewRuleManager(gwlog.FallbackLogger, cloud) + rm := NewRuleManager(gwlog.FallbackLogger, cloudWithTagging) ruleStatus, err := rm.Upsert(ctx, r2, l, svc) assert.Nil(t, err) assert.Equal(t, "existing-arn", ruleStatus.Arn) }) t.Run("test update - nothing to do", func(t *testing.T) { + mockTagging := mocks.NewMockTagging(c) + cloudWithTagging := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return( []*vpclattice.GetRuleOutput{ { @@ -239,7 +253,9 @@ func Test_Create(t *testing.T) { }, }, nil) // <-- should be an exact match, no update required - rm := NewRuleManager(gwlog.FallbackLogger, cloud) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", gomock.Any()).Return(nil) + + rm := NewRuleManager(gwlog.FallbackLogger, cloudWithTagging) ruleStatus, err := rm.Upsert(ctx, r, l, svc) assert.Nil(t, err) assert.Equal(t, "existing-arn", ruleStatus.Arn) @@ -399,3 +415,213 @@ func Test_UpdatePriorities(t *testing.T) { err := rm.UpdatePriorities(ctx, "svc-id", "l-id", rules) assert.Nil(t, err) } + +func Test_RuleManager_WithAdditionalTags_Create(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + svc := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + l := &model.Listener{ + Spec: model.ListenerSpec{ + Port: 80, + Protocol: "HTTP", + }, + Status: &model.ListenerStatus{Id: "listener-id"}, + } + + r := &model.Rule{ + Spec: model.RuleSpec{ + Priority: 1, + Method: "POST", + Action: model.RuleAction{ + TargetGroups: []*model.RuleTargetGroup{ + { + LatticeTgId: "tg-id", + Weight: 1, + }, + }, + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"RuleManager"}[0], + }, + }, + } + + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return([]*vpclattice.GetRuleOutput{}, nil) + + expectedTags := cloud.MergeTags(cloud.DefaultTags(), r.Spec.AdditionalTags) + + mockLattice.EXPECT().CreateRuleWithContext(ctx, gomock.Any()).DoAndReturn( + func(ctx context.Context, input *vpclattice.CreateRuleInput, i ...interface{}) (*vpclattice.CreateRuleOutput, error) { + assert.Equal(t, expectedTags, input.Tags, "Rule tags should include additional tags") + + return &vpclattice.CreateRuleOutput{ + Arn: aws.String("arn"), + Id: aws.String("id"), + Name: aws.String("name"), + }, nil + }) + + rm := NewRuleManager(gwlog.FallbackLogger, cloud) + ruleStatus, err := rm.Upsert(ctx, r, l, svc) + assert.Nil(t, err) + assert.Equal(t, "arn", ruleStatus.Arn) +} + +func Test_RuleManager_WithAdditionalTags_Update(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + svc := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + l := &model.Listener{ + Spec: model.ListenerSpec{ + Port: 80, + Protocol: "HTTP", + }, + Status: &model.ListenerStatus{Id: "listener-id"}, + } + + r := &model.Rule{ + Spec: model.RuleSpec{ + Priority: 1, + Method: "POST", + Action: model.RuleAction{ + TargetGroups: []*model.RuleTargetGroup{ + { + LatticeTgId: "tg-id", + Weight: 1, + }, + }, + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"RuleUpdate"}[0], + }, + }, + } + + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return( + []*vpclattice.GetRuleOutput{ + { + Id: aws.String("existing-id"), + Arn: aws.String("existing-arn"), + Match: &vpclattice.RuleMatch{ + HttpMatch: &vpclattice.HttpMatch{ + Method: aws.String("POST"), + }, + }, + Action: &vpclattice.RuleAction{ + FixedResponse: &vpclattice.FixedResponseAction{}, // Different action will trigger update + }, + Name: aws.String("existing-name"), + Priority: aws.Int64(1), + }, + }, nil) + + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", r.Spec.AdditionalTags).Return(nil) + + mockLattice.EXPECT().UpdateRuleWithContext(ctx, gomock.Any()).Return( + &vpclattice.UpdateRuleOutput{ + Arn: aws.String("existing-arn"), + Id: aws.String("existing-id"), + Name: aws.String("existing-name"), + }, nil) + + rm := NewRuleManager(gwlog.FallbackLogger, cloud) + ruleStatus, err := rm.Upsert(ctx, r, l, svc) + assert.Nil(t, err) + assert.Equal(t, "existing-arn", ruleStatus.Arn) +} + +func Test_RuleManager_WithAdditionalTags_UpdateNoActionChange(t *testing.T) { + // Test case: update existing rule with additional tags but no action change + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + svc := &model.Service{ + Status: &model.ServiceStatus{Id: "svc-id"}, + } + + l := &model.Listener{ + Spec: model.ListenerSpec{ + Port: 80, + Protocol: "HTTP", + }, + Status: &model.ListenerStatus{Id: "listener-id"}, + } + + r := &model.Rule{ + Spec: model.RuleSpec{ + Priority: 1, + Method: "POST", + Action: model.RuleAction{ + TargetGroups: []*model.RuleTargetGroup{ + { + LatticeTgId: "tg-id", + Weight: 1, + }, + }, + }, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Staging"}[0], + "Project": &[]string{"RuleNoUpdate"}[0], + }, + }, + } + + // Existing rule with exact match (no action update needed) + mockLattice.EXPECT().GetRulesAsList(ctx, gomock.Any()).Return( + []*vpclattice.GetRuleOutput{ + { + Id: aws.String("existing-id"), + Arn: aws.String("existing-arn"), + Match: &vpclattice.RuleMatch{ + HttpMatch: &vpclattice.HttpMatch{ + Method: aws.String("POST"), + }, + }, + Action: &vpclattice.RuleAction{ + Forward: &vpclattice.ForwardAction{ + TargetGroups: []*vpclattice.WeightedTargetGroup{ + { + TargetGroupIdentifier: aws.String("tg-id"), + Weight: aws.Int64(1), + }, + }, + }, + }, + Name: aws.String("existing-name"), + Priority: aws.Int64(1), + }, + }, nil) + + // Mock UpdateTags call for additional tags (should still be called even if no action update) + mockTagging.EXPECT().UpdateTags(ctx, "existing-arn", r.Spec.AdditionalTags).Return(nil) + + // No UpdateRule call expected since action matches + mockLattice.EXPECT().UpdateRuleWithContext(ctx, gomock.Any()).Times(0) + + rm := NewRuleManager(gwlog.FallbackLogger, cloud) + ruleStatus, err := rm.Upsert(ctx, r, l, svc) + assert.Nil(t, err) + assert.Equal(t, "existing-arn", ruleStatus.Arn) +} diff --git a/pkg/deploy/lattice/service_manager.go b/pkg/deploy/lattice/service_manager.go index 91fbca5b..223b4521 100644 --- a/pkg/deploy/lattice/service_manager.go +++ b/pkg/deploy/lattice/service_manager.go @@ -62,7 +62,7 @@ func (m *defaultServiceManager) createServiceAndAssociate(ctx context.Context, s // Only create associations if service networks are specified (not standalone) if len(svc.Spec.ServiceNetworkNames) > 0 { for _, snName := range svc.Spec.ServiceNetworkNames { - err = m.createAssociation(ctx, createSvcResp.Id, snName) + err = m.createAssociation(ctx, createSvcResp.Id, snName, svc) if err != nil { return ServiceInfo{}, err } @@ -76,16 +76,18 @@ func (m *defaultServiceManager) createServiceAndAssociate(ctx context.Context, s return svcInfo, nil } -func (m *defaultServiceManager) createAssociation(ctx context.Context, svcId *string, snName string) error { +func (m *defaultServiceManager) createAssociation(ctx context.Context, svcId *string, snName string, svc *Service) error { snInfo, err := m.cloud.Lattice().FindServiceNetwork(ctx, snName) if err != nil { return err } + tags := m.cloud.MergeTags(m.cloud.DefaultTags(), svc.Spec.AdditionalTags) + assocReq := &CreateSnSvcAssocReq{ ServiceIdentifier: svcId, ServiceNetworkIdentifier: snInfo.SvcNetwork.Id, - Tags: m.cloud.DefaultTags(), + Tags: tags, } assocResp, err := m.cloud.Lattice().CreateServiceNetworkServiceAssociationWithContext(ctx, assocReq) if err != nil { @@ -104,9 +106,11 @@ func (m *defaultServiceManager) createAssociation(ctx context.Context, svcId *st func (m *defaultServiceManager) newCreateSvcReq(svc *Service) *CreateSvcReq { svcName := svc.LatticeServiceName() + tags := m.cloud.MergeTags(m.cloud.DefaultTagsMergedWith(svc.Spec.ToTags()), svc.Spec.AdditionalTags) + req := &vpclattice.CreateServiceInput{ Name: &svcName, - Tags: m.cloud.DefaultTagsMergedWith(svc.Spec.ToTags()), + Tags: tags, } if svc.Spec.CustomerDomainName != "" { @@ -170,6 +174,11 @@ func (m *defaultServiceManager) checkAndUpdateTags(ctx context.Context, svc *Ser } func (m *defaultServiceManager) updateServiceAndAssociations(ctx context.Context, svc *Service, svcSum *SvcSummary) (ServiceInfo, error) { + err := m.cloud.Tagging().UpdateTags(ctx, aws.StringValue(svcSum.Arn), svc.Spec.AdditionalTags) + if err != nil { + return ServiceInfo{}, fmt.Errorf("failed to update tags for service %s: %w", aws.StringValue(svcSum.Id), err) + } + if svc.Spec.CustomerCertARN != "" { updReq := &UpdateSvcReq{ CertificateArn: aws.String(svc.Spec.CustomerCertARN), @@ -181,7 +190,7 @@ func (m *defaultServiceManager) updateServiceAndAssociations(ctx context.Context } } - err := m.updateAssociations(ctx, svc, svcSum) + err = m.updateAssociations(ctx, svc, svcSum) if err != nil { return ServiceInfo{}, err } @@ -215,17 +224,24 @@ func (m *defaultServiceManager) updateAssociations(ctx context.Context, svc *Ser return err } - toCreate, toDelete, err := associationsDiff(svc, assocs) + toCreate, toDelete, toUpdate, err := associationsDiff(svc, assocs) if err != nil { return err } for _, snName := range toCreate { - err := m.createAssociation(ctx, svcSum.Id, snName) + err := m.createAssociation(ctx, svcSum.Id, snName, svc) if err != nil { return err } } + for _, assoc := range toUpdate { + err := m.cloud.Tagging().UpdateTags(ctx, aws.StringValue(assoc.Arn), svc.Spec.AdditionalTags) + if err != nil { + return fmt.Errorf("failed to update tags for association %s: %w", aws.StringValue(assoc.Arn), err) + } + } + for _, assoc := range toDelete { isManaged, err := m.cloud.IsArnManaged(ctx, *assoc.Arn) if err != nil { @@ -265,9 +281,10 @@ func handleCreateAssociationResp(resp *CreateSnSvcAssocResp) error { // compare current sn-svc associations with new ones, // returns 2 slices: toCreate with SN names and toDelete with current associations // if assoc should be created but current state is in deletion we should retry -func associationsDiff(svc *Service, curAssocs []*SnSvcAssocSummary) ([]string, []SnSvcAssocSummary, error) { +func associationsDiff(svc *Service, curAssocs []*SnSvcAssocSummary) ([]string, []SnSvcAssocSummary, []*SnSvcAssocSummary, error) { toCreate := []string{} toDelete := []SnSvcAssocSummary{} + toUpdate := []*SnSvcAssocSummary{} // create two Sets and find Difference New-Old->toCreate and Old-New->toDelete newSet := map[string]bool{} @@ -283,12 +300,14 @@ func associationsDiff(svc *Service, curAssocs []*SnSvcAssocSummary) ([]string, [ oldSn, ok := oldSet[newSn] if !ok { toCreate = append(toCreate, newSn) + } else { + toUpdate = append(toUpdate, &oldSn) } // assoc should exists but in deletion state, will retry later to re-create // TODO: we should have something more lightweight, retrying full reconciliation looks to heavy if aws.StringValue(oldSn.Status) == vpclattice.ServiceNetworkServiceAssociationStatusDeleteInProgress { - return nil, nil, fmt.Errorf("%w: want to associate sn: %s to svc: %s, but status is: %s", + return nil, nil, nil, fmt.Errorf("%w: want to associate sn: %s to svc: %s, but status is: %s", lattice_runtime.NewRetryError(), newSn, svc.LatticeServiceName(), *oldSn.Status) } // TODO: if assoc in failed state, may be we should try to re-create? @@ -301,7 +320,7 @@ func associationsDiff(svc *Service, curAssocs []*SnSvcAssocSummary) ([]string, [ } } - return toCreate, toDelete, nil + return toCreate, toDelete, toUpdate, nil } func (m *defaultServiceManager) deleteAllAssociations(ctx context.Context, svc *SvcSummary) error { diff --git a/pkg/deploy/lattice/service_manager_test.go b/pkg/deploy/lattice/service_manager_test.go index bbd0a37b..fb56864c 100644 --- a/pkg/deploy/lattice/service_manager_test.go +++ b/pkg/deploy/lattice/service_manager_test.go @@ -22,8 +22,9 @@ func TestServiceManagerInteg(t *testing.T) { defer c.Finish() mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) cfg := pkg_aws.CloudConfig{VpcId: "vpc-id", AccountId: "account-id"} - cl := pkg_aws.NewDefaultCloud(mockLattice, cfg) + cl := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, cfg) ctx := context.Background() m := NewServiceManager(gwlog.FallbackLogger, cl) @@ -141,6 +142,8 @@ func TestServiceManagerInteg(t *testing.T) { }). Times(1) // for service only + mockTagging.EXPECT().UpdateTags(ctx, "svc-arn", gomock.Any()).Return(nil) + // 3 associations exist in lattice: keep, delete, and foreign mockLattice.EXPECT(). ListServiceNetworkServiceAssociationsAsList(gomock.Any(), gomock.Any()). @@ -158,6 +161,8 @@ func TestServiceManagerInteg(t *testing.T) { }). Times(1) + mockTagging.EXPECT().UpdateTags(ctx, "sn-keep-arn", gomock.Any()).Return(nil) + // return managed by gateway controller tags for all associations except for foreign and foreign ram mockLattice.EXPECT().ListTagsForResourceWithContext(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, req *vpclattice.ListTagsForResourceInput, _ ...interface{}) (*vpclattice.ListTagsForResourceOutput, error) { @@ -254,6 +259,8 @@ func TestServiceManagerInteg(t *testing.T) { Tags: svc.Spec.ToTags(), })).Times(1) + mockTagging.EXPECT().UpdateTags(ctx, "svc-arn", gomock.Any()).Return(nil) + mockLattice.EXPECT().ListServiceNetworkServiceAssociationsAsList(gomock.Any(), gomock.Any()).Times(1) mockLattice.EXPECT(). CreateServiceNetworkServiceAssociationWithContext(gomock.Any(), gomock.Any()). @@ -351,6 +358,8 @@ func TestServiceManagerInteg(t *testing.T) { }). Times(1) // for service only + mockTagging.EXPECT().UpdateTags(ctx, "standalone-svc-arn", gomock.Any()).Return(nil) + // no associations exist for standalone service mockLattice.EXPECT(). ListServiceNetworkServiceAssociationsAsList(gomock.Any(), gomock.Any()). @@ -522,7 +531,7 @@ func TestSnSvcAssocsDiff(t *testing.T) { ServiceNetworkNames: []string{"sn"}, }} assocs := []*SnSvcAssocSummary{{ServiceNetworkName: aws.String("sn")}} - c, d, err := associationsDiff(svc, assocs) + c, d, _, err := associationsDiff(svc, assocs) assert.Nil(t, err) assert.Equal(t, 0, len(c)) assert.Equal(t, 0, len(d)) @@ -533,7 +542,7 @@ func TestSnSvcAssocsDiff(t *testing.T) { ServiceNetworkNames: []string{"sn1", "sn2"}, }} assocs := []*SnSvcAssocSummary{} - c, d, _ := associationsDiff(svc, assocs) + c, d, _, _ := associationsDiff(svc, assocs) assert.Equal(t, 2, len(c)) assert.Equal(t, 0, len(d)) }) @@ -544,7 +553,7 @@ func TestSnSvcAssocsDiff(t *testing.T) { {ServiceNetworkName: aws.String("sn1")}, {ServiceNetworkName: aws.String("sn2")}, } - c, d, _ := associationsDiff(svc, assocs) + c, d, _, _ := associationsDiff(svc, assocs) assert.Equal(t, 0, len(c)) assert.Equal(t, 2, len(d)) }) @@ -557,7 +566,7 @@ func TestSnSvcAssocsDiff(t *testing.T) { {ServiceNetworkName: aws.String("sn1")}, {ServiceNetworkName: aws.String("sn4")}, } - c, d, _ := associationsDiff(svc, assocs) + c, d, _, _ := associationsDiff(svc, assocs) assert.Equal(t, 2, len(c)) assert.Equal(t, 1, len(d)) }) @@ -570,9 +579,273 @@ func TestSnSvcAssocsDiff(t *testing.T) { ServiceNetworkName: aws.String("sn"), Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusDeleteInProgress), }} - _, _, err := associationsDiff(svc, assocs) + _, _, _, err := associationsDiff(svc, assocs) var requeueNeededAfter *lattice_runtime.RequeueNeededAfter assert.True(t, errors.As(err, &requeueNeededAfter)) }) } + +func Test_ServiceManager_WithAdditionalTags_CreateService(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cfg := pkg_aws.CloudConfig{VpcId: "vpc-id", AccountId: "account-id"} + cl := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, cfg) + ctx := context.Background() + m := NewServiceManager(gwlog.FallbackLogger, cl) + + svc := &Service{ + Spec: model.ServiceSpec{ + ServiceTagFields: model.ServiceTagFields{ + RouteName: "svc-with-tags", + RouteNamespace: "ns", + RouteType: core.HttpRouteType, + }, + ServiceNetworkNames: []string{"sn"}, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"ServiceManager"}[0], + }, + }, + } + + mockLattice.EXPECT(). + FindService(gomock.Any(), gomock.Any()). + Return(nil, mocks.NewNotFoundError("", "")). + Times(1) + + expectedServiceTags := cl.MergeTags(cl.DefaultTagsMergedWith(svc.Spec.ToTags()), svc.Spec.AdditionalTags) + + mockLattice.EXPECT(). + CreateServiceWithContext(gomock.Any(), gomock.Any()). + DoAndReturn( + func(_ context.Context, req *CreateSvcReq, _ ...interface{}) (*CreateSvcResp, error) { + assert.Equal(t, svc.LatticeServiceName(), *req.Name) + assert.Equal(t, expectedServiceTags, req.Tags, "Service tags should include additional tags") + return &CreateSvcResp{ + Arn: aws.String("arn"), + DnsEntry: &vpclattice.DnsEntry{DomainName: aws.String("dns")}, + Id: aws.String("svc-id"), + }, nil + }). + Times(1) + + expectedAssocTags := cl.MergeTags(cl.DefaultTags(), svc.Spec.AdditionalTags) + + mockLattice.EXPECT(). + CreateServiceNetworkServiceAssociationWithContext(gomock.Any(), gomock.Any()). + DoAndReturn( + func(_ context.Context, req *CreateSnSvcAssocReq, _ ...interface{}) (*CreateSnSvcAssocResp, error) { + assert.Equal(t, "sn-id", *req.ServiceNetworkIdentifier) + assert.Equal(t, expectedAssocTags, req.Tags, "Association tags should include additional tags") + return &CreateSnSvcAssocResp{ + Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusActive), + }, nil + }). + Times(1) + + mockLattice.EXPECT(). + FindServiceNetwork(gomock.Any(), gomock.Any()). + DoAndReturn( + func(ctx context.Context, name string) (*mocks.ServiceNetworkInfo, error) { + return &mocks.ServiceNetworkInfo{ + SvcNetwork: vpclattice.ServiceNetworkSummary{ + Arn: aws.String("sn-arn"), + Id: aws.String("sn-id"), + Name: aws.String(name), + }, + Tags: nil, + }, nil + }). + Times(1) + + status, err := m.Upsert(ctx, svc) + assert.Nil(t, err) + assert.Equal(t, "arn", status.Arn) +} + +func Test_ServiceManager_WithAdditionalTags_UpdateService(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cfg := pkg_aws.CloudConfig{VpcId: "vpc-id", AccountId: "account-id"} + cl := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, cfg) + ctx := context.Background() + m := NewServiceManager(gwlog.FallbackLogger, cl) + + svc := &Service{ + Spec: model.ServiceSpec{ + ServiceTagFields: model.ServiceTagFields{ + RouteName: "svc-update", + RouteNamespace: "ns", + RouteType: core.HttpRouteType, + }, + ServiceNetworkNames: []string{"sn-keep"}, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"UpdateTest"}[0], + }, + }, + } + + mockLattice.EXPECT(). + FindService(gomock.Any(), gomock.Any()). + Return(&vpclattice.ServiceSummary{ + Arn: aws.String("svc-arn"), + Id: aws.String("svc-id"), + Name: aws.String(svc.LatticeServiceName()), + }, nil). + Times(1) + + mockLattice.EXPECT().ListTagsForResourceWithContext(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req *vpclattice.ListTagsForResourceInput, _ ...interface{}) (*vpclattice.ListTagsForResourceOutput, error) { + return &vpclattice.ListTagsForResourceOutput{ + Tags: cl.DefaultTagsMergedWith(svc.Spec.ToTags()), + }, nil + }). + Times(1) + + mockTagging.EXPECT().UpdateTags(ctx, "svc-arn", svc.Spec.AdditionalTags).Return(nil) + + mockLattice.EXPECT(). + ListServiceNetworkServiceAssociationsAsList(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req *ListSnSvcAssocsReq) ([]*SnSvcAssocSummary, error) { + return []*SnSvcAssocSummary{ + { + Arn: aws.String("assoc-arn"), + Id: aws.String("assoc-id"), + ServiceNetworkName: aws.String("sn-keep"), + Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusActive), + }, + }, nil + }). + Times(1) + + mockTagging.EXPECT().UpdateTags(ctx, "assoc-arn", svc.Spec.AdditionalTags).Return(nil) + + status, err := m.Upsert(ctx, svc) + assert.Nil(t, err) + assert.Equal(t, "svc-arn", status.Arn) +} + +func Test_ServiceManager_WithAdditionalTags_UpdateAssociations(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cfg := pkg_aws.CloudConfig{VpcId: "vpc-id", AccountId: "account-id"} + cl := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, cfg) + ctx := context.Background() + m := NewServiceManager(gwlog.FallbackLogger, cl) + + svc := &Service{ + Spec: model.ServiceSpec{ + ServiceTagFields: model.ServiceTagFields{ + RouteName: "complex-svc", + RouteNamespace: "ns", + RouteType: core.HttpRouteType, + }, + ServiceNetworkNames: []string{"sn-keep", "sn-create"}, + AdditionalTags: mocks.Tags{ + "Environment": &[]string{"Complex"}[0], + "Project": &[]string{"UpdateAssocTest"}[0], + }, + }, + } + + mockLattice.EXPECT(). + FindService(gomock.Any(), gomock.Any()). + Return(&vpclattice.ServiceSummary{ + Arn: aws.String("svc-arn"), + Id: aws.String("svc-id"), + Name: aws.String(svc.LatticeServiceName()), + }, nil). + Times(1) + + mockLattice.EXPECT().ListTagsForResourceWithContext(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req *vpclattice.ListTagsForResourceInput, _ ...interface{}) (*vpclattice.ListTagsForResourceOutput, error) { + return &vpclattice.ListTagsForResourceOutput{ + Tags: cl.DefaultTagsMergedWith(svc.Spec.ToTags()), + }, nil + }). + Times(1) + + mockTagging.EXPECT().UpdateTags(ctx, "svc-arn", svc.Spec.AdditionalTags).Return(nil) + + mockLattice.EXPECT(). + ListServiceNetworkServiceAssociationsAsList(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req *ListSnSvcAssocsReq) ([]*SnSvcAssocSummary, error) { + return []*SnSvcAssocSummary{ + { + Arn: aws.String("sn-keep-arn"), + Id: aws.String("sn-keep-id"), + ServiceNetworkName: aws.String("sn-keep"), + Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusActive), + }, + { + Arn: aws.String("sn-delete-arn"), + Id: aws.String("sn-delete-id"), + ServiceNetworkName: aws.String("sn-delete"), + Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusActive), + }, + }, nil + }). + Times(1) + + mockTagging.EXPECT().UpdateTags(ctx, "sn-keep-arn", svc.Spec.AdditionalTags).Return(nil) + + mockLattice.EXPECT().ListTagsForResourceWithContext(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req *vpclattice.ListTagsForResourceInput, _ ...interface{}) (*vpclattice.ListTagsForResourceOutput, error) { + return &vpclattice.ListTagsForResourceOutput{ + Tags: cl.DefaultTags(), + }, nil + }). + Times(1) + + expectedAssocTags := cl.MergeTags(cl.DefaultTags(), svc.Spec.AdditionalTags) + mockLattice.EXPECT(). + CreateServiceNetworkServiceAssociationWithContext(gomock.Any(), gomock.Any()). + DoAndReturn( + func(_ context.Context, req *CreateSnSvcAssocReq, _ ...interface{}) (*CreateSnSvcAssocResp, error) { + assert.Equal(t, "sn-create-id", *req.ServiceNetworkIdentifier) + assert.Equal(t, expectedAssocTags, req.Tags, "New association should include additional tags") + return &CreateSnSvcAssocResp{ + Status: aws.String(vpclattice.ServiceNetworkServiceAssociationStatusActive), + }, nil + }). + Times(1) + + mockLattice.EXPECT(). + DeleteServiceNetworkServiceAssociationWithContext(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn( + func(_ context.Context, req *DelSnSvcAssocReq, _ ...interface{}) (*DelSnSvcAssocResp, error) { + assert.Equal(t, "sn-delete-arn", *req.ServiceNetworkServiceAssociationIdentifier) + return &DelSnSvcAssocResp{}, nil + }). + Times(1) + + mockLattice.EXPECT(). + FindServiceNetwork(gomock.Any(), gomock.Any()). + DoAndReturn( + func(ctx context.Context, name string) (*mocks.ServiceNetworkInfo, error) { + return &mocks.ServiceNetworkInfo{ + SvcNetwork: vpclattice.ServiceNetworkSummary{ + Arn: aws.String(name + "-arn"), + Id: aws.String(name + "-id"), + Name: aws.String(name), + }, + Tags: nil, + }, nil + }). + Times(1) + + status, err := m.Upsert(ctx, svc) + assert.Nil(t, err) + assert.Equal(t, "svc-arn", status.Arn) +} diff --git a/pkg/deploy/lattice/service_network_manager.go b/pkg/deploy/lattice/service_network_manager.go index 844222be..26d8d222 100644 --- a/pkg/deploy/lattice/service_network_manager.go +++ b/pkg/deploy/lattice/service_network_manager.go @@ -22,7 +22,7 @@ import ( //go:generate mockgen -destination service_network_manager_mock.go -package lattice github.com/aws/aws-application-networking-k8s/pkg/deploy/lattice ServiceNetworkManager type ServiceNetworkManager interface { - UpsertVpcAssociation(ctx context.Context, snName string, sgIds []*string) (string, error) + UpsertVpcAssociation(ctx context.Context, snName string, sgIds []*string, additionalTags services.Tags) (string, error) DeleteVpcAssociation(ctx context.Context, snName string) error CreateOrUpdate(ctx context.Context, serviceNetwork *model.ServiceNetwork) (model.ServiceNetworkStatus, error) @@ -40,7 +40,7 @@ type defaultServiceNetworkManager struct { cloud pkg_aws.Cloud } -func (m *defaultServiceNetworkManager) UpsertVpcAssociation(ctx context.Context, snName string, sgIds []*string) (string, error) { +func (m *defaultServiceNetworkManager) UpsertVpcAssociation(ctx context.Context, snName string, sgIds []*string, additionalTags services.Tags) (string, error) { sn, err := m.cloud.Lattice().FindServiceNetwork(ctx, snName) if err != nil { return "", err @@ -60,17 +60,19 @@ func (m *defaultServiceNetworkManager) UpsertVpcAssociation(ctx context.Context, return "", services.NewConflictError("snva", snName, fmt.Sprintf("Found existing vpc association not owned by controller: %s", *snva.Arn)) } - _, err = m.updateServiceNetworkVpcAssociation(ctx, &sn.SvcNetwork, sgIds, snva.Id) + _, err = m.updateServiceNetworkVpcAssociation(ctx, &sn.SvcNetwork, sgIds, snva.Id, additionalTags) if err != nil { return "", err } return *snva.Arn, nil } else { + tags := m.cloud.MergeTags(m.cloud.DefaultTags(), additionalTags) + req := vpclattice.CreateServiceNetworkVpcAssociationInput{ ServiceNetworkIdentifier: sn.SvcNetwork.Id, VpcIdentifier: &config.VpcID, SecurityGroupIds: sgIds, - Tags: m.cloud.DefaultTags(), + Tags: tags, } resp, err := m.cloud.Lattice().CreateServiceNetworkVpcAssociationWithContext(ctx, &req) if err != nil { @@ -223,13 +225,19 @@ func (m *defaultServiceNetworkManager) CreateOrUpdate(ctx context.Context, servi return model.ServiceNetworkStatus{ServiceNetworkARN: serviceNetworkArn, ServiceNetworkID: serviceNetworkId}, nil } -func (m *defaultServiceNetworkManager) updateServiceNetworkVpcAssociation(ctx context.Context, existingSN *vpclattice.ServiceNetworkSummary, sgIds []*string, existingSnvaId *string) (model.ServiceNetworkStatus, error) { +func (m *defaultServiceNetworkManager) updateServiceNetworkVpcAssociation(ctx context.Context, existingSN *vpclattice.ServiceNetworkSummary, sgIds []*string, existingSnvaId *string, additionalTags services.Tags) (model.ServiceNetworkStatus, error) { snva, err := m.cloud.Lattice().GetServiceNetworkVpcAssociationWithContext(ctx, &vpclattice.GetServiceNetworkVpcAssociationInput{ ServiceNetworkVpcAssociationIdentifier: existingSnvaId, }) if err != nil { return model.ServiceNetworkStatus{}, err } + + err = m.cloud.Tagging().UpdateTags(ctx, aws.StringValue(snva.Arn), additionalTags) + if err != nil { + return model.ServiceNetworkStatus{}, fmt.Errorf("failed to update tags for service network vpc association %s: %w", aws.StringValue(snva.Id), err) + } + sgIdsEqual := securityGroupIdsEqual(sgIds, snva.SecurityGroupIds) if sgIdsEqual { // desiredSN's security group ids are same with snva's security group ids, don't need to update diff --git a/pkg/deploy/lattice/service_network_manager_mock.go b/pkg/deploy/lattice/service_network_manager_mock.go index ff6fa00c..548b4c64 100644 --- a/pkg/deploy/lattice/service_network_manager_mock.go +++ b/pkg/deploy/lattice/service_network_manager_mock.go @@ -65,16 +65,16 @@ func (mr *MockServiceNetworkManagerMockRecorder) DeleteVpcAssociation(arg0, arg1 } // UpsertVpcAssociation mocks base method. -func (m *MockServiceNetworkManager) UpsertVpcAssociation(arg0 context.Context, arg1 string, arg2 []*string) (string, error) { +func (m *MockServiceNetworkManager) UpsertVpcAssociation(arg0 context.Context, arg1 string, arg2 []*string, arg3 map[string]*string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertVpcAssociation", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "UpsertVpcAssociation", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // UpsertVpcAssociation indicates an expected call of UpsertVpcAssociation. -func (mr *MockServiceNetworkManagerMockRecorder) UpsertVpcAssociation(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockServiceNetworkManagerMockRecorder) UpsertVpcAssociation(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertVpcAssociation", reflect.TypeOf((*MockServiceNetworkManager)(nil).UpsertVpcAssociation), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertVpcAssociation", reflect.TypeOf((*MockServiceNetworkManager)(nil).UpsertVpcAssociation), arg0, arg1, arg2, arg3) } diff --git a/pkg/deploy/lattice/service_network_manager_test.go b/pkg/deploy/lattice/service_network_manager_test.go index baddc7f4..a9455e09 100644 --- a/pkg/deploy/lattice/service_network_manager_test.go +++ b/pkg/deploy/lattice/service_network_manager_test.go @@ -25,7 +25,8 @@ func Test_CreateOrUpdateServiceNetwork_SnNotExist_NeedToAssociate(t *testing.T) defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) snCreateInput := model.ServiceNetwork{ Spec: model.ServiceNetworkSpec{ @@ -80,7 +81,8 @@ func Test_CreateOrUpdateServiceNetwork_ListFailed(t *testing.T) { defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) snCreateInput := model.ServiceNetwork{ Spec: model.ServiceNetworkSpec{ @@ -503,7 +505,8 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Update defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) mockLattice.EXPECT().FindServiceNetwork(ctx, gomock.Any()).Return( &mocks.ServiceNetworkInfo{ SvcNetwork: item, @@ -521,6 +524,9 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Update mockLattice.EXPECT().ListTagsForResourceWithContext(ctx, gomock.Any()).Return(&vpclattice.ListTagsForResourceOutput{ Tags: cloud.DefaultTags(), }, nil) + + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + mockLattice.EXPECT().CreateServiceNetworkServiceAssociationWithContext(ctx, gomock.Any()).MaxTimes(0) mockLattice.EXPECT().UpdateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Return(&vpclattice.UpdateServiceNetworkVpcAssociationOutput{ Arn: &snArn, @@ -530,7 +536,7 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Update }, nil) snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) - resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds) + resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds, nil) assert.Equal(t, err, nil) assert.Equal(t, resp, snArn) @@ -564,13 +570,15 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Securi defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) mockLattice.EXPECT().FindServiceNetwork(ctx, gomock.Any()).Return( &mocks.ServiceNetworkInfo{ SvcNetwork: item, Tags: nil, }, nil) mockLattice.EXPECT().GetServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Return(&vpclattice.GetServiceNetworkVpcAssociationOutput{ + Arn: &snvaArn, ServiceNetworkArn: &snArn, ServiceNetworkId: &snId, ServiceNetworkName: &name, @@ -582,11 +590,14 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Securi mockLattice.EXPECT().ListTagsForResourceWithContext(ctx, gomock.Any()).Return(&vpclattice.ListTagsForResourceOutput{ Tags: cloud.DefaultTags(), }, nil) + + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + mockLattice.EXPECT().CreateServiceNetworkServiceAssociationWithContext(ctx, gomock.Any()).Times(0) mockLattice.EXPECT().UpdateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Times(0) snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) - resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds) + resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds, nil) assert.Equal(t, err, nil) assert.Equal(t, resp, snArn) @@ -630,7 +641,7 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaCreateInProgr mockLattice.EXPECT().UpdateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Times(0) snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) - _, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds) + _, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds, nil) var requeueNeededAfter *lattice_runtime.RequeueNeededAfter assert.True(t, errors.As(err, &requeueNeededAfter)) @@ -663,13 +674,15 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Cannot defer c.Finish() ctx := context.TODO() mockLattice := mocks.NewMockLattice(c) - cloud := pkg_aws.NewDefaultCloud(mockLattice, TestCloudConfig) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) mockLattice.EXPECT().FindServiceNetwork(ctx, gomock.Any()).Return( &mocks.ServiceNetworkInfo{ SvcNetwork: item, Tags: nil, }, nil) mockLattice.EXPECT().GetServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Return(&vpclattice.GetServiceNetworkVpcAssociationOutput{ + Arn: &snvaArn, ServiceNetworkArn: &snArn, ServiceNetworkId: &snId, ServiceNetworkName: &name, @@ -681,12 +694,131 @@ func Test_defaultServiceNetworkManager_CreateOrUpdate_SnExists_SnvaExists_Cannot mockLattice.EXPECT().ListTagsForResourceWithContext(ctx, gomock.Any()).Return(&vpclattice.ListTagsForResourceOutput{ Tags: cloud.DefaultTags(), }, nil) + + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + mockLattice.EXPECT().CreateServiceNetworkServiceAssociationWithContext(ctx, gomock.Any()).Times(0) updateSNVAError := errors.New("InvalidParameterException SecurityGroupIds cannot be empty") mockLattice.EXPECT().UpdateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Return(&vpclattice.UpdateServiceNetworkVpcAssociationOutput{}, updateSNVAError) snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) - _, err := snMgr.UpsertVpcAssociation(ctx, name, []*string{}) + _, err := snMgr.UpsertVpcAssociation(ctx, name, []*string{}, nil) assert.Equal(t, err, updateSNVAError) } + +func Test_UpsertVpcAssociation_WithAdditionalTags_ExistingAssociation(t *testing.T) { + securityGroupIds := []*string{aws.String("sg-123456789"), aws.String("sg-987654321")} + snId := "12345678912345678912" + snArn := "12345678912345678912" + snvaArn := "12345678912345678912" + name := "test" + vpcId := config.VpcID + item := vpclattice.ServiceNetworkSummary{ + Arn: &snArn, + Id: &snId, + Name: &name, + } + + status := vpclattice.ServiceNetworkVpcAssociationStatusActive + items := vpclattice.ServiceNetworkVpcAssociationSummary{ + ServiceNetworkArn: &snArn, + ServiceNetworkId: &snId, + ServiceNetworkName: &snId, + Status: &status, + VpcId: &config.VpcID, + Arn: &snvaArn, + } + statusServiceNetworkVPCOutput := []*vpclattice.ServiceNetworkVpcAssociationSummary{&items} + + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + mockLattice.EXPECT().FindServiceNetwork(ctx, gomock.Any()).Return( + &mocks.ServiceNetworkInfo{ + SvcNetwork: item, + Tags: nil, + }, nil) + mockLattice.EXPECT().GetServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Return(&vpclattice.GetServiceNetworkVpcAssociationOutput{ + Arn: &snvaArn, + ServiceNetworkArn: &snArn, + ServiceNetworkId: &snId, + ServiceNetworkName: &name, + Status: aws.String(vpclattice.ServiceNetworkVpcAssociationStatusActive), + VpcId: &vpcId, + SecurityGroupIds: securityGroupIds, + }, nil) + mockLattice.EXPECT().ListServiceNetworkVpcAssociationsAsList(ctx, gomock.Any()).Return(statusServiceNetworkVPCOutput, nil) + mockLattice.EXPECT().ListTagsForResourceWithContext(ctx, gomock.Any()).Return(&vpclattice.ListTagsForResourceOutput{ + Tags: cloud.DefaultTags(), + }, nil) + + additionalTags := mocks.Tags{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"SNManager"}[0], + } + + mockTagging.EXPECT().UpdateTags(ctx, snvaArn, additionalTags).Return(nil) + + mockLattice.EXPECT().UpdateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).Times(0) + + snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) + resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds, additionalTags) + + assert.Equal(t, err, nil) + assert.Equal(t, resp, snvaArn) +} + +func Test_UpsertVpcAssociation_WithAdditionalTags_NoExistingAssociation(t *testing.T) { + securityGroupIds := []*string{aws.String("sg-123456789"), aws.String("sg-987654321")} + snId := "12345678912345678912" + snArn := "12345678912345678912" + name := "test" + item := vpclattice.ServiceNetworkSummary{ + Arn: &snArn, + Id: &snId, + Name: &name, + } + + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + mockLattice := mocks.NewMockLattice(c) + mockTagging := mocks.NewMockTagging(c) + cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + + mockLattice.EXPECT().FindServiceNetwork(ctx, gomock.Any()).Return( + &mocks.ServiceNetworkInfo{ + SvcNetwork: item, + Tags: nil, + }, nil) + + mockLattice.EXPECT().ListServiceNetworkVpcAssociationsAsList(ctx, gomock.Any()).Return([]*vpclattice.ServiceNetworkVpcAssociationSummary{}, nil) + + additionalTags := mocks.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"CreateTest"}[0], + } + + expectedTags := cloud.MergeTags(cloud.DefaultTags(), additionalTags) + + mockLattice.EXPECT().CreateServiceNetworkVpcAssociationWithContext(ctx, gomock.Any()).DoAndReturn( + func(ctx context.Context, input *vpclattice.CreateServiceNetworkVpcAssociationInput, opts ...interface{}) (*vpclattice.CreateServiceNetworkVpcAssociationOutput, error) { + assert.Equal(t, expectedTags, input.Tags, "Tags should include both default and additional tags") + + return &vpclattice.CreateServiceNetworkVpcAssociationOutput{ + Arn: &snArn, + Status: aws.String(vpclattice.ServiceNetworkVpcAssociationStatusActive), + }, nil + }) + + snMgr := NewDefaultServiceNetworkManager(gwlog.FallbackLogger, cloud) + resp, err := snMgr.UpsertVpcAssociation(ctx, name, securityGroupIds, additionalTags) + + assert.Equal(t, err, nil) + assert.Equal(t, resp, snArn) +} diff --git a/pkg/deploy/lattice/target_group_manager.go b/pkg/deploy/lattice/target_group_manager.go index a948d6eb..6f958f29 100644 --- a/pkg/deploy/lattice/target_group_manager.go +++ b/pkg/deploy/lattice/target_group_manager.go @@ -103,6 +103,8 @@ func (s *defaultTargetGroupManager) create(ctx context.Context, modelTg *model.T createInput.Tags[model.K8SRouteNamespaceKey] = &modelTg.Spec.K8SRouteNamespace } + createInput.Tags = s.awsCloud.MergeTags(createInput.Tags, modelTg.Spec.AdditionalTags) + lattice := s.awsCloud.Lattice() resp, err := lattice.CreateTargetGroupWithContext(ctx, &createInput) if err != nil { @@ -130,6 +132,11 @@ func (s *defaultTargetGroupManager) create(ctx context.Context, modelTg *model.T func (s *defaultTargetGroupManager) update(ctx context.Context, targetGroup *model.TargetGroup, latticeTg *vpclattice.GetTargetGroupOutput) (model.TargetGroupStatus, error) { healthCheckConfig := targetGroup.Spec.HealthCheckConfig + err := s.awsCloud.Tagging().UpdateTags(ctx, aws.StringValue(latticeTg.Arn), targetGroup.Spec.AdditionalTags) + if err != nil { + return model.TargetGroupStatus{}, fmt.Errorf("failed to update tags for target group %s: %w", aws.StringValue(latticeTg.Id), err) + } + if healthCheckConfig == nil { s.log.Debugf(ctx, "HealthCheck is empty. Resetting to default settings") healthCheckConfig = &vpclattice.HealthCheckConfig{} diff --git a/pkg/deploy/lattice/target_group_manager_test.go b/pkg/deploy/lattice/target_group_manager_test.go index 4c93844e..109a9f87 100644 --- a/pkg/deploy/lattice/target_group_manager_test.go +++ b/pkg/deploy/lattice/target_group_manager_test.go @@ -74,6 +74,12 @@ func Test_CreateTargetGroup_TGNotExist_Active(t *testing.T) { tgSpec.K8SRouteNamespace = "default" tgSpec.Type = model.TargetGroupTypeIP } + + tgSpec.AdditionalTags = map[string]*string{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"TGManager"}[0], + } + tgCreateInput := model.TargetGroup{ ResourceMeta: core.ResourceMeta{}, Spec: tgSpec, @@ -95,6 +101,9 @@ func Test_CreateTargetGroup_TGNotExist_Active(t *testing.T) { expectedTags[model.K8SRouteNamespaceKey] = &tgSpec.K8SRouteNamespace } + expectedTags["Environment"] = &[]string{"Test"}[0] + expectedTags["Project"] = &[]string{"TGManager"}[0] + mockTagging.EXPECT().FindResourcesByTags(ctx, gomock.Any(), gomock.Any()).Return(nil, nil) mockLattice.EXPECT().CreateTargetGroupWithContext(ctx, gomock.Any()).DoAndReturn( func(ctx context.Context, input *vpclattice.CreateTargetGroupInput, arg3 ...interface{}) (*vpclattice.CreateTargetGroupOutput, error) { @@ -217,6 +226,10 @@ func Test_CreateTargetGroup_TGActive_UpdateHealthCheck(t *testing.T) { Protocol: vpclattice.TargetGroupProtocolHttps, ProtocolVersion: vpclattice.TargetGroupProtocolVersionHttp1, HealthCheckConfig: tt.healthCheckConfig, + AdditionalTags: map[string]*string{ + "Environment": &[]string{"Test"}[0], + "Project": &[]string{"UpdateTest"}[0], + }, } tgCreateInput := model.TargetGroup{ @@ -239,6 +252,8 @@ func Test_CreateTargetGroup_TGActive_UpdateHealthCheck(t *testing.T) { mockTagging.EXPECT().FindResourcesByTags(ctx, gomock.Any(), gomock.Any()).Return([]string{arn}, nil) mockLattice.EXPECT().GetTargetGroupWithContext(ctx, gomock.Any()).Return(&tgOutput, nil) + mockTagging.EXPECT().UpdateTags(ctx, arn, tgSpec.AdditionalTags).Return(nil) + if tt.wantErr { mockLattice.EXPECT().UpdateTargetGroupWithContext(ctx, gomock.Any()).Return(nil, errors.New("error")) } else { @@ -307,6 +322,9 @@ func Test_CreateTargetGroup_TGActive_HealthCheckSame(t *testing.T) { mockTagging.EXPECT().FindResourcesByTags(ctx, gomock.Any(), gomock.Any()).Return([]string{"arn"}, nil) mockLattice.EXPECT().GetTargetGroupWithContext(ctx, gomock.Any()).Return(&tgOutput, nil) + + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + mockLattice.EXPECT().UpdateTargetGroupWithContext(ctx, gomock.Any()).Times(0) tgManager := NewTargetGroupManager(gwlog.FallbackLogger, cloud, nil) @@ -1218,6 +1236,8 @@ func Test_update_ServiceExportWithPolicy_Integration(t *testing.T) { mockTagging := mocks.NewMockTagging(c) cloud := pkg_aws.NewDefaultCloudWithTagging(mockLattice, mockTagging, TestCloudConfig) + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + // Since we don't have a real k8s k8sClient in this test, we'll test the case where // no k8sClient is available (which should fall back to default behavior) tgManager := NewTargetGroupManager(gwlog.FallbackLogger, cloud, nil) @@ -1435,6 +1455,8 @@ func Test_update_ServiceExportWithPolicyResolution(t *testing.T) { }, } + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + tgManager := NewTargetGroupManager(gwlog.FallbackLogger, cloud, k8sClient) if tt.expectUpdate { @@ -1617,6 +1639,8 @@ func Test_update_BackwardsCompatibility(t *testing.T) { }, } + mockTagging.EXPECT().UpdateTags(ctx, gomock.Any(), gomock.Any()).Return(nil) + tgManager := NewTargetGroupManager(gwlog.FallbackLogger, cloud, k8sClient) if tt.expectUpdate { diff --git a/pkg/gateway/model_build_access_log_subscription.go b/pkg/gateway/model_build_access_log_subscription.go index b2c6badb..bdd2803a 100644 --- a/pkg/gateway/model_build_access_log_subscription.go +++ b/pkg/gateway/model_build_access_log_subscription.go @@ -103,6 +103,9 @@ func (t *accessLogSubscriptionModelBuildTask) run(ctx context.Context) error { ALPNamespacedName: t.accessLogPolicy.GetNamespacedName(), EventType: eventType, } + + alsSpec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.accessLogPolicy) + t.accessLogSubscription = model.NewAccessLogSubscription(t.stack, alsSpec, status) err = t.stack.AddResource(t.accessLogSubscription) if err != nil { diff --git a/pkg/gateway/model_build_access_log_subscription_test.go b/pkg/gateway/model_build_access_log_subscription_test.go index 580a8deb..c953a49f 100644 --- a/pkg/gateway/model_build_access_log_subscription_test.go +++ b/pkg/gateway/model_build_access_log_subscription_test.go @@ -350,3 +350,71 @@ func Test_BuildAccessLogSubscription(t *testing.T) { assert.Equal(t, tt.expectedError, err, tt.description) } } + +func Test_BuildAccessLogSubscription_WithAndWithoutAdditionalTagsAnnotation(t *testing.T) { + ctx := context.TODO() + scheme := runtime.NewScheme() + clientgoscheme.AddToScheme(scheme) + client := testclient.NewClientBuilder().WithScheme(scheme).Build() + modelBuilder := NewAccessLogSubscriptionModelBuilder(gwlog.FallbackLogger, client) + + tests := []struct { + name string + input *anv1alpha1.AccessLogPolicy + expectedAdditionalTags map[string]*string + description string + }{ + { + name: "AccessLogPolicy with additional tags annotation", + input: &anv1alpha1.AccessLogPolicy{ + ObjectMeta: apimachineryv1.ObjectMeta{ + Namespace: namespace, + Name: name, + Annotations: map[string]string{ + "application-networking.k8s.aws/tags": "Environment=Prod,Project=AccessLogTest,Team=Platform", + }, + }, + Spec: anv1alpha1.AccessLogPolicySpec{ + DestinationArn: aws.String(s3DestinationArn), + TargetRef: &gwv1alpha2.NamespacedPolicyTargetReference{ + Kind: gatewayKind, + Name: name, + }, + }, + }, + expectedAdditionalTags: map[string]*string{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"AccessLogTest"}[0], + "Team": &[]string{"Platform"}[0], + }, + description: "should set additional tags from AccessLogPolicy annotations in access log subscription spec", + }, + { + name: "AccessLogPolicy without additional tags annotation", + input: &anv1alpha1.AccessLogPolicy{ + ObjectMeta: apimachineryv1.ObjectMeta{ + Namespace: namespace, + Name: name, + }, + Spec: anv1alpha1.AccessLogPolicySpec{ + DestinationArn: aws.String(s3DestinationArn), + TargetRef: &gwv1alpha2.NamespacedPolicyTargetReference{ + Kind: gatewayKind, + Name: name, + }, + }, + }, + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present in access log subscription spec", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, als, err := modelBuilder.Build(ctx, tt.input) + assert.NoError(t, err, tt.description) + + assert.Equal(t, tt.expectedAdditionalTags, als.Spec.AdditionalTags, tt.description) + }) + } +} diff --git a/pkg/gateway/model_build_lattice_service.go b/pkg/gateway/model_build_lattice_service.go index 3acd90e7..07a821aa 100644 --- a/pkg/gateway/model_build_lattice_service.go +++ b/pkg/gateway/model_build_lattice_service.go @@ -128,6 +128,8 @@ func (t *latticeServiceModelBuildTask) buildLatticeService(ctx context.Context) }, } + spec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.route.K8sObject()) + // Check if standalone mode is enabled for this route standalone, err := t.isStandaloneMode(ctx) if err != nil { diff --git a/pkg/gateway/model_build_lattice_service_test.go b/pkg/gateway/model_build_lattice_service_test.go index 8b1e5e91..2545802a 100644 --- a/pkg/gateway/model_build_lattice_service_test.go +++ b/pkg/gateway/model_build_lattice_service_test.go @@ -1249,3 +1249,115 @@ func Test_latticeServiceModelBuildTask_isStandaloneMode(t *testing.T) { }) } } + +func Test_LatticeServiceModelBuild_HTTPRouteWithAndWithoutAdditionalTagsAnnotation(t *testing.T) { + vpcLatticeGatewayClass := gwv1.GatewayClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gwClass", + }, + Spec: gwv1.GatewayClassSpec{ + ControllerName: config.LatticeGatewayControllerName, + }, + } + + vpcLatticeGateway := gwv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gateway1", + Namespace: "default", + }, + Spec: gwv1.GatewaySpec{ + GatewayClassName: gwv1.ObjectName(vpcLatticeGatewayClass.Name), + }, + } + + namespacePtr := func(ns string) *gwv1.Namespace { + p := gwv1.Namespace(ns) + return &p + } + + tests := []struct { + name string + route core.Route + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "HTTPRoute with additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "service-with-tags", + Namespace: "default", + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Prod,Project=ServiceTest,Team=Platform", + }, + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: gwv1.ObjectName(vpcLatticeGateway.Name), + Namespace: namespacePtr(vpcLatticeGateway.Namespace), + }, + }, + }, + }, + }), + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"ServiceTest"}[0], + "Team": &[]string{"Platform"}[0], + }, + description: "should set additional tags from HTTPRoute annotations in service spec", + }, + { + name: "HTTPRoute without additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "service-no-tags", + Namespace: "default", + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: gwv1.ObjectName(vpcLatticeGateway.Name), + Namespace: namespacePtr(vpcLatticeGateway.Namespace), + }, + }, + }, + }, + }), + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present in service spec", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + gwv1.Install(k8sSchema) + gwv1alpha2.Install(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + assert.NoError(t, k8sClient.Create(ctx, vpcLatticeGatewayClass.DeepCopy())) + assert.NoError(t, k8sClient.Create(ctx, vpcLatticeGateway.DeepCopy())) + + stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(tt.route.K8sObject()))) + + task := &latticeServiceModelBuildTask{ + log: gwlog.FallbackLogger, + route: tt.route, + stack: stack, + client: k8sClient, + } + + svc, err := task.buildLatticeService(ctx) + assert.NoError(t, err, tt.description) + + assert.Equal(t, tt.expectedAdditionalTags, svc.Spec.AdditionalTags, tt.description) + }) + } +} diff --git a/pkg/gateway/model_build_listener.go b/pkg/gateway/model_build_listener.go index 16c9cf4a..2867c827 100644 --- a/pkg/gateway/model_build_listener.go +++ b/pkg/gateway/model_build_listener.go @@ -121,6 +121,8 @@ func (t *latticeServiceModelBuildTask) buildListeners(ctx context.Context, stack DefaultAction: defaultAction, } + spec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.route.K8sObject()) + modelListener, err := model.NewListener(t.stack, spec) if err != nil { return err diff --git a/pkg/gateway/model_build_listener_test.go b/pkg/gateway/model_build_listener_test.go index 03a994b8..0736f9d1 100644 --- a/pkg/gateway/model_build_listener_test.go +++ b/pkg/gateway/model_build_listener_test.go @@ -562,3 +562,153 @@ func Test_ListenerModelBuild(t *testing.T) { }) } } + +func Test_ListenerModelBuild_HTTPRouteWithAndWithoutAdditionalTagsAnnotation(t *testing.T) { + var sectionName gwv1.SectionName = "my-gw-listener" + var serviceKind gwv1.Kind = "Service" + var backendRef = gwv1.BackendRef{ + BackendObjectReference: gwv1.BackendObjectReference{ + Name: "targetgroup1", + Kind: &serviceKind, + }, + } + + vpcLatticeGatewayClass := gwv1.GatewayClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gwClass", + }, + Spec: gwv1.GatewayClassSpec{ + ControllerName: config.LatticeGatewayControllerName, + }, + } + + vpcLatticeGateway := gwv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gateway1", + Namespace: "default", + }, + Spec: gwv1.GatewaySpec{ + GatewayClassName: gwv1.ObjectName(vpcLatticeGatewayClass.Name), + Listeners: []gwv1.Listener{ + { + Port: 80, + Protocol: "HTTP", + Name: sectionName, + }, + }, + }, + } + + tests := []struct { + name string + route core.Route + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "HTTPRoute with additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "route-with-tags", + Namespace: "default", + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Prod,Project=ListenerTest,Team=Platform", + }, + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: gwv1.ObjectName(vpcLatticeGateway.Name), + SectionName: §ionName, + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: backendRef, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"ListenerTest"}[0], + "Team": &[]string{"Platform"}[0], + }, + description: "should set additional tags from HTTPRoute annotations in listener spec", + }, + { + name: "HTTPRoute without additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "route-no-tags", + Namespace: "default", + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: gwv1.ObjectName(vpcLatticeGateway.Name), + SectionName: §ionName, + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: backendRef, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present in listener spec", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + gwv1.Install(k8sSchema) + anv1alpha1.Install(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + assert.NoError(t, k8sClient.Create(ctx, vpcLatticeGatewayClass.DeepCopy())) + assert.NoError(t, k8sClient.Create(ctx, vpcLatticeGateway.DeepCopy())) + + mockBrTgBuilder := NewMockBackendRefTargetGroupModelBuilder(c) + stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(tt.route.K8sObject()))) + + task := &latticeServiceModelBuildTask{ + log: gwlog.FallbackLogger, + route: tt.route, + client: k8sClient, + stack: stack, + brTgBuilder: mockBrTgBuilder, + } + + err := task.buildListeners(ctx, "svc-id") + assert.NoError(t, err, tt.description) + + var resListener []*model.Listener + stack.ListResources(&resListener) + assert.Equal(t, 1, len(resListener), "Expected exactly one listener") + + actualListener := resListener[0] + assert.Equal(t, tt.expectedAdditionalTags, actualListener.Spec.AdditionalTags, tt.description) + }) + } +} diff --git a/pkg/gateway/model_build_rule.go b/pkg/gateway/model_build_rule.go index 1241d2a2..6d7b9415 100644 --- a/pkg/gateway/model_build_rule.go +++ b/pkg/gateway/model_build_rule.go @@ -10,6 +10,7 @@ import ( "k8s.io/apimachinery/pkg/types" anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/model/core" "github.com/aws/aws-application-networking-k8s/pkg/utils" @@ -130,6 +131,8 @@ func (t *latticeServiceModelBuildTask) buildRules(ctx context.Context, stackList TargetGroups: ruleTgList, } + ruleSpec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.route.K8sObject()) + // don't bother adding rules on delete, these will be removed automatically with the owning route/lattice service // target groups will still be present and removed as needed if t.route.DeletionTimestamp().IsZero() { diff --git a/pkg/gateway/model_build_rule_test.go b/pkg/gateway/model_build_rule_test.go index 3f2d1fad..f6a56f49 100644 --- a/pkg/gateway/model_build_rule_test.go +++ b/pkg/gateway/model_build_rule_test.go @@ -1739,6 +1739,136 @@ func Test_RuleModelBuild(t *testing.T) { } } +func Test_RuleModelBuild_WithAndWithoutAdditionalTagsAnnotation(t *testing.T) { + var httpSectionName gwv1.SectionName = "http" + var serviceKind gwv1.Kind = "Service" + var weight1 = int32(10) + + var backendRef1 = gwv1.BackendRef{ + BackendObjectReference: gwv1.BackendObjectReference{ + Name: "targetgroup1", + Kind: &serviceKind, + }, + Weight: &weight1, + } + + tests := []struct { + name string + route core.Route + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "HTTPRoute with additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: apimachineryv1.ObjectMeta{ + Name: "route-with-tags", + Namespace: "default", + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Prod,Project=RuleTest,Team=Backend", + }, + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: "gw1", + SectionName: &httpSectionName, + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: backendRef1, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"RuleTest"}[0], + "Team": &[]string{"Backend"}[0], + }, + description: "should set additional tags from HTTPRoute annotations in rule spec", + }, + { + name: "HTTPRoute without additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: apimachineryv1.ObjectMeta{ + Name: "route-no-tags", + Namespace: "default", + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: "gw1", + SectionName: &httpSectionName, + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: backendRef1, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present in rule spec", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := gomock.NewController(t) + defer c.Finish() + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + k8sSchema.AddKnownTypes(anv1alpha1.SchemeGroupVersion, &anv1alpha1.ServiceImport{}) + clientgoscheme.AddToScheme(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + svc := corev1.Service{ + ObjectMeta: apimachineryv1.ObjectMeta{ + Name: string(backendRef1.Name), + Namespace: "default", + }, + Status: corev1.ServiceStatus{}, + } + assert.NoError(t, k8sClient.Create(ctx, svc.DeepCopy())) + stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(tt.route.K8sObject()))) + + task := &latticeServiceModelBuildTask{ + log: gwlog.FallbackLogger, + route: tt.route, + stack: stack, + client: k8sClient, + brTgBuilder: &dummyTgBuilder{}, + } + + err := task.buildRules(ctx, "listener-id") + assert.NoError(t, err, tt.description) + + var resRules []*model.Rule + stack.ListResources(&resRules) + assert.Equal(t, 1, len(resRules), "Expected exactly one rule") + + actualRule := resRules[0] + assert.Equal(t, tt.expectedAdditionalTags, actualRule.Spec.AdditionalTags, tt.description) + }) + } +} + func validateEqual(t *testing.T, expectedRules []model.RuleSpec, actualRules []*model.Rule) { assert.Equal(t, len(expectedRules), len(actualRules)) assert.Equal(t, len(expectedRules), len(actualRules)) diff --git a/pkg/gateway/model_build_targetgroup.go b/pkg/gateway/model_build_targetgroup.go index 35a12f4d..e872c347 100644 --- a/pkg/gateway/model_build_targetgroup.go +++ b/pkg/gateway/model_build_targetgroup.go @@ -225,6 +225,8 @@ func (t *svcExportTargetGroupModelBuildTask) buildTargetGroupForExportedPort(ctx spec.K8SServiceNamespace = t.serviceExport.Namespace spec.K8SProtocolVersion = protocolVersion + spec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.serviceExport) + stackTG, err := model.NewTargetGroup(t.stack, spec) if err != nil { return nil, err @@ -302,6 +304,8 @@ func (t *svcExportTargetGroupModelBuildTask) buildTargetGroup(ctx context.Contex spec.K8SServiceNamespace = t.serviceExport.Namespace spec.K8SProtocolVersion = protocolVersion + spec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.serviceExport) + stackTG, err := model.NewTargetGroup(t.stack, spec) if err != nil { return nil, err @@ -479,6 +483,8 @@ func (t *backendRefTargetGroupModelBuildTask) buildTargetGroupSpec(ctx context.C spec.K8SRouteNamespace = t.route.Namespace() spec.K8SProtocolVersion = protocolVersion + spec.AdditionalTags = k8s.GetAdditionalTagsFromAnnotations(ctx, t.route.K8sObject()) + return spec, nil } diff --git a/pkg/gateway/model_build_targetgroup_test.go b/pkg/gateway/model_build_targetgroup_test.go index 5eefc356..b5db1955 100644 --- a/pkg/gateway/model_build_targetgroup_test.go +++ b/pkg/gateway/model_build_targetgroup_test.go @@ -924,3 +924,353 @@ func Test_buildTargetGroupIpAddressType(t *testing.T) { }) } } + +func Test_TGModelByServiceExportBuild_AdditionalTags(t *testing.T) { + config.VpcID = "vpc-id" + config.ClusterName = "cluster-name" + + tests := []struct { + name string + svcExport *anv1alpha1.ServiceExport + svc *corev1.Service + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "ServiceExport with additional tags annotation", + svcExport: &anv1alpha1.ServiceExport{ + ObjectMeta: metav1.ObjectMeta{ + Name: "export-with-tags", + Namespace: "ns1", + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Dev,Project=MyApp,Team=Platform", + }, + }, + Spec: anv1alpha1.ServiceExportSpec{ + ExportedPorts: []anv1alpha1.ExportedPort{ + { + Port: 80, + RouteType: "HTTP", + }, + }, + }, + }, + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "export-with-tags", + Namespace: "ns1", + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + {Port: 80}, + }, + IPFamilies: []corev1.IPFamily{ + corev1.IPv4Protocol, + }, + }, + }, + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Dev"}[0], + "Project": &[]string{"MyApp"}[0], + "Team": &[]string{"Platform"}[0], + }, + description: "should set additional tags from ServiceExport annotations", + }, + { + name: "ServiceExport without additional tags annotation", + svcExport: &anv1alpha1.ServiceExport{ + ObjectMeta: metav1.ObjectMeta{ + Name: "export-no-tags", + Namespace: "ns1", + }, + Spec: anv1alpha1.ServiceExportSpec{ + ExportedPorts: []anv1alpha1.ExportedPort{ + { + Port: 80, + RouteType: "HTTP", + }, + }, + }, + }, + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "export-no-tags", + Namespace: "ns1", + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + {Port: 80}, + }, + IPFamilies: []corev1.IPFamily{ + corev1.IPv4Protocol, + }, + }, + }, + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + anv1alpha1.Install(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + assert.NoError(t, k8sClient.Create(ctx, tt.svc.DeepCopy())) + + builder := NewSvcExportTargetGroupBuilder(gwlog.FallbackLogger, k8sClient) + + stack, err := builder.Build(ctx, tt.svcExport) + assert.Nil(t, err, tt.description) + + var resTargetGroups []*model.TargetGroup + err = stack.ListResources(&resTargetGroups) + assert.Nil(t, err) + assert.Equal(t, 1, len(resTargetGroups)) + + stackTg := resTargetGroups[0] + assert.Equal(t, tt.expectedAdditionalTags, stackTg.Spec.AdditionalTags, tt.description) + }) + } +} + +func Test_TGModelByServiceExportBuildLegacy_AdditionalTags(t *testing.T) { + config.VpcID = "vpc-id" + config.ClusterName = "cluster-name" + + tests := []struct { + name string + svcExport *anv1alpha1.ServiceExport + svc *corev1.Service + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "ServiceExport with additional tags annotation", + svcExport: &anv1alpha1.ServiceExport{ + ObjectMeta: metav1.ObjectMeta{ + Name: "legacy-with-tags", + Namespace: "ns1", + Annotations: map[string]string{ + "application-networking.k8s.aws/port": "80", + k8s.TagsAnnotationKey: "Environment=Legacy,Project=TestApp", + }, + }, + }, + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "legacy-with-tags", + Namespace: "ns1", + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + {Port: 80}, + }, + IPFamilies: []corev1.IPFamily{ + corev1.IPv4Protocol, + }, + }, + }, + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Legacy"}[0], + "Project": &[]string{"TestApp"}[0], + }, + description: "should set additional tags from ServiceExport annotations in legacy mode", + }, + { + name: "ServiceExport without additional tags annotation", + svcExport: &anv1alpha1.ServiceExport{ + ObjectMeta: metav1.ObjectMeta{ + Name: "legacy-no-tags", + Namespace: "ns1", + Annotations: map[string]string{ + "application-networking.k8s.aws/port": "80", + }, + }, + }, + svc: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "legacy-no-tags", + Namespace: "ns1", + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + {Port: 80}, + }, + IPFamilies: []corev1.IPFamily{ + corev1.IPv4Protocol, + }, + }, + }, + expectedAdditionalTags: nil, + description: "should have nil additional tags in legacy mode when no tags annotation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + anv1alpha1.Install(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + assert.NoError(t, k8sClient.Create(ctx, tt.svc.DeepCopy())) + + builder := NewSvcExportTargetGroupBuilder(gwlog.FallbackLogger, k8sClient) + + stack, err := builder.Build(ctx, tt.svcExport) + assert.Nil(t, err, tt.description) + + var resTargetGroups []*model.TargetGroup + err = stack.ListResources(&resTargetGroups) + assert.Nil(t, err) + assert.Equal(t, 1, len(resTargetGroups)) + + stackTg := resTargetGroups[0] + assert.Equal(t, tt.expectedAdditionalTags, stackTg.Spec.AdditionalTags, tt.description) + }) + } +} + +func Test_TGModelByHTTPRouteBuild_AdditionalTags(t *testing.T) { + config.VpcID = "vpc-id" + config.ClusterName = "cluster-name" + + namespacePtr := func(ns string) *gwv1.Namespace { + p := gwv1.Namespace(ns) + return &p + } + + kindPtr := func(k string) *gwv1.Kind { + p := gwv1.Kind(k) + return &p + } + + tests := []struct { + name string + route core.Route + expectedAdditionalTags k8s.Tags + description string + }{ + { + name: "HTTPRoute with additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "route-with-tags", + Namespace: "ns1", + Annotations: map[string]string{ + k8s.TagsAnnotationKey: "Environment=Prod,Project=TestApp,Team=Backend", + }, + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: "gateway1", + Namespace: namespacePtr("ns1"), + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: gwv1.BackendRef{ + BackendObjectReference: gwv1.BackendObjectReference{ + Name: "service1", + Namespace: namespacePtr("ns1"), + Kind: kindPtr("Service"), + }, + }, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: k8s.Tags{ + "Environment": &[]string{"Prod"}[0], + "Project": &[]string{"TestApp"}[0], + "Team": &[]string{"Backend"}[0], + }, + description: "should set additional tags from HTTPRoute annotations", + }, + { + name: "HTTPRoute without additional tags annotation", + route: core.NewHTTPRoute(gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "route-no-tags", + Namespace: "ns1", + }, + Spec: gwv1.HTTPRouteSpec{ + CommonRouteSpec: gwv1.CommonRouteSpec{ + ParentRefs: []gwv1.ParentReference{ + { + Name: "gateway1", + Namespace: namespacePtr("ns1"), + }, + }, + }, + Rules: []gwv1.HTTPRouteRule{ + { + BackendRefs: []gwv1.HTTPBackendRef{ + { + BackendRef: gwv1.BackendRef{ + BackendObjectReference: gwv1.BackendObjectReference{ + Name: "service2", + Namespace: namespacePtr("ns1"), + Kind: kindPtr("Service"), + }, + }, + }, + }, + }, + }, + }, + }), + expectedAdditionalTags: nil, + description: "should have nil additional tags when no annotation present", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + + k8sSchema := runtime.NewScheme() + clientgoscheme.AddToScheme(k8sSchema) + anv1alpha1.Install(k8sSchema) + gwv1.Install(k8sSchema) + k8sClient := testclient.NewClientBuilder().WithScheme(k8sSchema).Build() + + stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(tt.route.K8sObject()))) + + rule := tt.route.Spec().Rules()[0] + httpBackendRef := rule.BackendRefs()[0] + + svc := corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: string(httpBackendRef.Name()), + Namespace: string(*httpBackendRef.Namespace()), + }, + Spec: corev1.ServiceSpec{ + IPFamilies: []corev1.IPFamily{corev1.IPv4Protocol}, + }, + } + assert.NoError(t, k8sClient.Create(ctx, svc.DeepCopy())) + + builder := NewBackendRefTargetGroupBuilder(gwlog.FallbackLogger, k8sClient) + + _, stackTg, err := builder.Build(ctx, tt.route, httpBackendRef, stack) + assert.Nil(t, err, tt.description) + + assert.Equal(t, tt.expectedAdditionalTags, stackTg.Spec.AdditionalTags, tt.description) + }) + } +} diff --git a/pkg/k8s/utils.go b/pkg/k8s/utils.go index 5fd1b5f4..5ba0a93e 100644 --- a/pkg/k8s/utils.go +++ b/pkg/k8s/utils.go @@ -3,10 +3,12 @@ package k8s import ( "context" "fmt" + "regexp" "strings" "github.com/aws/aws-application-networking-k8s/pkg/config" "github.com/aws/aws-application-networking-k8s/pkg/model/core" + "github.com/aws/aws-sdk-go/aws" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" @@ -23,8 +25,25 @@ const ( // Standalone annotation controls whether VPC Lattice services are created // without automatic service network association StandaloneAnnotation = AnnotationPrefix + "standalone" + + // AWS reserved prefix that cannot be used in tag keys + AWSReservedPrefix = "aws:" + + // Additional tags + TagsAnnotationKey = AnnotationPrefix + "tags" + + // AWS tag validation limits + maxTagKeyLength = 128 + maxTagValueLength = 256 + maxTagCount = 50 ) +var ( + tagPattern = regexp.MustCompile(`^([\p{L}\p{Z}\p{N}_.:\/=+\-@]*)$`) +) + +type Tags = map[string]*string + // NamespacedName returns the namespaced name for k8s objects func NamespacedName(obj client.Object) types.NamespacedName { return types.NamespacedName{ @@ -313,3 +332,94 @@ func GetStandaloneModeForRoute(ctx context.Context, c client.Client, route core. return false, nil } + +func GetAdditionalTagsFromAnnotations(ctx context.Context, obj client.Object) Tags { + if obj == nil || obj.GetAnnotations() == nil { + return nil + } + + annotations := obj.GetAnnotations() + tagValue, exists := annotations[TagsAnnotationKey] + if !exists || tagValue == "" { + return nil + } + + additionalTags := ParseTagsFromAnnotation(tagValue) + filteredTags := GetNonAWSManagedTags(additionalTags) + + if len(filteredTags) == 0 { + return nil + } + return filteredTags +} + +func ParseTagsFromAnnotation(annotationValue string) Tags { + tags := make(Tags) + if annotationValue == "" { + return tags + } + + for pair := range strings.SplitSeq(annotationValue, ",") { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" || value == "" || + len(key) > maxTagKeyLength || + len(value) > maxTagValueLength || + !tagPattern.MatchString(key) || + !tagPattern.MatchString(value) { + continue + } + + if _, exists := tags[key]; exists { + continue + } + + if len(tags) >= maxTagCount { + break + } + + tags[key] = aws.String(value) + } + return tags +} + +func CalculateTagDifference(currentTags Tags, desiredTags Tags) (tagsToAdd Tags, tagsToRemove []*string) { + tagsToAdd = make(Tags) + tagsToRemove = make([]*string, 0) + + for key := range currentTags { + if _, exists := desiredTags[key]; !exists { + tagsToRemove = append(tagsToRemove, aws.String(key)) + } + } + + for key, value := range desiredTags { + if currentValue, exists := currentTags[key]; !exists || *currentValue != *value { + tagsToAdd[key] = value + } + } + + return tagsToAdd, tagsToRemove +} + +func GetNonAWSManagedTags(tags Tags) Tags { + nonAWSManagedTags := make(Tags) + for key, value := range tags { + if strings.HasPrefix(key, AnnotationPrefix) || strings.HasPrefix(strings.ToLower(key), AWSReservedPrefix) { + continue + } + nonAWSManagedTags[key] = value + } + return nonAWSManagedTags +} diff --git a/pkg/k8s/utils_test.go b/pkg/k8s/utils_test.go index 482097c3..e5d13497 100644 --- a/pkg/k8s/utils_test.go +++ b/pkg/k8s/utils_test.go @@ -2,9 +2,11 @@ package k8s import ( "context" + "strings" "testing" "github.com/aws/aws-application-networking-k8s/pkg/model/core" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -963,3 +965,493 @@ func TestGetStandaloneModeForRouteWithValidation(t *testing.T) { }) } } + +func TestParseTagsFromAnnotation(t *testing.T) { + tests := []struct { + name string + annotation string + expected Tags + description string + }{ + { + name: "empty annotation", + annotation: "", + expected: Tags{}, + description: "should return empty map for empty annotation", + }, + { + name: "multiple tags", + annotation: "Environment=Dev,Project=MyApp,Team=Platform", + expected: Tags{"Environment": aws.String("Dev"), "Project": aws.String("MyApp"), "Team": aws.String("Platform")}, + description: "should parse multiple tags correctly", + }, + { + name: "tags with spaces", + annotation: "Environment = Dev , Project = MyApp", + expected: Tags{"Environment": aws.String("Dev"), "Project": aws.String("MyApp")}, + description: "should handle spaces around keys and values", + }, + { + name: "invalid tag format", + annotation: "Environment,Project=MyApp", + expected: Tags{"Project": aws.String("MyApp")}, + description: "should skip invalid tag format and parse valid ones", + }, + { + name: "empty key or value", + annotation: "=Dev,Project=,Team=Platform", + expected: Tags{"Team": aws.String("Platform")}, + description: "should skip tags with empty keys or values", + }, + { + name: "trailing comma", + annotation: "Environment=Dev,Project=MyApp,", + expected: Tags{"Environment": aws.String("Dev"), "Project": aws.String("MyApp")}, + description: "should handle trailing comma gracefully", + }, + { + name: "whitespace only pairs", + annotation: "Environment=Dev, ,Project=MyApp", + expected: Tags{"Environment": aws.String("Dev"), "Project": aws.String("MyApp")}, + description: "should skip whitespace-only pairs", + }, + { + name: "tag key too long", + annotation: strings.Repeat("a", 129) + "=value,Project=MyApp", + expected: Tags{"Project": aws.String("MyApp")}, + description: "should skip tags with keys longer than 128 characters", + }, + { + name: "tag value too long", + annotation: "Environment=Dev,key=" + strings.Repeat("v", 257), + expected: Tags{"Environment": aws.String("Dev")}, + description: "should skip tags with values longer than 256 characters", + }, + { + name: "duplicate keys", + annotation: "Environment=Dev,Project=App1,Environment=Prod,Team=Platform", + expected: Tags{"Environment": aws.String("Dev"), "Project": aws.String("App1"), "Team": aws.String("Platform")}, + description: "should keep first occurrence of duplicate keys", + }, + { + name: "invalid characters in key", + annotation: "Env#ironment=Dev,Project=MyApp", + expected: Tags{"Project": aws.String("MyApp")}, + description: "should skip tags with invalid characters in key", + }, + { + name: "invalid characters in value", + annotation: "Environment=Dev,Project=My$App", + expected: Tags{"Environment": aws.String("Dev")}, + description: "should skip tags with invalid characters in value", + }, + { + name: "more than 50 tags should keep 50 tags", + annotation: func() string { + var pairs []string + for i := 1; i <= 60; i++ { + pairs = append(pairs, "key"+string(rune(i/10+48))+string(rune(i%10+48))+"=value"+string(rune(i/10+48))+string(rune(i%10+48))) + } + return strings.Join(pairs, ",") + }(), + expected: func() Tags { + tags := make(Tags) + for i := 1; i <= 50; i++ { + key := "key" + string(rune(i/10+48)) + string(rune(i%10+48)) + value := "value" + string(rune(i/10+48)) + string(rune(i%10+48)) + tags[key] = aws.String(value) + } + return tags + }(), + description: "should limit to 50 tags maximum", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseTagsFromAnnotation(tt.annotation) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} + +func TestGetNonAWSManagedTags(t *testing.T) { + tests := []struct { + name string + tags Tags + expected Tags + description string + }{ + { + name: "nil tags", + tags: nil, + expected: Tags{}, + description: "should return empty map for nil input", + }, + { + name: "empty tags", + tags: Tags{}, + expected: Tags{}, + description: "should return empty map for empty input", + }, + { + name: "only additional tags", + tags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expected: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + description: "should return all tags when no AWS managed tags present", + }, + { + name: "only AWS managed tags", + tags: Tags{ + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + "application-networking.k8s.aws/RouteType": aws.String("http"), + "application-networking.k8s.aws/RouteName": aws.String("test-route"), + }, + expected: Tags{}, + description: "should return empty map when only AWS managed tags present", + }, + { + name: "mixed tags", + tags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "application-networking.k8s.aws/ManagedBy": aws.String("123456789/cluster/vpc-123"), + "application-networking.k8s.aws/RouteType": aws.String("http"), + "Team": aws.String("Platform"), + }, + expected: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "Team": aws.String("Platform"), + }, + description: "should filter out AWS managed tags and keep additional tags", + }, + { + name: "AWS reserved tags lowercase", + tags: Tags{ + "aws:cloudformation:stack-name": aws.String("my-stack"), + "aws:region": aws.String("us-west-2"), + "Environment": aws.String("Dev"), + }, + expected: Tags{ + "Environment": aws.String("Dev"), + }, + description: "should filter out lowercase aws: prefixed tags", + }, + { + name: "AWS reserved tags uppercase", + tags: Tags{ + "AWS:CloudFormation:StackName": aws.String("my-stack"), + "AWS:Region": aws.String("us-west-2"), + "Environment": aws.String("Dev"), + }, + expected: Tags{ + "Environment": aws.String("Dev"), + }, + description: "should filter out uppercase AWS: prefixed tags", + }, + { + name: "AWS reserved tags mixed case", + tags: Tags{ + "Aws:Service": aws.String("ec2"), + "aWs:Resource": aws.String("instance"), + "Environment": aws.String("Dev"), + }, + expected: Tags{ + "Environment": aws.String("Dev"), + }, + description: "should filter out mixed case aws: prefixed tags", + }, + { + name: "tags that start with aws but not aws:", + tags: Tags{ + "awesome": aws.String("value"), + "aws-region": aws.String("us-west-2"), + "Environment": aws.String("Dev"), + }, + expected: Tags{ + "awesome": aws.String("value"), + "aws-region": aws.String("us-west-2"), + "Environment": aws.String("Dev"), + }, + description: "should keep tags that start with 'aws' but not 'aws:'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetNonAWSManagedTags(tt.tags) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} + +func TestCalculateTagDifference(t *testing.T) { + tests := []struct { + name string + currentTags Tags + desiredTags Tags + expectedToAdd Tags + expectedToRemove []string + description string + }{ + { + name: "both nil", + currentTags: nil, + desiredTags: nil, + expectedToAdd: Tags{}, + expectedToRemove: []string{}, + description: "should handle nil inputs gracefully", + }, + { + name: "current nil, desired has tags", + currentTags: nil, + desiredTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expectedToAdd: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expectedToRemove: []string{}, + description: "should add all desired tags when current is nil", + }, + { + name: "current has tags, desired nil", + currentTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + desiredTags: nil, + expectedToAdd: Tags{}, + expectedToRemove: []string{"Environment", "Project"}, + description: "should remove all current tags when desired is nil", + }, + { + name: "no changes needed", + currentTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + desiredTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + expectedToAdd: Tags{}, + expectedToRemove: []string{}, + description: "should return empty when tags are identical", + }, + { + name: "add new tags", + currentTags: Tags{ + "Environment": aws.String("Dev"), + }, + desiredTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "Team": aws.String("Platform"), + }, + expectedToAdd: Tags{ + "Project": aws.String("MyApp"), + "Team": aws.String("Platform"), + }, + expectedToRemove: []string{}, + description: "should add new tags", + }, + { + name: "remove tags", + currentTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + "Team": aws.String("Platform"), + }, + desiredTags: Tags{ + "Environment": aws.String("Dev"), + }, + expectedToAdd: Tags{}, + expectedToRemove: []string{"Project", "Team"}, + description: "should remove unwanted tags", + }, + { + name: "update tag values", + currentTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("OldApp"), + }, + desiredTags: Tags{ + "Environment": aws.String("Prod"), + "Project": aws.String("NewApp"), + }, + expectedToAdd: Tags{ + "Environment": aws.String("Prod"), + "Project": aws.String("NewApp"), + }, + expectedToRemove: []string{}, + description: "should update changed tag values", + }, + { + name: "mixed operations", + currentTags: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("OldApp"), + "OldTag": aws.String("OldValue"), + }, + desiredTags: Tags{ + "Environment": aws.String("Prod"), + "Project": aws.String("OldApp"), + "NewTag": aws.String("NewValue"), + }, + expectedToAdd: Tags{ + "Environment": aws.String("Prod"), + "NewTag": aws.String("NewValue"), + }, + expectedToRemove: []string{"OldTag"}, + description: "should handle add, update, and remove operations", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tagsToAdd, tagsToRemove := CalculateTagDifference(tt.currentTags, tt.desiredTags) + + assert.Equal(t, tt.expectedToAdd, tagsToAdd, tt.description) + + removeStrings := make([]string, len(tagsToRemove)) + for i, tag := range tagsToRemove { + if tag != nil { + removeStrings[i] = *tag + } + } + assert.ElementsMatch(t, tt.expectedToRemove, removeStrings, tt.description) + }) + } +} + +func TestGetAdditionalTagsFromAnnotations(t *testing.T) { + tests := []struct { + name string + obj client.Object + expected Tags + description string + }{ + { + name: "nil object", + obj: nil, + expected: nil, + description: "should return nil for nil object", + }, + { + name: "object with no annotations", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + }, + }, + expected: nil, + description: "should return nil when annotations map is nil", + }, + { + name: "object with empty annotations", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{}, + }, + }, + expected: nil, + description: "should return nil when annotations map is empty", + }, + { + name: "object without tags annotation", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{ + "other-annotation": "value", + }, + }, + }, + expected: nil, + description: "should return nil when tags annotation is not present", + }, + { + name: "object with empty tags annotation", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{ + TagsAnnotationKey: "", + }, + }, + }, + expected: nil, + description: "should return nil when tags annotation is empty", + }, + { + name: "object with valid tags annotation", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{ + TagsAnnotationKey: "Environment=Dev,Project=MyApp", + }, + }, + }, + expected: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + description: "should parse and return valid tags", + }, + { + name: "object with tags containing AWS managed tags", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{ + TagsAnnotationKey: "Environment=Dev,application-networking.k8s.aws/ManagedBy=test-override,Project=MyApp", + }, + }, + }, + expected: Tags{ + "Environment": aws.String("Dev"), + "Project": aws.String("MyApp"), + }, + description: "should filter out AWS managed tags and return only additional tags", + }, + { + name: "object with only AWS managed tags", + obj: &gwv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: "default", + Annotations: map[string]string{ + TagsAnnotationKey: "application-networking.k8s.aws/ManagedBy=test-override,application-networking.k8s.aws/RouteType=http", + }, + }, + }, + expected: nil, + description: "should return nil when only AWS managed tags are present", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAdditionalTagsFromAnnotations(context.Background(), tt.obj) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/pkg/model/lattice/accesslogsubscription.go b/pkg/model/lattice/accesslogsubscription.go index 8c2c32f1..8c34d2ba 100644 --- a/pkg/model/lattice/accesslogsubscription.go +++ b/pkg/model/lattice/accesslogsubscription.go @@ -6,6 +6,7 @@ import ( "k8s.io/apimachinery/pkg/types" "github.com/aws/aws-application-networking-k8s/pkg/aws" + "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/model/core" ) @@ -30,6 +31,7 @@ type AccessLogSubscriptionSpec struct { DestinationArn string ALPNamespacedName types.NamespacedName EventType core.EventType + AdditionalTags services.Tags `json:"additionaltags,omitempty"` } type AccessLogSubscriptionStatus struct { diff --git a/pkg/model/lattice/listener.go b/pkg/model/lattice/listener.go index f8b23d3c..2d5163fe 100644 --- a/pkg/model/lattice/listener.go +++ b/pkg/model/lattice/listener.go @@ -5,6 +5,7 @@ import ( "github.com/aws/aws-sdk-go/service/vpclattice" + "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/model/core" ) @@ -27,6 +28,7 @@ type ListenerSpec struct { Port int64 `json:"port"` Protocol string `json:"protocol"` DefaultAction *DefaultAction `json:"defaultaction"` + AdditionalTags services.Tags `json:"additionaltags,omitempty"` } type DefaultAction struct { diff --git a/pkg/model/lattice/rule.go b/pkg/model/lattice/rule.go index e37aad37..00aaf9f4 100644 --- a/pkg/model/lattice/rule.go +++ b/pkg/model/lattice/rule.go @@ -5,6 +5,7 @@ import ( "github.com/aws/aws-sdk-go/service/vpclattice" + "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/model/core" ) @@ -30,6 +31,7 @@ type RuleSpec struct { Priority int64 `json:"priority"` Action RuleAction `json:"action"` CreateTime time.Time `json:"createtime"` + AdditionalTags services.Tags `json:"additionaltags,omitempty"` } type RuleAction struct { diff --git a/pkg/model/lattice/service.go b/pkg/model/lattice/service.go index 75a47f6a..961b0658 100644 --- a/pkg/model/lattice/service.go +++ b/pkg/model/lattice/service.go @@ -15,9 +15,10 @@ type Service struct { type ServiceSpec struct { ServiceTagFields - ServiceNetworkNames []string `json:"servicenetworkhnames"` - CustomerDomainName string `json:"customerdomainname"` - CustomerCertARN string `json:"customercertarn"` + ServiceNetworkNames []string `json:"servicenetworkhnames"` + CustomerDomainName string `json:"customerdomainname"` + CustomerCertARN string `json:"customercertarn"` + AdditionalTags services.Tags `json:"additionaltags,omitempty"` } type ServiceStatus struct { diff --git a/pkg/model/lattice/targetgroup.go b/pkg/model/lattice/targetgroup.go index 9dcfa161..ed025f47 100644 --- a/pkg/model/lattice/targetgroup.go +++ b/pkg/model/lattice/targetgroup.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go/service/vpclattice" "github.com/aws/aws-application-networking-k8s/pkg/aws" + "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/model/core" "github.com/aws/aws-application-networking-k8s/pkg/utils" ) @@ -44,6 +45,7 @@ type TargetGroupSpec struct { IpAddressType string `json:"ipaddresstype"` HealthCheckConfig *vpclattice.HealthCheckConfig `json:"healthcheckconfig"` TargetGroupTagFields + AdditionalTags services.Tags `json:"additionaltags,omitempty"` } type TargetGroupTagFields struct { K8SClusterName string `json:"k8sclustername"` diff --git a/test/pkg/test/framework.go b/test/pkg/test/framework.go index 69b3af3d..a940e6d2 100644 --- a/test/pkg/test/framework.go +++ b/test/pkg/test/framework.go @@ -141,7 +141,7 @@ func NewFramework(ctx context.Context, log gwlog.Logger, testNamespace string) * namespace: testNamespace, controllerRuntimeConfig: controllerRuntimeConfig, } - framework.Cloud = anaws.NewDefaultCloud(framework.LatticeClient, cloudConfig) + framework.Cloud = anaws.NewDefaultCloudWithTagging(framework.LatticeClient, framework.TaggingClient, cloudConfig) framework.DefaultTags = framework.Cloud.DefaultTags() SetDefaultEventuallyTimeout(5 * time.Minute) SetDefaultEventuallyPollingInterval(10 * time.Second) diff --git a/test/suites/integration/access_log_policy_test.go b/test/suites/integration/access_log_policy_test.go index 3a6ccf98..7f9da1be 100644 --- a/test/suites/integration/access_log_policy_test.go +++ b/test/suites/integration/access_log_policy_test.go @@ -29,6 +29,7 @@ import ( gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/config" "github.com/aws/aws-application-networking-k8s/pkg/model/core" @@ -1275,6 +1276,91 @@ var _ = Describe("Access Log Policy", Ordered, func() { }).Should(Succeed()) }) + It("AccessLogPolicy with additional tags creates Access Log Subscription with additional tags", func() { + accessLogPolicy := &anv1alpha1.AccessLogPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: k8sResourceName + "-additional-tags", + Namespace: k8snamespace, + Annotations: map[string]string{ + "application-networking.k8s.aws/tags": "Environment=Dev,Project=MyApp,Team=Platform,CostCenter=12345", + }, + }, + Spec: anv1alpha1.AccessLogPolicySpec{ + DestinationArn: aws.String(bucketArn), + TargetRef: &gwv1alpha2.NamespacedPolicyTargetReference{ + Group: gwv1.GroupName, + Kind: "Gateway", + Name: gwv1alpha2.ObjectName(testGateway.Name), + Namespace: (*gwv1alpha2.Namespace)(aws.String(k8snamespace)), + }, + }, + } + testFramework.ExpectCreated(ctx, accessLogPolicy) + + Eventually(func(g Gomega) { + listALSOutput, err := testFramework.LatticeClient.ListAccessLogSubscriptions(&vpclattice.ListAccessLogSubscriptionsInput{ + ResourceIdentifier: testServiceNetwork.Arn, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(listALSOutput.Items)).To(BeNumerically(">", 0)) + + var targetALS *vpclattice.AccessLogSubscriptionSummary + for _, als := range listALSOutput.Items { + if *als.DestinationArn == bucketArn { + targetALS = als + break + } + } + g.Expect(targetALS).ToNot(BeNil()) + + alsTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*targetALS.Arn}) + g.Expect(err).To(BeNil()) + alsTagsMap := alsTags[*targetALS.Arn] + + g.Expect(alsTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + }).Should(Succeed()) + + err := testFramework.Get(ctx, client.ObjectKeyFromObject(accessLogPolicy), accessLogPolicy) + Expect(err).To(BeNil()) + + accessLogPolicy.Annotations["application-networking.k8s.aws/tags"] = "Environment=Prod,Project=MyApp,Team=Platform,application-networking.k8s.aws/ManagedBy=test-override" + + testFramework.ExpectUpdated(ctx, accessLogPolicy) + + Eventually(func(g Gomega) { + listALSOutput, err := testFramework.LatticeClient.ListAccessLogSubscriptions(&vpclattice.ListAccessLogSubscriptionsInput{ + ResourceIdentifier: testServiceNetwork.Arn, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(listALSOutput.Items)).To(BeNumerically(">", 0)) + + // Find the access log subscription for our policy + var targetALS *vpclattice.AccessLogSubscriptionSummary + for _, als := range listALSOutput.Items { + if *als.DestinationArn == bucketArn { + targetALS = als + break + } + } + g.Expect(targetALS).ToNot(BeNil()) + + alsTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*targetALS.Arn}) + g.Expect(err).To(BeNil()) + alsTagsMap := alsTags[*targetALS.Arn] + + g.Expect(alsTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(alsTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(alsTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(alsTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + }).Should(Succeed()) + }) + AfterEach(func() { // Delete Access Log Policies in test namespace alps := &anv1alpha1.AccessLogPolicyList{} diff --git a/test/suites/integration/additional_tags_test.go b/test/suites/integration/additional_tags_test.go new file mode 100644 index 00000000..92312a24 --- /dev/null +++ b/test/suites/integration/additional_tags_test.go @@ -0,0 +1,351 @@ +package integration + +import ( + "fmt" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/vpclattice" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" + "github.com/aws/aws-application-networking-k8s/pkg/model/core" + "github.com/aws/aws-application-networking-k8s/test/pkg/test" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +var _ = Describe("Additional Tags test", Ordered, func() { + var ( + httpDeployment1 *appsv1.Deployment + httpSvc1 *v1.Service + httpRoute *gwv1.HTTPRoute + ) + + It("Set up HTTPRoute with additional tags", func() { + httpDeployment1, httpSvc1 = testFramework.NewHttpApp(test.HTTPAppOptions{Name: "additional-tags-http", Namespace: k8snamespace}) + httpRoute = testFramework.NewHttpRoute(testGateway, httpSvc1, "Service") + httpRoute.Annotations = map[string]string{ + "application-networking.k8s.aws/tags": "Environment=Dev,Project=MyApp,Team=Platform,CostCenter=12345", + } + + testFramework.ExpectCreated(ctx, + httpRoute, + httpSvc1, + httpDeployment1, + ) + }) + + It("Verify additional tags on HTTPRoute VPC Lattice resources", func() { + Eventually(func(g Gomega) { + route, _ := core.NewRoute(httpRoute) + vpcLatticeService := testFramework.GetVpcLatticeService(ctx, route) + + serviceTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*vpcLatticeService.Arn}) + g.Expect(err).To(BeNil()) + serviceTagsMap := serviceTags[*vpcLatticeService.Arn] + + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + + tgSummary := testFramework.GetTargetGroup(ctx, httpSvc1) + tg, err := testFramework.LatticeClient.GetTargetGroup(&vpclattice.GetTargetGroupInput{ + TargetGroupIdentifier: aws.String(*tgSummary.Id), + }) + g.Expect(err).To(BeNil()) + + tgTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*tg.Arn}) + g.Expect(err).To(BeNil()) + tgTagsMap := tgTags[*tg.Arn] + + g.Expect(tgTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + + listeners, err := testFramework.LatticeClient.ListListeners(&vpclattice.ListListenersInput{ + ServiceIdentifier: vpcLatticeService.Id, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(listeners.Items)).To(BeNumerically(">", 0)) + + for _, listener := range listeners.Items { + listenerTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*listener.Arn}) + g.Expect(err).To(BeNil()) + listenerTagsMap := listenerTags[*listener.Arn] + + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + } + + for _, listener := range listeners.Items { + rules, err := testFramework.LatticeClient.ListRules(&vpclattice.ListRulesInput{ + ServiceIdentifier: vpcLatticeService.Id, + ListenerIdentifier: listener.Id, + }) + g.Expect(err).To(BeNil()) + + for _, rule := range rules.Items { + if rule.IsDefault != nil && *rule.IsDefault { + continue + } + + ruleTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*rule.Arn}) + g.Expect(err).To(BeNil()) + ruleTagsMap := ruleTags[*rule.Arn] + + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + } + } + + associations, err := testFramework.LatticeClient.ListServiceNetworkServiceAssociations(&vpclattice.ListServiceNetworkServiceAssociationsInput{ + ServiceIdentifier: vpcLatticeService.Id, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(associations.Items)).To(BeNumerically(">", 0)) + + for _, association := range associations.Items { + associationTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*association.Arn}) + g.Expect(err).To(BeNil()) + associationTagsMap := associationTags[*association.Arn] + + g.Expect(associationTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + } + }).Within(1 * time.Minute).Should(Succeed()) + }) + + It("Update HTTPRoute additional tags and verify AWS managed tags cannot be overridden", func() { + err := testFramework.Get(ctx, client.ObjectKeyFromObject(httpRoute), httpRoute) + Expect(err).To(BeNil()) + + httpRoute.Annotations = map[string]string{ + "application-networking.k8s.aws/tags": "Environment=Prod,Project=MyApp,Team=Platform,application-networking.k8s.aws/ManagedBy=test-override", + } + testFramework.ExpectUpdated(ctx, httpRoute) + + Eventually(func(g Gomega) { + route, _ := core.NewRoute(httpRoute) + vpcLatticeService := testFramework.GetVpcLatticeService(ctx, route) + + serviceTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*vpcLatticeService.Arn}) + g.Expect(err).To(BeNil()) + serviceTagsMap := serviceTags[*vpcLatticeService.Arn] + + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(serviceTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(serviceTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(serviceTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + + tgSummary := testFramework.GetTargetGroup(ctx, httpSvc1) + tg, err := testFramework.LatticeClient.GetTargetGroup(&vpclattice.GetTargetGroupInput{ + TargetGroupIdentifier: aws.String(*tgSummary.Id), + }) + g.Expect(err).To(BeNil()) + + tgTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*tg.Arn}) + g.Expect(err).To(BeNil()) + tgTagsMap := tgTags[*tg.Arn] + + g.Expect(tgTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(tgTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(tgTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + + listeners, err := testFramework.LatticeClient.ListListeners(&vpclattice.ListListenersInput{ + ServiceIdentifier: vpcLatticeService.Id, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(listeners.Items)).To(BeNumerically(">", 0)) + + for _, listener := range listeners.Items { + listenerTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*listener.Arn}) + g.Expect(err).To(BeNil()) + listenerTagsMap := listenerTags[*listener.Arn] + + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(listenerTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(listenerTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(listenerTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + } + + for _, listener := range listeners.Items { + rules, err := testFramework.LatticeClient.ListRules(&vpclattice.ListRulesInput{ + ServiceIdentifier: vpcLatticeService.Id, + ListenerIdentifier: listener.Id, + }) + g.Expect(err).To(BeNil()) + + for _, rule := range rules.Items { + if rule.IsDefault != nil && *rule.IsDefault { + continue + } + + ruleTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*rule.Arn}) + g.Expect(err).To(BeNil()) + ruleTagsMap := ruleTags[*rule.Arn] + + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(ruleTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(ruleTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(ruleTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + } + } + + associations, err := testFramework.LatticeClient.ListServiceNetworkServiceAssociations(&vpclattice.ListServiceNetworkServiceAssociationsInput{ + ServiceIdentifier: vpcLatticeService.Id, + }) + g.Expect(err).To(BeNil()) + g.Expect(len(associations.Items)).To(BeNumerically(">", 0)) + + for _, association := range associations.Items { + associationTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*association.Arn}) + g.Expect(err).To(BeNil()) + associationTagsMap := associationTags[*association.Arn] + + g.Expect(associationTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(associationTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(associationTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(associationTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + } + }).Within(1 * time.Minute).Should(Succeed()) + }) + + It("Cleanup HTTPRoute resources", func() { + testFramework.ExpectDeletedThenNotFound(ctx, + httpRoute, + httpDeployment1, + httpSvc1, + ) + }) + + var ( + serviceExportDeployment1 *appsv1.Deployment + serviceExportSvc1 *v1.Service + serviceExport1 *anv1alpha1.ServiceExport + ) + + It("Set up ServiceExport with additional tags", func() { + serviceExportDeployment1, serviceExportSvc1 = testFramework.NewHttpApp(test.HTTPAppOptions{Name: "additional-tags-serviceexport", Namespace: k8snamespace}) + + serviceExport1 = &anv1alpha1.ServiceExport{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "application-networking.k8s.aws/v1alpha1", + Kind: "ServiceExport", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: serviceExportSvc1.Name, + Namespace: serviceExportSvc1.Namespace, + Annotations: map[string]string{ + "application-networking.k8s.aws/federation": "amazon-vpc-lattice", + "application-networking.k8s.aws/tags": "Environment=Dev,Project=MyApp,Team=Platform,CostCenter=12345", + }, + }, + Spec: anv1alpha1.ServiceExportSpec{ + ExportedPorts: []anv1alpha1.ExportedPort{ + { + Port: serviceExportSvc1.Spec.Ports[0].Port, + RouteType: "HTTP", + }, + }, + }, + } + + testFramework.ExpectCreated(ctx, + serviceExport1, + serviceExportSvc1, + serviceExportDeployment1, + ) + }) + + It("Verify additional tags on ServiceExport VPC Lattice resources", func() { + Eventually(func(g Gomega) { + tgSummary := testFramework.GetTargetGroup(ctx, serviceExportSvc1) + g.Expect(tgSummary).ToNot(BeNil()) + + tg, err := testFramework.LatticeClient.GetTargetGroup(&vpclattice.GetTargetGroupInput{ + TargetGroupIdentifier: aws.String(*tgSummary.Id), + }) + g.Expect(err).To(BeNil()) + + tgTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*tg.Arn}) + g.Expect(err).To(BeNil()) + tgTagsMap := tgTags[*tg.Arn] + + g.Expect(tgTagsMap).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + + g.Expect(tgTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + }).Within(1 * time.Minute).Should(Succeed()) + }) + + It("Update ServiceExport additional tags and verify AWS managed tags cannot be overridden", func() { + err := testFramework.Get(ctx, client.ObjectKeyFromObject(serviceExport1), serviceExport1) + Expect(err).To(BeNil()) + + serviceExport1.Annotations = map[string]string{ + "application-networking.k8s.aws/federation": "amazon-vpc-lattice", + "application-networking.k8s.aws/tags": "Environment=Prod,Project=MyApp,Team=Platform,application-networking.k8s.aws/ManagedBy=test-override", + } + testFramework.ExpectUpdated(ctx, serviceExport1) + + Eventually(func(g Gomega) { + tgSummary := testFramework.GetTargetGroup(ctx, serviceExportSvc1) + tg, err := testFramework.LatticeClient.GetTargetGroup(&vpclattice.GetTargetGroupInput{ + TargetGroupIdentifier: aws.String(*tgSummary.Id), + }) + g.Expect(err).To(BeNil()) + + tgTags, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{*tg.Arn}) + g.Expect(err).To(BeNil()) + tgTagsMap := tgTags[*tg.Arn] + + g.Expect(tgTagsMap).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(tgTagsMap).ToNot(HaveKey("CostCenter")) + + g.Expect(tgTagsMap).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(tgTagsMap).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + }).Within(1 * time.Minute).Should(Succeed()) + }) + + It("Cleanup ServiceExport resources", func() { + testFramework.ExpectDeletedThenNotFound(ctx, + serviceExport1, + serviceExportDeployment1, + serviceExportSvc1, + ) + }) +}) diff --git a/test/suites/integration/vpc_association_policy_test.go b/test/suites/integration/vpc_association_policy_test.go index ce6cf460..7d7f3b34 100644 --- a/test/suites/integration/vpc_association_policy_test.go +++ b/test/suites/integration/vpc_association_policy_test.go @@ -1,6 +1,7 @@ package integration import ( + "fmt" "time" "github.com/aws/aws-sdk-go/aws" @@ -13,6 +14,7 @@ import ( "k8s.io/apimachinery/pkg/types" "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/config" "github.com/aws/aws-application-networking-k8s/test/pkg/test" gwv1 "sigs.k8s.io/gateway-api/apis/v1" @@ -69,11 +71,14 @@ var _ = Describe("Test vpc association policy", Serial, Ordered, func() { Expect(err).To(BeNil()) }) - It("Create the VpcAssociationPolicy with a SecurityGroupId, expecting the ServiceNetworkVpcAssociation with a security group", func() { + It("Create the VpcAssociationPolicy with a SecurityGroupId and additional tags, expecting the ServiceNetworkVpcAssociation with a security group and additional tags", func() { vpcAssociationPolicy = &v1alpha1.VpcAssociationPolicy{ ObjectMeta: metav1.ObjectMeta{ Name: "test-vpc-association-policy", Namespace: k8snamespace, + Annotations: map[string]string{ + "application-networking.k8s.aws/tags": "Environment=Dev,Project=MyApp,Team=Platform,CostCenter=12345", + }, }, Spec: v1alpha1.VpcAssociationPolicySpec{ TargetRef: &gwv1alpha2.NamespacedPolicyTargetReference{ @@ -83,6 +88,7 @@ var _ = Describe("Test vpc association policy", Serial, Ordered, func() { Namespace: lo.ToPtr(gwv1.Namespace(k8snamespace)), }, SecurityGroupIds: []v1alpha1.SecurityGroupId{sgId}, + AssociateWithVpc: lo.ToPtr(true), }, } testFramework.ExpectCreated(ctx, vpcAssociationPolicy) @@ -91,12 +97,55 @@ var _ = Describe("Test vpc association policy", Serial, Ordered, func() { associated, snva, err := testFramework.IsVpcAssociatedWithServiceNetwork(ctx, test.CurrentClusterVpcId, testServiceNetwork) g.Expect(err).To(BeNil()) g.Expect(associated).To(BeTrue()) + output, err := testFramework.LatticeClient.GetServiceNetworkVpcAssociationWithContext(ctx, &vpclattice.GetServiceNetworkVpcAssociationInput{ ServiceNetworkVpcAssociationIdentifier: snva.Id, }) g.Expect(err).To(BeNil()) g.Expect(output.SecurityGroupIds).To(HaveLen(1)) g.Expect(*output.SecurityGroupIds[0]).To(Equal(string(sgId))) + + snvaArn := aws.StringValue(snva.Arn) + + tagsMap, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{snvaArn}) + g.Expect(err).To(BeNil()) + + tags := tagsMap[snvaArn] + g.Expect(tags).To(HaveKeyWithValue("Environment", aws.String("Dev"))) + g.Expect(tags).To(HaveKeyWithValue("Project", aws.String("MyApp"))) + g.Expect(tags).To(HaveKeyWithValue("Team", aws.String("Platform"))) + g.Expect(tags).To(HaveKeyWithValue("CostCenter", aws.String("12345"))) + g.Expect(tags).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) + }).WithTimeout(5 * time.Minute).Should(Succeed()) + }) + + It("Update VpcAssociationPolicy with additional tags changed and attempting to override aws managed tags", func() { + testFramework.Get(ctx, types.NamespacedName{ + Namespace: vpcAssociationPolicy.Namespace, + Name: vpcAssociationPolicy.Name, + }, vpcAssociationPolicy) + + vpcAssociationPolicy.ObjectMeta.Annotations["application-networking.k8s.aws/tags"] = "Environment=Prod,Project=MyApp-v2,Team=DevOps,application-networking.k8s.aws/ManagedBy=test-override" + + testFramework.ExpectUpdated(ctx, vpcAssociationPolicy) + + Eventually(func(g Gomega) { + associated, snva, err := testFramework.IsVpcAssociatedWithServiceNetwork(ctx, test.CurrentClusterVpcId, testServiceNetwork) + g.Expect(err).To(BeNil()) + g.Expect(associated).To(BeTrue()) + + snvaArn := aws.StringValue(snva.Arn) + testFramework.Log.Infof(ctx, "ServiceNetworkVpcAssociation ARN: %s", snvaArn) + tagsMap, err := testFramework.Cloud.Tagging().GetTagsForArns(ctx, []string{snvaArn}) + g.Expect(err).To(BeNil()) + tags := tagsMap[snvaArn] + g.Expect(tags).To(HaveKeyWithValue("Environment", aws.String("Prod"))) + g.Expect(tags).To(HaveKeyWithValue("Project", aws.String("MyApp-v2"))) + g.Expect(tags).To(HaveKeyWithValue("Team", aws.String("DevOps"))) + g.Expect(tags).ToNot(HaveKey("CostCenter")) + + g.Expect(tags).ToNot(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String("test-override"))) + g.Expect(tags).To(HaveKeyWithValue(pkg_aws.TagManagedBy, aws.String(fmt.Sprintf("%s/%s/%s", testFramework.Cloud.Config().AccountId, testFramework.Cloud.Config().ClusterName, testFramework.Cloud.Config().VpcId)))) }).WithTimeout(5 * time.Minute).Should(Succeed()) })