Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions cmd/nvidia-container-runtime-hook/container_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []stri
return containerImage.VisibleDevicesFromEnvVar()
}

func getDevices(hookConfig *HookConfig, image image.CUDA, privileged bool) []string {
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.VisibleDevicesFromMounts()
Expand Down Expand Up @@ -197,7 +197,7 @@ func getMigDevices(image image.CUDA, envvar string) *string {
return &devices
}

func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) []string {
func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.ImexChannelsFromMounts()
Expand All @@ -217,10 +217,10 @@ func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool)
return nil
}

func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)

capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)

Expand All @@ -244,10 +244,10 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
return capabilities
}

func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) *nvidiaConfig {
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy()

devices := getDevices(hookConfig, image, privileged)
devices := hookConfig.getDevices(image, privileged)
if len(devices) == 0 {
// empty devices means this is not a GPU container.
return nil
Expand All @@ -269,7 +269,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
}

imexChannels := getImexChannels(hookConfig, image, privileged)
imexChannels := hookConfig.getImexChannels(image, privileged)

driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()

Expand All @@ -288,7 +288,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
}
}

func getContainerConfig(hook HookConfig) (config containerConfig) {
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
var h HookState
d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil {
Expand All @@ -305,7 +305,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
image, err := image.New(
image.WithEnv(s.Process.Env),
image.WithMounts(s.Mounts),
image.WithDisableRequire(hook.DisableRequire),
image.WithDisableRequire(hookConfig.DisableRequire),
)
if err != nil {
log.Panicln(err)
Expand All @@ -316,6 +316,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
Pid: h.Pid,
Rootfs: s.Root.Path,
Image: image,
Nvidia: getNvidiaConfig(&hook, image, privileged),
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
}
}
78 changes: 46 additions & 32 deletions cmd/nvidia-container-runtime-hook/container_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)

Expand All @@ -15,7 +16,7 @@ func TestGetNvidiaConfig(t *testing.T) {
description string
env map[string]string
privileged bool
hookConfig *HookConfig
hookConfig *hookConfig
expectedConfig *nvidiaConfig
expectedPanic bool
}{
Expand Down Expand Up @@ -394,8 +395,10 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaDriverCapabilities: "all",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display",
},
},
expectedConfig: &nvidiaConfig{
Devices: []string{"all"},
Expand All @@ -409,8 +412,10 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaDriverCapabilities: "video,display",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display,compute,utility",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,compute,utility",
},
},
expectedConfig: &nvidiaConfig{
Devices: []string{"all"},
Expand All @@ -423,8 +428,10 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaVisibleDevices: "all",
},
privileged: true,
hookConfig: &HookConfig{
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: []string{"all"},
Expand All @@ -438,9 +445,11 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
},
privileged: true,
hookConfig: &HookConfig{
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: []string{"GPU1", "GPU2"},
Expand All @@ -454,9 +463,11 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
},
privileged: true,
hookConfig: &HookConfig{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
},
expectedConfig: &nvidiaConfig{
Devices: []string{"GPU1", "GPU2"},
Expand All @@ -470,14 +481,14 @@ func TestGetNvidiaConfig(t *testing.T) {
image.WithEnvMap(tc.env),
)
// Wrap the call to getNvidiaConfig() in a closure.
var config *nvidiaConfig
var cfg *nvidiaConfig
getConfig := func() {
hookConfig := tc.hookConfig
if hookConfig == nil {
defaultConfig, _ := getDefaultHookConfig()
hookConfig = &defaultConfig
hookCfg := tc.hookConfig
if hookCfg == nil {
defaultConfig, _ := config.GetDefault()
hookCfg = &hookConfig{defaultConfig}
}
config = getNvidiaConfig(hookConfig, image, tc.privileged)
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
}

// For any tests that are expected to panic, make sure they do.
Expand All @@ -491,18 +502,18 @@ func TestGetNvidiaConfig(t *testing.T) {

// And start comparing the test results to the expected results.
if tc.expectedConfig == nil {
require.Nil(t, config, tc.description)
require.Nil(t, cfg, tc.description)
return
}

require.NotNil(t, config, tc.description)
require.NotNil(t, cfg, tc.description)

require.Equal(t, tc.expectedConfig.Devices, config.Devices)
require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices)
require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices)
require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities)
require.Equal(t, tc.expectedConfig.Devices, cfg.Devices)
require.Equal(t, tc.expectedConfig.MigConfigDevices, cfg.MigConfigDevices)
require.Equal(t, tc.expectedConfig.MigMonitorDevices, cfg.MigMonitorDevices)
require.Equal(t, tc.expectedConfig.DriverCapabilities, cfg.DriverCapabilities)

