Skip to content

Commit

Permalink
AWS Cloud Provider now uses accurate pod/memory values when construct…
Browse files Browse the repository at this point in the history
…ing node objects (#572)
  • Loading branch information
ellistarn committed Jul 29, 2021
1 parent 01e0756 commit 7ac2ea6
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 130 deletions.
16 changes: 5 additions & 11 deletions pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ var (
)

type CloudProvider struct {
nodeAPI *NodeFactory
launchTemplateProvider *LaunchTemplateProvider
subnetProvider *SubnetProvider
instanceTypeProvider *InstanceTypeProvider
Expand All @@ -82,16 +81,16 @@ func NewCloudProvider(ctx context.Context, options cloudprovider.Options) *Cloud
}
logging.FromContext(ctx).Debugf("Using AWS region %s", *sess.Config.Region)
ec2api := ec2.New(sess)
instanceTypeProvider := NewInstanceTypeProvider(ec2api)
return &CloudProvider{
nodeAPI: &NodeFactory{ec2api: ec2api},
launchTemplateProvider: NewLaunchTemplateProvider(
ec2api,
NewAMIProvider(ssm.New(sess), options.ClientSet),
NewSecurityGroupProvider(ec2api),
),
subnetProvider: NewSubnetProvider(ec2api),
instanceTypeProvider: NewInstanceTypeProvider(ec2api),
instanceProvider: &InstanceProvider{ec2api: ec2api},
instanceTypeProvider: instanceTypeProvider,
instanceProvider: &InstanceProvider{ec2api, instanceTypeProvider},
creationQueue: parallel.NewWorkQueue(CreationQPS, CreationBurst),
}
}
Expand Down Expand Up @@ -132,14 +131,9 @@ func (c *CloudProvider) create(ctx context.Context, provisioner *v1alpha3.Provis
return fmt.Errorf("getting launch template, %w", err)
}
// 3. Create instance
instanceID, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType())
node, err := c.instanceProvider.Create(ctx, launchTemplate, packing.InstanceTypeOptions, subnets, constraints.GetCapacityType())
if err != nil {
return fmt.Errorf("launching instances, %w", err)
}
// 4. Convert to node
node, err := c.nodeAPI.For(ctx, instanceID)
if err != nil {
return fmt.Errorf("constructing node, %w", err)
return fmt.Errorf("launching instance, %w", err)
}
return callback(node)
}
Expand Down
117 changes: 102 additions & 15 deletions pkg/cloudprovider/aws/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@ import (
"context"
"fmt"
"strings"
"time"

"github.com/avast/retry-go"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/awslabs/karpenter/pkg/apis/provisioning/v1alpha3"
"github.com/awslabs/karpenter/pkg/cloudprovider"
"knative.dev/pkg/logging"

"go.uber.org/multierr"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/sets"
)

Expand All @@ -35,19 +40,69 @@ const (
)

type InstanceProvider struct {
ec2api ec2iface.EC2API
ec2api ec2iface.EC2API
instanceTypeProvider *InstanceTypeProvider
}

// Create an instance given the constraints.
// instanceTypeOptions should be sorted by priority for spot capacity type.
// If spot is not used, the instanceTypeOptions are not required to be sorted
// instanceTypes should be sorted by priority for spot capacity type.
// If spot is not used, the instanceTypes are not required to be sorted
// because we are using ec2 fleet's lowest-price OD allocation strategy
func (p *InstanceProvider) Create(ctx context.Context,
launchTemplate *LaunchTemplate,
instanceTypeOptions []cloudprovider.InstanceType,
instanceTypes []cloudprovider.InstanceType,
subnets []*ec2.Subnet,
capacityType string,
) (*string, error) {
) (*v1.Node, error) {
// 1. Launch Instance
id, err := p.launchInstance(ctx, launchTemplate, instanceTypes, subnets, capacityType)
if err != nil {
return nil, err
}
// 2. Get Instance with backoff retry since EC2 is eventually consistent
instance := &ec2.Instance{}
if err := retry.Do(
func() (err error) { return p.getInstance(ctx, id, instance) },
retry.Delay(1*time.Second),
retry.Attempts(3),
); err != nil {
return nil, err
}
logging.FromContext(ctx).Infof("Launched instance: %s, type: %s, zone: %s, hostname: %s",
aws.StringValue(instance.InstanceId),
aws.StringValue(instance.InstanceType),
aws.StringValue(instance.Placement.AvailabilityZone),
aws.StringValue(instance.PrivateDnsName),
)
// 3. Convert Instance to Node
node, err := p.instanceToNode(ctx, instance, instanceTypes)
if err != nil {
return nil, err
}
return node, nil
}

func (p *InstanceProvider) Terminate(ctx context.Context, node *v1.Node) error {
id, err := getInstanceID(node)
if err != nil {
return fmt.Errorf("getting instance ID for node %s, %w", node.Name, err)
}
if _, err = p.ec2api.TerminateInstancesWithContext(ctx, &ec2.TerminateInstancesInput{
InstanceIds: []*string{id},
}); err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == EC2InstanceIDNotFoundErrCode {
return nil
}
return fmt.Errorf("terminating instance %s, %w", node.Name, err)
}
return nil
}

