Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cli/internal/config/nodebootstrap/nodebootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ const (

var r = configcmd.NewRouter("node-bootstrap", "Generate a node bootstrap config for a remote cloud")
var Command *cobra.Command = r.Command()
var flagHasGPU bool
var flagEnableNvidiaGPURuntime bool
var flagVariant string
var flagArch string
var flagKubeVersion string

func init() {
r.Handle("ubuntu", writeUbuntuUserData)
r.Handle("flex", writeFlexUserData)

Command.Flags().BoolVar(&flagHasGPU, "gpu", false, "Indicates whether the node has GPU. This may affect the generated userdata.")
Command.Flags().BoolVar(&flagEnableNvidiaGPURuntime, "nvidia-gpu", false, "Enable Nvidia GPU runtime in containerd configuration.")
Command.Flags().StringVar(&flagArch, "arch", "amd64",
"CPU architecture for the flex node binary (e.g. amd64, arm64).")
Command.Flags().StringVar(&flagKubeVersion, "k8s-version", "1.33.3",
"Kubernetes version for the downloaded binaries (e.g. 1.33.3).")
Command.Flags().StringVar(&flagVariant, "variant", variantCloudInit,
fmt.Sprintf("Output variant: %q produces cloud-init YAML user data, %q produces an equivalent standalone bash script.", variantCloudInit, variantScript))
}
Expand Down Expand Up @@ -56,7 +62,12 @@ func marshalUserData(ud *cloudinit.UserData, w io.Writer) error {
}

func writeFlexUserData(ctx context.Context, w io.Writer) error {
ud, err := flex.UserData(flagHasGPU, "1.33.3", configcmd.DefaultKubeadmConfig(ctx))
ud, err := flex.UserData(
flex.WithEnableNvidiaGPURuntime(flagEnableNvidiaGPURuntime),
flex.WithArch(flagArch),
flex.WithKubeVersion(flagKubeVersion),
flex.WithKubeadmConfig(configcmd.DefaultKubeadmConfig(ctx)),
)
if err != nil {
return fmt.Errorf("generating flex userdata: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ func (srv *agentpoolsServer) CreateOrUpdate(ctx context.Context, req *api.Create
// if err != nil {
// return nil, err
// }
userData, err := flex.UserData(false, "1.33.3", kubeadmConfig)
userData, err := flex.UserData(
flex.WithKubeadmConfig(kubeadmConfig),
)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion plugin/pkg/services/agentpools/nebius/instance/agentpools.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ func (srv *agentPoolsServer) CreateOrUpdate(
// TODO: get gpu info from spec (might need to infer from SKU)
hasGPU := strings.Contains(apSpec.GetImageFamily(), "cuda")
// TODO: get the k8s version from spec
ud, err := flex.UserData(hasGPU, "1.33.3", kubeadmConfig)
ud, err := flex.UserData(
flex.WithEnableNvidiaGPURuntime(hasGPU),
flex.WithKubeadmConfig(kubeadmConfig),
)
if err != nil {
return nil, fmt.Errorf("failed to generate userdata: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mkdir -p /tmp/flex
curl -L -o /tmp/flex/aks-flex-node-linux-{{ .Arch }}.tar.gz https://github.com/Azure/AKSFlexNode/releases/download/{{ .Version }}/aks-flex-node-linux-{{ .Arch }}.tar.gz
tar -xzf /tmp/flex/aks-flex-node-linux-{{ .Arch }}.tar.gz -C /tmp/flex
mv /tmp/flex/aks-flex-node-linux-{{ .Arch }} /tmp/flex/aks-flex-node
chmod +x /tmp/flex/aks-flex-node
/tmp/flex/aks-flex-node apply -f /tmp/flex-config.json
rm -rf /tmp/flex
123 changes: 110 additions & 13 deletions plugin/pkg/services/agentpools/userdata/flex/flex.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package flex

import (
"bytes"
_ "embed"
"encoding/json"
"fmt"
"maps"
"strings"
"text/template"

"github.com/Azure/AKSFlexNode/components/api"
"github.com/Azure/AKSFlexNode/components/cri"
Expand All @@ -17,6 +21,82 @@ import (
"github.com/Azure/aks-flex/plugin/pkg/util/cloudinit"
)

//go:embed assets/bootstrap.sh.tmpl
var bootstrapTmpl string

var bootstrapTemplate = template.Must(template.New("bootstrap.sh").Parse(bootstrapTmpl))

const (
flexNodeVersion = "v0.0.13"
defaultArch = "amd64"
defaultKubeVer = "1.33.3"
)

// Options configures how the flex node userdata is generated.
type Options struct {
EnableNvidiaGPURuntime bool
KubeVersion string
Arch string
KubeadmConfig *kubeadmapi.Config
}

// Option is a functional option for [UserData].
type Option func(*Options)

// WithEnableNvidiaGPURuntime configures the containerd Nvidia GPU runtime.
func WithEnableNvidiaGPURuntime(enable bool) Option {
return func(o *Options) { o.EnableNvidiaGPURuntime = enable }
}

// WithKubeVersion sets the Kubernetes version for the downloaded binaries.
func WithKubeVersion(v string) Option {
return func(o *Options) { o.KubeVersion = v }
}

// WithArch sets the CPU architecture for the flex node binary (e.g. "amd64", "arm64").
func WithArch(arch string) Option {
return func(o *Options) { o.Arch = arch }
}

// WithKubeadmConfig sets the kubeadm join configuration.
func WithKubeadmConfig(cfg *kubeadmapi.Config) Option {
return func(o *Options) { o.KubeadmConfig = cfg }
}

func defaultOptions() *Options {
return &Options{
KubeVersion: defaultKubeVer,
Arch: defaultArch,
}
}

// supportedArchs is the set of CPU architectures for which flex node binaries
// are published.
var supportedArchs = map[string]bool{
"amd64": true,
"arm64": true,
}

// validate performs least-effort validation on the options. This is intentionally
// minimal to catch obvious mistakes for ad-hoc values; callers should perform
// more thorough validation beforehand.
func (o *Options) validate() error {
if !supportedArchs[o.Arch] {
return fmt.Errorf("unsupported arch %q, supported: amd64, arm64", o.Arch)
}
o.KubeVersion = strings.TrimPrefix(o.KubeVersion, "v")
if o.KubeVersion == "" {
return fmt.Errorf("kube version must not be empty")
}
return nil
}

// bootstrapParams holds the template parameters for the bootstrap script.
type bootstrapParams struct {
Arch string
Version string
}

func flexMetadata[T proto.Message](name string) *api.Metadata {
var zero T
typeName := string(zero.ProtoReflect().Descriptor().FullName())
Expand All @@ -27,12 +107,12 @@ func flexMetadata[T proto.Message](name string) *api.Metadata {
}

func resolveFlexComponentConfigs(
hasGPU bool,
enableNvidiaGPURuntime bool,
kubeVersion string,
kubeadmConfig *kubeadmapi.Config,
) ([]byte, error) {
startCRISpecBuilder := cri.StartContainerdServiceSpec_builder{}
if hasGPU {
if enableNvidiaGPURuntime {
startCRISpecBuilder.GpuConfig = cri.GPUConfig_builder{
NvidiaRuntime: cri.NvidiaRuntime_builder{}.Build(),
}.Build()
Expand Down Expand Up @@ -104,8 +184,33 @@ func resolveFlexComponentConfigs(
return b, nil
}

func UserData(hasGPU bool, kubeVersion string, kubeadmConfig *kubeadmapi.Config) (*cloudinit.UserData, error) {
componentConfigsJSON, err := resolveFlexComponentConfigs(hasGPU, kubeVersion, kubeadmConfig)
func renderBootstrapScript(arch string) (string, error) {
var buf bytes.Buffer
if err := bootstrapTemplate.Execute(&buf, bootstrapParams{
Arch: arch,
Version: flexNodeVersion,
}); err != nil {
return "", fmt.Errorf("rendering bootstrap script: %w", err)
}
return buf.String(), nil
}

func UserData(opts ...Option) (*cloudinit.UserData, error) {
o := defaultOptions()
for _, opt := range opts {
opt(o)
}

if err := o.validate(); err != nil {
return nil, err
}

componentConfigsJSON, err := resolveFlexComponentConfigs(o.EnableNvidiaGPURuntime, o.KubeVersion, o.KubeadmConfig)
if err != nil {
return nil, err
}

bootstrapScript, err := renderBootstrapScript(o.Arch)
if err != nil {
return nil, err
}
Expand All @@ -124,15 +229,7 @@ func UserData(hasGPU bool, kubeVersion string, kubeadmConfig *kubeadmapi.Config)
},
RunCmd: []any{
[]string{"set", "-e"},
strings.Join([]string{
"mkdir -p /tmp/flex",
// TODO: this should be overridable
"curl -L -o /tmp/flex/aks-flex-node-linux-amd64.tar.gz https://github.com/Azure/AKSFlexNode/releases/download/v0.0.12/aks-flex-node-linux-amd64.tar.gz",
"tar -xzf /tmp/flex/aks-flex-node-linux-amd64.tar.gz -C /tmp/flex",
"mv /tmp/flex/aks-flex-node-linux-amd64 /tmp/flex/aks-flex-node",
"chmod +x /tmp/flex/aks-flex-node",
"/tmp/flex/aks-flex-node apply -f /tmp/flex-config.json",
}, "\n"),
bootstrapScript,
},
}

Expand Down
113 changes: 113 additions & 0 deletions plugin/pkg/services/agentpools/userdata/flex/flex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flex

import (
"encoding/json"
"strings"
"testing"

kubeadm "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/api/features/kubeadm"
Expand All @@ -23,3 +24,115 @@ func Test_resolveFlexComponentConfigs_basic(t *testing.T) {
t.Fatalf("failed to unmarshal generated config: %v", err)
}
}

func TestUserData_defaults(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

ud, err := UserData(WithKubeadmConfig(kubeadmSpec))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

b, err := ud.Marshal()
if err != nil {
t.Fatalf("failed to marshal userdata: %v", err)
}

content := string(b)
// defaults should produce amd64 binary URL
if !strings.Contains(content, "amd64") {
t.Error("expected default arch amd64 in userdata")
}
}

func TestUserData_arm64(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

ud, err := UserData(
WithArch("arm64"),
WithKubeadmConfig(kubeadmSpec),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

b, err := ud.Marshal()
if err != nil {
t.Fatalf("failed to marshal userdata: %v", err)
}

content := string(b)
if !strings.Contains(content, "arm64") {
t.Error("expected arm64 in userdata")
}
if strings.Contains(content, "amd64") {
t.Error("unexpected amd64 in userdata when arm64 was specified")
}
}

func TestUserData_invalidArch(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

_, err := UserData(
WithArch("mips64"),
WithKubeadmConfig(kubeadmSpec),
)
if err == nil {
t.Fatal("expected error for unsupported arch")
}
if !strings.Contains(err.Error(), "unsupported arch") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestUserData_invalidKubeVersion(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

_, err := UserData(
WithKubeVersion(""),
WithKubeadmConfig(kubeadmSpec),
)
if err == nil {
t.Fatal("expected error for empty kube version")
}
if !strings.Contains(err.Error(), "must not be empty") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestUserData_trimsLeadingV(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

ud, err := UserData(
WithKubeVersion("v1.33.3"),
WithKubeadmConfig(kubeadmSpec),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

b, err := ud.Marshal()
if err != nil {
t.Fatalf("failed to marshal userdata: %v", err)
}

content := string(b)
if strings.Contains(content, "v1.33.3") {
t.Error("expected leading 'v' to be trimmed from kube version")
}
if !strings.Contains(content, "1.33.3") {
t.Error("expected kube version 1.33.3 in userdata")
}
}

func TestUserData_preReleaseKubeVersion(t *testing.T) {
kubeadmSpec := kubeadm.Config_builder{}.Build()

_, err := UserData(
WithKubeVersion("1.33.0-rc.1"),
WithKubeadmConfig(kubeadmSpec),
)
if err != nil {
t.Fatalf("expected pre-release kube version to be valid, got: %v", err)
}
}