Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")
}
if len(automaticDevices) > 0 {
automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...)
automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
if err == nil {
return automaticModifier, nil
Expand Down Expand Up @@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string {
return devices
}

type gatedDevices image.CUDA

// DeviceRequests returns a list of devices that are required for gated devices.
func (g gatedDevices) DeviceRequests() []string {
i := (image.CUDA)(g)

var devices []string
if i.Getenv("NVIDIA_GDS") == "enabled" {
devices = append(devices, "mode=gds")
}
if i.Getenv("NVIDIA_MOFED") == "enabled" {
devices = append(devices, "mode=mofed")
}
if i.Getenv("NVIDIA_GDRCOPY") == "enabled" {
devices = append(devices, "mode=gdrcopy")
}
if i.Getenv("NVIDIA_NVSWITCH") == "enabled" {
devices = append(devices, "mode=nvswitch")
}

return devices
}

// filterAutomaticDevices searches for "automatic" device names in the input slice.
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
Expand All @@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string {
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)

var identifiers []string
perModeIdentifiers := make(map[string][]string)
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
modes := []string{"auto"}
for _, device := range devices {
identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix))
if strings.HasPrefix(device, "mode=") {
modes = append(modes, strings.TrimPrefix(device, "mode="))
continue
}
perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix))
}

cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
nvcdi.WithVendor(automaticDeviceVendor),
nvcdi.WithClass(automaticDeviceClass),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library: %w", err)
}
var modifiers oci.SpecModifiers
for _, mode := range modes {
cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
nvcdi.WithVendor(automaticDeviceVendor),
nvcdi.WithClass(perModeDeviceClass[mode]),
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing perModeDeviceClass[mode] without checking if the key exists could cause a panic or unexpected behavior. The map only contains "auto" key, but modes can include "gds", "mofed", "gdrcopy", "nvswitch".

Copilot uses AI. Check for mistakes.
nvcdi.WithMode(mode),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
}

spec, err := cdilib.GetSpec(identifiers...)
if err != nil {
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
}
cdiDeviceRequestor, err := cdi.New(
cdi.WithLogger(logger),
cdi.WithSpec(spec.Raw()),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...)
if err != nil {
return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err)
}

cdiDeviceRequestor, err := cdi.New(
cdi.WithLogger(logger),
cdi.WithSpec(spec.Raw()),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI modifier for mode %q: %w", mode, err)
}

modifiers = append(modifiers, cdiDeviceRequestor)
}

return cdiDeviceRequestor, nil
return modifiers, nil
}

type deduplicatedDeviceRequestor struct {
Expand Down
19 changes: 19 additions & 0 deletions internal/oci/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ type SpecModifier interface {
Modify(*specs.Spec) error
}

// SpecModifiers is a collection of OCI Spec modifiers that can be treated as a
// single modifier.
type SpecModifiers []SpecModifier

var _ SpecModifier = (SpecModifiers)(nil)

// Spec defines the operations to be performed on an OCI specification
//
//go:generate moq -rm -fmt=goimports -stub -out spec_mock.go . Spec
Expand All @@ -57,3 +63,16 @@ func NewSpec(logger logger.Interface, args []string) (Spec, error) {

return ociSpec, nil
}

// Modify a spec based on a collection of modifiers.
func (ms SpecModifiers) Modify(s *specs.Spec) error {
for _, m := range ms {
if m == nil {
continue
}
if err := m.Modify(s); err != nil {
return err
}
}
return nil
}
31 changes: 23 additions & 8 deletions pkg/nvcdi/mofed.go → pkg/nvcdi/gated.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
)

type mofedlib nvcdilib
type gatedlib nvcdilib

var _ deviceSpecGeneratorFactory = (*mofedlib)(nil)
var _ deviceSpecGeneratorFactory = (*gatedlib)(nil)

