Skip to content

Commit

Permalink
More unmangling. More working on #16 and #18
Browse files Browse the repository at this point in the history
  • Loading branch information
andresvia committed Oct 19, 2020
1 parent b4bf6a7 commit bd7db8e
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 39 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# go-awsecs

[![godoc reference](http://img.shields.io/badge/godoc-reference-blue.svg)](https://pkg.go.dev/github.com/Autodesk/go-awsecs)

[![travis ci](https://api.travis-ci.org/Autodesk/go-awsecs.svg?branch=master)](https://travis-ci.org/Autodesk/go-awsecs)

[![coverage status](https://coveralls.io/repos/github/Autodesk/go-awsecs/badge.svg?branch=master)](https://coveralls.io/github/Autodesk/go-awsecs?branch=master)
Expand Down
22 changes: 10 additions & 12 deletions asg.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,16 @@ func drainingContainerInstanceIsDrained(ECSAPI ecs.ECS, clusterName, containerIn
}

func findDrainingContainerInstance(output *ecs.DescribeContainerInstancesOutput, containerInstanceID string) error {
for _, containerInstance := range output.ContainerInstances {
containerInstanceArn := *containerInstance.ContainerInstanceArn
parsedArn, err := arn.Parse(containerInstanceArn)
if err != nil {
return err
}
err = checkDrainingContainerInstance(containerInstance, parsedArn, containerInstanceID)
if err != nil {
return err
}
if len(output.ContainerInstances) == 0 {
return ErrContainerInstanceNotFound
}
return backoff.Permanent(errors.New("container instance not found"))
containerInstance := output.ContainerInstances[0]
containerInstanceArn := *containerInstance.ContainerInstanceArn
parsedArn, err := arn.Parse(containerInstanceArn)
if err != nil {
return err
}
return checkDrainingContainerInstance(containerInstance, parsedArn, containerInstanceID)
}

func checkDrainingContainerInstance(containerInstance *ecs.ContainerInstance, parsedArn arn.ARN, containerInstanceID string) error {
Expand All @@ -206,7 +204,7 @@ func checkDrainingContainerInstance(containerInstance *ecs.ContainerInstance, pa
}
return nil
}
return nil
return ErrContainerInstanceNotFound
}

func drainAll(ASAPI autoscaling.AutoScaling, ECSAPI ecs.ECS, EC2API ec2.EC2, instances []ecsEC2Instance, asgName, clusterName string) error {
Expand Down
2 changes: 1 addition & 1 deletion asg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestCheckDrainingContainerInstance(t *testing.T) {
},
{
name: "Not matching container instance ID",
wantErr: false,
wantErr: true,
args: args{
containerInstance: &ecs.ContainerInstance{},
parsedArn: arn.ARN{
Expand Down
76 changes: 50 additions & 26 deletions ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ var (
ErrServiceNotFound = errors.New("the service does not exist")
// ErrServiceDeletedAfterUpdate service was updated and then deleted elsewhere
ErrServiceDeletedAfterUpdate = backoff.Permanent(errors.New("the service was deleted after the update"))
// ErrContainerInstanceNotFound the container instance was removed from the cluster elsewhere
ErrContainerInstanceNotFound = backoff.Permanent(errors.New("container instance not found"))
)

var (
Expand Down Expand Up @@ -165,42 +167,64 @@ func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMap
if err != nil {
return "", err
}
arn := tdNew.TaskDefinition.TaskDefinitionArn
return *arn, nil
taskDefinitionArn := tdNew.TaskDefinition.TaskDefinitionArn
return *taskDefinitionArn, nil
}

// TODO: add coverage
func alterService(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string) (ecs.Service, ecs.Service, error) {
output, err := api.DescribeServices(&ecs.DescribeServicesInput{Cluster: aws.String(cluster), Services: []*string{aws.String(service)}})
if err != nil {
return ecs.Service{}, ecs.Service{}, err
}
for _, svc := range output.Services {
clusterArn := *svc.ClusterArn
parsedClusterArn, err := arn.Parse(clusterArn)
copyTaskDefinitionAction := func(sourceTaskDefinition string) (string, error) {
return copyTaskDef(api, sourceTaskDefinition, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole)
}
updateAction := func(newTaskDefinition *string, desiredCount *int64) (*ecs.UpdateServiceOutput, error) {
updateServiceInput := &ecs.UpdateServiceInput{
Cluster: aws.String(cluster),
Service: aws.String(service),
TaskDefinition: newTaskDefinition,
DesiredCount: desiredCount,
ForceNewDeployment: aws.Bool(true),
}
return api.UpdateService(updateServiceInput)
}
return findAndUpdateService(output, cluster, service, taskdef, desiredCount, copyTaskDefinitionAction, updateAction)
}

func findAndUpdateService(output *ecs.DescribeServicesOutput, cluster, service, taskDefinition string, desiredCount *int64, copyTdAction func(string) (string, error), updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error)) (ecs.Service, ecs.Service, error) {
if len(output.Services) == 0 {
return ecs.Service{}, ecs.Service{}, ErrServiceNotFound
}
svc := output.Services[0]
clusterArn := *svc.ClusterArn
parsedClusterArn, err := arn.Parse(clusterArn)
if err != nil {
return ecs.Service{}, ecs.Service{}, err
}
return updateService(parsedClusterArn, svc, cluster, service, taskDefinition, desiredCount, copyTdAction, updateSvcAction)
}

func updateService(parsedClusterArn arn.ARN, svc *ecs.Service, cluster, service, td string, desiredCount *int64, copyTdAction func(string) (string, error), updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error)) (ecs.Service, ecs.Service, error) {
clusterNameFound := strings.TrimPrefix(parsedClusterArn.Resource, "cluster/")
serviceNameFound := *svc.ServiceName
if clusterNameFound == cluster && serviceNameFound == service {
srcTaskDef := svc.TaskDefinition
if td != "" {
srcTaskDef = &td
}
newTd, err := copyTdAction(*srcTaskDef)
if err != nil {
return ecs.Service{}, ecs.Service{}, err
return *svc, ecs.Service{}, err
}
clusterNameFound := strings.TrimPrefix(parsedClusterArn.Resource, "cluster/")
serviceNameFound := *svc.ServiceName
if clusterNameFound == cluster && serviceNameFound == service {
srcTaskDef := svc.TaskDefinition
if taskdef != "" {
srcTaskDef = &taskdef
}
newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole)
if err != nil {
return *svc, ecs.Service{}, err
}
if desiredCount == nil {
desiredCount = svc.DesiredCount
}
updated, err := api.UpdateService(&ecs.UpdateServiceInput{Cluster: aws.String(cluster), Service: aws.String(service), TaskDefinition: aws.String(newTd), DesiredCount: desiredCount, ForceNewDeployment: aws.Bool(true)})
if err != nil {
return *svc, ecs.Service{}, err
}
return *svc, *updated.Service, nil
if desiredCount == nil {
desiredCount = svc.DesiredCount
}
updated, err := updateSvcAction(aws.String(newTd), desiredCount)
if err != nil {
return *svc, ecs.Service{}, err
}
return *svc, *updated.Service, nil
}
return ecs.Service{}, ecs.Service{}, ErrServiceNotFound
}
Expand Down
160 changes: 160 additions & 0 deletions ecs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package awsecs

import (
"errors"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/service/ecs"
"reflect"
"testing"
)

func TestUpdateService(t *testing.T) {
type args struct {
parsedClusterArn arn.ARN
svc *ecs.Service
cluster string
service string
td string
desiredCount *int64
copyTdAction func(string) (string, error)
updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error)
}
tests := []struct {
name string
wantErr bool
beforeUpdate ecs.Service
afterUpdate ecs.Service
args args
}{
{
name: "On copy error I want error",
wantErr: true,
beforeUpdate: ecs.Service{
ServiceName: aws.String("my-service"),
},
afterUpdate: ecs.Service{},
args: args{
parsedClusterArn: arn.ARN{
Resource: "cluster/my-cluster",
},
svc: &ecs.Service{
ServiceName: aws.String("my-service"),
},
cluster: "my-cluster",
service: "my-service",
td: "task:1",
desiredCount: aws.Int64(1),
copyTdAction: func(string) (string, error) {
return "", errors.New("failed to copy")
},
updateSvcAction: nil,
},
},
{
name: "On update error I want error",
wantErr: true,
beforeUpdate: ecs.Service{
ServiceName: aws.String("my-service"),
},
afterUpdate: ecs.Service{},
args: args{
parsedClusterArn: arn.ARN{
Resource: "cluster/my-cluster",
},
svc: &ecs.Service{
ServiceName: aws.String("my-service"),
},
cluster: "my-cluster",
service: "my-service",
td: "task:1",
desiredCount: aws.Int64(1),
copyTdAction: func(string) (string, error) {
return "task:2", nil
},
updateSvcAction: func(*string, *int64) (*ecs.UpdateServiceOutput, error) {
return nil, errors.New("failed to update")
},
},
},
{
name: "On non matching cluster I want error",
wantErr: true,
beforeUpdate: ecs.Service{},
afterUpdate: ecs.Service{},
args: args{
parsedClusterArn: arn.ARN{
Resource: "cluster/my-cluster",
},
svc: &ecs.Service{
ServiceName: aws.String("my-service"),
},
cluster: "my-other-cluster",
service: "my-service",
},
},
{
name: "On non matching service I want error",
wantErr: true,
beforeUpdate: ecs.Service{},
afterUpdate: ecs.Service{},
args: args{
parsedClusterArn: arn.ARN{
Resource: "cluster/my-cluster",
},
svc: &ecs.Service{
ServiceName: aws.String("my-service"),
},
cluster: "my-cluster",
service: "my-other-service",
},
},
{
name: "Check before and after update",
wantErr: false,
beforeUpdate: ecs.Service{
ServiceName: aws.String("my-service"),
},
afterUpdate: ecs.Service{
TaskDefinition: aws.String("task:2"),
},
args: args{
parsedClusterArn: arn.ARN{
Resource: "cluster/my-cluster",
},
svc: &ecs.Service{
ServiceName: aws.String("my-service"),
},
cluster: "my-cluster",
service: "my-service",
td: "task:1",
desiredCount: nil,
copyTdAction: func(s string) (string, error) {
return "task:2", nil
},
updateSvcAction: func(s *string, i *int64) (*ecs.UpdateServiceOutput, error) {
return &ecs.UpdateServiceOutput{
Service: &ecs.Service{
TaskDefinition: aws.String("task:2"),
},
}, nil
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := updateService(tt.args.parsedClusterArn, tt.args.svc, tt.args.cluster, tt.args.service, tt.args.td, tt.args.desiredCount, tt.args.copyTdAction, tt.args.updateSvcAction)
if (err != nil) != tt.wantErr {
t.Errorf("updateService() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.beforeUpdate) {
t.Errorf("updateService() got = %v, want %v", got, tt.beforeUpdate)
}
if !reflect.DeepEqual(got1, tt.afterUpdate) {
t.Errorf("updateService() got1 = %v, want %v", got1, tt.afterUpdate)
}
})
}
}

0 comments on commit bd7db8e

Please sign in to comment.