diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 9edd8598f..b2b8cad3f 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -73,6 +73,10 @@ type options struct { ignorePatterns []string } + deviceIDs []string + + noAllDevice bool + // the following are used for dependency injection during spec generation. nvmllib nvml.Interface } @@ -232,6 +236,20 @@ func (m command) build() *cli.Command { Destination: &opts.featureFlags, Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS"), }, + &cli.StringSliceFlag{ + Name: "device-id", + Aliases: []string{"device-ids", "device", "devices"}, + Usage: "Restrict generation to the specified device identifiers", + Value: []string{"all"}, + Destination: &opts.deviceIDs, + Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"), + }, + &cli.BoolFlag{ + Name: "no-all-device", + Usage: "Don't generate an `all` device for the resultant spec", + Destination: &opts.noAllDevice, + Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_NO_ALL_DEVICE"), + }, }, } @@ -373,7 +391,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) { return nil, fmt.Errorf("failed to create CDI library: %v", err) } - allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all") + allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...) if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } @@ -387,13 +405,18 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) { spec.WithVendor(opts.vendor), spec.WithEdits(*commonEdits.ContainerEdits), spec.WithFormat(opts.format), - spec.WithMergedDeviceOptions( - transform.WithName(allDeviceName), - transform.WithSkipIfExists(true), - ), spec.WithPermissions(0644), } + if !opts.noAllDevice { + commonSpecOptions = append(commonSpecOptions, + spec.WithMergedDeviceOptions( + transform.WithName(allDeviceName), + transform.WithSkipIfExists(true), + ), + ) + } + fullSpec, err := spec.New( append(commonSpecOptions, spec.WithClass(opts.class), diff --git a/cmd/nvidia-ctk/cdi/generate/generate_test.go b/cmd/nvidia-ctk/cdi/generate/generate_test.go index 5e4940bfc..916f885ef 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate_test.go +++ b/cmd/nvidia-ctk/cdi/generate/generate_test.go @@ -452,6 +452,10 @@ containerEdits: for _, tc := range testCases { // Apply overrides for all test cases: tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" + if tc.options.deviceIDs == nil { + tc.options.deviceIDs = []string{"all"} + tc.expectedOptions.deviceIDs = []string{"all"} + } t.Run(tc.description, func(t *testing.T) { c := command{ diff --git a/internal/platform-support/tegra/csv.go b/internal/platform-support/tegra/csv.go index edb7fdc48..c622e91b6 100644 --- a/internal/platform-support/tegra/csv.go +++ b/internal/platform-support/tegra/csv.go @@ -17,36 +17,30 @@ package tegra import ( - "fmt" - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" ) -// newDiscovererFromCSVFiles creates a discoverer for the specified CSV files. A logger is also supplied. -// The constructed discoverer is comprised of a list, with each element in the list being associated with a -// single CSV files. -func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { - if len(o.csvFiles) == 0 { - o.logger.Warningf("No CSV files specified") +func (o options) newDiscovererFromMountSpecs() (discover.Discover, error) { + pathsByType := o.MountSpecPathsByType() + if len(pathsByType) == 0 { + o.logger.Warningf("No mount specs specified") return discover.None{}, nil } - targetsByType := getTargetsFromCSVFiles(o.logger, o.csvFiles) - devices := discover.NewCharDeviceDiscoverer( o.logger, o.devRoot, - targetsByType[csv.MountSpecDev], + pathsByType[csv.MountSpecDev], ) directories := discover.NewMounts( o.logger, lookup.NewDirectoryLocator(lookup.WithLogger(o.logger), lookup.WithRoot(o.driverRoot)), o.driverRoot, - targetsByType[csv.MountSpecDir], + pathsByType[csv.MountSpecDir], ) // We create a discoverer for mounted libraries and add additional .so @@ -57,14 +51,14 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { o.logger, o.symlinkLocator, o.driverRoot, - targetsByType[csv.MountSpecLib], + pathsByType[csv.MountSpecLib], ), "", o.hookCreator, ) // We process the explicitly requested symlinks. - symlinkTargets := o.ignorePatterns.Apply(targetsByType[csv.MountSpecSym]...) + symlinkTargets := pathsByType[csv.MountSpecSym] o.logger.Debugf("Filtered symlink targets: %v", symlinkTargets) symlinks := discover.NewMounts( o.logger, @@ -85,35 +79,34 @@ func (o tegraOptions) newDiscovererFromCSVFiles() (discover.Discover, error) { return d, nil } -// getTargetsFromCSVFiles returns the list of mount specs from the specified CSV files. -// These are aggregated by mount spec type. -// TODO: We use a function variable here to allow this to be overridden for testing. -// This should be properly mocked. -var getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { - targetsByType := make(map[csv.MountSpecType][]string) - for _, filename := range files { - targets, err := loadCSVFile(logger, filename) - if err != nil { - logger.Warningf("Skipping CSV file %v: %v", filename, err) - continue - } - for _, t := range targets { - targetsByType[t.Type] = append(targetsByType[t.Type], t.Path) - } +// MountSpecsFromCSVFiles returns a MountSpecPathsByTyper for the specified list +// of CSV files. +func MountSpecsFromCSVFiles(logger logger.Interface, csvFiles ...string) MountSpecPathsByTyper { + var tts []MountSpecPathsByTyper + + for _, filename := range csvFiles { + tts = append(tts, &fromCSVFile{logger, filename}) } - return targetsByType + return Merge(tts...) } -// loadCSVFile loads the specified CSV file and returns the list of mount specs -func loadCSVFile(logger logger.Interface, filename string) ([]*csv.MountSpec, error) { +type fromCSVFile struct { + logger logger.Interface + filename string +} + +// MountSpecPathsByType returns mountspecs defined in the specified CSV file. +func (t *fromCSVFile) MountSpecPathsByType() MountSpecPathsByType { // Create a discoverer for each file-kind combination - targets, err := csv.NewCSVFileParser(logger, filename).Parse() + targets, err := csv.NewCSVFileParser(t.logger, t.filename).Parse() if err != nil { - return nil, fmt.Errorf("failed to parse CSV file: %v", err) - } - if len(targets) == 0 { - return nil, fmt.Errorf("CSV file is empty") + t.logger.Warningf("failed to parse CSV file %v: %v", t.filename, err) + return nil } - return targets, nil + targetsByType := make(MountSpecPathsByType) + for _, t := range targets { + targetsByType[t.Type] = append(targetsByType[t.Type], t.Path) + } + return targetsByType } diff --git a/internal/platform-support/tegra/csv_test.go b/internal/platform-support/tegra/csv_test.go index 1fcda971b..894ee529b 100644 --- a/internal/platform-support/tegra/csv_test.go +++ b/internal/platform-support/tegra/csv_test.go @@ -24,7 +24,6 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" @@ -34,7 +33,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) { logger, _ := testlog.NewNullLogger() testCases := []struct { description string - moutSpecs map[csv.MountSpecType][]string + moutSpecs MountSpecPathsByType ignorePatterns []string symlinkLocator lookup.Locator symlinkChainLocator lookup.Locator @@ -186,19 +185,19 @@ func TestDiscovererFromCSVFiles(t *testing.T) { hookCreator := discover.NewHookCreator() for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - defer setGetTargetsFromCSVFiles(tc.moutSpecs)() - - o := tegraOptions{ - logger: logger, - hookCreator: hookCreator, - csvFiles: []string{"dummy"}, - ignorePatterns: tc.ignorePatterns, + o := options{ + logger: logger, + hookCreator: hookCreator, + MountSpecPathsByTyper: Filter( + tc.moutSpecs, + Symlinks(tc.ignorePatterns...), + ), symlinkLocator: tc.symlinkLocator, symlinkChainLocator: tc.symlinkChainLocator, resolveSymlink: tc.symlinkResolver, } - d, err := o.newDiscovererFromCSVFiles() + d, err := o.newDiscovererFromMountSpecs() require.ErrorIs(t, err, tc.expectedError) hooks, err := d.Hooks() @@ -212,14 +211,3 @@ func TestDiscovererFromCSVFiles(t *testing.T) { }) } } - -func setGetTargetsFromCSVFiles(override map[csv.MountSpecType][]string) func() { - original := getTargetsFromCSVFiles - getTargetsFromCSVFiles = func(logger logger.Interface, files []string) map[csv.MountSpecType][]string { - return override - } - - return func() { - getTargetsFromCSVFiles = original - } -} diff --git a/internal/platform-support/tegra/mount_specs.go b/internal/platform-support/tegra/mount_specs.go new file mode 100644 index 000000000..db7826514 --- /dev/null +++ b/internal/platform-support/tegra/mount_specs.go @@ -0,0 +1,166 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import ( + "path/filepath" + "strconv" + "strings" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" +) + +// A MountSpecPathsByTyper provides a function to return mount specs paths by +// mount type. +// The MountSpecTypes are one of: dev, dir, lib, sym and define how these should +// be included in a container (or represented in the associated CDI spec). +type MountSpecPathsByTyper interface { + MountSpecPathsByType() MountSpecPathsByType +} + +type MountSpecPathsByType map[csv.MountSpecType][]string + +var _ MountSpecPathsByTyper = (MountSpecPathsByType)(nil) + +// MountSpecPathsByType for a variable of type MountSpecPathsByType returns the +// underlying data structure. +// This allows for using this type in functions such as Merge and Filter. +func (m MountSpecPathsByType) MountSpecPathsByType() MountSpecPathsByType { + return m +} + +type merge []MountSpecPathsByTyper + +// Merge combines the MountSpecPathsByType for the specified sources. +func Merge(sources ...MountSpecPathsByTyper) MountSpecPathsByTyper { + return merge(sources) +} + +// MountSpecPathsByType for a set of merged mount specs combines the list of +// paths per type. +func (ts merge) MountSpecPathsByType() MountSpecPathsByType { + targetsByType := make(MountSpecPathsByType) + for _, t := range ts { + if t == nil { + continue + } + for tType, targets := range t.MountSpecPathsByType() { + targetsByType[tType] = append(targetsByType[tType], targets...) + } + } + return targetsByType +} + +type filterMountSpecs struct { + from MountSpecPathsByTyper + remove MountSpecPathsByTyper +} + +// Filter removes the specified MountSpecPaths (by type) from the specified +// set of MountSpecPaths. +// Here the paths in the remove set are treated as patterns, and elements in +// from that match any specified pattern are filtered out. +func Filter(from MountSpecPathsByTyper, remove MountSpecPathsByTyper) MountSpecPathsByTyper { + return filterMountSpecs{ + from: from, + remove: remove, + } +} + +// MountSpecPathsByType for a filter get the mountspecs defined in the source +// and apply the specified per-type filters. +func (m filterMountSpecs) MountSpecPathsByType() MountSpecPathsByType { + ms := m.from.MountSpecPathsByType() + if len(ms) == 0 { + return ms + } + + for t, patterns := range m.remove.MountSpecPathsByType() { + paths := ms[t] + if len(paths) == 0 { + continue + } + filtered := ignoreMountSpecPatterns(patterns).Apply(paths...) + ms[t] = filtered + } + + return ms +} + +type stripDeviceNodes struct { + from MountSpecPathsByTyper +} + +// WithoutRegularDeviceNodes creates a MountSpecPathsByTyper which removes +// regular `/dev/nvidia[0-9]+` device nodes from the source. +func WithoutRegularDeviceNodes(from MountSpecPathsByTyper) MountSpecPathsByTyper { + return &stripDeviceNodes{from} +} + +// MountSpecPathsByType returns the source mount specs with regular nvidia +// device nodes removed from the source. +func (d *stripDeviceNodes) MountSpecPathsByType() MountSpecPathsByType { + ms := d.from.MountSpecPathsByType() + if len(ms) == 0 { + return ms + } + + filtered := d.Apply(ms[csv.MountSpecDev]...) + ms[csv.MountSpecDev] = filtered + + return ms +} + +func (d *stripDeviceNodes) Apply(input ...string) []string { + var filtered []string + for _, name := range input { + if d.Match(name) { + continue + } + filtered = append(filtered, name) + } + return filtered +} + +// Match returns true if name is a REGULAR NVIDIA GPU device node. +func (d *stripDeviceNodes) Match(name string) bool { + pattern := "/dev/nvidia*" + if match, _ := filepath.Match(pattern, name); !match { + return false + } + suffix := strings.TrimPrefix(name, "/dev/nvidia") + // Check whether path has the form /dev/nvidia%d + _, err := strconv.Atoi(suffix) + return err == nil +} + +// DeviceNodes creates a set of MountSpecPaths for the specified device nodes. +// These have the MoutSpecDev type. +func DeviceNodes(dn ...string) MountSpecPathsByTyper { + return MountSpecPathsByType{ + csv.MountSpecDev: dn, + } +} + +// DeviceNodes creates a set of MountSpecPaths for the specified symlinks. +// These have the MountSpecSym type. +func Symlinks(s ...string) MountSpecPathsByTyper { + return MountSpecPathsByType{ + csv.MountSpecSym: s, + } +} diff --git a/internal/platform-support/tegra/options.go b/internal/platform-support/tegra/options.go new file mode 100644 index 000000000..8a005a80c --- /dev/null +++ b/internal/platform-support/tegra/options.go @@ -0,0 +1,105 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package tegra + +import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" +) + +type options struct { + logger logger.Interface + driverRoot string + devRoot string + hookCreator discover.HookCreator + ldconfigPath string + librarySearchPaths []string + + // The following can be overridden for testing + symlinkLocator lookup.Locator + symlinkChainLocator lookup.Locator + // TODO: This should be replaced by a regular mock + resolveSymlink func(string) (string, error) + + MountSpecPathsByTyper +} + +// Option defines a functional option for configuring a Tegra discoverer. +type Option func(*options) + +// WithLogger sets the logger for the discoverer. +func WithLogger(logger logger.Interface) Option { + return func(o *options) { + o.logger = logger + } +} + +// WithDriverRoot sets the driver root for the discoverer. +func WithDriverRoot(driverRoot string) Option { + return func(o *options) { + o.driverRoot = driverRoot + } +} + +// WithDevRoot sets the /dev root. +// If this is unset, the driver root is assumed. +func WithDevRoot(devRoot string) Option { + return func(o *options) { + o.devRoot = devRoot + } +} + +// WithHookCreator sets the hook creator for the discoverer. +func WithHookCreator(hookCreator discover.HookCreator) Option { + return func(o *options) { + o.hookCreator = hookCreator + } +} + +// WithLdconfigPath sets the path to the ldconfig program +func WithLdconfigPath(ldconfigPath string) Option { + return func(o *options) { + o.ldconfigPath = ldconfigPath + } +} + +// WithLibrarySearchPaths sets the library search paths for the discoverer. +func WithLibrarySearchPaths(librarySearchPaths ...string) Option { + return func(o *options) { + o.librarySearchPaths = librarySearchPaths + } +} + +// WithMountSpecsByPath sets the source of MountSpec paths per type. +// If multiple values are supplied, these are merged. +func WithMountSpecsByPath(msfp ...MountSpecPathsByTyper) Option { + return func(o *options) { + o.MountSpecPathsByTyper = Merge(msfp...) + } +} + +// MountSpecPathsByType returns the mounts specs by path configured for these +// options. +// For an unconfigured MountSpecPathsByTyper no mountspecs are returned. +func (o options) MountSpecPathsByType() MountSpecPathsByType { + if o.MountSpecPathsByTyper == nil { + return nil + } + return o.MountSpecPathsByTyper.MountSpecPathsByType() +} diff --git a/internal/platform-support/tegra/symlinks.go b/internal/platform-support/tegra/symlinks.go index 822d482fd..00e664a19 100644 --- a/internal/platform-support/tegra/symlinks.go +++ b/internal/platform-support/tegra/symlinks.go @@ -36,7 +36,7 @@ type symlinkHook struct { } // createCSVSymlinkHooks creates a discoverer for a hook that creates required symlinks in the container -func (o tegraOptions) createCSVSymlinkHooks(targets []string) discover.Discover { +func (o options) createCSVSymlinkHooks(targets []string) discover.Discover { return symlinkHook{ logger: o.logger, hookCreator: o.hookCreator, diff --git a/internal/platform-support/tegra/tegra.go b/internal/platform-support/tegra/tegra.go index 6ad774b4e..229af05f5 100644 --- a/internal/platform-support/tegra/tegra.go +++ b/internal/platform-support/tegra/tegra.go @@ -20,34 +20,13 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" ) -type tegraOptions struct { - logger logger.Interface - csvFiles []string - driverRoot string - devRoot string - hookCreator discover.HookCreator - ldconfigPath string - librarySearchPaths []string - ignorePatterns ignoreMountSpecPatterns - - // The following can be overridden for testing - symlinkLocator lookup.Locator - symlinkChainLocator lookup.Locator - // TODO: This should be replaced by a regular mock - resolveSymlink func(string) (string, error) -} - -// Option defines a functional option for configuring a Tegra discoverer. -type Option func(*tegraOptions) - -// New creates a new tegra discoverer using the supplied options. +// New creates a new tegra discoverer using the supplied functional options. func New(opts ...Option) (discover.Discover, error) { - o := &tegraOptions{} + o := &options{} for _, opt := range opts { opt(o) } @@ -75,12 +54,12 @@ func New(opts ...Option) (discover.Discover, error) { o.resolveSymlink = symlinks.Resolve } - csvDiscoverer, err := o.newDiscovererFromCSVFiles() + mountSpecDiscoverer, err := o.newDiscovererFromMountSpecs() if err != nil { return nil, fmt.Errorf("failed to create CSV discoverer: %v", err) } - ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, csvDiscoverer, o.hookCreator, o.ldconfigPath) + ldcacheUpdateHook, err := discover.NewLDCacheUpdateHook(o.logger, mountSpecDiscoverer, o.hookCreator, o.ldconfigPath) if err != nil { return nil, fmt.Errorf("failed to create ldcach update hook discoverer: %v", err) } @@ -95,7 +74,7 @@ func New(opts ...Option) (discover.Discover, error) { ) d := discover.Merge( - csvDiscoverer, + mountSpecDiscoverer, // The ldcacheUpdateHook is added last to ensure that the created symlinks are included ldcacheUpdateHook, tegraSystemMounts, @@ -103,60 +82,3 @@ func New(opts ...Option) (discover.Discover, error) { return d, nil } - -// WithLogger sets the logger for the discoverer. -func WithLogger(logger logger.Interface) Option { - return func(o *tegraOptions) { - o.logger = logger - } -} - -// WithDriverRoot sets the driver root for the discoverer. -func WithDriverRoot(driverRoot string) Option { - return func(o *tegraOptions) { - o.driverRoot = driverRoot - } -} - -// WithDevRoot sets the /dev root. -// If this is unset, the driver root is assumed. -func WithDevRoot(devRoot string) Option { - return func(o *tegraOptions) { - o.devRoot = devRoot - } -} - -// WithCSVFiles sets the CSV files for the discoverer. -func WithCSVFiles(csvFiles []string) Option { - return func(o *tegraOptions) { - o.csvFiles = csvFiles - } -} - -// WithHookCreator sets the hook creator for the discoverer. -func WithHookCreator(hookCreator discover.HookCreator) Option { - return func(o *tegraOptions) { - o.hookCreator = hookCreator - } -} - -// WithLdconfigPath sets the path to the ldconfig program -func WithLdconfigPath(ldconfigPath string) Option { - return func(o *tegraOptions) { - o.ldconfigPath = ldconfigPath - } -} - -// WithLibrarySearchPaths sets the library search paths for the discoverer. -func WithLibrarySearchPaths(librarySearchPaths ...string) Option { - return func(o *tegraOptions) { - o.librarySearchPaths = librarySearchPaths - } -} - -// WithIngorePatterns sets patterns to ignore in the CSV files -func WithIngorePatterns(ignorePatterns ...string) Option { - return func(o *tegraOptions) { - o.ignorePatterns = ignoreMountSpecPatterns(ignorePatterns) - } -} diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index fce32bc88..14cbdb83f 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -88,4 +88,8 @@ const ( // FeatureEnableCoherentAnnotations enables the addition of annotations // coherent or non-coherent devices. FeatureEnableCoherentAnnotations = FeatureFlag("enable-coherent-annotations") + + // FeatureDisableMultipleCSVDevices disables the handling of multiple devices + // in CSV mode. + FeatureDisableMultipleCSVDevices = FeatureFlag("disable-multiple-csv-devices") ) diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 6380d79dc..d2cc06bbd 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -18,10 +18,17 @@ package nvcdi import ( "fmt" + "slices" + "strconv" + "strings" "tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/specs-go" + "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" + "github.com/NVIDIA/go-nvml/pkg/nvml" + "github.com/google/uuid" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra" @@ -29,9 +36,31 @@ import ( type csvlib nvcdilib +type mixedcsvlib nvcdilib + var _ deviceSpecGeneratorFactory = (*csvlib)(nil) +// DeviceSpecGenerators creates a set of generators for the specified set of +// devices. +// If NVML is not available or the disable-multiple-csv-devices feature flag is +// enabled, a single device is assumed. func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { + if l.featureFlags[FeatureDisableMultipleCSVDevices] { + return l.purecsvDeviceSpecGenerators(ids...) + } + hasNVML, _ := l.infolib.HasNvml() + if !hasNVML { + return l.purecsvDeviceSpecGenerators(ids...) + } + mixed, err := l.mixedDeviceSpecGenerators(ids...) + if err != nil { + l.logger.Warningf("Failed to create mixed CSV spec generator; falling back to pure CSV implementation: %v", err) + return l.purecsvDeviceSpecGenerators(ids...) + } + return mixed, nil +} + +func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { for _, id := range ids { switch id { case "all": @@ -40,21 +69,59 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error return nil, fmt.Errorf("unsupported device id: %v", id) } } + g := &csvDeviceGenerator{ + csvlib: l, + index: 0, + uuid: "", + } + return g, nil +} + +func (l *csvlib) mixedDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { + return (*mixedcsvlib)(l).DeviceSpecGenerators(ids...) +} - return l, nil +// A csvDeviceGenerator generates CDI specs for a device based on a set of +// platform-specific CSV files. +type csvDeviceGenerator struct { + *csvlib + index int + uuid string + onlyDeviceNodes []string + additionalDeviceNodes []string +} + +func (l *csvDeviceGenerator) GetUUID() (string, error) { + return l.uuid, nil } // GetDeviceSpecs returns the CDI device specs for a single device. -func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) { +func (l *csvDeviceGenerator) GetDeviceSpecs() ([]specs.Device, error) { + mountSpecs := tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...) + if len(l.onlyDeviceNodes) > 0 { + mountSpecs = tegra.Merge( + tegra.WithoutRegularDeviceNodes(mountSpecs), + tegra.DeviceNodes(l.onlyDeviceNodes...), + ) + } d, err := tegra.New( tegra.WithLogger(l.logger), tegra.WithDriverRoot(l.driverRoot), tegra.WithDevRoot(l.devRoot), tegra.WithHookCreator(l.hookCreator), tegra.WithLdconfigPath(l.ldconfigPath), - tegra.WithCSVFiles(l.csvFiles), tegra.WithLibrarySearchPaths(l.librarySearchPaths...), - tegra.WithIngorePatterns(l.csvIgnorePatterns...), + tegra.WithMountSpecsByPath( + tegra.Filter( + tegra.Merge( + mountSpecs, + tegra.DeviceNodes(l.additionalDeviceNodes...), + ), + tegra.Merge( + tegra.Symlinks(l.csvIgnorePatterns...), + ), + ), + ), ) if err != nil { return nil, fmt.Errorf("failed to create discoverer for CSV files: %v", err) @@ -64,7 +131,7 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) { return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err) } - names, err := l.deviceNamers.GetDeviceNames(0, uuidIgnored{}) + names, err := l.deviceNamers.GetDeviceNames(l.index, l) if err != nil { return nil, fmt.Errorf("failed to get device name: %v", err) } @@ -84,3 +151,145 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) { func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) { return edits.FromDiscoverer(discover.None{}) } + +func (l *mixedcsvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { + asNvmlLib := (*nvmllib)(l) + err := asNvmlLib.init() + if err != nil { + return nil, fmt.Errorf("failed to initialize nvml: %w", err) + } + defer asNvmlLib.tryShutdown() + + if slices.Contains(ids, "all") { + ids, err = l.getAllDeviceIndices() + if err != nil { + return nil, fmt.Errorf("failed to get device indices: %w", err) + } + } + + var DeviceSpecGenerators DeviceSpecGenerators + for _, id := range ids { + generator, err := l.deviceSpecGeneratorForId(device.Identifier(id)) + if err != nil { + return nil, fmt.Errorf("failed to create device spec generator for device %q: %w", id, err) + } + DeviceSpecGenerators = append(DeviceSpecGenerators, generator) + } + + return DeviceSpecGenerators, nil +} + +func (l *mixedcsvlib) getAllDeviceIndices() ([]string, error) { + numDevices, ret := l.nvmllib.DeviceGetCount() + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("faled to get device count: %v", ret) + } + + var allIndices []string + for index := range numDevices { + allIndices = append(allIndices, fmt.Sprintf("%d", index)) + } + return allIndices, nil +} + +func (l *mixedcsvlib) deviceSpecGeneratorForId(id device.Identifier) (DeviceSpecGenerator, error) { + switch { + case id.IsGpuUUID(), isIntegratedGPUID(id): + uuid := string(id) + device, ret := l.nvmllib.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get device handle from UUID %q: %v", uuid, ret) + } + index, ret := device.GetIndex() + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get device index: %v", ret) + } + return l.csvDeviceSpecGenerator(index, uuid, device) + case id.IsGpuIndex(): + index, err := strconv.Atoi(string(id)) + if err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + device, ret := l.nvmllib.DeviceGetHandleByIndex(index) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get device handle from index: %v", ret) + } + uuid, ret := device.GetUUID() + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get UUID: %v", ret) + } + return l.csvDeviceSpecGenerator(index, uuid, device) + case id.IsMigUUID(): + fallthrough + case id.IsMigIndex(): + return nil, fmt.Errorf("generating a CDI spec for MIG id %q is not supported in CSV mode", id) + } + return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id) +} + +func (l *mixedcsvlib) csvDeviceSpecGenerator(index int, uuid string, device nvml.Device) (DeviceSpecGenerator, error) { + var additionalDeviceNodes []string + isIntegrated, err := isIntegratedGPU(device) + if err != nil { + return nil, fmt.Errorf("is-integrated check failed for device (index=%v,uuid=%v)", index, uuid) + } + if !isIntegrated { + additionalDeviceNodes = []string{ + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + } + } + g := &csvDeviceGenerator{ + csvlib: (*csvlib)(l), + index: index, + uuid: uuid, + onlyDeviceNodes: []string{fmt.Sprintf("/dev/nvidia%d", index)}, + additionalDeviceNodes: additionalDeviceNodes, + } + return g, nil +} + +func isIntegratedGPUID(id device.Identifier) bool { + _, err := uuid.Parse(string(id)) + return err == nil +} + +// isIntegratedGPU checks whether the specified device is an integrated GPU. +// As a proxy we check the PCI Bus if for thes +// TODO: This should be replaced by an explicit NVML call once available. +func isIntegratedGPU(d nvml.Device) (bool, error) { + pciInfo, ret := d.GetPciInfo() + if ret == nvml.ERROR_NOT_SUPPORTED { + name, ret := d.GetName() + if ret != nvml.SUCCESS { + return false, fmt.Errorf("failed to get device name: %v", ret) + } + return isIntegratedGPUName(name), nil + } + if ret != nvml.SUCCESS { + return false, fmt.Errorf("failed to get PCI info: %v", ret) + } + + if pciInfo.Domain != 0 { + return false, nil + } + if pciInfo.Bus != 1 { + return false, nil + } + return pciInfo.Device == 0, nil +} + +// isIntegratedGPUName returns true if the specified device name is associated +// with a known iGPU. +// +// TODO: Consider making go-nvlib/pkg/nvlib/info/isIntegratedGPUName public +// instead. +func isIntegratedGPUName(name string) bool { + if strings.Contains(name, "(nvgpu)") { + return true + } + if strings.Contains(name, "NVIDIA Thor") { + return true + } + return false +} diff --git a/pkg/nvcdi/namer.go b/pkg/nvcdi/namer.go index 8019f699e..8ebdd33b4 100644 --- a/pkg/nvcdi/namer.go +++ b/pkg/nvcdi/namer.go @@ -105,12 +105,6 @@ type convert struct { nvmlUUIDer } -type uuidIgnored struct{} - -func (m uuidIgnored) GetUUID() (string, error) { - return "", nil -} - type uuidUnsupported struct{} func (m convert) GetUUID() (string, error) {