require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements)
require.ElementsMatch(t, tc.expectedConfig.Requirements, cfg.Requirements)
})
}
}
Expand Down Expand Up @@ -612,10 +623,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
),
image.WithMounts(tc.mountDevices),
)
hookConfig, _ := getDefaultHookConfig()
hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = getDevices(&hookConfig, image, tc.privileged)
defaultConfig, _ := config.GetDefault()
cfg := &hookConfig{defaultConfig}
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = cfg.getDevices(image, tc.privileged)
}

// For all other tests, just grab the devices and check the results
Expand Down Expand Up @@ -940,8 +952,10 @@ func TestGetDriverCapabilities(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
var capabilities string

c := HookConfig{
SupportedDriverCapabilities: tc.supportedCapabilities,
c := hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: tc.supportedCapabilities,
},
}

image, _ := image.New(
Expand Down
22 changes: 8 additions & 14 deletions cmd/nvidia-container-runtime-hook/hook_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,10 @@ const (
driverPath = "/run/nvidia/driver"
)

// HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig config.Config

func getDefaultHookConfig() (HookConfig, error) {
defaultCfg, err := config.GetDefault()
if err != nil {
return HookConfig{}, err
}

return *(*HookConfig)(defaultCfg), nil
// hookConfig wraps the toolkit config.
// This allows for functions to be defined on the local type.
type hookConfig struct {
*config.Config
}

// loadConfig loads the required paths for the hook config.
Expand Down Expand Up @@ -56,12 +50,12 @@ func loadConfig() (*config.Config, error) {
return config.GetDefault()
}

func getHookConfig() (*HookConfig, error) {
func getHookConfig() (*hookConfig, error) {
cfg, err := loadConfig()
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
config := (*HookConfig)(cfg)
config := &hookConfig{cfg}

allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
Expand All @@ -79,7 +73,7 @@ func getHookConfig() (*HookConfig, error) {

// getConfigOption returns the toml config option associated with the
// specified struct field.
func (c HookConfig) getConfigOption(fieldName string) string {
func (c hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c)
f, ok := t.FieldByName(fieldName)
if !ok {
Expand All @@ -93,7 +87,7 @@ func (c HookConfig) getConfigOption(fieldName string) string {
}

// getSwarmResourceEnvvars returns the swarm resource envvars for the config.
func (c *HookConfig) getSwarmResourceEnvvars() []string {
func (c *hookConfig) getSwarmResourceEnvvars() []string {
if c.SwarmResource == "" {
return nil
}
Expand Down
13 changes: 8 additions & 5 deletions cmd/nvidia-container-runtime-hook/hook_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)

Expand Down Expand Up @@ -89,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
}
}

var config HookConfig
var cfg hookConfig
getHookConfig := func() {
c, _ := getHookConfig()
config = *c
cfg = *c
}

if tc.expectedPanic {
Expand All @@ -102,7 +103,7 @@ func TestGetHookConfig(t *testing.T) {

getHookConfig()

require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities)
require.EqualValues(t, tc.expectedDriverCapabilities, cfg.SupportedDriverCapabilities)
})
}
}
Expand Down Expand Up @@ -144,8 +145,10 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {

for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
c := &HookConfig{
SwarmResource: tc.value,
c := &hookConfig{
Config: &config.Config{
SwarmResource: tc.value,
},
}

envvars := c.getSwarmResourceEnvvars()
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-container-runtime-hook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func doPrestart() {
}
cli := hook.NVIDIAContainerCLIConfig

container := getContainerConfig(*hook)
container := hook.getContainerConfig()
nvidia := container.Nvidia
if nvidia == nil {
// Not a GPU container, nothing to do.
Expand Down
Loading