Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
19 changes: 13 additions & 6 deletions cmd/nvidia-ctk-installer/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ type Options struct {
// mount.
ExecutablePath string
// EnabledCDI indicates whether CDI should be enabled.
EnableCDI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
EnableCDI bool
EnableNRI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
NRIPluginIndex string
NRISocket string

ConfigSources []string
}
Expand Down Expand Up @@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
cfg.EnableCDI()
}

if o.EnableNRI {
cfg.EnableNRI()
}

return nil
}

Expand Down
101 changes: 101 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package nri

import (
"context"
"fmt"

"github.com/containerd/nri/pkg/api"
"github.com/containerd/nri/pkg/stub"
"sigs.k8s.io/yaml"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

const (
// nodeResourceCDIDeviceKey is the prefix of the key used for CDI device annotations.
nodeResourceCDIDeviceKey = "cdi-devices.noderesource.dev"
// Prefix of the key used for CDI device annotations.
nriCDIDeviceKey = "cdi-devices.nri.io"
)

type Plugin struct {
logger logger.Interface

Stub stub.Stub
}

// CreateContainer handles container creation requests.
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
adjust := &api.ContainerAdjustment{}

if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
return nil, nil, err
}

return adjust, nil, nil
}

func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
if err != nil {
return err
}

if len(devices) == 0 {
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Name: name,
},
)
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
}

return nil
}

func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
var (
cdiDevices []string
)

annotation := getAnnotation(annotations, nodeResourceCDIDeviceKey, nriCDIDeviceKey, ctr)
if len(annotation) == 0 {
return nil, nil
}

if err := yaml.Unmarshal(annotation, &cdiDevices); err != nil {
return nil, fmt.Errorf("invalid CDI device annotation %q: %w", string(annotation), err)
}

return cdiDevices, nil
}

func getAnnotation(annotations map[string]string, mainKey, oldKey, ctr string) []byte {
for _, key := range []string{
mainKey + "/container." + ctr,
oldKey + "/container." + ctr,
mainKey + "/pod",
oldKey + "/pod",
mainKey,
oldKey,
} {
if value, ok := annotations[key]; ok {
return []byte(value)
}
}

return nil
}

// Construct a container name for log messages.
func containerName(pod *api.PodSandbox, container *api.Container) string {
if pod != nil {
return pod.Name + "/" + container.Name
}
return container.Name
}
55 changes: 55 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
package runtime

