-
Notifications
You must be signed in to change notification settings - Fork 439
Add support for gated modifications jit-cdi mode #1230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8085d40
4ac1158
6412bca
447ff50
3fcc351
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD | |
| return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices") | ||
| } | ||
| if len(automaticDevices) > 0 { | ||
| automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...) | ||
| automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices) | ||
| if err == nil { | ||
| return automaticModifier, nil | ||
|
|
@@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string { | |
| return devices | ||
| } | ||
|
|
||
| type gatedDevices image.CUDA | ||
|
|
||
| // DeviceRequests returns a list of devices that are required for gated devices. | ||
| func (g gatedDevices) DeviceRequests() []string { | ||
| i := (image.CUDA)(g) | ||
|
|
||
| var devices []string | ||
| if i.Getenv("NVIDIA_GDS") == "enabled" { | ||
| devices = append(devices, "mode=gds") | ||
| } | ||
| if i.Getenv("NVIDIA_MOFED") == "enabled" { | ||
| devices = append(devices, "mode=mofed") | ||
| } | ||
| if i.Getenv("NVIDIA_GDRCOPY") == "enabled" { | ||
| devices = append(devices, "mode=gdrcopy") | ||
| } | ||
| if i.Getenv("NVIDIA_NVSWITCH") == "enabled" { | ||
| devices = append(devices, "mode=nvswitch") | ||
| } | ||
|
|
||
| return devices | ||
| } | ||
|
|
||
| // filterAutomaticDevices searches for "automatic" device names in the input slice. | ||
| // "Automatic" devices are a well-defined list of CDI device names which, when requested, | ||
| // trigger the generation of a CDI spec at runtime. This removes the need to generate a | ||
|
|
@@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string { | |
| func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { | ||
| logger.Debugf("Generating in-memory CDI specs for devices %v", devices) | ||
|
|
||
| var identifiers []string | ||
| perModeIdentifiers := make(map[string][]string) | ||
| perModeDeviceClass := map[string]string{"auto": automaticDeviceClass} | ||
| modes := []string{"auto"} | ||
| for _, device := range devices { | ||
| identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix)) | ||
| if strings.HasPrefix(device, "mode=") { | ||
| modes = append(modes, strings.TrimPrefix(device, "mode=")) | ||
| continue | ||
| } | ||
| perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix)) | ||
| } | ||
|
|
||
| cdilib, err := nvcdi.New( | ||
| nvcdi.WithLogger(logger), | ||
| nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), | ||
| nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), | ||
| nvcdi.WithVendor(automaticDeviceVendor), | ||
| nvcdi.WithClass(automaticDeviceClass), | ||
| ) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to construct CDI library: %w", err) | ||
| } | ||
| var modifiers oci.SpecModifiers | ||
| for _, mode := range modes { | ||
| cdilib, err := nvcdi.New( | ||
| nvcdi.WithLogger(logger), | ||
| nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), | ||
| nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), | ||
| nvcdi.WithVendor(automaticDeviceVendor), | ||
| nvcdi.WithClass(perModeDeviceClass[mode]), | ||
|
||
| nvcdi.WithMode(mode), | ||
| ) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err) | ||
| } | ||
|
|
||
| spec, err := cdilib.GetSpec(identifiers...) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to generate CDI spec: %w", err) | ||
| } | ||
| cdiDeviceRequestor, err := cdi.New( | ||
| cdi.WithLogger(logger), | ||
| cdi.WithSpec(spec.Raw()), | ||
| ) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to construct CDI modifier: %w", err) | ||
| spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err) | ||
| } | ||
|
|
||
| cdiDeviceRequestor, err := cdi.New( | ||
| cdi.WithLogger(logger), | ||
| cdi.WithSpec(spec.Raw()), | ||
| ) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to construct CDI modifier for mode %q: %w", mode, err) | ||
| } | ||
|
|
||
| modifiers = append(modifiers, cdiDeviceRequestor) | ||
| } | ||
|
|
||
| return cdiDeviceRequestor, nil | ||
| return modifiers, nil | ||
| } | ||
|
|
||
| type deduplicatedDeviceRequestor struct { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,23 +26,23 @@ import ( | |||||||||
| "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| type mofedlib nvcdilib | ||||||||||
| type gatedlib nvcdilib | ||||||||||
|
|
||||||||||
| var _ deviceSpecGeneratorFactory = (*mofedlib)(nil) | ||||||||||
| var _ deviceSpecGeneratorFactory = (*gatedlib)(nil) | ||||||||||
|
|
||||||||||
| func (l *mofedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) { | ||||||||||
| func (l *gatedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) { | ||||||||||
| return l, nil | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // GetDeviceSpecs returns the CDI device specs for a single all device. | ||||||||||
| func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) { | ||||||||||
| discoverer, err := discover.NewMOFEDDiscoverer(l.logger, l.driverRoot) | ||||||||||
| func (l *gatedlib) GetDeviceSpecs() ([]specs.Device, error) { | ||||||||||
| discoverer, err := l.getModeDiscoverer() | ||||||||||
| if err != nil { | ||||||||||
| return nil, fmt.Errorf("failed to create MOFED discoverer: %v", err) | ||||||||||
| return nil, fmt.Errorf("failed to create discoverer for mode %q: %w", l.mode, err) | ||||||||||
| } | ||||||||||
| edits, err := edits.FromDiscoverer(discoverer) | ||||||||||
| if err != nil { | ||||||||||
| return nil, fmt.Errorf("failed to create container edits for MOFED devices: %v", err) | ||||||||||
| return nil, fmt.Errorf("failed to create container edits: %w", err) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| deviceSpec := specs.Device{ | ||||||||||
|
|
@@ -53,7 +53,22 @@ func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) { | |||||||||
| return []specs.Device{deviceSpec}, nil | ||||||||||
| } | ||||||||||
|
|
||||||||||
| func (l *gatedlib) getModeDiscoverer() (discover.Discover, error) { | ||||||||||
| switch l.mode { | ||||||||||
| case ModeGdrcopy: | ||||||||||
| return discover.NewGDRCopyDiscoverer(l.logger, l.devRoot) | ||||||||||
| case ModeGds: | ||||||||||
| return discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot) | ||||||||||
| case ModeMofed: | ||||||||||
| return discover.NewMOFEDDiscoverer(l.logger, l.driverRoot) | ||||||||||
| case ModeNvswitch: | ||||||||||
| return discover.NewNvSwitchDiscoverer(l.logger, l.devRoot) | ||||||||||
| default: | ||||||||||
| return nil, fmt.Errorf("unrecognized mode") | ||||||||||
|
||||||||||
| return nil, fmt.Errorf("unrecognized mode") | |
| return nil, fmt.Errorf("unrecognized mode: %q", l.mode) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message "unrecognized mode" is not helpful as it doesn't specify which mode was unrecognized. Consider including the actual mode value in the error message.
| return nil, fmt.Errorf("unrecognized mode") | |
| return nil, fmt.Errorf("unrecognized mode %q", l.mode) |
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.