diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 2a73a0543..aefa65c7a 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices") } if len(automaticDevices) > 0 { + automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...) automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices) if err == nil { return automaticModifier, nil @@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string { return devices } +type gatedDevices image.CUDA + +// DeviceRequests returns a list of devices that are required for gated devices. +func (g gatedDevices) DeviceRequests() []string { + i := (image.CUDA)(g) + + var devices []string + if i.Getenv("NVIDIA_GDS") == "enabled" { + devices = append(devices, "mode=gds") + } + if i.Getenv("NVIDIA_MOFED") == "enabled" { + devices = append(devices, "mode=mofed") + } + if i.Getenv("NVIDIA_GDRCOPY") == "enabled" { + devices = append(devices, "mode=gdrcopy") + } + if i.Getenv("NVIDIA_NVSWITCH") == "enabled" { + devices = append(devices, "mode=nvswitch") + } + + return devices +} + // filterAutomaticDevices searches for "automatic" device names in the input slice. // "Automatic" devices are a well-defined list of CDI device names which, when requested, // trigger the generation of a CDI spec at runtime. This removes the need to generate a @@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string { func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { logger.Debugf("Generating in-memory CDI specs for devices %v", devices) - var identifiers []string + perModeIdentifiers := make(map[string][]string) + perModeDeviceClass := map[string]string{"auto": automaticDeviceClass} + modes := []string{"auto"} for _, device := range devices { - identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix)) + if strings.HasPrefix(device, "mode=") { + modes = append(modes, strings.TrimPrefix(device, "mode=")) + continue + } + perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix)) } - cdilib, err := nvcdi.New( - nvcdi.WithLogger(logger), - nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), - nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), - nvcdi.WithVendor(automaticDeviceVendor), - nvcdi.WithClass(automaticDeviceClass), - ) - if err != nil { - return nil, fmt.Errorf("failed to construct CDI library: %w", err) - } + var modifiers oci.SpecModifiers + for _, mode := range modes { + cdilib, err := nvcdi.New( + nvcdi.WithLogger(logger), + nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), + nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), + nvcdi.WithVendor(automaticDeviceVendor), + nvcdi.WithClass(perModeDeviceClass[mode]), + nvcdi.WithMode(mode), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err) + } - spec, err := cdilib.GetSpec(identifiers...) - if err != nil { - return nil, fmt.Errorf("failed to generate CDI spec: %w", err) - } - cdiDeviceRequestor, err := cdi.New( - cdi.WithLogger(logger), - cdi.WithSpec(spec.Raw()), - ) - if err != nil { - return nil, fmt.Errorf("failed to construct CDI modifier: %w", err) + spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err) + } + + cdiDeviceRequestor, err := cdi.New( + cdi.WithLogger(logger), + cdi.WithSpec(spec.Raw()), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI modifier for mode %q: %w", mode, err) + } + + modifiers = append(modifiers, cdiDeviceRequestor) } - return cdiDeviceRequestor, nil + return modifiers, nil } type deduplicatedDeviceRequestor struct { diff --git a/internal/oci/spec.go b/internal/oci/spec.go index 11cac3e7b..1e2c144a7 100644 --- a/internal/oci/spec.go +++ b/internal/oci/spec.go @@ -31,6 +31,12 @@ type SpecModifier interface { Modify(*specs.Spec) error } +// SpecModifiers is a collection of OCI Spec modifiers that can be treated as a +// single modifier. +type SpecModifiers []SpecModifier + +var _ SpecModifier = (SpecModifiers)(nil) + // Spec defines the operations to be performed on an OCI specification // //go:generate moq -rm -fmt=goimports -stub -out spec_mock.go . Spec @@ -57,3 +63,16 @@ func NewSpec(logger logger.Interface, args []string) (Spec, error) { return ociSpec, nil } + +// Modify a spec based on a collection of modifiers. +func (ms SpecModifiers) Modify(s *specs.Spec) error { + for _, m := range ms { + if m == nil { + continue + } + if err := m.Modify(s); err != nil { + return err + } + } + return nil +} diff --git a/pkg/nvcdi/mofed.go b/pkg/nvcdi/gated.go similarity index 59% rename from pkg/nvcdi/mofed.go rename to pkg/nvcdi/gated.go index 2169eac23..853e4d04a 100644 --- a/pkg/nvcdi/mofed.go +++ b/pkg/nvcdi/gated.go @@ -26,23 +26,23 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" ) -type mofedlib nvcdilib +type gatedlib nvcdilib -var _ deviceSpecGeneratorFactory = (*mofedlib)(nil) +var _ deviceSpecGeneratorFactory = (*gatedlib)(nil) -func (l *mofedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) { +func (l *gatedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) { return l, nil } // GetDeviceSpecs returns the CDI device specs for a single all device. -func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) { - discoverer, err := discover.NewMOFEDDiscoverer(l.logger, l.driverRoot) +func (l *gatedlib) GetDeviceSpecs() ([]specs.Device, error) { + discoverer, err := l.getModeDiscoverer() if err != nil { - return nil, fmt.Errorf("failed to create MOFED discoverer: %v", err) + return nil, fmt.Errorf("failed to create discoverer for mode %q: %w", l.mode, err) } edits, err := edits.FromDiscoverer(discoverer) if err != nil { - return nil, fmt.Errorf("failed to create container edits for MOFED devices: %v", err) + return nil, fmt.Errorf("failed to create container edits: %w", err) } deviceSpec := specs.Device{ @@ -53,7 +53,22 @@ func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) { return []specs.Device{deviceSpec}, nil } +func (l *gatedlib) getModeDiscoverer() (discover.Discover, error) { + switch l.mode { + case ModeGdrcopy: + return discover.NewGDRCopyDiscoverer(l.logger, l.devRoot) + case ModeGds: + return discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot) + case ModeMofed: + return discover.NewMOFEDDiscoverer(l.logger, l.driverRoot) + case ModeNvswitch: + return discover.NewNvSwitchDiscoverer(l.logger, l.devRoot) + default: + return nil, fmt.Errorf("unrecognized mode") + } +} + // GetCommonEdits generates a CDI specification that can be used for ANY devices -func (l *mofedlib) GetCommonEdits() (*cdi.ContainerEdits, error) { +func (l *gatedlib) GetCommonEdits() (*cdi.ContainerEdits, error) { return edits.FromDiscoverer(discover.None{}) } diff --git a/pkg/nvcdi/gds.go b/pkg/nvcdi/gds.go deleted file mode 100644 index ad7cf6501..000000000 --- a/pkg/nvcdi/gds.go +++ /dev/null @@ -1,59 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package nvcdi - -import ( - "fmt" - - "tags.cncf.io/container-device-interface/pkg/cdi" - "tags.cncf.io/container-device-interface/specs-go" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" -) - -type gdslib nvcdilib - -var _ deviceSpecGeneratorFactory = (*gdslib)(nil) - -func (l *gdslib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) { - return l, nil -} - -// GetDeviceSpecs returns the CDI device specs for a single all device. -func (l *gdslib) GetDeviceSpecs() ([]specs.Device, error) { - discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot) - if err != nil { - return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err) - } - edits, err := edits.FromDiscoverer(discoverer) - if err != nil { - return nil, fmt.Errorf("failed to create container edits for GPUDirect Storage: %v", err) - } - - deviceSpec := specs.Device{ - Name: "all", - ContainerEdits: *edits.ContainerEdits, - } - - return []specs.Device{deviceSpec}, nil -} - -// GetCommonEdits generates a CDI specification that can be used for ANY devices -func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) { - return edits.FromDiscoverer(discover.None{}) -} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 4b5b0e26f..781b15507 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -129,16 +129,11 @@ func New(opts ...Option) (Interface, error) { factory = (*nvmllib)(l) case ModeWsl: factory = (*wsllib)(l) - case ModeGds: + case ModeGdrcopy, ModeGds, ModeMofed: if l.class == "" { - l.class = "gds" + l.class = string(l.mode) } - factory = (*gdslib)(l) - case ModeMofed: - if l.class == "" { - l.class = "mofed" - } - factory = (*mofedlib)(l) + factory = (*gatedlib)(l) case ModeImex: if l.class == "" { l.class = classImexChannel diff --git a/pkg/nvcdi/mode.go b/pkg/nvcdi/mode.go index 5b8f0369e..a68170ece 100644 --- a/pkg/nvcdi/mode.go +++ b/pkg/nvcdi/mode.go @@ -33,6 +33,8 @@ const ( ModeWsl = Mode("wsl") // ModeManagement configures the CDI spec generator to generate a management spec. ModeManagement = Mode("management") + // ModeGdrcopy configures the CDI spec generator to generate a GDR Copy spec. + ModeGdrcopy = Mode("gdrcopy") // ModeGds configures the CDI spec generator to generate a GDS spec. ModeGds = Mode("gds") // ModeMofed configures the CDI spec generator to generate a MOFED spec. @@ -40,8 +42,10 @@ const ( // ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV // mountspec files. ModeCSV = Mode("csv") - // ModeImex configures the CDI spec generated to generate a spec for the available IMEX channels. + // ModeImex configures the CDI spec generator to generate a spec for the available IMEX channels. ModeImex = Mode("imex") + // ModeNvswitch configures the CDI spec generator to generate a spec for the available nvswitch devices. + ModeNvswitch = Mode("nvswitch") ) type modeConstraint interface { @@ -60,12 +64,15 @@ func getModes() modes { validModesOnce.Do(func() { all := []Mode{ ModeAuto, - ModeNvml, - ModeWsl, - ModeManagement, + ModeCSV, + ModeGdrcopy, ModeGds, + ModeImex, + ModeManagement, ModeMofed, - ModeCSV, + ModeNvml, + ModeNvswitch, + ModeWsl, } lookup := make(map[Mode]bool) @@ -103,6 +110,7 @@ func (l *nvcdilib) resolveMode() (rmode Mode) { } defer func() { l.logger.Infof("Auto-detected mode as '%v'", rmode) + l.mode = rmode }() platform := l.infolib.ResolvePlatform()