import (
"context"
"fmt"
"os"

"github.com/containerd/nri/pkg/stub"
"github.com/urfave/cli/v3"

"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/crio"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/docker"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
Expand All @@ -34,6 +38,8 @@ const (
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
defaultRuntimeName = "nvidia"
defaultHostRootMount = "/host"
defaultNRIPluginIdx = "10"
defaultNRISocket = "/var/run/nri/nri.sock"

runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
)
Expand Down Expand Up @@ -94,6 +100,27 @@ func Flags(opts *Options) []cli.Flag {
Destination: &opts.EnableCDI,
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
},
&cli.BoolFlag{
Name: "enable-nri-in-runtime",
Usage: "Enable NRI in the configured runtime",
Destination: &opts.EnableNRI,
Value: true,
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
},
&cli.StringFlag{
Name: "nri-plugin-index",
Usage: "Specify the plugin index to register to NRI",
Value: defaultNRIPluginIdx,
Destination: &opts.NRIPluginIndex,
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
},
&cli.StringFlag{
Name: "nri-socket",
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
Value: defaultNRISocket,
Destination: &opts.NRISocket,
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
Expand Down Expand Up @@ -250,3 +277,31 @@ func GetLowlevelRuntimePaths(opts *Options, runtime string) ([]string, error) {
return nil, fmt.Errorf("undefined runtime %v", runtime)
}
}

func StartNRIPlugin(ctx context.Context, opts *Options) (*nri.Plugin, error) {
nriSocketPath := defaultNRISocket
if len(opts.NRISocket) > 0 {
nriSocketPath = opts.NRISocket
}
p := &nri.Plugin{}
_, err := os.Stat(nriSocketPath)
if err != nil {
return nil, fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
}

var pluginOpts []stub.Option
pluginOpts = append(pluginOpts, stub.WithPluginIdx(opts.NRIPluginIndex))
pluginOpts = append(pluginOpts, stub.WithSocketPath(nriSocketPath))
if p.Stub, err = stub.New(p, pluginOpts...); err != nil {
return nil, fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
}
err = p.Stub.Run(ctx)
if err != nil {
return nil, fmt.Errorf("plugin exited with error: %w", err)
}
return p, nil
}

func StopNRIPlugin(plugin *nri.Plugin) {
plugin.Stub.Stop()
}
33 changes: 29 additions & 4 deletions cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"syscall"

"github.com/containerd/nri/pkg/stub"
"github.com/urfave/cli/v3"
"golang.org/x/sys/unix"

Expand Down Expand Up @@ -70,7 +71,8 @@ func main() {
type app struct {
logger logger.Interface

toolkit *toolkit.Installer
pluginStub stub.Stub
toolkit *toolkit.Installer
}

// NewApp creates the CLI app fro the specified options.
Expand All @@ -93,8 +95,8 @@ func (a app) build() *cli.Command {
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
return ctx, a.Before(cmd, &options)
},
Action: func(_ context.Context, cmd *cli.Command) error {
return a.Run(cmd, &options)
Action: func(ctx context.Context, cmd *cli.Command) error {
return a.Run(ctx, cmd, &options)
},
Flags: []cli.Flag{
&cli.BoolFlag{
Expand Down Expand Up @@ -194,7 +196,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
// If the application is run as a daemon, the application waits and unconfigures
// the runtime on termination.
func (a *app) Run(c *cli.Command, o *options) error {
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
err := a.initialize(o.pidFile)
if err != nil {
return fmt.Errorf("unable to initialize: %v", err)
Expand Down Expand Up @@ -222,6 +224,14 @@ func (a *app) Run(c *cli.Command, o *options) error {
}

if !o.noDaemon {
if o.runtimeOptions.EnableNRI {
go func() {
err = a.startNRIPluginServer(ctx, &o.runtimeOptions)
if err != nil {
a.logger.Errorf("unable to start runtime plugin server: %v", err)
}
}()
}
err = a.waitForSignal()
if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err)
Expand Down Expand Up @@ -290,6 +300,11 @@ func (a *app) waitForSignal() error {
func (a *app) shutdown(pidFile string) {
a.logger.Infof("Shutting Down")

if a.pluginStub != nil {
a.logger.Infof("Stopping NRI plugin server...")
a.pluginStub.Stop()
}

err := os.Remove(pidFile)
if err != nil {
a.logger.Warningf("Unable to remove pidfile: %v", err)
Expand Down Expand Up @@ -327,3 +342,13 @@ func (a *app) resolvePackageType(hostRoot string, packageType string) (rPackageT

return "deb", nil
}

func (a *app) startNRIPluginServer(ctx context.Context, opts *runtime.Options) error {
a.logger.Info("Starting NRI Plugin server...")
plugin, err := runtime.StartNRIPlugin(ctx, opts)
if plugin == nil || err != nil {
return fmt.Errorf("unable to setup NRI plugin server: %w", err)
}
a.pluginStub = plugin.Stub
return nil
}
1 change: 1 addition & 0 deletions cmd/nvidia-ctk-installer/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ version = 2
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
"--restart-mode=none",
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
"--enable-nri-in-runtime=false",
}

err := app.Run(context.Background(), append(testArgs, tc.args...))
Expand Down
14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/NVIDIA/go-nvlib v0.8.1
github.com/NVIDIA/go-nvml v0.13.0-1
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
github.com/google/uuid v1.6.0
github.com/moby/sys/mountinfo v0.7.2
github.com/moby/sys/reexec v0.1.0
Expand All @@ -19,24 +20,31 @@ require (
github.com/urfave/cli/v3 v3.6.1
golang.org/x/mod v0.30.0
golang.org/x/sys v0.38.0
sigs.k8s.io/yaml v1.4.0
tags.cncf.io/container-device-interface v1.0.2-0.20251114135136-1b24d969689f
tags.cncf.io/container-device-interface/specs-go v1.0.0
)

require (
cyphar.com/go-pathrs v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/ttrpc v1.2.7 // indirect
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/knqyf263/go-plugin v0.9.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/moby/sys/capability v0.4.0 // indirect
github.com/opencontainers/cgroups v0.0.4 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
google.golang.org/grpc v1.57.1 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
)
Loading
Loading