From f3023401da2aaa4ae7b6c173e057ba1503cf0b39 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 19 Nov 2025 12:20:09 +0100 Subject: [PATCH] Use jit-cdi mode for CSV systems Signed-off-by: Evan Lezar --- internal/info/auto.go | 18 ++++++++++++++---- internal/info/auto_test.go | 24 ++++++++++++------------ internal/modifier/cdi.go | 6 ++++-- internal/modifier/csv.go | 20 ++++++++++++-------- internal/runtime/runtime_factory.go | 2 ++ 5 files changed, 44 insertions(+), 26 deletions(-) diff --git a/internal/info/auto.go b/internal/info/auto.go index 3d69ad5c4..2a4a43868 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -53,9 +53,10 @@ type RuntimeModeResolver interface { type modeResolver struct { logger logger.Interface // TODO: This only needs to consider the requested devices. - image *image.CUDA - propertyExtractor info.PropertyExtractor - defaultMode RuntimeMode + image *image.CUDA + propertyExtractor info.PropertyExtractor + defaultMode RuntimeMode + forceCSVModeForTegraSystems bool } type Option func(*modeResolver) @@ -66,6 +67,12 @@ func WithDefaultMode(defaultMode RuntimeMode) Option { } } +func WithForceCSVModeForTegraSystems(forceCSVModeForTegraSystems bool) Option { + return func(mr *modeResolver) { + mr.forceCSVModeForTegraSystems = forceCSVModeForTegraSystems + } +} + func WithLogger(logger logger.Interface) Option { return func(mr *modeResolver) { mr.logger = logger @@ -130,7 +137,10 @@ func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) { case info.PlatformNVML, info.PlatformWSL: return m.defaultMode case info.PlatformTegra: - return CSVRuntimeMode + if m.forceCSVModeForTegraSystems { + return CSVRuntimeMode + } + return JitCDIRuntimeMode } return m.defaultMode } diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index f6d99c7e5..15ff59bc3 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -55,34 +55,34 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "jit-cdi", }, { - description: "non-nvml, non-tegra, nvgpu resolves to csv", + description: "non-nvml, non-tegra, nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": false, "tegra": false, "nvgpu": true, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { - description: "non-nvml, tegra, non-nvgpu resolves to csv", + description: "non-nvml, tegra, non-nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": false, "tegra": true, "nvgpu": false, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { - description: "non-nvml, tegra, nvgpu resolves to csv", + description: "non-nvml, tegra, nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": false, "tegra": true, "nvgpu": true, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { description: "nvml, non-tegra, non-nvgpu resolves to jit-cdi", @@ -95,14 +95,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "jit-cdi", }, { - description: "nvml, non-tegra, nvgpu resolves to csv", + description: "nvml, non-tegra, nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": false, "nvgpu": true, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { description: "nvml, tegra, non-nvgpu resolves to jit-cdi", @@ -115,14 +115,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "jit-cdi", }, { - description: "nvml, tegra, nvgpu resolves to csv", + description: "nvml, tegra, nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": true, "nvgpu": true, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { description: "cdi devices resolves to cdi", @@ -154,7 +154,7 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "jit-cdi", }, { - description: "at least one non-cdi device resolves to csv", + description: "at least one non-cdi device resolves to jit-cdi", mode: "auto", envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", @@ -164,7 +164,7 @@ func TestResolveAutoMode(t *testing.T) { "tegra": true, "nvgpu": false, }, - expectedMode: "csv", + expectedMode: "jit-cdi", }, { description: "cdi mount devices resolves to CDI", diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index f3543cc36..60e7a14dd 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -65,7 +65,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD automaticDevices = append(automaticDevices, withUniqueDevices(gatedDevices(image)).DeviceRequests()...) automaticDevices = append(automaticDevices, withUniqueDevices(imexDevices(image)).DeviceRequests()...) - automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices) + automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, image, automaticDevices) if err == nil { return automaticModifier, nil } @@ -163,9 +163,10 @@ func filterAutomaticDevices(devices []string) []string { return automatic } -func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { +func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, devices []string) (oci.SpecModifier, error) { logger.Debugf("Generating in-memory CDI specs for devices %v", devices) + csvFileList := getCSVFileList(cfg, image) cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...) logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers) @@ -179,6 +180,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]), nvcdi.WithMode(mode), nvcdi.WithFeatureFlags(cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...), + nvcdi.WithCSVFiles(csvFileList), ) if err != nil { return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err) diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index c8cf4ead3..cb58f5c1d 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -44,14 +44,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image return nil, fmt.Errorf("requirements not met: %v", err) } - csvFiles, err := csv.GetFileList(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) - if err != nil { - return nil, fmt.Errorf("failed to get list of CSV files: %v", err) - } - - if container.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" { - csvFiles = csv.BaseFilesOnly(csvFiles) - } + csvFiles := getCSVFileList(cfg, container) cdilib, err := nvcdi.New( nvcdi.WithLogger(logger), @@ -106,3 +99,14 @@ func checkRequirements(logger logger.Interface, image image.CUDA) error { return r.Assert() } + +func getCSVFileList(cfg *config.Config, container image.CUDA) []string { + csvFiles, err := csv.GetFileList(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) + if err != nil { + return nil + } + if container.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" { + csvFiles = csv.BaseFilesOnly(csvFiles) + } + return csvFiles +} diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index dc6424c36..afebd4797 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -141,6 +141,8 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe modeResolver := info.NewRuntimeModeResolver( info.WithLogger(logger), info.WithImage(&image), + // TODO: Add a feature flag. + info.WithForceCSVModeForTegraSystems(false), ) mode := modeResolver.ResolveRuntimeMode(cfg.NVIDIAContainerRuntimeConfig.Mode) // We update the mode here so that we can continue passing just the config to other functions.