diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index d112e51e9..3b4955dd1 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -74,6 +74,7 @@ type options struct { } noAllDevice bool + deviceIDs []string // the following are used for dependency injection during spec generation. nvmllib nvml.Interface @@ -240,6 +241,14 @@ func (m command) build() *cli.Command { Destination: &opts.noAllDevice, Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_NO_ALL_DEVICE"), }, + &cli.StringSliceFlag{ + Name: "device-id", + Aliases: []string{"device-ids", "device", "devices"}, + Usage: "Restrict generation to the specified device identifiers", + Value: []string{"all"}, + Destination: &opts.deviceIDs, + Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"), + }, }, } @@ -381,7 +390,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) { return nil, fmt.Errorf("failed to create CDI library: %v", err) } - allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all") + allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...) if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } diff --git a/cmd/nvidia-ctk/cdi/generate/generate_test.go b/cmd/nvidia-ctk/cdi/generate/generate_test.go index 5e4940bfc..bff0ed2b0 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate_test.go +++ b/cmd/nvidia-ctk/cdi/generate/generate_test.go @@ -18,6 +18,7 @@ package generate import ( "bytes" + "fmt" "path/filepath" "strings" "testing" @@ -47,6 +48,27 @@ func TestGenerateSpec(t *testing.T) { expectedError error expectedSpec string }{ + { + description: "invalid device id", + options: options{ + format: "yaml", + mode: "nvml", + vendor: "example.com", + class: "device", + deviceIDs: []string{"99"}, + driverRoot: driverRoot, + }, + expectedOptions: options{ + format: "yaml", + mode: "nvml", + vendor: "example.com", + class: "device", + nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook", + deviceIDs: []string{"99"}, + driverRoot: driverRoot, + }, + expectedError: fmt.Errorf("failed to create device CDI specs: failed to construct device spec generators: failed to get device handle from index: ERROR_INVALID_ARGUMENT"), + }, { description: "default", options: options{ @@ -452,6 +474,10 @@ containerEdits: for _, tc := range testCases { // Apply overrides for all test cases: tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" + if tc.options.deviceIDs == nil { + tc.options.deviceIDs = []string{"all"} + tc.expectedOptions.deviceIDs = []string{"all"} + } t.Run(tc.description, func(t *testing.T) { c := command{ @@ -481,7 +507,7 @@ containerEdits: tc.options.nvmllib = server specs, err := c.generateSpecs(&tc.options) - require.ErrorIs(t, err, tc.expectedError) + require.EqualValues(t, err, tc.expectedError) var buf bytes.Buffer for _, spec := range specs {