From 398b5121ddfd7f9a7ffb45b17f81af0a3ffea919 Mon Sep 17 00:00:00 2001 From: Brandon Wagner Date: Mon, 20 Jul 2020 15:43:04 -0500 Subject: [PATCH 1/4] accept memory and gpu memory in GiB instead of MiB --- README.md | 53 ++++++++++++------------- cmd/examples/example1.go | 8 ++-- cmd/main.go | 8 ++-- pkg/cli/cli.go | 58 ++++++++++++++++++++-------- pkg/cli/flags.go | 45 ++++++++++++++++++++- pkg/cli/flags_test.go | 18 +++++++++ pkg/cli/types.go | 29 +++++++++++--- pkg/cli/types_test.go | 12 ++++++ pkg/selector/aggregates.go | 6 +-- pkg/selector/outputs/outputs.go | 12 +++--- pkg/selector/outputs/outputs_test.go | 12 ++++++ pkg/selector/selector.go | 20 +++++++++- pkg/selector/selector_test.go | 2 +- pkg/selector/types.go | 15 +++++-- 14 files changed, 224 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 67f2843..1bc64cb 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ $ export AWS_REGION="us-east-1" **Find Instance Types with 4 GiB of memory, 2 vcpus, and runs on the x86_64 CPU architecture** ``` -$ ec2-instance-selector --memory 4096 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 +$ ec2-instance-selector --memory 4 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 c5.large c5d.large t2.medium @@ -93,26 +93,27 @@ r5n.24xlarge **Short Table Output** ``` -$ ec2-instance-selector --memory 4096 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 -o table -Instance Type VCPUs Mem (MiB) +$ ec2-instance-selector --memory 4 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 -o table +Instance Type VCPUs Mem (GiB) ------------- ----- --------- -c5.large 2 4096 -c5d.large 2 4096 -t2.medium 2 4096 -t3.medium 2 4096 -t3a.medium 2 4096 +c5.large 2 4.000 +c5d.large 2 4.000 +t2.medium 2 4.000 +t3.medium 2 4.000 +t3a.medium 2 4.000 ``` **Wide Table Output** ``` -$ ec2-instance-selector --memory 4096 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 -o table-wide -Instance Type VCPUs Mem (MiB) Hypervisor Current Gen Hibernation Support CPU Arch Network Performance ENIs GPUs -------------- ----- --------- ---------- ----------- ------------------- -------- ------------------- ---- ---- -c5.large 2 4096 nitro true true x86_64 Up to 10 Gigabit 3 0 -c5d.large 2 4096 nitro true false x86_64 Up to 10 Gigabit 3 0 -t2.medium 2 4096 xen true true i386, x86_64 Low to Moderate 3 0 -t3.medium 2 4096 nitro true false x86_64 Up to 5 Gigabit 3 0 -t3a.medium 2 4096 nitro true false x86_64 Up to 5 Gigabit 3 0 +$ ec2-instance-selector --memory 4 --vcpus 2 --cpu-architecture x86_64 -r us-east-1 -o table-wide +Instance Type VCPUs Mem (GiB) Hypervisor Current Gen Hibernation Support CPU Arch Network Performance ENIs GPUs +------------- ----- --------- ---------- ----------- ------------------- -------- ------------------- ---- ---- +c5.large 2 4.000 nitro true true x86_64 Up to 10 Gigabit 3 0 +c5a.large 2 4.000 nitro true false x86_64 Up to 10 Gigabit 3 0 +c5d.large 2 4.000 nitro true false x86_64 Up to 10 Gigabit 3 0 +t2.medium 2 4.000 xen true true i386, x86_64 Low to Moderate 3 0 +t3.medium 2 4.000 nitro true false x86_64 Up to 5 Gigabit 3 0 +t3a.medium 2 4.000 nitro true false x86_64 Up to 5 Gigabit 3 0 ``` **All CLI Options** @@ -144,17 +145,17 @@ Filter Flags: --deny-list string List of instance types which should be excluded w/ regex syntax (Example: m[1-2]\.*) -e, --ena-support Instance types where ENA is supported or required -f, --fpga-support FPGA instance types - --gpu-memory-total int Number of GPUs' total memory in MiB (Example: 4096) (sets --gpu-memory-total-min and -max to the same value) - --gpu-memory-total-max int Maximum Number of GPUs' total memory in MiB (Example: 4096) If --gpu-memory-total-min is not specified, the lower bound will be 0 - --gpu-memory-total-min int Minimum Number of GPUs' total memory in MiB (Example: 4096) If --gpu-memory-total-max is not specified, the upper bound will be infinity + --gpu-memory-total float Number of GPUs' total memory in GiB (Example: 4) (sets --gpu-memory-total-min and -max to the same value) + --gpu-memory-total-max float Maximum Number of GPUs' total memory in GiB (Example: 4) If --gpu-memory-total-min is not specified, the lower bound will be 0 + --gpu-memory-total-min float Minimum Number of GPUs' total memory in GiB (Example: 4) If --gpu-memory-total-max is not specified, the upper bound will be infinity -g, --gpus int Total Number of GPUs (Example: 4) (sets --gpus-min and -max to the same value) --gpus-max int Maximum Total Number of GPUs (Example: 4) If --gpus-min is not specified, the lower bound will be 0 --gpus-min int Minimum Total Number of GPUs (Example: 4) If --gpus-max is not specified, the upper bound will be infinity --hibernation-support Hibernation supported --hypervisor string Hypervisor: [xen or nitro] - -m, --memory int Amount of Memory available in MiB (Example: 4096) (sets --memory-min and -max to the same value) - --memory-max int Maximum Amount of Memory available in MiB (Example: 4096) If --memory-min is not specified, the lower bound will be 0 - --memory-min int Minimum Amount of Memory available in MiB (Example: 4096) If --memory-max is not specified, the upper bound will be infinity + -m, --memory float Amount of Memory available in GiB (Example: 4) (sets --memory-min and -max to the same value) + --memory-max float Maximum Amount of Memory available in GiB (Example: 4) If --memory-min is not specified, the lower bound will be 0 + --memory-min float Minimum Amount of Memory available in GiB (Example: 4) If --memory-max is not specified, the upper bound will be infinity --network-interfaces int Number of network interfaces (ENIs) that can be attached to the instance (sets --network-interfaces-min and -max to the same value) --network-interfaces-max int Maximum Number of network interfaces (ENIs) that can be attached to the instance If --network-interfaces-min is not specified, the lower bound will be 0 --network-interfaces-min int Minimum Number of network interfaces (ENIs) that can be attached to the instance If --network-interfaces-max is not specified, the upper bound will be infinity @@ -221,10 +222,10 @@ func main() { LowerBound: 2, UpperBound: 4, } - // Instantiate an int range filter to specify min and max memory in MiB - memoryRange := selector.IntRangeFilter{ - LowerBound: 1024, - UpperBound: 4096, + // Instantiate a float64 range filter to specify min and max memory in GiB + memoryRange := selector.Float64RangeFilter{ + LowerBound: 1.0, + UpperBound: 4.0, } // Create a string for the CPU Architecture so that it can be passed as a pointer // when creating the Filter struct diff --git a/cmd/examples/example1.go b/cmd/examples/example1.go index 4fdb43b..5f2c05a 100644 --- a/cmd/examples/example1.go +++ b/cmd/examples/example1.go @@ -27,10 +27,10 @@ func main() { LowerBound: 2, UpperBound: 4, } - // Instantiate an int range filter to specify min and max memory in MiB - memoryRange := selector.IntRangeFilter{ - LowerBound: 1024, - UpperBound: 4096, + // Instantiate a float64 range filter to specify min and max memory in GiB + memoryRange := selector.Float64RangeFilter{ + LowerBound: 1.0, + UpperBound: 4.0, } // Create a string for the CPU Architecture so that it can be passed as a pointer // when creating the Filter struct diff --git a/cmd/main.go b/cmd/main.go index 77ead80..1c921f3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -116,11 +116,11 @@ Full docs can be found at github.com/aws/amazon-` + binName // Filter Flags - These will be grouped at the top of the help flags cli.IntMinMaxRangeFlags(vcpus, cli.StringMe("c"), nil, "Number of vcpus available to the instance type.") - cli.IntMinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available in MiB (Example: 4096)") + cli.Float64MinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available in GiB (Example: 4)") cli.RatioFlag(vcpusToMemoryRatio, nil, nil, "The ratio of vcpus to memory in MiB. (Example: 1:2)") cli.StringFlag(cpuArchitecture, cli.StringMe("a"), nil, "CPU architecture [x86_64, i386, or arm64]", nil) cli.IntMinMaxRangeFlags(gpus, cli.StringMe("g"), nil, "Total Number of GPUs (Example: 4)") - cli.IntMinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory in MiB (Example: 4096)") + cli.Float64MinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory in GiB (Example: 4)") cli.StringFlag(placementGroupStrategy, nil, nil, "Placement group strategy: [cluster, partition, spread]", nil) cli.StringFlag(usageClass, cli.StringMe("u"), nil, "Usage class: [spot or on-demand]", nil) cli.StringFlag(rootDeviceType, nil, nil, "Supported root device types: [ebs or instance-store]", nil) @@ -189,11 +189,11 @@ Full docs can be found at github.com/aws/amazon-` + binName filters := selector.Filters{ VCpusRange: cli.IntRangeMe(flags[vcpus]), - MemoryRange: cli.IntRangeMe(flags[memory]), + MemoryRange: cli.Float64RangeMe(flags[memory]), VCpusToMemoryRatio: cli.Float64Me(flags[vcpusToMemoryRatio]), CPUArchitecture: cli.StringMe(flags[cpuArchitecture]), GpusRange: cli.IntRangeMe(flags[gpus]), - GpuMemoryRange: cli.IntRangeMe(flags[gpuMemoryTotal]), + GpuMemoryRange: cli.Float64RangeMe(flags[gpuMemoryTotal]), PlacementGroupStrategy: cli.StringMe(flags[placementGroupStrategy]), UsageClass: cli.StringMe(flags[usageClass]), RootDeviceType: cli.StringMe(flags[rootDeviceType]), diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 8180730..4cf8731 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -37,12 +37,12 @@ func New(binaryName string, shortUsage string, longUsage, examples string, run r Run: run, } return CommandLineInterface{ - Command: cmd, - Flags: map[string]interface{}{}, - nilDefaults: map[string]bool{}, - intRangeFlags: map[string]bool{}, - validators: map[string]validator{}, - suiteFlags: pflag.NewFlagSet("suite", pflag.ExitOnError), + Command: cmd, + Flags: map[string]interface{}{}, + nilDefaults: map[string]bool{}, + rangeFlags: map[string]bool{}, + validators: map[string]validator{}, + suiteFlags: pflag.NewFlagSet("suite", pflag.ExitOnError), } } @@ -164,6 +164,10 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error { if reflect.ValueOf(*v).IsZero() { cl.Flags[f.Name] = nil } + case *float64: + if reflect.ValueOf(*v).IsZero() { + cl.Flags[f.Name] = nil + } case *string: if reflect.ValueOf(*v).IsZero() { cl.Flags[f.Name] = nil @@ -188,29 +192,51 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error { return nil } -// ProcessRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max +// ProcessRangeFilterFlags sets min and max to the appropriate 0 or max bounds based on the 3-tuple that a user specifies for base flag, min, and/or max func (cl *CommandLineInterface) ProcessRangeFilterFlags() error { - for flagName := range cl.intRangeFlags { + for flagName := range cl.rangeFlags { rangeHelperMin := fmt.Sprintf("%s-%s", flagName, "min") rangeHelperMax := fmt.Sprintf("%s-%s", flagName, "max") if cl.Flags[flagName] != nil { if cl.Flags[rangeHelperMin] != nil || cl.Flags[rangeHelperMax] != nil { return fmt.Errorf("error: --%s and --%s cannot be set when using --%s", rangeHelperMin, rangeHelperMax, flagName) } - cl.Flags[rangeHelperMin] = cl.IntMe(cl.Flags[flagName]) - cl.Flags[rangeHelperMax] = cl.IntMe(cl.Flags[flagName]) + cl.Flags[rangeHelperMin] = cl.Flags[flagName] + cl.Flags[rangeHelperMax] = cl.Flags[flagName] } if cl.Flags[rangeHelperMin] == nil && cl.Flags[rangeHelperMax] == nil { continue - } else if cl.Flags[rangeHelperMin] == nil { - cl.Flags[rangeHelperMin] = cl.IntMe(0) + } + + if cl.Flags[rangeHelperMin] == nil { + switch cl.Flags[rangeHelperMax].(type) { + case *int: + cl.Flags[rangeHelperMin] = cl.IntMe(0) + case *float64: + cl.Flags[rangeHelperMin] = cl.Float64Me(0) + } } else if cl.Flags[rangeHelperMax] == nil { - cl.Flags[rangeHelperMax] = cl.IntMe(maxInt) + switch cl.Flags[rangeHelperMin].(type) { + case *int: + cl.Flags[rangeHelperMax] = cl.IntMe(maxInt) + case *float64: + cl.Flags[rangeHelperMax] = cl.Float64Me(maxFloat64) + } } - cl.Flags[flagName] = &selector.IntRangeFilter{ - LowerBound: *cl.IntMe(cl.Flags[rangeHelperMin]), - UpperBound: *cl.IntMe(cl.Flags[rangeHelperMax]), + + switch cl.Flags[rangeHelperMin].(type) { + case *int: + cl.Flags[flagName] = &selector.IntRangeFilter{ + LowerBound: *cl.IntMe(cl.Flags[rangeHelperMin]), + UpperBound: *cl.IntMe(cl.Flags[rangeHelperMax]), + } + case *float64: + cl.Flags[flagName] = &selector.Float64RangeFilter{ + LowerBound: *cl.Float64Me(cl.Flags[rangeHelperMin]), + UpperBound: *cl.Float64Me(cl.Flags[rangeHelperMax]), + } } + } return nil } diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index ea9e178..1d43848 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "math" "regexp" "strconv" "strings" @@ -10,7 +11,8 @@ import ( ) const ( - maxInt = int(^uint(0) >> 1) + maxFloat64 = math.MaxFloat64 + maxInt = int(^uint(0) >> 1) ) // RatioFlag creates and registers a flag accepting a Ratio @@ -51,6 +53,11 @@ func (cl *CommandLineInterface) IntMinMaxRangeFlags(name string, shorthand *stri cl.IntMinMaxRangeFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) } +// Float64MinMaxRangeFlags creates and registers a min, max, and helper flag each accepting a Float64 +func (cl *CommandLineInterface) Float64MinMaxRangeFlags(name string, shorthand *string, defaultValue *float64, description string) { + cl.Float64MinMaxRangeFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) +} + // IntFlag creates and registers a flag accepting an Integer func (cl *CommandLineInterface) IntFlag(name string, shorthand *string, defaultValue *int, description string) { cl.IntFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) @@ -167,7 +174,28 @@ func (cl *CommandLineInterface) IntMinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagS } return nil } - cl.intRangeFlags[name] = true + cl.rangeFlags[name] = true +} + +// Float64MinMaxRangeFlagOnFlagSet creates and registers a min, max, and helper flag each accepting a Float64 +func (cl *CommandLineInterface) Float64MinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *float64, description string) { + cl.Float64FlagOnFlagSet(flagSet, name, shorthand, defaultValue, fmt.Sprintf("%s (sets --%s-min and -max to the same value)", description, name)) + cl.Float64FlagOnFlagSet(flagSet, name+"-min", nil, nil, fmt.Sprintf("Minimum %s If --%s-max is not specified, the upper bound will be infinity", description, name)) + cl.Float64FlagOnFlagSet(flagSet, name+"-max", nil, nil, fmt.Sprintf("Maximum %s If --%s-min is not specified, the lower bound will be 0", description, name)) + cl.validators[name] = func(val interface{}) error { + if cl.Flags[name+"-min"] == nil || cl.Flags[name+"-max"] == nil { + return nil + } + minArg := name + "-min" + maxArg := name + "-max" + minVal := cl.Flags[minArg].(*float64) + maxVal := cl.Flags[maxArg].(*float64) + if *minVal > *maxVal { + return fmt.Errorf("Invalid input for --%s and --%s. %s must be less than or equal to %s", minArg, maxArg, minArg, maxArg) + } + return nil + } + cl.rangeFlags[name] = true } // IntFlagOnFlagSet creates and registers a flag accepting an Integer @@ -183,6 +211,19 @@ func (cl *CommandLineInterface) IntFlagOnFlagSet(flagSet *pflag.FlagSet, name st cl.Flags[name] = flagSet.Int(name, *defaultValue, description) } +// Float64FlagOnFlagSet creates and registers a flag accepting a Float64 +func (cl *CommandLineInterface) Float64FlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *float64, description string) { + if defaultValue == nil { + cl.nilDefaults[name] = true + defaultValue = cl.Float64Me(0.0) + } + if shorthand != nil { + cl.Flags[name] = flagSet.Float64P(name, string(*shorthand), *defaultValue, description) + return + } + cl.Flags[name] = flagSet.Float64(name, *defaultValue, description) +} + // StringFlagOnFlagSet creates and registers a flag accepting a String and a validator function. // The validator function is provided so that more complex flags can be created from a string input. func (cl *CommandLineInterface) StringFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *string, description string, validationFn validator) { diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go index b243fdc..1d50ce5 100644 --- a/pkg/cli/flags_test.go +++ b/pkg/cli/flags_test.go @@ -151,6 +151,24 @@ func TestIntMinMaxRangeFlags(t *testing.T) { h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) } +func TestFloat64MinMaxRangeFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-float-min-max-range" + cli.Float64MinMaxRangeFlags(flagName, cli.StringMe("t"), nil, "Test Min Max Range") + _, ok := cli.Flags[flagName] + _, minOk := cli.Flags[flagName+"-min"] + _, maxOk := cli.Flags[flagName+"-max"] + h.Assert(t, len(cli.Flags) == 3, "Should contain 3 flags") + h.Assert(t, ok, "Should contain %s flag", flagName) + h.Assert(t, minOk, "Should contain %s flag", flagName) + h.Assert(t, maxOk, "Should contain %s flag", flagName) + + cli = getTestCLI() + cli.Float64MinMaxRangeFlags(flagName, nil, nil, "Test Min Max Range") + h.Assert(t, len(cli.Flags) == 3, "Should contain 3 flags w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) +} + func TestRegexFlag(t *testing.T) { cli := getTestCLI() for _, flagFn := range []func(string, *string, *string, string){cli.RegexFlag} { diff --git a/pkg/cli/types.go b/pkg/cli/types.go index e08c089..ed2cd91 100644 --- a/pkg/cli/types.go +++ b/pkg/cli/types.go @@ -52,12 +52,12 @@ type validator = func(val interface{}) error // CommandLineInterface is a type to group CLI funcs and state type CommandLineInterface struct { - Command *cobra.Command - Flags map[string]interface{} - nilDefaults map[string]bool - intRangeFlags map[string]bool - validators map[string]validator - suiteFlags *pflag.FlagSet + Command *cobra.Command + Flags map[string]interface{} + nilDefaults map[string]bool + rangeFlags map[string]bool + validators map[string]validator + suiteFlags *pflag.FlagSet } // Float64Me takes an interface and returns a pointer to a float64 value @@ -117,6 +117,23 @@ func (*CommandLineInterface) IntRangeMe(i interface{}) *selector.IntRangeFilter } } +// Float64RangeMe takes an interface and returns a pointer to a Float64RangeFilter value +// If the underlying interface kind is not Float64RangeFilter or *Float64RangeFilter then nil is returned +func (*CommandLineInterface) Float64RangeMe(i interface{}) *selector.Float64RangeFilter { + if i == nil { + return nil + } + switch v := i.(type) { + case *selector.Float64RangeFilter: + return v + case selector.Float64RangeFilter: + return &v + default: + log.Printf("%s cannot be converted to a Float64Range", i) + return nil + } +} + // StringMe takes an interface and returns a pointer to a string value // If the underlying interface kind is not string or *string then nil is returned func (*CommandLineInterface) StringMe(i interface{}) *string { diff --git a/pkg/cli/types_test.go b/pkg/cli/types_test.go index 6d6c21e..d6b52c5 100644 --- a/pkg/cli/types_test.go +++ b/pkg/cli/types_test.go @@ -107,6 +107,18 @@ func TestIntRangeMe(t *testing.T) { h.Assert(t, val == nil, "Should return nil if nil is passed in") } +func TestFloat64RangeMe(t *testing.T) { + cli := getTestCLI() + float64RangeVal := selector.Float64RangeFilter{LowerBound: 1.0, UpperBound: 2.0} + val := cli.Float64RangeMe(float64RangeVal) + h.Assert(t, *val == float64RangeVal, "Should return %s from passed in float64 range value", float64RangeVal) + val = cli.Float64RangeMe(&float64RangeVal) + h.Assert(t, *val == float64RangeVal, "Should return %s from passed in range pointer", float64RangeVal) + val = cli.Float64RangeMe(true) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.Float64RangeMe(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} func TestRegexMe(t *testing.T) { cli := getTestCLI() regexVal, err := regexp.Compile("c4.*") diff --git a/pkg/selector/aggregates.go b/pkg/selector/aggregates.go index ff0e3d1..4bca5d2 100644 --- a/pkg/selector/aggregates.go +++ b/pkg/selector/aggregates.go @@ -62,9 +62,9 @@ func (itf Selector) TransformBaseInstanceType(filters Filters) (Filters, error) filters.GpusRange = &IntRangeFilter{LowerBound: gpuCount, UpperBound: gpuCount} } if filters.MemoryRange == nil { - lowerBound := int(float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateLowPercentile) - upperBound := int(float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateHighPercentile) - filters.MemoryRange = &IntRangeFilter{LowerBound: lowerBound, UpperBound: upperBound} + lowerBound := (float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateLowPercentile) / 1024 + upperBound := (float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateHighPercentile) / 1024 + filters.MemoryRange = &Float64RangeFilter{LowerBound: lowerBound, UpperBound: upperBound} } if filters.VCpusRange == nil { lowerBound := int(float64(*instanceTypeInfo.VCpuInfo.DefaultVCpus) * AggregateLowPercentile) diff --git a/pkg/selector/outputs/outputs.go b/pkg/selector/outputs/outputs.go index 8a2a4b8..30ecf95 100644 --- a/pkg/selector/outputs/outputs.go +++ b/pkg/selector/outputs/outputs.go @@ -164,7 +164,7 @@ func TableOutputShort(instanceTypeInfoSlice []*ec2.InstanceTypeInfo) []string { headers := []interface{}{ "Instance Type", "VCPUs", - "Mem (MiB)", + "Mem (GiB)", } separators := []interface{}{} @@ -177,10 +177,10 @@ func TableOutputShort(instanceTypeInfoSlice []*ec2.InstanceTypeInfo) []string { fmt.Fprintf(w, "\n"+headerFormat, separators...) for _, instanceTypeInfo := range instanceTypeInfoSlice { - fmt.Fprintf(w, "\n%s\t%d\t%d\t", + fmt.Fprintf(w, "\n%s\t%d\t%.3f\t", *instanceTypeInfo.InstanceType, *instanceTypeInfo.VCpuInfo.DefaultVCpus, - *instanceTypeInfo.MemoryInfo.SizeInMiB, + float64(*instanceTypeInfo.MemoryInfo.SizeInMiB)/1024.0, ) } w.Flush() @@ -242,10 +242,10 @@ func TableOutputWide(instanceTypeInfoSlice []*ec2.InstanceTypeInfo) []string { } } - fmt.Fprintf(w, "\n%s\t%d\t%d\t%s\t%t\t%t\t%s\t%s\t%d\t%d\t%d\t%s\t", + fmt.Fprintf(w, "\n%s\t%d\t%.3f\t%s\t%t\t%t\t%s\t%s\t%d\t%d\t%.2f\t%s\t", *instanceTypeInfo.InstanceType, *instanceTypeInfo.VCpuInfo.DefaultVCpus, - *instanceTypeInfo.MemoryInfo.SizeInMiB, + float64(*instanceTypeInfo.MemoryInfo.SizeInMiB)/1024.0, *hypervisor, *instanceTypeInfo.CurrentGeneration, *instanceTypeInfo.HibernationSupported, @@ -253,7 +253,7 @@ func TableOutputWide(instanceTypeInfoSlice []*ec2.InstanceTypeInfo) []string { *instanceTypeInfo.NetworkInfo.NetworkPerformance, *instanceTypeInfo.NetworkInfo.MaximumNetworkInterfaces, gpus, - gpuMemory, + float64(gpuMemory)/1024.0, strings.Join(gpuType, ", "), ) } diff --git a/pkg/selector/outputs/outputs_test.go b/pkg/selector/outputs/outputs_test.go index e04d0d9..683aa05 100644 --- a/pkg/selector/outputs/outputs_test.go +++ b/pkg/selector/outputs/outputs_test.go @@ -117,3 +117,15 @@ func TestTableOutputWide(t *testing.T) { h.Assert(t, strings.Contains(outputStr, "Moderate"), "wide table should include network performance") h.Assert(t, strings.Contains(outputStr, "NVIDIA K520"), "wide table should include GPU Info") } + +func TestTableOutput_MBtoGB(t *testing.T) { + instanceTypes := getInstanceTypes(t, "g2_2xlarge.json") + instanceTypeOut := outputs.TableOutputWide(instanceTypes) + outputStr := strings.Join(instanceTypeOut, "") + h.Assert(t, strings.Contains(outputStr, "15.000"), "table should include 15.000 GB of memory") + h.Assert(t, strings.Contains(outputStr, "4.00"), "wide table should include 4.00 GB of gpu memory") + + instanceTypeOut = outputs.TableOutputShort(instanceTypes) + outputStr = strings.Join(instanceTypeOut, "") + h.Assert(t, strings.Contains(outputStr, "15.000"), "table should include 15.000 GB of memory") +} diff --git a/pkg/selector/selector.go b/pkg/selector/selector.go index eabfa6c..3b16731 100644 --- a/pkg/selector/selector.go +++ b/pkg/selector/selector.go @@ -16,6 +16,7 @@ package selector import ( "fmt" + "math" "reflect" "regexp" "sort" @@ -178,8 +179,8 @@ func (itf Selector) rawFilter(filters Filters) ([]*ec2.InstanceTypeInfo, error) rootDeviceType: {filters.RootDeviceType, instanceTypeInfo.SupportedRootDeviceTypes}, hibernationSupported: {filters.HibernationSupported, instanceTypeInfo.HibernationSupported}, vcpusRange: {filters.VCpusRange, instanceTypeInfo.VCpuInfo.DefaultVCpus}, - memoryRange: {filters.MemoryRange, instanceTypeInfo.MemoryInfo.SizeInMiB}, - gpuMemoryRange: {filters.GpuMemoryRange, getTotalGpuMemory(instanceTypeInfo.GpuInfo)}, + memoryRange: {gibToMibRange(filters.MemoryRange), instanceTypeInfo.MemoryInfo.SizeInMiB}, + gpuMemoryRange: {gibToMibRange(filters.GpuMemoryRange), getTotalGpuMemory(instanceTypeInfo.GpuInfo)}, gpusRange: {filters.GpusRange, getTotalGpusCount(instanceTypeInfo.GpuInfo)}, placementGroupStrategy: {filters.PlacementGroupStrategy, instanceTypeInfo.PlacementGroupInfo.SupportedStrategies}, hypervisor: {filters.Hypervisor, instanceTypeInfo.Hypervisor}, @@ -391,3 +392,18 @@ func isInAllowList(allowRegex *regexp.Regexp, instanceTypeName string) bool { } return allowRegex.MatchString(instanceTypeName) } + +func gibToMibRange(gbRange *Float64RangeFilter) *IntRangeFilter { + if gbRange == nil { + return nil + } + mbRangeFilter := IntRangeFilter{ + LowerBound: int(gbRange.LowerBound * 1024), + } + if gbRange.UpperBound == math.MaxFloat64 { + mbRangeFilter.UpperBound = math.MaxInt32 + return &mbRangeFilter + } + mbRangeFilter.UpperBound = int(gbRange.UpperBound * 1024) + return &mbRangeFilter +} diff --git a/pkg/selector/selector_test.go b/pkg/selector/selector_test.go index 6344b41..e768db5 100644 --- a/pkg/selector/selector_test.go +++ b/pkg/selector/selector_test.go @@ -243,7 +243,7 @@ func TestFilterVerbose_Gpus(t *testing.T) { } filters := selector.Filters{ GpusRange: &selector.IntRangeFilter{LowerBound: 8, UpperBound: 8}, - GpuMemoryRange: &selector.IntRangeFilter{LowerBound: 131072, UpperBound: 131072}, + GpuMemoryRange: &selector.Float64RangeFilter{LowerBound: 128.0, UpperBound: 128.0}, } results, err := itf.FilterVerbose(filters) h.Ok(t, err) diff --git a/pkg/selector/types.go b/pkg/selector/types.go index ecd0bd1..2de1dfb 100644 --- a/pkg/selector/types.go +++ b/pkg/selector/types.go @@ -47,6 +47,13 @@ type IntRangeFilter struct { LowerBound int } +// Float64RangeFilter holds an upper and lower bound float64 +// The lower and upper bound are used to range filter resource specs +type Float64RangeFilter struct { + UpperBound float64 + LowerBound float64 +} + // filterPair holds a tuple of the passed in filter value and the instance resource spec value type filterPair struct { filterValue interface{} @@ -111,8 +118,8 @@ type Filters struct { // GpusRange filter is a range of acceptable GPU count available to an EC2 instance type GpusRange *IntRangeFilter - // GpuMemoryRange filter is a range of acceptable GPU memory available to an EC2 instance type in aggreagte across all GPUs. - GpuMemoryRange *IntRangeFilter + // GpuMemoryRange filter is a range of acceptable GPU memory in Gibibytes (GiB) available to an EC2 instance type in aggreagte across all GPUs. + GpuMemoryRange *Float64RangeFilter // HibernationSupported denotes whether EC2 hibernate is supported // Possible values are: true or false @@ -125,8 +132,8 @@ type Filters struct { // MaxResults is the maximum number of instance types to return that match the filter criteria MaxResults *int - // MemoryRange filter is a range of acceptable DRAM memory in Mebibytes (MiB) for the instance type - MemoryRange *IntRangeFilter + // MemoryRange filter is a range of acceptable DRAM memory in Gibibytes (GiB) for the instance type + MemoryRange *Float64RangeFilter // NetworkInterfaces filter is a range of the number of ENI attachments an instance type can support NetworkInterfaces *IntRangeFilter From 0e05ac204864032b0570f71019352eb1e2c3b4a4 Mon Sep 17 00:00:00 2001 From: Brandon Wagner Date: Thu, 23 Jul 2020 17:24:30 -0500 Subject: [PATCH 2/4] implement byte quantity pkg --- pkg/bytequantity/bytequantity.go | 152 ++++++++++++++++++++++++++ pkg/bytequantity/bytequantity_test.go | 118 ++++++++++++++++++++ pkg/selector/selector_test.go | 2 - 3 files changed, 270 insertions(+), 2 deletions(-) create mode 100644 pkg/bytequantity/bytequantity.go create mode 100644 pkg/bytequantity/bytequantity_test.go diff --git a/pkg/bytequantity/bytequantity.go b/pkg/bytequantity/bytequantity.go new file mode 100644 index 0000000..6579dde --- /dev/null +++ b/pkg/bytequantity/bytequantity.go @@ -0,0 +1,152 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package bytequantity + +import ( + "fmt" + "math" + "regexp" + "strconv" + "strings" +) + +const ( + /// Examples: 1mb, 1 gb, 1.0tb, 1mib, 2g, 2.001 t + byteQuantityRegex = `^([0-9]+\.?[0-9]{0,3})[ ]?(mi?b?|gi?b?|ti?b?)?$` + mib = "MiB" + gib = "GiB" + tib = "TiB" + gbConvert = 1 << 10 + tbConvert = gbConvert << 10 + maxGiB = math.MaxUint64 / gbConvert + maxTiB = math.MaxUint64 / tbConvert +) + +// ByteQuantity is a data type representing a byte quantity +type ByteQuantity struct { + Quantity uint64 +} + +// ParseToByteQuantity parses a string representation of a byte quantity to a ByteQuantity type +func ParseToByteQuantity(byteQuantityStr string) (ByteQuantity, error) { + bqRegexp := regexp.MustCompile(byteQuantityRegex) + matches := bqRegexp.FindStringSubmatch(strings.ToLower(byteQuantityStr)) + if len(matches) < 2 { + return ByteQuantity{}, fmt.Errorf("%s is not a valid byte quantity", byteQuantityStr) + } + + quantityStr := matches[1] + unit := gib + if len(matches) > 2 && matches[2] != "" { + unit = matches[2] + } + quantity := uint64(0) + switch strings.ToLower(string(unit[0])) { + //mib + case "m": + inputDecSplit := strings.Split(quantityStr, ".") + if len(inputDecSplit) == 2 { + d, err := strconv.Atoi(inputDecSplit[1]) + if err != nil { + return ByteQuantity{}, err + } + if d != 0 { + return ByteQuantity{}, fmt.Errorf("cannot accept floating point MB value, only integers are accepted") + } + } + // need error here so that this quantity doesn't bind in the local scope + var err error + quantity, err = strconv.ParseUint(inputDecSplit[0], 10, 64) + if err != nil { + return ByteQuantity{}, err + } + //gib + case "g": + quantityDec, err := strconv.ParseFloat(quantityStr, 10) + if err != nil { + return ByteQuantity{}, err + } + if quantityDec > maxGiB { + return ByteQuantity{}, fmt.Errorf("error GiB value is too large") + } + quantity = uint64(quantityDec * gbConvert) + //tib + case "t": + quantityDec, err := strconv.ParseFloat(quantityStr, 10) + if err != nil { + return ByteQuantity{}, err + } + if quantityDec > maxTiB { + return ByteQuantity{}, fmt.Errorf("error TiB value is too large") + } + quantity = uint64(quantityDec * tbConvert) + default: + return ByteQuantity{}, fmt.Errorf("error unit %s is not supported", unit) + } + + return ByteQuantity{ + Quantity: quantity, + }, nil +} + +// FromTiB returns a byte quantity of the passed in tebibytes quantity +func FromTiB(tib uint64) ByteQuantity { + return ByteQuantity{ + Quantity: tib * tbConvert, + } +} + +// FromGiB returns a byte quantity of the passed in gibibytes quantity +func FromGiB(gib uint64) ByteQuantity { + return ByteQuantity{ + Quantity: gib * gbConvert, + } +} + +// FromMiB returns a byte quantity of the passed in mebibytes quantity +func FromMiB(mib uint64) ByteQuantity { + return ByteQuantity{ + Quantity: mib, + } +} + +// StringMiB returns a byte quantity in a mebibytes string representation +func (bq ByteQuantity) StringMiB() string { + return fmt.Sprintf("%.0f %s", bq.MiB(), mib) +} + +// StringGiB returns a byte quantity in a gibibytes string representation +func (bq ByteQuantity) StringGiB() string { + return fmt.Sprintf("%.3f %s", bq.GiB(), gib) +} + +// StringTiB returns a byte quantity in a tebibytes string representation +func (bq ByteQuantity) StringTiB() string { + return fmt.Sprintf("%.3f %s", bq.TiB(), tib) +} + +// MiB returns a byte quantity in mebibytes +func (bq ByteQuantity) MiB() float64 { + return float64(bq.Quantity) +} + +// GiB returns a byte quantity in gibibytes +func (bq ByteQuantity) GiB() float64 { + return float64(bq.Quantity) * 1 / gbConvert +} + +// TiB returns a byte quantity in tebibytes +func (bq ByteQuantity) TiB() float64 { + return float64(bq.Quantity) * 1 / tbConvert +} diff --git a/pkg/bytequantity/bytequantity_test.go b/pkg/bytequantity/bytequantity_test.go new file mode 100644 index 0000000..43aedf8 --- /dev/null +++ b/pkg/bytequantity/bytequantity_test.go @@ -0,0 +1,118 @@ +package bytequantity_test + +import ( + "fmt" + "testing" + + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" + h "github.com/aws/amazon-ec2-instance-selector/pkg/test" +) + +func TestParseToByteQuantity(t *testing.T) { + + for _, testQuantity := range []string{"10mb", "10 mb", "10.0 mb", "10.0mb", "10m", "10mib", "10 M", "10.000 MiB"} { + expectationVal := uint64(10) + bq, err := bytequantity.ParseToByteQuantity(testQuantity) + h.Ok(t, err) + h.Assert(t, bq.Quantity == expectationVal, "quantity should have been %d, got %d instead on string %s", expectationVal, bq.Quantity, testQuantity) + } + + for _, testQuantity := range []string{"4", "4.0", "4gb", "4 gb", "4.0 gb", "4.0gb", "4g", "4gib", "4 G", "4.000 GiB"} { + expectationVal := uint64(4096) + bq, err := bytequantity.ParseToByteQuantity(testQuantity) + h.Ok(t, err) + h.Assert(t, bq.Quantity == expectationVal, "quantity should have been %d, got %d instead on string %s", expectationVal, bq.Quantity, testQuantity) + } + + for _, testQuantity := range []string{"109tb", "109 tb", "109.0 tb", "109.0tb", "109t", "109tib", "109 T", "109.000 TiB"} { + expectationVal := uint64(114294784) + bq, err := bytequantity.ParseToByteQuantity(testQuantity) + h.Ok(t, err) + h.Assert(t, bq.Quantity == expectationVal, "quantity should have been %d, got %d instead on string %s", expectationVal, bq.Quantity, testQuantity) + } + + expectationVal := uint64(1025) + testQuantity := "1.001 gb" + bq, err := bytequantity.ParseToByteQuantity(testQuantity) + h.Ok(t, err) + h.Assert(t, bq.Quantity == expectationVal, "quantity should have been %d, got %d instead on string %s", expectationVal, bq.Quantity, testQuantity) + + // Only supports 3 decimal places + bq, err = bytequantity.ParseToByteQuantity("109.0001") + h.Nok(t, err) + + // Only support decimals on GiB and TiB + bq, err = bytequantity.ParseToByteQuantity("109.001 mib") + h.Nok(t, err) + + // Overflow a uint64 + overflow := "18446744073709551616" + bq, err = bytequantity.ParseToByteQuantity(fmt.Sprintf("%s mib", overflow)) + h.Nok(t, err) + + bq, err = bytequantity.ParseToByteQuantity(fmt.Sprintf("%s gib", overflow)) + h.Nok(t, err) + + bq, err = bytequantity.ParseToByteQuantity(fmt.Sprintf("%s tib", overflow)) + h.Nok(t, err) + + // Unit not supported + bq, err = bytequantity.ParseToByteQuantity("1 NS") + h.Nok(t, err) +} + +func TestStringGiB(t *testing.T) { + expectedVal := "0.098 GiB" + testVal := uint64(100) + bq := bytequantity.ByteQuantity{Quantity: testVal} + h.Assert(t, bq.StringGiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringGiB()) + + expectedVal = "1.000 GiB" + testVal = uint64(1024) + bq = bytequantity.ByteQuantity{Quantity: 1024} + h.Assert(t, bq.StringGiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringGiB()) +} + +func TestStringTiB(t *testing.T) { + expectedVal := "1.000 TiB" + testVal := uint64(1048576) + bq := bytequantity.ByteQuantity{Quantity: testVal} + h.Assert(t, bq.StringTiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringTiB()) + + expectedVal = "0.005 TiB" + testVal = uint64(5240) + bq = bytequantity.ByteQuantity{Quantity: testVal} + h.Assert(t, bq.StringTiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringTiB()) +} + +func TestStringMiB(t *testing.T) { + expectedVal := "1 MiB" + testVal := uint64(1) + bq := bytequantity.ByteQuantity{Quantity: testVal} + h.Assert(t, bq.StringMiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringMiB()) + + expectedVal = "2 MiB" + testVal = uint64(2) + bq = bytequantity.ByteQuantity{Quantity: testVal} + h.Assert(t, bq.StringMiB() == expectedVal, "%d MiB should equal %s, instead got %s", testVal, expectedVal, bq.StringMiB()) +} + +func TestFromMiB(t *testing.T) { + expectedVal := uint64(1) + bq := bytequantity.FromMiB(expectedVal) + h.Assert(t, bq.MiB() == float64(expectedVal), "%d MiB should equal %d, instead got %s", expectedVal, expectedVal, bq.StringMiB()) +} + +func TestFromGiB(t *testing.T) { + expectedVal := float64(1.0) + testVal := uint64(1) + bq := bytequantity.FromGiB(testVal) + h.Assert(t, bq.GiB() == expectedVal, "%d GiB should equal %d, instead got %s", expectedVal, expectedVal, bq.StringGiB()) +} + +func TestFromTiB(t *testing.T) { + expectedVal := float64(1.0) + testVal := uint64(1) + bq := bytequantity.FromTiB(testVal) + h.Assert(t, bq.TiB() == expectedVal, "%d TiB should equal %d, instead got %s", expectedVal, expectedVal, bq.StringTiB()) +} diff --git a/pkg/selector/selector_test.go b/pkg/selector/selector_test.go index e768db5..5c3a493 100644 --- a/pkg/selector/selector_test.go +++ b/pkg/selector/selector_test.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "io/ioutil" - "log" "regexp" "strconv" "testing" @@ -279,7 +278,6 @@ func TestFilter_MoreFilters(t *testing.T) { } results, err := itf.Filter(filters) h.Ok(t, err) - log.Println(results) h.Assert(t, len(results) == 1, "Should only return 1 instance type with 2 vcpus") h.Assert(t, results[0] == "t3.micro", "Should return t3.micro, got %s instead", results[0]) } From fb40900c6f1e975a837f180cc127cd9a142ab0ed Mon Sep 17 00:00:00 2001 From: Brandon Wagner Date: Thu, 23 Jul 2020 17:25:17 -0500 Subject: [PATCH 3/4] update cli and selector pkg to byte quantity for memory filters --- cmd/examples/example1.go | 9 +- cmd/main.go | 8 +- pkg/cli/cli.go | 62 +++++--- pkg/cli/cli_test.go | 163 +++++++++++++++++++--- pkg/cli/flags.go | 137 +++++++++++------- pkg/cli/flags_test.go | 25 +++- pkg/cli/types.go | 34 ++++- pkg/cli/types_test.go | 18 +-- pkg/selector/aggregates.go | 7 +- pkg/selector/comparators.go | 13 +- pkg/selector/comparators_internal_test.go | 52 ++++--- pkg/selector/selector.go | 43 +++--- pkg/selector/selector_test.go | 11 +- pkg/selector/types.go | 20 ++- 14 files changed, 428 insertions(+), 174 deletions(-) diff --git a/cmd/examples/example1.go b/cmd/examples/example1.go index 5f2c05a..88106f0 100644 --- a/cmd/examples/example1.go +++ b/cmd/examples/example1.go @@ -3,6 +3,7 @@ package main import ( "fmt" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/selector" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -27,10 +28,10 @@ func main() { LowerBound: 2, UpperBound: 4, } - // Instantiate a float64 range filter to specify min and max memory in GiB - memoryRange := selector.Float64RangeFilter{ - LowerBound: 1.0, - UpperBound: 4.0, + // Instantiate a byte quantity range filter to specify min and max memory in GiB + memoryRange := selector.ByteQuantityRangeFilter{ + LowerBound: bytequantity.FromGiB(2), + UpperBound: bytequantity.FromGiB(4), } // Create a string for the CPU Architecture so that it can be passed as a pointer // when creating the Filter struct diff --git a/cmd/main.go b/cmd/main.go index 1c921f3..d93310c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -116,11 +116,11 @@ Full docs can be found at github.com/aws/amazon-` + binName // Filter Flags - These will be grouped at the top of the help flags cli.IntMinMaxRangeFlags(vcpus, cli.StringMe("c"), nil, "Number of vcpus available to the instance type.") - cli.Float64MinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available in GiB (Example: 4)") + cli.ByteQuantityMinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available in GiB (Example: 4)") cli.RatioFlag(vcpusToMemoryRatio, nil, nil, "The ratio of vcpus to memory in MiB. (Example: 1:2)") cli.StringFlag(cpuArchitecture, cli.StringMe("a"), nil, "CPU architecture [x86_64, i386, or arm64]", nil) cli.IntMinMaxRangeFlags(gpus, cli.StringMe("g"), nil, "Total Number of GPUs (Example: 4)") - cli.Float64MinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory in GiB (Example: 4)") + cli.ByteQuantityMinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory in GiB (Example: 4)") cli.StringFlag(placementGroupStrategy, nil, nil, "Placement group strategy: [cluster, partition, spread]", nil) cli.StringFlag(usageClass, cli.StringMe("u"), nil, "Usage class: [spot or on-demand]", nil) cli.StringFlag(rootDeviceType, nil, nil, "Supported root device types: [ebs or instance-store]", nil) @@ -189,11 +189,11 @@ Full docs can be found at github.com/aws/amazon-` + binName filters := selector.Filters{ VCpusRange: cli.IntRangeMe(flags[vcpus]), - MemoryRange: cli.Float64RangeMe(flags[memory]), + MemoryRange: cli.ByteQuantityRangeMe(flags[memory]), VCpusToMemoryRatio: cli.Float64Me(flags[vcpusToMemoryRatio]), CPUArchitecture: cli.StringMe(flags[cpuArchitecture]), GpusRange: cli.IntRangeMe(flags[gpus]), - GpuMemoryRange: cli.Float64RangeMe(flags[gpuMemoryTotal]), + GpuMemoryRange: cli.ByteQuantityRangeMe(flags[gpuMemoryTotal]), PlacementGroupStrategy: cli.StringMe(flags[placementGroupStrategy]), UsageClass: cli.StringMe(flags[usageClass]), RootDeviceType: cli.StringMe(flags[rootDeviceType]), diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 4cf8731..d078545 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -20,6 +20,7 @@ import ( "reflect" "strings" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/selector" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -42,6 +43,7 @@ func New(binaryName string, shortUsage string, longUsage, examples string, run r nilDefaults: map[string]bool{}, rangeFlags: map[string]bool{}, validators: map[string]validator{}, + processors: map[string]processor{}, suiteFlags: pflag.NewFlagSet("suite", pflag.ExitOnError), } } @@ -52,24 +54,23 @@ func (cl *CommandLineInterface) ParseFlags() (map[string]interface{}, error) { // Remove Suite Flags so that args only include Config and Filter Flags cl.Command.SetArgs(removeIntersectingArgs(cl.suiteFlags)) // This parses Config and Filter flags only - err := cl.Command.Execute() - if err != nil { + if err := cl.Command.Execute(); err != nil { return nil, err } + // Remove Config and Filter flags so that only suite flags are parsed - err = cl.suiteFlags.Parse(removeIntersectingArgs(cl.Command.Flags())) - if err != nil { + if err := cl.suiteFlags.Parse(removeIntersectingArgs(cl.Command.Flags())); err != nil { return nil, err } + // Add suite flags to Command flagset so that other processing can occur // This has to be done after usage is printed so that the flagsets can be grouped properly when printed cl.Command.Flags().AddFlagSet(cl.suiteFlags) - err = cl.SetUntouchedFlagValuesToNil() - if err != nil { + if err := cl.SetUntouchedFlagValuesToNil(); err != nil { return nil, err } - err = cl.ProcessRangeFilterFlags() - if err != nil { + + if err := cl.ProcessFlags(); err != nil { return nil, err } return cl.Flags, nil @@ -82,13 +83,29 @@ func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, if err != nil { return nil, err } - err = cl.ValidateFlags() - if err != nil { + if err := cl.ValidateFlags(); err != nil { return nil, err } return flags, nil } +// ProcessFlags iterates through any registered processors and executes them +// Processors are executed before validators +func (cl *CommandLineInterface) ProcessFlags() error { + for flagName, processorFn := range cl.processors { + if processorFn == nil { + continue + } + if err := processorFn(cl.Flags[flagName]); err != nil { + return err + } + } + if err := cl.ProcessRangeFilterFlags(); err != nil { + return err + } + return nil +} + // ValidateFlags iterates through any registered validators and executes them func (cl *CommandLineInterface) ValidateFlags() error { for flagName, validationFn := range cl.validators { @@ -164,8 +181,8 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error { if reflect.ValueOf(*v).IsZero() { cl.Flags[f.Name] = nil } - case *float64: - if reflect.ValueOf(*v).IsZero() { + case *bytequantity.ByteQuantity: + if v.Quantity == 0 { cl.Flags[f.Name] = nil } case *string: @@ -212,15 +229,19 @@ func (cl *CommandLineInterface) ProcessRangeFilterFlags() error { switch cl.Flags[rangeHelperMax].(type) { case *int: cl.Flags[rangeHelperMin] = cl.IntMe(0) - case *float64: - cl.Flags[rangeHelperMin] = cl.Float64Me(0) + case *bytequantity.ByteQuantity: + cl.Flags[rangeHelperMin] = cl.ByteQuantityMe(bytequantity.ByteQuantity{Quantity: 0}) + default: + return fmt.Errorf("Unable to set %s", rangeHelperMax) } } else if cl.Flags[rangeHelperMax] == nil { switch cl.Flags[rangeHelperMin].(type) { case *int: cl.Flags[rangeHelperMax] = cl.IntMe(maxInt) - case *float64: - cl.Flags[rangeHelperMax] = cl.Float64Me(maxFloat64) + case *bytequantity.ByteQuantity: + cl.Flags[rangeHelperMax] = cl.ByteQuantityMe(bytequantity.ByteQuantity{Quantity: maxUint64}) + default: + return fmt.Errorf("Unable to set %s", rangeHelperMin) } } @@ -230,13 +251,12 @@ func (cl *CommandLineInterface) ProcessRangeFilterFlags() error { LowerBound: *cl.IntMe(cl.Flags[rangeHelperMin]), UpperBound: *cl.IntMe(cl.Flags[rangeHelperMax]), } - case *float64: - cl.Flags[flagName] = &selector.Float64RangeFilter{ - LowerBound: *cl.Float64Me(cl.Flags[rangeHelperMin]), - UpperBound: *cl.Float64Me(cl.Flags[rangeHelperMax]), + case *bytequantity.ByteQuantity: + cl.Flags[flagName] = &selector.ByteQuantityRangeFilter{ + LowerBound: *cl.ByteQuantityMe(cl.Flags[rangeHelperMin]), + UpperBound: *cl.ByteQuantityMe(cl.Flags[rangeHelperMax]), } } - } return nil } diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 9ec0bde..f97e175 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -15,10 +15,14 @@ package cli_test import ( "fmt" + "math" "os" + "reflect" "testing" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/cli" + "github.com/aws/amazon-ec2-instance-selector/pkg/selector" h "github.com/aws/amazon-ec2-instance-selector/pkg/test" "github.com/spf13/cobra" ) @@ -168,6 +172,98 @@ func TestParseFlags_IntRangeErr(t *testing.T) { h.Nok(t, err) } +func TestParseFlags_ByteQuantityRange(t *testing.T) { + flagName := "test-flag" + flagMinArg := fmt.Sprintf("%s-%s", flagName, "min") + flagMaxArg := fmt.Sprintf("%s-%s", flagName, "max") + flagArg := fmt.Sprintf("--%s", flagName) + + // Root set Min and Max to the same val + cli := getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", flagArg, "5"} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagMinOutput := flags[flagMinArg].(*bytequantity.ByteQuantity) + flagMaxOutput := flags[flagMaxArg].(*bytequantity.ByteQuantity) + h.Assert(t, flagMinOutput.GiB() == 5.0 && flagMaxOutput.GiB() == 5.0, "Flag %s min and max should have been parsed to the same number", flagArg) + + // Min is set to a val and max is set to maxInt + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "5"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*bytequantity.ByteQuantity) + flagMaxOutput = flags[flagMaxArg].(*bytequantity.ByteQuantity) + h.Assert(t, flagMinOutput.GiB() == 5.0 && flagMaxOutput.Quantity == math.MaxUint64, "Flag %s min should have been parsed from cmdline and max set to maxInt", flagArg) + + // Max is set to a val and min is set to 0 + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMaxArg, "50"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*bytequantity.ByteQuantity) + flagMaxOutput = flags[flagMaxArg].(*bytequantity.ByteQuantity) + h.Assert(t, flagMinOutput.Quantity == 0 && flagMaxOutput.GiB() == 50.0, "Flag %s max should have been parsed from cmdline and min set to 0", flagArg) + + // Min and Max are set to separate values + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "10", "--" + flagMaxArg, "500"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*bytequantity.ByteQuantity) + flagMaxOutput = flags[flagMaxArg].(*bytequantity.ByteQuantity) + h.Assert(t, flagMinOutput.GiB() == 10.0 && flagMaxOutput.GiB() == 500.0, "Flag %s max and min should have been parsed from cmdline", flagArg) + flagType := reflect.TypeOf(flags[flagName]) + bqRangeFilterType := reflect.TypeOf(&selector.ByteQuantityRangeFilter{}) + h.Assert(t, flagType == bqRangeFilterType, "%s should be of type %v, instead got %v", flagArg, bqRangeFilterType, flagType) +} + +func TestParseAndValidateFlags_ByteQuantityRange(t *testing.T) { + flagName := "test-flag" + flagMinArg := fmt.Sprintf("%s-%s", flagName, "min") + flagMaxArg := fmt.Sprintf("%s-%s", flagName, "max") + flagArg := fmt.Sprintf("--%s", flagName) + + // Root set Min and Max to the same val + cli := getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", flagArg, "5"} + _, err := cli.ParseAndValidateFlags() + h.Ok(t, err) + + // Min is set to a val and max is set to maxInt + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "5"} + _, err = cli.ParseAndValidateFlags() + h.Ok(t, err) + + // Max is set to a val and min is set to 0 + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMaxArg, "50"} + _, err = cli.ParseAndValidateFlags() + h.Ok(t, err) + + // Min and Max are set to separate values + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "10", "--" + flagMaxArg, "500"} + _, err = cli.ParseFlags() + h.Ok(t, err) + + // no args + cli = getTestCLI() + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector"} + _, err = cli.ParseAndValidateFlags() + h.Ok(t, err) +} + func TestParseFlags_RootErr(t *testing.T) { cli := getTestCLI() os.Args = []string{"ec2-instance-selector", "--test", "test"} @@ -244,30 +340,26 @@ func TestParseFlags_UntouchedFlags(t *testing.T) { func TestParseFlags_UntouchedFlagsAllTypes(t *testing.T) { cli := getTestCLI() - flagName := "test-flag" - ratioName := flagName + "-ratio" - configName := flagName + "-config" - suiteName := flagName + "-suite" - flagArg := fmt.Sprintf("--%s", flagName) - ratioArg := fmt.Sprintf("--%s", ratioName) - configArg := fmt.Sprintf("--%s", configName) - suiteArg := fmt.Sprintf("--%s", suiteName) + intName := "int" + ratioName := "ratio" + byteQName := "bq" + configName := "config" + suiteName := "suite" - cli.IntFlag(flagName, nil, nil, "Test Filter Flag") + cli.IntFlag(intName, nil, nil, "Test Filter Flag") cli.RatioFlag(ratioName, nil, nil, "Test Ratio Flag") + cli.ByteQuantityFlag(byteQName, nil, nil, "Test Byte Quantity Flag") cli.ConfigStringFlag(configName, nil, nil, "Test Config Flag", nil) cli.SuiteBoolFlag(suiteName, nil, nil, "Test Suite Flag") os.Args = []string{"ec2-instance-selector"} flags, err := cli.ParseFlags() h.Ok(t, err) - val, ok := flags[flagName] - ratioVal, ratioOk := flags[ratioName] - configVal, configOk := flags[configName] - suiteVal, suiteOk := flags[suiteName] - h.Assert(t, ok && ratioOk && configOk && suiteOk, "Flags %s, %s, %s should exist for all types in flags map", flagArg, ratioArg, configArg, suiteArg) - h.Assert(t, val == nil && ratioVal == nil && configVal == nil && suiteVal == nil, - "Flag %s, %s, %s should be set to nil when not explicitly set", flagArg, ratioArg, configArg, suiteArg) + for _, name := range []string{intName, ratioName, byteQName, configName, suiteName} { + val, ok := flags[name] + h.Assert(t, ok, "Flag %s should exist in flags map", "--"+name) + h.Assert(t, val == nil, "Flag %s should be set to nil when not explicitly set", "--"+name) + } } func TestParseAndValidateFlags_Err(t *testing.T) { @@ -282,18 +374,33 @@ func TestParseAndValidateFlags_Err(t *testing.T) { h.Nok(t, err) } -func TestParseAndValidateFlags(t *testing.T) { +func TestParseAndValidateFlags_ByteQuantityErr(t *testing.T) { cli := getTestCLI() flagName := "test-flag" flagArg := fmt.Sprintf("--%s", flagName) flagMin := flagArg + "-min" flagMax := flagArg + "-max" - cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test with validation") + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test with validation") os.Args = []string{"ec2-instance-selector", flagMin, "5", flagMax, "1"} _, err := cli.ParseAndValidateFlags() h.Nok(t, err) } +func TestParseAndValidateFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + flagMin := flagArg + "-min" + flagMax := flagArg + "-max" + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test with validation") + os.Args = []string{"ec2-instance-selector", flagMin, "1", flagMax, "5"} + flags, err := cli.ParseAndValidateFlags() + h.Ok(t, err) + flagType := reflect.TypeOf(flags[flagName]) + intRangeFilterType := reflect.TypeOf(&selector.IntRangeFilter{}) + h.Assert(t, flagType == intRangeFilterType, "%s should be of type %v, instead got %v", flagArg, intRangeFilterType, flagType) +} + func TestParseAndValidateRegexFlag(t *testing.T) { flagName := "test-regex-flag" flagArg := fmt.Sprintf("--%s", flagName) @@ -313,3 +420,23 @@ func TestParseAndValidateRegexFlag(t *testing.T) { _, err = cli.ParseAndValidateFlags() h.Nok(t, err) } + +func TestParseAndValidateByteQuantityFlag(t *testing.T) { + flagName := "test-bq-flag" + flagArg := fmt.Sprintf("--%s", flagName) + + cli := getTestCLI() + cli.ByteQuantityFlag(flagName, nil, nil, "Test with validation") + os.Args = []string{"ec2-instance-selector", flagArg, "450"} + flags, err := cli.ParseAndValidateFlags() + h.Ok(t, err) + h.Assert(t, len(flags) == 1, "1 flag should have been processed") + _, err = cli.ParseAndValidateFlags() + h.Ok(t, err) + + cli = getTestCLI() + cli.ByteQuantityFlag(flagName, nil, nil, "Test with validation") + os.Args = []string{"ec2-instance-selector", flagArg, "(("} + _, err = cli.ParseAndValidateFlags() + h.Nok(t, err) +} diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index 1d43848..77da494 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -7,12 +7,13 @@ import ( "strconv" "strings" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/spf13/pflag" ) const ( - maxFloat64 = math.MaxFloat64 - maxInt = int(^uint(0) >> 1) + maxInt = int(^uint(0) >> 1) + maxUint64 = math.MaxUint64 ) // RatioFlag creates and registers a flag accepting a Ratio @@ -53,9 +54,14 @@ func (cl *CommandLineInterface) IntMinMaxRangeFlags(name string, shorthand *stri cl.IntMinMaxRangeFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) } -// Float64MinMaxRangeFlags creates and registers a min, max, and helper flag each accepting a Float64 -func (cl *CommandLineInterface) Float64MinMaxRangeFlags(name string, shorthand *string, defaultValue *float64, description string) { - cl.Float64MinMaxRangeFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) +// ByteQuantityMinMaxRangeFlags creates and registers a min, max, and helper flag each accepting a byte quantity like 512mb +func (cl *CommandLineInterface) ByteQuantityMinMaxRangeFlags(name string, shorthand *string, defaultValue *bytequantity.ByteQuantity, description string) { + cl.ByteQuantityMinMaxRangeFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) +} + +// ByteQuantityFlag creates and registers a flag accepting a byte quantity like 512mb +func (cl *CommandLineInterface) ByteQuantityFlag(name string, shorthand *string, defaultValue *bytequantity.ByteQuantity, description string) { + cl.ByteQuantityFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description) } // IntFlag creates and registers a flag accepting an Integer @@ -66,7 +72,7 @@ func (cl *CommandLineInterface) IntFlag(name string, shorthand *string, defaultV // StringFlag creates and registers a flag accepting a String and a validator function. // The validator function is provided so that more complex flags can be created from a string input. func (cl *CommandLineInterface) StringFlag(name string, shorthand *string, defaultValue *string, description string, validationFn validator) { - cl.StringFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description, validationFn) + cl.StringFlagOnFlagSet(cl.Command.Flags(), name, shorthand, defaultValue, description, nil, validationFn) } // StringSliceFlag creates and registers a flag accepting a list of strings. @@ -92,7 +98,7 @@ func (cl *CommandLineInterface) BoolFlag(name string, shorthand *string, default // ConfigStringFlag creates and registers a flag accepting a String for configuration purposes. // Config flags will be grouped at the bottom in the output of --help func (cl *CommandLineInterface) ConfigStringFlag(name string, shorthand *string, defaultValue *string, description string, validationFn validator) { - cl.StringFlagOnFlagSet(cl.Command.PersistentFlags(), name, shorthand, defaultValue, description, validationFn) + cl.StringFlagOnFlagSet(cl.Command.PersistentFlags(), name, shorthand, defaultValue, description, nil, validationFn) } // ConfigStringSliceFlag creates and registers a flag accepting a list of strings. @@ -128,7 +134,7 @@ func (cl *CommandLineInterface) SuiteBoolFlag(name string, shorthand *string, de // SuiteStringFlag creates and registers a flag accepting a string for aggreagate filters. // Suite flags will be grouped in the middle of the output --help func (cl *CommandLineInterface) SuiteStringFlag(name string, shorthand *string, defaultValue *string, description string, validationFn validator) { - cl.StringFlagOnFlagSet(cl.suiteFlags, name, shorthand, defaultValue, description, validationFn) + cl.StringFlagOnFlagSet(cl.suiteFlags, name, shorthand, defaultValue, description, nil, validationFn) } // SuiteStringOptionsFlag creates and registers a flag accepting a string and valid options for use in validation. @@ -177,20 +183,20 @@ func (cl *CommandLineInterface) IntMinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagS cl.rangeFlags[name] = true } -// Float64MinMaxRangeFlagOnFlagSet creates and registers a min, max, and helper flag each accepting a Float64 -func (cl *CommandLineInterface) Float64MinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *float64, description string) { - cl.Float64FlagOnFlagSet(flagSet, name, shorthand, defaultValue, fmt.Sprintf("%s (sets --%s-min and -max to the same value)", description, name)) - cl.Float64FlagOnFlagSet(flagSet, name+"-min", nil, nil, fmt.Sprintf("Minimum %s If --%s-max is not specified, the upper bound will be infinity", description, name)) - cl.Float64FlagOnFlagSet(flagSet, name+"-max", nil, nil, fmt.Sprintf("Maximum %s If --%s-min is not specified, the lower bound will be 0", description, name)) +// ByteQuantityMinMaxRangeFlagOnFlagSet creates and registers a min, max, and helper flag each accepting a ByteQuantity like 5mb or 12gb +func (cl *CommandLineInterface) ByteQuantityMinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *bytequantity.ByteQuantity, description string) { + cl.ByteQuantityFlagOnFlagSet(flagSet, name, shorthand, defaultValue, fmt.Sprintf("%s (sets --%s-min and -max to the same value)", description, name)) + cl.ByteQuantityFlagOnFlagSet(flagSet, name+"-min", nil, nil, fmt.Sprintf("Minimum %s If --%s-max is not specified, the upper bound will be infinity", description, name)) + cl.ByteQuantityFlagOnFlagSet(flagSet, name+"-max", nil, nil, fmt.Sprintf("Maximum %s If --%s-min is not specified, the lower bound will be 0", description, name)) cl.validators[name] = func(val interface{}) error { if cl.Flags[name+"-min"] == nil || cl.Flags[name+"-max"] == nil { return nil } minArg := name + "-min" maxArg := name + "-max" - minVal := cl.Flags[minArg].(*float64) - maxVal := cl.Flags[maxArg].(*float64) - if *minVal > *maxVal { + minVal := cl.Flags[name+"-min"].(*bytequantity.ByteQuantity).MiB() + maxVal := cl.Flags[name+"-max"].(*bytequantity.ByteQuantity).MiB() + if minVal > maxVal { return fmt.Errorf("Invalid input for --%s and --%s. %s must be less than or equal to %s", minArg, maxArg, minArg, maxArg) } return nil @@ -198,6 +204,47 @@ func (cl *CommandLineInterface) Float64MinMaxRangeFlagOnFlagSet(flagSet *pflag.F cl.rangeFlags[name] = true } +// ByteQuantityFlagOnFlagSet creates and registers a flag accepting a ByteQuantity +func (cl *CommandLineInterface) ByteQuantityFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *bytequantity.ByteQuantity, description string) { + invalidInputMsg := fmt.Sprintf("Invalid input for --%s. A valid example is 16gb. ", name) + byteQuantityProcessor := func(val interface{}) error { + if val == nil { + return nil + } + switch byteQuantityInput := val.(type) { + case *string: + bq, err := bytequantity.ParseToByteQuantity(*byteQuantityInput) + if err != nil { + return fmt.Errorf(invalidInputMsg+"Can't parse byte quantity %s.", *byteQuantityInput) + } + cl.Flags[name] = &bq + case *bytequantity.ByteQuantity: + return nil + default: + return fmt.Errorf(invalidInputMsg + "Input type is unsupported.") + } + return nil + } + byteQuantityValidator := func(val interface{}) error { + if val == nil { + return nil + } + switch val.(type) { + case *bytequantity.ByteQuantity: + return nil + default: + return fmt.Errorf(invalidInputMsg + "Processing failed.") + } + } + var stringDefaultValue *string + if defaultValue != nil { + stringDefaultValue = cl.StringMe(defaultValue.StringGiB()) + } else { + stringDefaultValue = nil + } + cl.StringFlagOnFlagSet(flagSet, name, shorthand, stringDefaultValue, description, byteQuantityProcessor, byteQuantityValidator) +} + // IntFlagOnFlagSet creates and registers a flag accepting an Integer func (cl *CommandLineInterface) IntFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *int, description string) { if defaultValue == nil { @@ -211,32 +258,19 @@ func (cl *CommandLineInterface) IntFlagOnFlagSet(flagSet *pflag.FlagSet, name st cl.Flags[name] = flagSet.Int(name, *defaultValue, description) } -// Float64FlagOnFlagSet creates and registers a flag accepting a Float64 -func (cl *CommandLineInterface) Float64FlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *float64, description string) { - if defaultValue == nil { - cl.nilDefaults[name] = true - defaultValue = cl.Float64Me(0.0) - } - if shorthand != nil { - cl.Flags[name] = flagSet.Float64P(name, string(*shorthand), *defaultValue, description) - return - } - cl.Flags[name] = flagSet.Float64(name, *defaultValue, description) -} - // StringFlagOnFlagSet creates and registers a flag accepting a String and a validator function. // The validator function is provided so that more complex flags can be created from a string input. -func (cl *CommandLineInterface) StringFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *string, description string, validationFn validator) { +func (cl *CommandLineInterface) StringFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *string, description string, processorFn processor, validationFn validator) { if defaultValue == nil { cl.nilDefaults[name] = true defaultValue = cl.StringMe("") } if shorthand != nil { cl.Flags[name] = flagSet.StringP(name, string(*shorthand), *defaultValue, description) - cl.validators[name] = validationFn } else { cl.Flags[name] = flagSet.String(name, *defaultValue, description) } + cl.processors[name] = processorFn cl.validators[name] = validationFn } @@ -254,7 +288,7 @@ func (cl *CommandLineInterface) StringOptionsFlagOnFlagSet(flagSet *pflag.FlagSe } return fmt.Errorf("error %s must be one of: %s", name, strings.Join(validOpts, ", ")) } - cl.StringFlagOnFlagSet(flagSet, name, shorthand, defaultValue, description, validationFn) + cl.StringFlagOnFlagSet(flagSet, name, shorthand, defaultValue, description, nil, validationFn) } // StringSliceFlagOnFlagSet creates and registers a flag accepting a String Slice. @@ -272,33 +306,36 @@ func (cl *CommandLineInterface) StringSliceFlagOnFlagSet(flagSet *pflag.FlagSet, // RegexFlagOnFlagSet creates and registers a flag accepting a string slice of regular expressions. func (cl *CommandLineInterface) RegexFlagOnFlagSet(flagSet *pflag.FlagSet, name string, shorthand *string, defaultValue *string, description string) { - if defaultValue == nil { - cl.nilDefaults[name] = true - defaultValue = cl.StringMe("") - } - if shorthand != nil { - cl.Flags[name] = flagSet.StringP(name, string(*shorthand), *defaultValue, description) - } else { - cl.Flags[name] = flagSet.String(name, *defaultValue, description) - } - cl.validators[name] = func(val interface{}) error { + invalidInputMsg := fmt.Sprintf("Invalid regex input for --%s. ", name) + regexProcessor := func(val interface{}) error { if val == nil { return nil } - regexStringVal := "" switch v := val.(type) { case *string: - regexStringVal = *v + regexVal, err := regexp.Compile(*v) + if err != nil { + return fmt.Errorf(invalidInputMsg + "Unable to compile the regex.") + } + cl.Flags[name] = regexVal case *regexp.Regexp: return nil default: - return fmt.Errorf("Invalid regex input for --%s", name) - } - regexVal, err := regexp.Compile(regexStringVal) - if err != nil { - return fmt.Errorf("Invalid regex input for --%s", name) + return fmt.Errorf(invalidInputMsg + "Input type is unsupported.") } - cl.Flags[name] = regexVal + return nil } + regexValidator := func(val interface{}) error { + if val == nil { + return nil + } + switch val.(type) { + case *regexp.Regexp: + return nil + default: + return fmt.Errorf(invalidInputMsg + "Processing failed.") + } + } + cl.StringFlagOnFlagSet(flagSet, name, shorthand, defaultValue, description, regexProcessor, regexValidator) } diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go index 1d50ce5..3b25195 100644 --- a/pkg/cli/flags_test.go +++ b/pkg/cli/flags_test.go @@ -17,6 +17,7 @@ import ( "fmt" "testing" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" h "github.com/aws/amazon-ec2-instance-selector/pkg/test" ) @@ -151,10 +152,10 @@ func TestIntMinMaxRangeFlags(t *testing.T) { h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) } -func TestFloat64MinMaxRangeFlags(t *testing.T) { +func TestByteQuantityMinMaxRangeFlags(t *testing.T) { cli := getTestCLI() - flagName := "test-float-min-max-range" - cli.Float64MinMaxRangeFlags(flagName, cli.StringMe("t"), nil, "Test Min Max Range") + flagName := "test-bq-min-max-range" + cli.ByteQuantityMinMaxRangeFlags(flagName, cli.StringMe("t"), nil, "Test Min Max Range") _, ok := cli.Flags[flagName] _, minOk := cli.Flags[flagName+"-min"] _, maxOk := cli.Flags[flagName+"-max"] @@ -164,11 +165,27 @@ func TestFloat64MinMaxRangeFlags(t *testing.T) { h.Assert(t, maxOk, "Should contain %s flag", flagName) cli = getTestCLI() - cli.Float64MinMaxRangeFlags(flagName, nil, nil, "Test Min Max Range") + cli.ByteQuantityMinMaxRangeFlags(flagName, nil, nil, "Test Min Max Range") h.Assert(t, len(cli.Flags) == 3, "Should contain 3 flags w/ no shorthand") h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) } +func TestByteQuantityFlag(t *testing.T) { + cli := getTestCLI() + for _, flagFn := range []func(string, *string, *bytequantity.ByteQuantity, string){cli.ByteQuantityFlag} { + flagName := "test-bq-flag" + flagFn(flagName, cli.StringMe("t"), nil, "Test Byte Quantity") + _, ok := cli.Flags[flagName] + h.Assert(t, ok, "Should contain %s flag", flagName) + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + + cli = getTestCLI() + flagFn(flagName, nil, nil, "Test Byte Quantity") + h.Assert(t, len(cli.Flags) == 1, "Should contain 3 flags w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) + } +} + func TestRegexFlag(t *testing.T) { cli := getTestCLI() for _, flagFn := range []func(string, *string, *string, string){cli.RegexFlag} { diff --git a/pkg/cli/types.go b/pkg/cli/types.go index ed2cd91..a6efd74 100644 --- a/pkg/cli/types.go +++ b/pkg/cli/types.go @@ -18,6 +18,7 @@ import ( "log" "regexp" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/selector" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -50,6 +51,9 @@ Global Flags: // validator defines the function for providing validation on a flag type validator = func(val interface{}) error +// processor defines the function for providing mutating processing on a flag +type processor = func(val interface{}) error + // CommandLineInterface is a type to group CLI funcs and state type CommandLineInterface struct { Command *cobra.Command @@ -57,6 +61,7 @@ type CommandLineInterface struct { nilDefaults map[string]bool rangeFlags map[string]bool validators map[string]validator + processors map[string]processor suiteFlags *pflag.FlagSet } @@ -117,19 +122,19 @@ func (*CommandLineInterface) IntRangeMe(i interface{}) *selector.IntRangeFilter } } -// Float64RangeMe takes an interface and returns a pointer to a Float64RangeFilter value -// If the underlying interface kind is not Float64RangeFilter or *Float64RangeFilter then nil is returned -func (*CommandLineInterface) Float64RangeMe(i interface{}) *selector.Float64RangeFilter { +// ByteQuantityRangeMe takes an interface and returns a pointer to a ByteQuantityRangeFilter value +// If the underlying interface kind is not ByteQuantityRangeFilter or *ByteQuantityRangeFilter then nil is returned +func (*CommandLineInterface) ByteQuantityRangeMe(i interface{}) *selector.ByteQuantityRangeFilter { if i == nil { return nil } switch v := i.(type) { - case *selector.Float64RangeFilter: + case *selector.ByteQuantityRangeFilter: return v - case selector.Float64RangeFilter: + case selector.ByteQuantityRangeFilter: return &v default: - log.Printf("%s cannot be converted to a Float64Range", i) + log.Printf("%s cannot be converted to a ByteQuantityRange", i) return nil } } @@ -201,3 +206,20 @@ func (*CommandLineInterface) RegexMe(i interface{}) *regexp.Regexp { return nil } } + +// ByteQuantityMe takes an interface and returns a pointer to a regex +// If the underlying interface kind is not bytequantity.ByteQuantity or *bytequantity.ByteQuantity then nil is returned +func (*CommandLineInterface) ByteQuantityMe(i interface{}) *bytequantity.ByteQuantity { + if i == nil { + return nil + } + switch v := i.(type) { + case *bytequantity.ByteQuantity: + return v + case bytequantity.ByteQuantity: + return &v + default: + log.Printf("%s cannot be converted to a byte quantity", i) + return nil + } +} diff --git a/pkg/cli/types_test.go b/pkg/cli/types_test.go index d6b52c5..a9bd2b3 100644 --- a/pkg/cli/types_test.go +++ b/pkg/cli/types_test.go @@ -18,6 +18,7 @@ import ( "regexp" "testing" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/selector" h "github.com/aws/amazon-ec2-instance-selector/pkg/test" ) @@ -107,16 +108,17 @@ func TestIntRangeMe(t *testing.T) { h.Assert(t, val == nil, "Should return nil if nil is passed in") } -func TestFloat64RangeMe(t *testing.T) { +func TestByteQuantityRangeMe(t *testing.T) { cli := getTestCLI() - float64RangeVal := selector.Float64RangeFilter{LowerBound: 1.0, UpperBound: 2.0} - val := cli.Float64RangeMe(float64RangeVal) - h.Assert(t, *val == float64RangeVal, "Should return %s from passed in float64 range value", float64RangeVal) - val = cli.Float64RangeMe(&float64RangeVal) - h.Assert(t, *val == float64RangeVal, "Should return %s from passed in range pointer", float64RangeVal) - val = cli.Float64RangeMe(true) + bq1 := bytequantity.ByteQuantity{Quantity: 1} + bqRangeVal := selector.ByteQuantityRangeFilter{LowerBound: bq1, UpperBound: bq1} + val := cli.ByteQuantityRangeMe(bqRangeVal) + h.Assert(t, *val == bqRangeVal, "Should return %s from passed in byte quantity range value", bqRangeVal) + val = cli.ByteQuantityRangeMe(&bqRangeVal) + h.Assert(t, *val == bqRangeVal, "Should return %s from passed in range pointer", bqRangeVal) + val = cli.ByteQuantityRangeMe(true) h.Assert(t, val == nil, "Should return nil from other data type passed in") - val = cli.Float64RangeMe(nil) + val = cli.ByteQuantityRangeMe(nil) h.Assert(t, val == nil, "Should return nil if nil is passed in") } func TestRegexMe(t *testing.T) { diff --git a/pkg/selector/aggregates.go b/pkg/selector/aggregates.go index 4bca5d2..35ee501 100644 --- a/pkg/selector/aggregates.go +++ b/pkg/selector/aggregates.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" ) @@ -62,9 +63,9 @@ func (itf Selector) TransformBaseInstanceType(filters Filters) (Filters, error) filters.GpusRange = &IntRangeFilter{LowerBound: gpuCount, UpperBound: gpuCount} } if filters.MemoryRange == nil { - lowerBound := (float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateLowPercentile) / 1024 - upperBound := (float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateHighPercentile) / 1024 - filters.MemoryRange = &Float64RangeFilter{LowerBound: lowerBound, UpperBound: upperBound} + lowerBound := bytequantity.ByteQuantity{Quantity: uint64(float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateLowPercentile)} + upperBound := bytequantity.ByteQuantity{Quantity: uint64(float64(*instanceTypeInfo.MemoryInfo.SizeInMiB) * AggregateHighPercentile)} + filters.MemoryRange = &ByteQuantityRangeFilter{LowerBound: lowerBound, UpperBound: upperBound} } if filters.VCpusRange == nil { lowerBound := int(float64(*instanceTypeInfo.VCpuInfo.DefaultVCpus) * AggregateLowPercentile) diff --git a/pkg/selector/comparators.go b/pkg/selector/comparators.go index c2c4dc3..b1156c7 100644 --- a/pkg/selector/comparators.go +++ b/pkg/selector/comparators.go @@ -66,15 +66,18 @@ func isSupportedWithRangeInt64(instanceTypeValue *int64, target *IntRangeFilter) return int(*instanceTypeValue) >= target.LowerBound && int(*instanceTypeValue) <= target.UpperBound } -func isSupportedWithFloat64(instanceTypeValue *float64, target *float64) bool { +func isSupportedWithRangeUint64(instanceTypeValue *int64, target *Uint64RangeFilter) bool { if target == nil { return true - } - if instanceTypeValue == nil { + } else if instanceTypeValue == nil && target.LowerBound == 0 && target.UpperBound == 0 { + return true + } else if instanceTypeValue == nil { return false } - // compare up to values' two decimal floor - return math.Floor(*instanceTypeValue*100)/100 == math.Floor(*target*100)/100 + if target.UpperBound > math.MaxInt64 { + target.UpperBound = math.MaxInt64 + } + return uint64(*instanceTypeValue) >= target.LowerBound && uint64(*instanceTypeValue) <= target.UpperBound } func isSupportedWithBool(instanceTypeValue *bool, target *bool) bool { diff --git a/pkg/selector/comparators_internal_test.go b/pkg/selector/comparators_internal_test.go index 2eca13b..350621f 100644 --- a/pkg/selector/comparators_internal_test.go +++ b/pkg/selector/comparators_internal_test.go @@ -14,6 +14,7 @@ package selector import ( + "math" "testing" h "github.com/aws/amazon-ec2-instance-selector/pkg/test" @@ -148,41 +149,50 @@ func TestIsSupportedWithRangeInt64_SourceNilTarget0(t *testing.T) { h.Assert(t, isSupported == true, "IntRangeFilter should match with 0 target and nil source") } -func TestIsSupportedWithFloat64_Supported(t *testing.T) { - isSupported := isSupportedWithFloat64(aws.Float64(0.33), aws.Float64(0.33)) - h.Assert(t, isSupported == true, "Float64 comparison should match exactly with 2 decimal places") +// uint64 + +func TestIsSupportedWithRangeUint64_SupportedExact(t *testing.T) { + target := IntRangeFilter{LowerBound: 4, UpperBound: 4} + isSupported := isSupportedWithRangeInt64(aws.Int64(4), &target) + h.Assert(t, isSupported == true, "IntRangeFilter should match exactly") } -func TestIsSupportedWithFloat64_SupportedTruncatedDecPlacesExact(t *testing.T) { - isSupported := isSupportedWithFloat64(aws.Float64(0.3322), aws.Float64(0.3322)) - h.Assert(t, isSupported == true, "Float64 comparison should match exactly with 4 decimal places") +func TestIsSupportedWithRangeUint64_SupportedAround(t *testing.T) { + target := Uint64RangeFilter{LowerBound: 2, UpperBound: 6} + isSupported := isSupportedWithRangeUint64(aws.Int64(4), &target) + h.Assert(t, isSupported == true, "UintRangeFilter should match with lower and upper bound around the desired source") } -func TestIsSupportedWithFloat64_SupportedTruncatedDecPlaces(t *testing.T) { - isSupported := isSupportedWithFloat64(aws.Float64(0.3399), aws.Float64(0.3311)) - h.Assert(t, isSupported == true, "Float64 comparison should match when truncating to 2 decimal places") +func TestIsSupportedWithRangeUint64_Nil(t *testing.T) { + target := Uint64RangeFilter{LowerBound: 2, UpperBound: 6} + isSupported := isSupportedWithRangeUint64(nil, &target) + h.Assert(t, isSupported == false, "Uint64RangeFilter should NOT match with nil source") } -func TestIsSupportedWithFloat64_Unsupported(t *testing.T) { - isSupported := isSupportedWithFloat64(aws.Float64(0.4), aws.Float64(0.3399)) - h.Assert(t, isSupported == false, "Float64 comparison should NOT match") +func TestIsSupportedWithRangeUint64_NilTarget(t *testing.T) { + isSupported := isSupportedWithRangeUint64(aws.Int64(4), nil) + h.Assert(t, isSupported == true, "Uint64RangeFilter should match with nil target") } -func TestIsSupportedWithFloat64_SourceNil(t *testing.T) { - isSupported := isSupportedWithFloat64(nil, aws.Float64(0.3399)) - h.Assert(t, isSupported == false, "Float64 comparison should NOT match with nil source") +func TestIsSupportedWithRangeUint64_BothNil(t *testing.T) { + isSupported := isSupportedWithRangeUint64(nil, nil) + h.Assert(t, isSupported == true, "Uint64RangeFilter should match with nil target and nil source") } -func TestIsSupportedWithFloat64_TargetNil(t *testing.T) { - isSupported := isSupportedWithFloat64(aws.Float64(0.3399), nil) - h.Assert(t, isSupported == true, "Float64 comparison should match with nil target") +func TestIsSupportedWithRangeUint64_SourceNilTarget0(t *testing.T) { + target := Uint64RangeFilter{LowerBound: 0, UpperBound: 0} + isSupported := isSupportedWithRangeUint64(nil, &target) + h.Assert(t, isSupported == true, "Uint64RangeFilter should match with 0 target and nil source") } -func TestIsSupportedWithFloat64_BothNil(t *testing.T) { - isSupported := isSupportedWithFloat64(nil, nil) - h.Assert(t, isSupported == true, "Float64 comparison should match with nil target and source") +func TestIsSupportedWithRangeUint64_Overflow(t *testing.T) { + target := Uint64RangeFilter{LowerBound: 0, UpperBound: math.MaxUint64} + isSupported := isSupportedWithRangeUint64(aws.Int64(4), &target) + h.Assert(t, isSupported == true, "Uint64RangeFilter should match with 0 - MAX target and source 4") } +// bools + func TestSupportSyntaxToBool_Supported(t *testing.T) { isSupported := supportSyntaxToBool(aws.String("supported")) h.Assert(t, *isSupported == true, "Supported should evaluate to true") diff --git a/pkg/selector/selector.go b/pkg/selector/selector.go index 3b16731..28ca456 100644 --- a/pkg/selector/selector.go +++ b/pkg/selector/selector.go @@ -16,7 +16,6 @@ package selector import ( "fmt" - "math" "reflect" "regexp" "sort" @@ -179,8 +178,8 @@ func (itf Selector) rawFilter(filters Filters) ([]*ec2.InstanceTypeInfo, error) rootDeviceType: {filters.RootDeviceType, instanceTypeInfo.SupportedRootDeviceTypes}, hibernationSupported: {filters.HibernationSupported, instanceTypeInfo.HibernationSupported}, vcpusRange: {filters.VCpusRange, instanceTypeInfo.VCpuInfo.DefaultVCpus}, - memoryRange: {gibToMibRange(filters.MemoryRange), instanceTypeInfo.MemoryInfo.SizeInMiB}, - gpuMemoryRange: {gibToMibRange(filters.GpuMemoryRange), getTotalGpuMemory(instanceTypeInfo.GpuInfo)}, + memoryRange: {filters.MemoryRange, instanceTypeInfo.MemoryInfo.SizeInMiB}, + gpuMemoryRange: {filters.GpuMemoryRange, getTotalGpuMemory(instanceTypeInfo.GpuInfo)}, gpusRange: {filters.GpusRange, getTotalGpusCount(instanceTypeInfo.GpuInfo)}, placementGroupStrategy: {filters.PlacementGroupStrategy, instanceTypeInfo.PlacementGroupInfo.SupportedStrategies}, hypervisor: {filters.Hypervisor, instanceTypeInfo.Hypervisor}, @@ -291,10 +290,27 @@ func (itf Selector) executeFilters(filterToInstanceSpecMapping map[string]filter default: return false, fmt.Errorf(invalidInstanceSpecTypeMsg) } - case *float64: + case *ByteQuantityRangeFilter: + mibRange := Uint64RangeFilter{ + LowerBound: filter.LowerBound.Quantity, + UpperBound: filter.UpperBound.Quantity, + } switch iSpec := instanceSpec.(type) { - case *float64: - if !isSupportedWithFloat64(iSpec, filter) { + case *int: + var iSpec64 *int64 + if iSpec != nil { + iSpecVal := int64(*iSpec) + iSpec64 = &iSpecVal + } + if !isSupportedWithRangeUint64(iSpec64, &mibRange) { + return false, nil + } + case *int64: + mibRange := Uint64RangeFilter{ + LowerBound: filter.LowerBound.Quantity, + UpperBound: filter.UpperBound.Quantity, + } + if !isSupportedWithRangeUint64(iSpec, &mibRange) { return false, nil } default: @@ -392,18 +408,3 @@ func isInAllowList(allowRegex *regexp.Regexp, instanceTypeName string) bool { } return allowRegex.MatchString(instanceTypeName) } - -func gibToMibRange(gbRange *Float64RangeFilter) *IntRangeFilter { - if gbRange == nil { - return nil - } - mbRangeFilter := IntRangeFilter{ - LowerBound: int(gbRange.LowerBound * 1024), - } - if gbRange.UpperBound == math.MaxFloat64 { - mbRangeFilter.UpperBound = math.MaxInt32 - return &mbRangeFilter - } - mbRangeFilter.UpperBound = int(gbRange.UpperBound * 1024) - return &mbRangeFilter -} diff --git a/pkg/selector/selector_test.go b/pkg/selector/selector_test.go index 5c3a493..d380138 100644 --- a/pkg/selector/selector_test.go +++ b/pkg/selector/selector_test.go @@ -22,6 +22,7 @@ import ( "strconv" "testing" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/amazon-ec2-instance-selector/pkg/selector" h "github.com/aws/amazon-ec2-instance-selector/pkg/test" "github.com/aws/aws-sdk-go/aws" @@ -240,9 +241,14 @@ func TestFilterVerbose_Gpus(t *testing.T) { itf := selector.Selector{ EC2: ec2Mock, } + gpuMemory, err := bytequantity.ParseToByteQuantity("128g") + h.Ok(t, err) filters := selector.Filters{ - GpusRange: &selector.IntRangeFilter{LowerBound: 8, UpperBound: 8}, - GpuMemoryRange: &selector.Float64RangeFilter{LowerBound: 128.0, UpperBound: 128.0}, + GpusRange: &selector.IntRangeFilter{LowerBound: 8, UpperBound: 8}, + GpuMemoryRange: &selector.ByteQuantityRangeFilter{ + LowerBound: gpuMemory, + UpperBound: gpuMemory, + }, } results, err := itf.FilterVerbose(filters) h.Ok(t, err) @@ -449,7 +455,6 @@ func TestRetrieveInstanceTypesSupportedInAZs_Intersection(t *testing.T) { } results, err := itf.RetrieveInstanceTypesSupportedInLocations([]string{"us-east-2a", "us-east-2b"}) h.Ok(t, err) - fmt.Println(results) h.Assert(t, len(results) == 3, "Should return instance types that are included in both files") // Check reversed zones to ensure order does not matter diff --git a/pkg/selector/types.go b/pkg/selector/types.go index 2de1dfb..e898e56 100644 --- a/pkg/selector/types.go +++ b/pkg/selector/types.go @@ -17,6 +17,7 @@ import ( "encoding/json" "regexp" + "github.com/aws/amazon-ec2-instance-selector/pkg/bytequantity" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" ) @@ -47,11 +48,18 @@ type IntRangeFilter struct { LowerBound int } -// Float64RangeFilter holds an upper and lower bound float64 +// Uint64RangeFilter holds an upper and lower bound uint64 // The lower and upper bound are used to range filter resource specs -type Float64RangeFilter struct { - UpperBound float64 - LowerBound float64 +type Uint64RangeFilter struct { + UpperBound uint64 + LowerBound uint64 +} + +// ByteQuantityRangeFilter holds an upper and lower bound byte quantity +// The lower and upper bound are used to range filter resource specs +type ByteQuantityRangeFilter struct { + UpperBound bytequantity.ByteQuantity + LowerBound bytequantity.ByteQuantity } // filterPair holds a tuple of the passed in filter value and the instance resource spec value @@ -119,7 +127,7 @@ type Filters struct { GpusRange *IntRangeFilter // GpuMemoryRange filter is a range of acceptable GPU memory in Gibibytes (GiB) available to an EC2 instance type in aggreagte across all GPUs. - GpuMemoryRange *Float64RangeFilter + GpuMemoryRange *ByteQuantityRangeFilter // HibernationSupported denotes whether EC2 hibernate is supported // Possible values are: true or false @@ -133,7 +141,7 @@ type Filters struct { MaxResults *int // MemoryRange filter is a range of acceptable DRAM memory in Gibibytes (GiB) for the instance type - MemoryRange *Float64RangeFilter + MemoryRange *ByteQuantityRangeFilter // NetworkInterfaces filter is a range of the number of ENI attachments an instance type can support NetworkInterfaces *IntRangeFilter From 666964e89a9bd52dd4c240e7e3f5f7b665eef17f Mon Sep 17 00:00:00 2001 From: Brandon Wagner Date: Fri, 24 Jul 2020 10:43:52 -0500 Subject: [PATCH 4/4] fixes based on pr comments --- cmd/main.go | 4 ++-- pkg/bytequantity/bytequantity.go | 3 ++- pkg/cli/types.go | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index d93310c..f40bb23 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -116,11 +116,11 @@ Full docs can be found at github.com/aws/amazon-` + binName // Filter Flags - These will be grouped at the top of the help flags cli.IntMinMaxRangeFlags(vcpus, cli.StringMe("c"), nil, "Number of vcpus available to the instance type.") - cli.ByteQuantityMinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available in GiB (Example: 4)") + cli.ByteQuantityMinMaxRangeFlags(memory, cli.StringMe("m"), nil, "Amount of Memory available (Example: 4 GiB)") cli.RatioFlag(vcpusToMemoryRatio, nil, nil, "The ratio of vcpus to memory in MiB. (Example: 1:2)") cli.StringFlag(cpuArchitecture, cli.StringMe("a"), nil, "CPU architecture [x86_64, i386, or arm64]", nil) cli.IntMinMaxRangeFlags(gpus, cli.StringMe("g"), nil, "Total Number of GPUs (Example: 4)") - cli.ByteQuantityMinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory in GiB (Example: 4)") + cli.ByteQuantityMinMaxRangeFlags(gpuMemoryTotal, nil, nil, "Number of GPUs' total memory (Example: 4 GiB)") cli.StringFlag(placementGroupStrategy, nil, nil, "Placement group strategy: [cluster, partition, spread]", nil) cli.StringFlag(usageClass, cli.StringMe("u"), nil, "Usage class: [spot or on-demand]", nil) cli.StringFlag(rootDeviceType, nil, nil, "Supported root device types: [ebs or instance-store]", nil) diff --git a/pkg/bytequantity/bytequantity.go b/pkg/bytequantity/bytequantity.go index 6579dde..b442a2c 100644 --- a/pkg/bytequantity/bytequantity.go +++ b/pkg/bytequantity/bytequantity.go @@ -38,7 +38,8 @@ type ByteQuantity struct { Quantity uint64 } -// ParseToByteQuantity parses a string representation of a byte quantity to a ByteQuantity type +// ParseToByteQuantity parses a string representation of a byte quantity to a ByteQuantity type. +// A unit can be appended such as 16 GiB. If no unit is appended, GiB is assumed. func ParseToByteQuantity(byteQuantityStr string) (ByteQuantity, error) { bqRegexp := regexp.MustCompile(byteQuantityRegex) matches := bqRegexp.FindStringSubmatch(strings.ToLower(byteQuantityStr)) diff --git a/pkg/cli/types.go b/pkg/cli/types.go index a6efd74..09d2c11 100644 --- a/pkg/cli/types.go +++ b/pkg/cli/types.go @@ -207,7 +207,7 @@ func (*CommandLineInterface) RegexMe(i interface{}) *regexp.Regexp { } } -// ByteQuantityMe takes an interface and returns a pointer to a regex +// ByteQuantityMe takes an interface and returns a pointer to a ByteQuantity // If the underlying interface kind is not bytequantity.ByteQuantity or *bytequantity.ByteQuantity then nil is returned func (*CommandLineInterface) ByteQuantityMe(i interface{}) *bytequantity.ByteQuantity { if i == nil {