func (l *mofedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
func (l *gatedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
return l, nil
}

// GetDeviceSpecs returns the CDI device specs for a single all device.
func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) {
discoverer, err := discover.NewMOFEDDiscoverer(l.logger, l.driverRoot)
func (l *gatedlib) GetDeviceSpecs() ([]specs.Device, error) {
discoverer, err := l.getModeDiscoverer()
if err != nil {
return nil, fmt.Errorf("failed to create MOFED discoverer: %v", err)
return nil, fmt.Errorf("failed to create discoverer for mode %q: %w", l.mode, err)
}
edits, err := edits.FromDiscoverer(discoverer)
if err != nil {
return nil, fmt.Errorf("failed to create container edits for MOFED devices: %v", err)
return nil, fmt.Errorf("failed to create container edits: %w", err)
}

deviceSpec := specs.Device{
Expand All @@ -53,7 +53,22 @@ func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) {
return []specs.Device{deviceSpec}, nil
}

func (l *gatedlib) getModeDiscoverer() (discover.Discover, error) {
switch l.mode {
case ModeGdrcopy:
return discover.NewGDRCopyDiscoverer(l.logger, l.devRoot)
case ModeGds:
return discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
case ModeMofed:
return discover.NewMOFEDDiscoverer(l.logger, l.driverRoot)
case ModeNvswitch:
return discover.NewNvSwitchDiscoverer(l.logger, l.devRoot)
default:
return nil, fmt.Errorf("unrecognized mode")
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message "unrecognized mode" is not helpful for debugging. It should include the actual mode value that was unrecognized.

Suggested change
return nil, fmt.Errorf("unrecognized mode")
return nil, fmt.Errorf("unrecognized mode: %q", l.mode)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message "unrecognized mode" is not helpful as it doesn't specify which mode was unrecognized. Consider including the actual mode value in the error message.

Suggested change
return nil, fmt.Errorf("unrecognized mode")
return nil, fmt.Errorf("unrecognized mode %q", l.mode)

Copilot uses AI. Check for mistakes.
}
}

// GetCommonEdits generates a CDI specification that can be used for ANY devices
func (l *mofedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
func (l *gatedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return edits.FromDiscoverer(discover.None{})
}
59 changes: 0 additions & 59 deletions pkg/nvcdi/gds.go

This file was deleted.

11 changes: 3 additions & 8 deletions pkg/nvcdi/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,11 @@ func New(opts ...Option) (Interface, error) {
factory = (*nvmllib)(l)
case ModeWsl:
factory = (*wsllib)(l)
case ModeGds:
case ModeGdrcopy, ModeGds, ModeMofed:
if l.class == "" {
l.class = "gds"
l.class = string(l.mode)
}
factory = (*gdslib)(l)
case ModeMofed:
if l.class == "" {
l.class = "mofed"
}
factory = (*mofedlib)(l)
factory = (*gatedlib)(l)
case ModeImex:
if l.class == "" {
l.class = classImexChannel
Expand Down
18 changes: 13 additions & 5 deletions pkg/nvcdi/mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ const (
ModeWsl = Mode("wsl")
// ModeManagement configures the CDI spec generator to generate a management spec.
ModeManagement = Mode("management")
// ModeGdrcopy configures the CDI spec generator to generate a GDR Copy spec.
ModeGdrcopy = Mode("gdrcopy")
// ModeGds configures the CDI spec generator to generate a GDS spec.
ModeGds = Mode("gds")
// ModeMofed configures the CDI spec generator to generate a MOFED spec.
ModeMofed = Mode("mofed")
// ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV
// mountspec files.
ModeCSV = Mode("csv")
// ModeImex configures the CDI spec generated to generate a spec for the available IMEX channels.
// ModeImex configures the CDI spec generator to generate a spec for the available IMEX channels.
ModeImex = Mode("imex")
// ModeNvswitch configures the CDI spec generator to generate a spec for the available nvswitch devices.
ModeNvswitch = Mode("nvswitch")
)

type modeConstraint interface {
Expand All @@ -60,12 +64,15 @@ func getModes() modes {
validModesOnce.Do(func() {
all := []Mode{
ModeAuto,
ModeNvml,
ModeWsl,
ModeManagement,
ModeCSV,
ModeGdrcopy,
ModeGds,
ModeImex,
ModeManagement,
ModeMofed,
ModeCSV,
ModeNvml,
ModeNvswitch,
ModeWsl,
}
lookup := make(map[Mode]bool)

Expand Down Expand Up @@ -103,6 +110,7 @@ func (l *nvcdilib) resolveMode() (rmode Mode) {
}
defer func() {
l.logger.Infof("Auto-detected mode as '%v'", rmode)
l.mode = rmode
}()

platform := l.infolib.ResolvePlatform()
Expand Down