Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions internal/platform-support/dgpu/dgpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,29 @@ import (

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
)

// NewForDevice creates a discoverer for the specified Device.
func NewForDevice(d device.Device, opts ...Option) (discover.Discover, error) {
o := &options{}
for _, opt := range opts {
opt(o)
}

if o.logger == nil {
o.logger = logger.New()
}
o := new(opts...)

return o.newNvmlDGPUDiscoverer(&toRequiredInfo{d})
}

// NewForDevice creates a discoverer for the specified device and its associated MIG device.
func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (discover.Discover, error) {
o := new(opts...)

return o.newNvmlMigDiscoverer(
&toRequiredMigInfo{
MigDevice: mig,
parent: &toRequiredInfo{d},
},
)
}

func new(opts ...Option) *options {
o := &options{}
for _, opt := range opts {
opt(o)
Expand All @@ -48,10 +53,15 @@ func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (dis
o.logger = logger.New()
}

return o.newNvmlMigDiscoverer(
&toRequiredMigInfo{
MigDevice: mig,
parent: &toRequiredInfo{d},
},
)
if o.migCaps == nil {
migCaps, err := nvcaps.NewMigCaps()
if err != nil {
o.logger.Debugf("ignoring error getting MIG capability device paths: %v", err)
o.migCapsError = err
} else {
o.migCaps = migCaps
}
}

return o
}
13 changes: 6 additions & 7 deletions internal/platform-support/dgpu/nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,23 @@ type requiredMigInfo interface {
}

func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, error) {
gpu, gi, ci, err := d.getPlacementInfo()
if err != nil {
return nil, fmt.Errorf("error getting placement info: %w", err)
if o.migCaps == nil || o.migCapsError != nil {
return nil, fmt.Errorf("error getting MIG capability device paths: %v", o.migCapsError)
}

migCaps, err := nvcaps.NewMigCaps()
gpu, gi, ci, err := d.getPlacementInfo()
if err != nil {
return nil, fmt.Errorf("error getting MIG capability device paths: %v", err)
return nil, fmt.Errorf("error getting placement info: %w", err)
}

giCap := nvcaps.NewGPUInstanceCap(gpu, gi)
giCapDevicePath, err := migCaps.GetCapDevicePath(giCap)
giCapDevicePath, err := o.migCaps.GetCapDevicePath(giCap)
if err != nil {
return nil, fmt.Errorf("failed to get GI cap device path: %v", err)
}

ciCap := nvcaps.NewComputeInstanceCap(gpu, gi, ci)
ciCapDevicePath, err := migCaps.GetCapDevicePath(ciCap)
ciCapDevicePath, err := o.migCaps.GetCapDevicePath(ciCap)
if err != nil {
return nil, fmt.Errorf("failed to get CI cap device path: %v", err)
}
Expand Down
84 changes: 84 additions & 0 deletions internal/platform-support/dgpu/nvml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
)

// TODO: In order to properly test this, we need a mechanism to inject /
Expand Down Expand Up @@ -85,3 +86,86 @@ func TestNewNvmlDGPUDiscoverer(t *testing.T) {
})
}
}

func TestNewNvmlMIGDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()

nvmllib := &mock.Interface{}
devicelib := device.New(
nvmllib,
)

testCases := []struct {
description string
mig *mock.Device
parent nvml.Device
migCaps nvcaps.MigCaps
expectedError error
expectedDevices []discover.Device
expectedHooks []discover.Hook
expectedMounts []discover.Mount
}{
{
description: "",
mig: &mock.Device{
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
return true, nvml.SUCCESS
},
GetGpuInstanceIdFunc: func() (int, nvml.Return) {
return 1, nvml.SUCCESS
},
GetComputeInstanceIdFunc: func() (int, nvml.Return) {
return 2, nvml.SUCCESS
},
},
parent: &mock.Device{
GetMinorNumberFunc: func() (int, nvml.Return) {
return 3, nvml.SUCCESS
},
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
var busID [32]int8
for i, b := range []byte("00000000:45:00:00") {
busID[i] = int8(b)
}
info := nvml.PciInfo{
BusId: busID,
}
return info, nvml.SUCCESS
},
},
migCaps: nvcaps.MigCaps{
"gpu3/gi1/access": 31,
"gpu3/gi1/ci2/access": 312,
},
expectedDevices: nil,
expectedMounts: nil,
expectedHooks: []discover.Hook{},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {

tc.mig.GetDeviceHandleFromMigDeviceHandleFunc = func() (nvml.Device, nvml.Return) {
return tc.parent, nvml.SUCCESS
}
parent, err := devicelib.NewDevice(tc.parent)
require.NoError(t, err)

mig, err := devicelib.NewMigDevice(tc.mig)
require.NoError(t, err)

d, err := NewForMigDevice(parent, mig,
WithLogger(logger),
WithMIGCaps(tc.migCaps),
)
require.ErrorIs(t, err, tc.expectedError)

devices, _ := d.Devices()
require.EqualValues(t, tc.expectedDevices, devices)
hooks, _ := d.Hooks()
require.EqualValues(t, tc.expectedHooks, hooks)
mounts, _ := d.Mounts()
require.EqualValues(t, tc.expectedMounts, mounts)
})
}
}
13 changes: 13 additions & 0 deletions internal/platform-support/dgpu/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ package dgpu

import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
)

type options struct {
logger logger.Interface
devRoot string
nvidiaCDIHookPath string

// migCaps stores the MIG capabilities for the system.
// If MIG is not available, this is nil.
migCaps nvcaps.MigCaps
migCapsError error
}

type Option func(*options)
Expand All @@ -48,3 +54,10 @@ func WithNVIDIACDIHookPath(path string) Option {
l.nvidiaCDIHookPath = path
}
}

// WithMIGCaps sets the MIG capabilities.
func WithMIGCaps(migCaps nvcaps.MigCaps) Option {
return func(l *options) {
l.migCaps = migCaps
}
}