func (p *InstanceProvider) launchInstance(ctx context.Context,
launchTemplate *LaunchTemplate,
instanceTypeOptions []cloudprovider.InstanceType,
subnets []*ec2.Subnet,
capacityType string) (*string, error) {
// 1. Construct override options.
var overrides []*ec2.FleetLaunchTemplateOverridesRequest
for i, instanceType := range instanceTypeOptions {
Expand Down Expand Up @@ -107,22 +162,54 @@ func (p *InstanceProvider) Create(ctx context.Context,
return createFleetOutput.Instances[0].InstanceIds[0], nil
}

func (p *InstanceProvider) Terminate(ctx context.Context, node *v1.Node) error {
id, err := getInstanceID(node)
func (p *InstanceProvider) getInstance(ctx context.Context, id *string, instance *ec2.Instance) error {
describeInstancesOutput, err := p.ec2api.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{InstanceIds: []*string{id}})
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == EC2InstanceIDNotFoundErrCode {
return aerr
}
if err != nil {
return fmt.Errorf("getting instance ID for node %s, %w", node.Name, err)
return fmt.Errorf("failed to describe ec2 instances, %w", err)
}
if _, err = p.ec2api.TerminateInstancesWithContext(ctx, &ec2.TerminateInstancesInput{
InstanceIds: []*string{id},
}); err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == EC2InstanceIDNotFoundErrCode {
return nil
}
return fmt.Errorf("terminating instance %s, %w", node.Name, err)
if len(describeInstancesOutput.Reservations) != 1 {
return fmt.Errorf("expected a single instance reservation, got %d", len(describeInstancesOutput.Reservations))
}
if len(describeInstancesOutput.Reservations[0].Instances) != 1 {
return fmt.Errorf("expected a single instance, got %d", len(describeInstancesOutput.Reservations[0].Instances))
}
*instance = *describeInstancesOutput.Reservations[0].Instances[0]
if len(aws.StringValue(instance.PrivateDnsName)) == 0 {
return fmt.Errorf("expected PrivateDnsName to be set")
}
return nil
}

func (p *InstanceProvider) instanceToNode(ctx context.Context, instance *ec2.Instance, instanceTypes []cloudprovider.InstanceType) (*v1.Node, error) {
for _, instanceType := range instanceTypes {
if instanceType.Name() == aws.StringValue(instance.InstanceType) {
return &v1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: aws.StringValue(instance.PrivateDnsName),
},
Spec: v1.NodeSpec{
ProviderID: fmt.Sprintf("aws:///%s/%s", aws.StringValue(instance.Placement.AvailabilityZone), aws.StringValue(instance.InstanceId)),
},
Status: v1.NodeStatus{
Allocatable: v1.ResourceList{
v1.ResourcePods: *instanceType.Pods(),
v1.ResourceCPU: *instanceType.CPU(),
v1.ResourceMemory: *instanceType.Memory(),
},
NodeInfo: v1.NodeSystemInfo{
Architecture: aws.StringValue(instance.Architecture),
OperatingSystem: v1alpha3.OperatingSystemLinux,
},
},
}, nil
}
}
return nil, fmt.Errorf("unrecognized instance type %s", aws.StringValue(instance.InstanceType))
}

func getInstanceID(node *v1.Node) (*string, error) {
id := strings.Split(node.Spec.ProviderID, "/")
if len(id) < 5 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cloudprovider/aws/instancetypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewInstanceTypeProvider(ec2api ec2iface.EC2API) *InstanceTypeProvider {
}
}

// Get instance types that are available per availability zone
// Get all instance types that are available per availability zone
func (p *InstanceTypeProvider) Get(ctx context.Context) ([]cloudprovider.InstanceType, error) {
var instanceTypes []cloudprovider.InstanceType
if cached, ok := p.cache.Get(allInstanceTypesKey); ok {
Expand Down
97 changes: 0 additions & 97 deletions pkg/cloudprovider/aws/node.go

This file was deleted.

7 changes: 3 additions & 4 deletions pkg/cloudprovider/aws/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,19 @@ func TestAPIs(t *testing.T) {
var _ = BeforeSuite(func() {
launchTemplateCache = cache.New(CacheTTL, CacheCleanupInterval)
fakeEC2API = &fake.EC2API{}
instanceTypeProvider := NewInstanceTypeProvider(fakeEC2API)
env = test.NewEnvironment(ctx, func(e *test.Environment) {
clientSet := kubernetes.NewForConfigOrDie(e.Config)

cloudProvider := &CloudProvider{
nodeAPI: &NodeFactory{fakeEC2API},
launchTemplateProvider: &LaunchTemplateProvider{
fakeEC2API,
NewAMIProvider(&fake.SSMAPI{}, clientSet),
NewSecurityGroupProvider(fakeEC2API),
launchTemplateCache,
},
subnetProvider: NewSubnetProvider(fakeEC2API),
instanceTypeProvider: NewInstanceTypeProvider(fakeEC2API),
instanceProvider: &InstanceProvider{fakeEC2API},
instanceTypeProvider: instanceTypeProvider,
instanceProvider: &InstanceProvider{fakeEC2API, instanceTypeProvider},
creationQueue: parallel.NewWorkQueue(CreationQPS, CreationBurst),
}
registry.RegisterOrDie(cloudProvider)
Expand Down
5 changes: 3 additions & 2 deletions pkg/controllers/allocation/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ func (b *Binder) Bind(ctx context.Context, node *v1.Node, pods []*v1.Pod) error
}

// 4. Bind pods
logging.FromContext(ctx).Infof("Binding %d pod(s) to node %s", len(pods), node.Name)
errs := make([]error, len(pods))
workqueue.ParallelizeUntil(ctx, len(pods), len(pods), func(index int) {
errs[index] = b.bind(ctx, node, pods[index])
})
return multierr.Combine(errs...)
err := multierr.Combine(errs...)
logging.FromContext(ctx).Infof("Bound %d pod(s) to node %s", len(pods)-len(multierr.Errors(err)), node.Name)
return err
}

func (b *Binder) bind(ctx context.Context, node *v1.Node, pod *v1.Pod) error {
Expand Down

0 comments on commit 7ac2ea6

Please sign in to comment.