Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type options struct {
}

noAllDevice bool
deviceIDs []string

// the following are used for dependency injection during spec generation.
nvmllib nvml.Interface
Expand Down Expand Up @@ -240,6 +241,14 @@ func (m command) build() *cli.Command {
Destination: &opts.noAllDevice,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_NO_ALL_DEVICE"),
},
&cli.StringSliceFlag{
Name: "device-id",
Aliases: []string{"device-ids", "device", "devices"},
Usage: "Restrict generation to the specified device identifiers",
Value: []string{"all"},
Destination: &opts.deviceIDs,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"),
},
},
}

Expand Down Expand Up @@ -381,7 +390,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
return nil, fmt.Errorf("failed to create CDI library: %v", err)
}

allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...)
if err != nil {
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
}
Expand Down
28 changes: 27 additions & 1 deletion cmd/nvidia-ctk/cdi/generate/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package generate

import (
"bytes"
"fmt"
"path/filepath"
"strings"
"testing"
Expand Down Expand Up @@ -47,6 +48,27 @@ func TestGenerateSpec(t *testing.T) {
expectedError error
expectedSpec string
}{
{
description: "invalid device id",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
deviceIDs: []string{"99"},
driverRoot: driverRoot,
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
deviceIDs: []string{"99"},
driverRoot: driverRoot,
},
expectedError: fmt.Errorf("failed to create device CDI specs: failed to construct device spec generators: failed to get device handle from index: ERROR_INVALID_ARGUMENT"),
},
{
description: "default",
options: options{
Expand Down Expand Up @@ -452,6 +474,10 @@ containerEdits:
for _, tc := range testCases {
// Apply overrides for all test cases:
tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
if tc.options.deviceIDs == nil {
tc.options.deviceIDs = []string{"all"}
tc.expectedOptions.deviceIDs = []string{"all"}
}

t.Run(tc.description, func(t *testing.T) {
c := command{
Expand Down Expand Up @@ -481,7 +507,7 @@ containerEdits:
tc.options.nvmllib = server

specs, err := c.generateSpecs(&tc.options)
require.ErrorIs(t, err, tc.expectedError)
require.EqualValues(t, err, tc.expectedError)

var buf bytes.Buffer
for _, spec := range specs {
Expand Down