From 492b7a32703cc4b71de5a871cea9762efdccd668 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 5 May 2026 11:16:55 +0800 Subject: [PATCH 1/5] Implement execution server and task runner for coordinator protocol - Add `messages_test.go` to test message decoding and encoding functionalities. - Introduce `serde.go` for serialization of various data types to Airflow's format. - Create `serde_test.go` to validate serialization logic and ensure correctness. - Implement `server.go` to handle communication with the supervisor and manage task execution. - Add `task_runner.go` to execute tasks based on received startup details and handle success/failure. --- .../bundle/bundlev1/bundlev1server/server.go | 120 +++++- go-sdk/bundle/bundlev1/registry.go | 87 +++- go-sdk/bundle/bundlev1/task.go | 7 +- go-sdk/go.mod | 2 + go-sdk/go.sum | 4 + go-sdk/pkg/execution/client.go | 183 ++++++++ go-sdk/pkg/execution/comms.go | 128 ++++++ go-sdk/pkg/execution/comms_test.go | 166 ++++++++ go-sdk/pkg/execution/dag_parser.go | 53 +++ go-sdk/pkg/execution/frames.go | 229 ++++++++++ go-sdk/pkg/execution/frames_test.go | 200 +++++++++ go-sdk/pkg/execution/integration_test.go | 358 ++++++++++++++++ go-sdk/pkg/execution/logger.go | 158 +++++++ go-sdk/pkg/execution/logger_test.go | 134 ++++++ go-sdk/pkg/execution/messages.go | 394 ++++++++++++++++++ go-sdk/pkg/execution/messages_test.go | 258 ++++++++++++ go-sdk/pkg/execution/serde.go | 279 +++++++++++++ go-sdk/pkg/execution/serde_test.go | 226 ++++++++++ go-sdk/pkg/execution/server.go | 155 +++++++ go-sdk/pkg/execution/task_runner.go | 117 ++++++ go-sdk/pkg/sdkcontext/keys.go | 8 + go-sdk/sdk/client.go | 2 +- go-sdk/sdk/connection.go | 6 +- 23 files changed, 3240 insertions(+), 34 deletions(-) create mode 100644 go-sdk/pkg/execution/client.go create mode 100644 go-sdk/pkg/execution/comms.go create mode 100644 go-sdk/pkg/execution/comms_test.go create mode 100644 go-sdk/pkg/execution/dag_parser.go create mode 100644 go-sdk/pkg/execution/frames.go create mode 100644 go-sdk/pkg/execution/frames_test.go create mode 100644 go-sdk/pkg/execution/integration_test.go create mode 100644 go-sdk/pkg/execution/logger.go create mode 100644 go-sdk/pkg/execution/logger_test.go create mode 100644 go-sdk/pkg/execution/messages.go create mode 100644 go-sdk/pkg/execution/messages_test.go create mode 100644 go-sdk/pkg/execution/serde.go create mode 100644 go-sdk/pkg/execution/serde_test.go create mode 100644 go-sdk/pkg/execution/server.go create mode 100644 go-sdk/pkg/execution/task_runner.go diff --git a/go-sdk/bundle/bundlev1/bundlev1server/server.go b/go-sdk/bundle/bundlev1/bundlev1server/server.go index 67d212ff06976..e2e795dc1dd60 100644 --- a/go-sdk/bundle/bundlev1/bundlev1server/server.go +++ b/go-sdk/bundle/bundlev1/bundlev1server/server.go @@ -32,9 +32,25 @@ import ( "github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server/impl" "github.com/apache/airflow/go-sdk/pkg/bundles/shared" "github.com/apache/airflow/go-sdk/pkg/config" + "github.com/apache/airflow/go-sdk/pkg/execution" ) -var versionInfo *bool = flag.Bool("bundle-metadata", false, "show the embedded bundle info") +// Flags. The bundle-metadata flag is the existing ADR 0001 introspection +// hook; --comm and --logs select the coordinator-mode protocol added by +// ADR 0003. All three are read by Serve to choose a server mode below. +var ( + versionInfo = flag.Bool("bundle-metadata", false, "show the embedded bundle info") + commAddr = flag.String( + "comm", + "", + "host:port of the supervisor's coordinator comm channel (selects coordinator mode)", + ) + logsAddr = flag.String( + "logs", + "", + "host:port of the supervisor's coordinator logs channel (selects coordinator mode)", + ) +) // ServeOpt is an interface for defining options that can be passed to the // Serve function. Each implementation modifies the ServeConfig being @@ -52,24 +68,30 @@ func (s serveConfigFunc) ApplyServeOpt(in *ServerConfig) error { type ServerConfig struct{} -// Serve is the entrypoint for your bundle, and sets it up ready for Airflow's Go Worker to use +// serveMode tags the protocol the binary will speak this run. +type serveMode int + +const ( + modePlugin serveMode = iota // go-plugin gRPC (existing Edge Worker path) + modeMetadataDump // --bundle-metadata: print BundleInfo JSON + modeCoordinator // --comm/--logs: msgpack-over-IPC (ADR 0003) + modeUsageError // misuse: print usage and exit non-zero +) + +// Serve is the entrypoint for your bundle, and sets it up ready for Airflow's +// Go Worker (go-plugin) or Python supervisor (coordinator protocol) to use. +// +// The mode is decided from CLI flags and process environment, so user code is +// always one line: // -// Zero or more options to configure the server may also be passed. There are no options yet, this is to allow -// future changes without breaking compatibility +// func main() { bundlev1server.Serve(&myBundle{}) } +// +// Zero or more options to configure the server may also be passed. There are +// no options yet; the parameter exists to allow future additions without +// breaking compatibility. func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error { config.SetupViper("") - hcLogger := hclog.New(&hclog.LoggerOptions{ - Level: hclog.Trace, - Output: os.Stderr, - JSONFormat: true, - IncludeLocation: true, - AdditionalLocationOffset: 3, - }) - - log := slog.New(hclogslog.Adapt(hcLogger)) - slog.SetDefault(log) - flag.Parse() serveConfig := &ServerConfig{} @@ -77,16 +99,69 @@ func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error { c.ApplyServeOpt(serveConfig) } + switch decideMode() { + case modeMetadataDump: + return dumpBundleMetadata(bundle) + case modeCoordinator: + // In coordinator mode the supervisor reads the logs channel for + // structured records, so configuring the hclog/stderr default + // logger here is unnecessary — execution.Serve installs its own + // slog handler against the logs socket before any user code runs. + return execution.Serve(bundle, *commAddr, *logsAddr) + case modePlugin: + installPluginLogger() + return servePlugin(bundle) + case modeUsageError: + fmt.Fprintln(os.Stderr, "error: --comm and --logs must be supplied together") + flag.CommandLine.SetOutput(os.Stderr) + flag.Usage() + os.Exit(2) + } + return nil +} + +func decideMode() serveMode { if *versionInfo { - meta := bundle.GetBundleVersion() - data, err := json.MarshalIndent(meta, "", " ") - if err != nil { - return err - } - fmt.Println(string(data)) - return nil + return modeMetadataDump + } + commSet := *commAddr != "" + logsSet := *logsAddr != "" + if commSet && logsSet { + return modeCoordinator + } + if commSet || logsSet { + // Partial use is a hard error per ADR 0003: both flags are + // required, otherwise the supervisor is misconfigured and the + // runtime should fail loudly rather than fall through to + // go-plugin (which would hang on the missing magic-cookie). + return modeUsageError + } + return modePlugin +} + +func dumpBundleMetadata(bundle bundlev1.BundleProvider) error { + meta := bundle.GetBundleVersion() + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return err } + fmt.Println(string(data)) + return nil +} + +func installPluginLogger() { + hcLogger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + Output: os.Stderr, + JSONFormat: true, + IncludeLocation: true, + AdditionalLocationOffset: 3, + }) + log := slog.New(hclogslog.Adapt(hcLogger)) + slog.SetDefault(log) +} +func servePlugin(bundle bundlev1.BundleProvider) error { pluginConfig := &plugin.ServeConfig{ HandshakeConfig: shared.Handshake, Plugins: plugin.PluginSet{ @@ -99,6 +174,5 @@ func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error { // Likely never returns plugin.Serve(pluginConfig) - return nil } diff --git a/go-sdk/bundle/bundlev1/registry.go b/go-sdk/bundle/bundlev1/registry.go index 8d902efa081f6..c7be26a47592b 100644 --- a/go-sdk/bundle/bundlev1/registry.go +++ b/go-sdk/bundle/bundlev1/registry.go @@ -43,9 +43,38 @@ type ( AddDag(dagId string) Dag } + // TaskInfo describes a registered task. Coordinator-mode DAG parsing uses + // it to render the per-task block of a DagFileParsingResult. + TaskInfo struct { + // ID is the user-visible task id (the function name unless overridden + // via AddTaskWithName). + ID string + // TypeName is the unqualified Go function name (e.g. "extract"). + TypeName string + // PkgPath is the Go package path (e.g. "main", "github.com/x/y"). + PkgPath string + } + + // DagInfo describes a registered dag together with its tasks in + // registration order. + DagInfo struct { + DagID string + Tasks []TaskInfo + } + + // EnumerableBundle exposes the dag/task identity recorded by + // RegisterDags. The default registry implements it; the coordinator-mode + // runtime relies on it for the DAG-parse one-shot. + EnumerableBundle interface { + OrderedDags() []DagInfo + } + registry struct { sync.RWMutex taskFuncMap map[string]map[string]Task + taskInfo map[string]map[string]TaskInfo + dagOrder []string + taskOrder map[string][]string } ) @@ -64,22 +93,38 @@ func (d dagShim) AddTaskWithName(taskId string, fn any) { // Function New creates a new bundle on which Dag and Tasks can be registered func New() Registry { - return ®istry{taskFuncMap: make(map[string]map[string]Task)} + return ®istry{ + taskFuncMap: make(map[string]map[string]Task), + taskInfo: make(map[string]map[string]TaskInfo), + taskOrder: make(map[string][]string), + } +} + +func splitFullName(fullName string) (typeName, pkgPath string) { + // fullName looks like "main.extract" or "github.com/x/y.MyTask"; method + // values get a "-fm" suffix. + lastDot := strings.LastIndex(fullName, ".") + if lastDot < 0 { + return strings.TrimSuffix(fullName, "-fm"), "" + } + return strings.TrimSuffix(fullName[lastDot+1:], "-fm"), fullName[:lastDot] } func getFnName(fn reflect.Value) string { fullName := runtime.FuncForPC(fn.Pointer()).Name() - parts := strings.Split(fullName, ".") - fnName := parts[len(parts)-1] - // Go adds `-fm` suffix to a method names - return strings.TrimSuffix(fnName, "-fm") + name, _ := splitFullName(fullName) + return name } func (r *registry) AddDag(dagId string) Dag { + r.RWMutex.Lock() + defer r.RWMutex.Unlock() if _, exists := r.taskFuncMap[dagId]; exists { panic(fmt.Errorf("Dag %q already exists in bundle", dagId)) } r.taskFuncMap[dagId] = make(map[string]Task) + r.taskInfo[dagId] = make(map[string]TaskInfo) + r.dagOrder = append(r.dagOrder, dagId) return dagShim{dagId, r} } @@ -101,21 +146,28 @@ func (r *registry) registerTaskWithName(dagId, taskId string, fn any) { panic(fmt.Errorf("error registering task %q for DAG %q: %w", taskId, dagId, err)) } + val := reflect.ValueOf(fn) + fullName := runtime.FuncForPC(val.Pointer()).Name() + typeName, pkgPath := splitFullName(fullName) + r.RWMutex.Lock() defer r.RWMutex.Unlock() dagTasks, exists := r.taskFuncMap[dagId] - if !exists { dagTasks = make(map[string]Task) r.taskFuncMap[dagId] = dagTasks + r.taskInfo[dagId] = make(map[string]TaskInfo) + r.dagOrder = append(r.dagOrder, dagId) } - _, exists = dagTasks[taskId] - if exists { + if _, exists := dagTasks[taskId]; exists { panic(fmt.Errorf("taskId %q is already registered for DAG %q", taskId, dagId)) } + dagTasks[taskId] = task + r.taskInfo[dagId][taskId] = TaskInfo{ID: taskId, TypeName: typeName, PkgPath: pkgPath} + r.taskOrder[dagId] = append(r.taskOrder[dagId], taskId) } func (r *registry) LookupTask(dagId, taskId string) (task Task, exists bool) { @@ -129,3 +181,22 @@ func (r *registry) LookupTask(dagId, taskId string) (task Task, exists bool) { task, exists = dagTasks[taskId] return task, exists } + +// OrderedDags returns the registered dags in the order AddDag was called, +// each with its tasks in the order AddTask / AddTaskWithName was called. The +// returned slice is freshly allocated; callers may mutate it freely. +func (r *registry) OrderedDags() []DagInfo { + r.RLock() + defer r.RUnlock() + + out := make([]DagInfo, 0, len(r.dagOrder)) + for _, dagID := range r.dagOrder { + taskIDs := r.taskOrder[dagID] + tasks := make([]TaskInfo, 0, len(taskIDs)) + for _, tid := range taskIDs { + tasks = append(tasks, r.taskInfo[dagID][tid]) + } + out = append(out, DagInfo{DagID: dagID, Tasks: tasks}) + } + return out +} diff --git a/go-sdk/bundle/bundlev1/task.go b/go-sdk/bundle/bundlev1/task.go index 4271f4892bb6f..5277f40681cad 100644 --- a/go-sdk/bundle/bundlev1/task.go +++ b/go-sdk/bundle/bundlev1/task.go @@ -45,7 +45,12 @@ func NewTaskFunction(fn any) (Task, error) { func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error { fnType := f.fn.Type() - sdkClient := sdk.NewClient() + var sdkClient sdk.Client + if injected, ok := ctx.Value(sdkcontext.SdkClientContextKey).(sdk.Client); ok { + sdkClient = injected + } else { + sdkClient = sdk.NewClient() + } reflectArgs := make([]reflect.Value, fnType.NumIn()) for i := range reflectArgs { diff --git a/go-sdk/go.mod b/go-sdk/go.mod index bfb400eee9432..f3bfcd4b0f600 100644 --- a/go-sdk/go.mod +++ b/go-sdk/go.mod @@ -16,6 +16,7 @@ require ( github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.11.1 + github.com/vmihailenco/msgpack/v5 v5.4.1 google.golang.org/grpc v1.79.3 google.golang.org/protobuf v1.36.10 resty.dev/v3 v3.0.0-beta.2 @@ -38,6 +39,7 @@ require ( github.com/spf13/afero v1.12.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect go.opentelemetry.io/otel v1.41.0 // indirect go.opentelemetry.io/otel/trace v1.41.0 // indirect go.uber.org/multierr v1.10.0 // indirect diff --git a/go-sdk/go.sum b/go-sdk/go.sum index 5b7940672b1ca..a275d6b63c8be 100644 --- a/go-sdk/go.sum +++ b/go-sdk/go.sum @@ -114,6 +114,10 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.41.0 h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c= diff --git a/go-sdk/pkg/execution/client.go b/go-sdk/pkg/execution/client.go new file mode 100644 index 0000000000000..0e047803b6859 --- /dev/null +++ b/go-sdk/pkg/execution/client.go @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/sdk" +) + +// CoordinatorClient implements sdk.Client by communicating with the Airflow supervisor +// over the comm socket using msgpack-framed IPC instead of HTTP. +type CoordinatorClient struct { + comm *CoordinatorComm + details *StartupDetails +} + +var _ sdk.Client = (*CoordinatorClient)(nil) + +// NewCoordinatorClient creates a new client backed by the comm socket. +func NewCoordinatorClient(comm *CoordinatorComm, details *StartupDetails) *CoordinatorClient { + return &CoordinatorClient{ + comm: comm, + details: details, + } +} + +// GetVariable requests a variable value from the supervisor. +func (c *CoordinatorClient) GetVariable(_ context.Context, key string) (string, error) { + resp, err := c.comm.Communicate(GetVariableMsg{Key: key}.toMap()) + if err != nil { + return "", err + } + + result, err := decodeVariableResult(resp) + if err != nil { + return "", fmt.Errorf("decoding variable result: %w", err) + } + + if result.Value == nil { + return "", fmt.Errorf("%w: %q", sdk.VariableNotFound, key) + } + + switch v := result.Value.(type) { + case string: + return v, nil + default: + // If the value is not a string, marshal it to JSON. + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("marshaling variable value: %w", err) + } + return string(b), nil + } +} + +// UnmarshalJSONVariable gets a variable and unmarshals its JSON value. +func (c *CoordinatorClient) UnmarshalJSONVariable( + ctx context.Context, + key string, + pointer any, +) error { + val, err := c.GetVariable(ctx, key) + if err != nil { + return err + } + return json.Unmarshal([]byte(val), pointer) +} + +// GetConnection requests a connection from the supervisor. +func (c *CoordinatorClient) GetConnection( + _ context.Context, + connID string, +) (sdk.Connection, error) { + resp, err := c.comm.Communicate(GetConnectionMsg{ConnID: connID}.toMap()) + if err != nil { + return sdk.Connection{}, err + } + + result, err := decodeConnectionResult(resp) + if err != nil { + return sdk.Connection{}, fmt.Errorf("decoding connection result: %w", err) + } + + conn := sdk.Connection{ + ID: result.ConnID, + Type: result.ConnType, + Host: result.Host, + Port: result.Port, + Path: result.Schema, + } + + if result.Login != "" { + login := result.Login + conn.Login = &login + } + if result.Password != "" { + password := result.Password + conn.Password = &password + } + if result.Extra != "" { + conn.Extra = map[string]any{} + if err := json.Unmarshal([]byte(result.Extra), &conn.Extra); err != nil { + return conn, fmt.Errorf("parsing connection extra: %w", err) + } + } + + return conn, nil +} + +// GetXCom requests an XCom value from the supervisor. +func (c *CoordinatorClient) GetXCom( + _ context.Context, + dagId, runId, taskId string, + mapIndex *int, + key string, + _ any, +) (any, error) { + msg := GetXComMsg{ + Key: key, + DagID: dagId, + TaskID: taskId, + RunID: runId, + } + if mapIndex != nil { + msg.MapIndex = mapIndex + } + + resp, err := c.comm.Communicate(msg.toMap()) + if err != nil { + return nil, err + } + + result, err := decodeXComResult(resp) + if err != nil { + return nil, fmt.Errorf("decoding xcom result: %w", err) + } + + return result.Value, nil +} + +// PushXCom sends an XCom value to the supervisor. +func (c *CoordinatorClient) PushXCom( + _ context.Context, + ti api.TaskInstance, + key string, + value any, +) error { + mapIndex := -1 + if ti.MapIndex != nil && *ti.MapIndex != -1 { + mapIndex = *ti.MapIndex + } + + msg := SetXComMsg{ + Key: key, + Value: value, + DagID: ti.DagId, + TaskID: ti.TaskId, + RunID: ti.RunId, + MapIndex: mapIndex, + } + + _, err := c.comm.Communicate(msg.toMap()) + return err +} diff --git a/go-sdk/pkg/execution/comms.go b/go-sdk/pkg/execution/comms.go new file mode 100644 index 0000000000000..81b47caf1600c --- /dev/null +++ b/go-sdk/pkg/execution/comms.go @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "fmt" + "io" + "log/slog" + "sync" + "sync/atomic" +) + +// CoordinatorComm manages bidirectional communication with the Airflow supervisor +// over a length-prefixed msgpack socket connection. +type CoordinatorComm struct { + reader io.Reader + writer io.Writer + nextID atomic.Int32 + logger *slog.Logger + + wmu sync.Mutex // serialises writes + rmu sync.Mutex // serialises reads +} + +// NewCoordinatorComm creates a new communication channel. +func NewCoordinatorComm(reader io.Reader, writer io.Writer, logger *slog.Logger) *CoordinatorComm { + return &CoordinatorComm{ + reader: reader, + writer: writer, + logger: logger, + } +} + +// ReadMessage reads and decodes one frame from the comm socket. +// It returns the raw IncomingFrame with decoded map bodies. +func (c *CoordinatorComm) ReadMessage() (IncomingFrame, error) { + c.rmu.Lock() + defer c.rmu.Unlock() + frame, err := readFrame(c.reader) + if err != nil { + return IncomingFrame{}, fmt.Errorf("reading frame: %w", err) + } + c.logger.Debug("Received frame", "id", frame.ID) + return frame, nil +} + +// SendRequest sends a request frame (2-element: [id, body]) to the supervisor. +func (c *CoordinatorComm) SendRequest(id int, body map[string]any) error { + payload, err := encodeRequest(id, body) + if err != nil { + return fmt.Errorf("encoding request: %w", err) + } + c.logger.Debug("Sending request", "id", id) + c.wmu.Lock() + defer c.wmu.Unlock() + return writeFrame(c.writer, payload) +} + +// Communicate sends a request and waits for the corresponding response. +// This is a synchronous request-response: the caller sends a request and blocks +// until the next frame arrives. The protocol is single-threaded on the comm socket. +// +// If the response contains an error element, it is returned as an ApiError. +// Otherwise, the response body map is returned. +func (c *CoordinatorComm) Communicate(body map[string]any) (map[string]any, error) { + id := int(c.nextID.Add(1) - 1) + + if err := c.SendRequest(id, body); err != nil { + return nil, err + } + + frame, err := c.ReadMessage() + if err != nil { + return nil, err + } + + // Check for error in the response. + if frame.Err != nil { + errResp := decodeErrorResponse(frame.Err) + if errResp != nil { + return nil, &ApiError{ + Err: errResp.Error, + Detail: errResp.Detail, + } + } + } + + // Also check if the body itself is an ErrorResponse. + if frame.Body != nil { + if typ, _ := frame.Body["type"].(string); typ == "ErrorResponse" { + errResp := decodeErrorResponse(frame.Body) + return nil, &ApiError{ + Err: errResp.Error, + Detail: errResp.Detail, + } + } + } + + return frame.Body, nil +} + +// ApiError represents an error returned by the supervisor over the comm socket. +type ApiError struct { + Err string + Detail any +} + +func (e *ApiError) Error() string { + if e.Detail != nil { + return fmt.Sprintf("[%s] %v", e.Err, e.Detail) + } + return e.Err +} diff --git a/go-sdk/pkg/execution/comms_test.go b/go-sdk/pkg/execution/comms_test.go new file mode 100644 index 0000000000000..a5360b7358d22 --- /dev/null +++ b/go-sdk/pkg/execution/comms_test.go @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "io" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" +) + +func TestCoordinatorCommReadMessage(t *testing.T) { + body := map[string]any{ + "type": "DagFileParseRequest", + "file": "/path/to/dags.go", + } + payload, err := encodeRequest(0, body) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, writeFrame(&buf, payload)) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(&buf, io.Discard, logger) + + frame, err := comm.ReadMessage() + require.NoError(t, err) + assert.Equal(t, 0, frame.ID) + assert.Equal(t, "DagFileParseRequest", frame.Body["type"]) +} + +func TestCoordinatorCommSendRequest(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(bytes.NewReader(nil), &buf, logger) + + body := map[string]any{ + "type": "GetVariable", + "key": "test_key", + } + err := comm.SendRequest(5, body) + require.NoError(t, err) + + frame, err := readFrame(&buf) + require.NoError(t, err) + assert.Equal(t, 5, frame.ID) + assert.Equal(t, "GetVariable", frame.Body["type"]) + assert.Equal(t, "test_key", frame.Body["key"]) +} + +func TestCoordinatorCommCommunicate(t *testing.T) { + // Prepare a response frame for the mock supervisor to "return". + responseBody := map[string]any{ + "type": "VariableResult", + "key": "my_var", + "value": "my_value", + } + responsePayload, err := encodeRequest(0, responseBody) + require.NoError(t, err) + + var responseBuf bytes.Buffer + require.NoError(t, writeFrame(&responseBuf, responsePayload)) + + var requestBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(&responseBuf, &requestBuf, logger) + + result, err := comm.Communicate(GetVariableMsg{Key: "my_var"}.toMap()) + require.NoError(t, err) + assert.Equal(t, "VariableResult", result["type"]) + assert.Equal(t, "my_value", result["value"]) + + // Verify the request was sent. + sentFrame, err := readFrame(&requestBuf) + require.NoError(t, err) + assert.Equal(t, "GetVariable", sentFrame.Body["type"]) +} + +func TestCoordinatorCommCommunicateError(t *testing.T) { + // Build a 3-element response frame with an error. + responsePayload := encodeResponseFrame(t, 0, nil, map[string]any{ + "type": "ErrorResponse", + "error": "not_found", + "detail": "Variable not found", + }) + + var responseBuf bytes.Buffer + require.NoError(t, writeFrame(&responseBuf, responsePayload)) + + var requestBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(&responseBuf, &requestBuf, logger) + + _, err := comm.Communicate(GetVariableMsg{Key: "missing"}.toMap()) + require.Error(t, err) + + apiErr, ok := err.(*ApiError) + require.True(t, ok) + assert.Equal(t, "not_found", apiErr.Err) +} + +func TestCoordinatorCommCommunicateBodyError(t *testing.T) { + // Error can also come in the body element of a 2-element frame. + errorBody := map[string]any{ + "type": "ErrorResponse", + "error": "server_error", + "detail": "internal failure", + } + responsePayload, err := encodeRequest(0, errorBody) + require.NoError(t, err) + + var responseBuf bytes.Buffer + require.NoError(t, writeFrame(&responseBuf, responsePayload)) + + var requestBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(&responseBuf, &requestBuf, logger) + + _, err = comm.Communicate(GetVariableMsg{Key: "test"}.toMap()) + require.Error(t, err) + + apiErr, ok := err.(*ApiError) + require.True(t, ok) + assert.Equal(t, "server_error", apiErr.Err) +} + +func TestApiErrorFormat(t *testing.T) { + err := &ApiError{Err: "not_found", Detail: "Variable 'x' not found"} + assert.Equal(t, "[not_found] Variable 'x' not found", err.Error()) + + err2 := &ApiError{Err: "server_error"} + assert.Equal(t, "server_error", err2.Error()) +} + +// encodeResponseFrame encodes a 3-element response frame for testing. +func encodeResponseFrame(t *testing.T, id int, body, errBody map[string]any) []byte { + t.Helper() + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + require.NoError(t, enc.EncodeArrayLen(3)) + require.NoError(t, enc.EncodeInt(int64(id))) + require.NoError(t, enc.Encode(body)) + require.NoError(t, enc.Encode(errBody)) + return buf.Bytes() +} diff --git a/go-sdk/pkg/execution/dag_parser.go b/go-sdk/pkg/execution/dag_parser.go new file mode 100644 index 0000000000000..734a5ded77f72 --- /dev/null +++ b/go-sdk/pkg/execution/dag_parser.go @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "github.com/apache/airflow/go-sdk/bundle/bundlev1" +) + +// ParseDags processes a DagFileParseRequest by serialising every dag +// registered on bundle to DagSerialization v3 and returning the result as a +// DagFileParsingResult body. bundle is the materialised registry produced by +// running BundleProvider.RegisterDags. +func ParseDags(bundle bundlev1.Bundle, req *DagFileParseRequest) map[string]any { + fileloc := req.File + bundlePath := req.BundlePath + relativeFileloc := computeRelativeFileloc(fileloc, bundlePath) + + var dags []bundlev1.DagInfo + if enum, ok := bundle.(bundlev1.EnumerableBundle); ok { + dags = enum.OrderedDags() + } + + serializedDags := make([]any, 0, len(dags)) + for _, d := range dags { + serializedDags = append(serializedDags, map[string]any{ + "data": map[string]any{ + "__version": 3, + "dag": SerializeDag(d, fileloc, relativeFileloc), + }, + }) + } + + return map[string]any{ + "type": "DagFileParsingResult", + "fileloc": fileloc, + "serialized_dags": serializedDags, + } +} diff --git a/go-sdk/pkg/execution/frames.go b/go-sdk/pkg/execution/frames.go new file mode 100644 index 0000000000000..7458532d5aa18 --- /dev/null +++ b/go-sdk/pkg/execution/frames.go @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +// IncomingFrame represents a decoded frame received from the comm socket. +type IncomingFrame struct { + ID int + Body map[string]any + Err map[string]any // non-nil only for response frames (3-element arrays) +} + +// encodeRequest encodes a request frame (2-element msgpack array: [id, body]). +func encodeRequest(id int, body map[string]any) ([]byte, error) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + if err := enc.EncodeArrayLen(2); err != nil { + return nil, err + } + if err := enc.EncodeInt(int64(id)); err != nil { + return nil, err + } + if err := enc.Encode(body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// writeFrame writes a length-prefixed msgpack payload to the writer. +// Format: [4-byte big-endian length][payload bytes] +func writeFrame(w io.Writer, payload []byte) error { + prefix := make([]byte, 4) + binary.BigEndian.PutUint32(prefix, uint32(len(payload))) + if _, err := w.Write(prefix); err != nil { + return fmt.Errorf("writing length prefix: %w", err) + } + if _, err := w.Write(payload); err != nil { + return fmt.Errorf("writing payload: %w", err) + } + return nil +} + +// readFrame reads one length-prefixed msgpack frame from the reader and decodes it. +func readFrame(r io.Reader) (IncomingFrame, error) { + // Read 4-byte big-endian length prefix. + prefix := make([]byte, 4) + if _, err := io.ReadFull(r, prefix); err != nil { + return IncomingFrame{}, fmt.Errorf("reading length prefix: %w", err) + } + payloadLen := binary.BigEndian.Uint32(prefix) + + // Read the payload. + payload := make([]byte, payloadLen) + if _, err := io.ReadFull(r, payload); err != nil { + return IncomingFrame{}, fmt.Errorf("reading payload (%d bytes): %w", payloadLen, err) + } + + return decodeFrame(payload) +} + +// decodeFrame decodes a msgpack payload into an IncomingFrame. +func decodeFrame(data []byte) (IncomingFrame, error) { + dec := msgpack.NewDecoder(bytes.NewReader(data)) + + arrLen, err := dec.DecodeArrayLen() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding array header: %w", err) + } + if arrLen < 2 { + return IncomingFrame{}, fmt.Errorf("unexpected frame arity %d, need at least 2", arrLen) + } + + id64, err := dec.DecodeInt64() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding frame id: %w", err) + } + + // Decode the body element. + bodyRaw, err := dec.DecodeInterface() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding body: %w", err) + } + body, _ := toStringMap(bodyRaw) + + // For response frames (3-element), decode the error element. + var errMap map[string]any + if arrLen >= 3 { + errRaw, err := dec.DecodeInterface() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding error element: %w", err) + } + errMap, _ = toStringMap(errRaw) + } + + return IncomingFrame{ + ID: int(id64), + Body: body, + Err: errMap, + }, nil +} + +// toStringMap converts a decoded interface{} to map[string]any. +// Returns nil, false if the input is nil or not a map. +func toStringMap(v any) (map[string]any, bool) { + if v == nil { + return nil, false + } + switch m := v.(type) { + case map[string]any: + return m, true + case map[any]any: + result := make(map[string]any, len(m)) + for k, val := range m { + result[fmt.Sprint(k)] = val + } + return result, true + default: + return nil, false + } +} + +// mapString extracts a string value from a map. +func mapString(m map[string]any, key string) (string, error) { + v, ok := m[key] + if !ok { + return "", fmt.Errorf("missing key %q", key) + } + s, ok := v.(string) + if !ok { + return "", fmt.Errorf("key %q: expected string, got %T", key, v) + } + return s, nil +} + +// mapIntOr extracts an int value from a map, returning the default if missing. +func mapIntOr(m map[string]any, key string, def int) int { + v, ok := m[key] + if !ok { + return def + } + n, err := toInt(v) + if err != nil { + return def + } + return n +} + +// mapStringOr extracts a string value from a map, returning the default if missing. +func mapStringOr(m map[string]any, key string, def string) string { + v, ok := m[key] + if !ok { + return def + } + s, ok := v.(string) + if !ok { + return def + } + return s +} + +// mapMap extracts a nested map from a map. +func mapMap(m map[string]any, key string) map[string]any { + v, ok := m[key] + if !ok || v == nil { + return nil + } + sub, ok := toStringMap(v) + if !ok { + return nil + } + return sub +} + +// toInt converts various numeric types from msgpack decoding to int. +func toInt(v any) (int, error) { + switch n := v.(type) { + case int: + return n, nil + case int8: + return int(n), nil + case int16: + return int(n), nil + case int32: + return int(n), nil + case int64: + return int(n), nil + case uint: + return int(n), nil + case uint8: + return int(n), nil + case uint16: + return int(n), nil + case uint32: + return int(n), nil + case uint64: + return int(n), nil + case float32: + return int(n), nil + case float64: + return int(n), nil + default: + return 0, fmt.Errorf("expected numeric, got %T", v) + } +} diff --git a/go-sdk/pkg/execution/frames_test.go b/go-sdk/pkg/execution/frames_test.go new file mode 100644 index 0000000000000..e0800dd8191ac --- /dev/null +++ b/go-sdk/pkg/execution/frames_test.go @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" +) + +func TestEncodeRequest(t *testing.T) { + body := map[string]any{ + "type": "GetVariable", + "key": "my_var", + } + + data, err := encodeRequest(42, body) + require.NoError(t, err) + + // Decode and verify structure. + dec := msgpack.NewDecoder(bytes.NewReader(data)) + arrLen, err := dec.DecodeArrayLen() + require.NoError(t, err) + assert.Equal(t, 2, arrLen, "request frame should be 2-element array") + + id, err := dec.DecodeInt64() + require.NoError(t, err) + assert.Equal(t, int64(42), id) + + var decodedBody map[string]any + err = dec.Decode(&decodedBody) + require.NoError(t, err) + assert.Equal(t, "GetVariable", decodedBody["type"]) + assert.Equal(t, "my_var", decodedBody["key"]) +} + +func TestWriteAndReadFrame(t *testing.T) { + body := map[string]any{ + "type": "GetConnection", + "conn_id": "my_db", + } + + payload, err := encodeRequest(7, body) + require.NoError(t, err) + + // Write to buffer with length prefix. + var buf bytes.Buffer + err = writeFrame(&buf, payload) + require.NoError(t, err) + + // Verify length prefix. + prefix := buf.Bytes()[:4] + expectedLen := uint32(len(payload)) + assert.Equal(t, expectedLen, binary.BigEndian.Uint32(prefix)) + + // Read back. + frame, err := readFrame(&buf) + require.NoError(t, err) + assert.Equal(t, 7, frame.ID) + assert.Equal(t, "GetConnection", frame.Body["type"]) + assert.Equal(t, "my_db", frame.Body["conn_id"]) + assert.Nil(t, frame.Err) +} + +func TestDecodeResponseFrame(t *testing.T) { + // Encode a 3-element response frame: [id, body, error] + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + require.NoError(t, enc.EncodeArrayLen(3)) + require.NoError(t, enc.EncodeInt(5)) + require.NoError(t, enc.Encode(map[string]any{ + "type": "ConnectionResult", + "conn_id": "test_conn", + "host": "localhost", + })) + require.NoError(t, enc.Encode(nil)) // no error + + frame, err := decodeFrame(buf.Bytes()) + require.NoError(t, err) + assert.Equal(t, 5, frame.ID) + assert.Equal(t, "ConnectionResult", frame.Body["type"]) + assert.Equal(t, "localhost", frame.Body["host"]) + assert.Nil(t, frame.Err) +} + +func TestDecodeResponseFrameWithError(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + require.NoError(t, enc.EncodeArrayLen(3)) + require.NoError(t, enc.EncodeInt(3)) + require.NoError(t, enc.Encode(nil)) // nil body + require.NoError(t, enc.Encode(map[string]any{ + "type": "ErrorResponse", + "error": "not_found", + "detail": "Variable 'x' not found", + })) + + frame, err := decodeFrame(buf.Bytes()) + require.NoError(t, err) + assert.Equal(t, 3, frame.ID) + assert.Nil(t, frame.Body) + assert.NotNil(t, frame.Err) + assert.Equal(t, "not_found", frame.Err["error"]) +} + +func TestRoundTripMultipleFrames(t *testing.T) { + var buf bytes.Buffer + + // Write two frames. + bodies := []map[string]any{ + {"type": "GetVariable", "key": "v1"}, + {"type": "GetVariable", "key": "v2"}, + } + for i, body := range bodies { + payload, err := encodeRequest(i, body) + require.NoError(t, err) + require.NoError(t, writeFrame(&buf, payload)) + } + + // Read them back. + for i, expected := range bodies { + frame, err := readFrame(&buf) + require.NoError(t, err) + assert.Equal(t, i, frame.ID) + assert.Equal(t, expected["key"], frame.Body["key"]) + } +} + +func TestToStringMap(t *testing.T) { + tests := []struct { + name string + input any + want map[string]any + ok bool + }{ + {"nil", nil, nil, false}, + {"string map", map[string]any{"a": 1}, map[string]any{"a": 1}, true}, + {"any key map", map[any]any{"b": 2}, map[string]any{"b": 2}, true}, + {"not a map", "hello", nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := toStringMap(tt.input) + assert.Equal(t, tt.ok, ok) + if tt.ok { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestToInt(t *testing.T) { + tests := []struct { + input any + want int + }{ + {int8(42), 42}, + {int16(42), 42}, + {int32(42), 42}, + {int64(42), 42}, + {uint8(42), 42}, + {uint16(42), 42}, + {uint32(42), 42}, + {uint64(42), 42}, + {float32(42.0), 42}, + {float64(42.0), 42}, + {int(42), 42}, + } + for _, tt := range tests { + got, err := toInt(tt.input) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + _, err := toInt("not a number") + assert.Error(t, err) +} diff --git a/go-sdk/pkg/execution/integration_test.go b/go-sdk/pkg/execution/integration_test.go new file mode 100644 index 0000000000000..a7b18256a8c1e --- /dev/null +++ b/go-sdk/pkg/execution/integration_test.go @@ -0,0 +1,358 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "errors" + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" +) + +// --- Test task functions --- + +func failingTask() error { + return errors.New("task failed intentionally") +} + +func panicTask() error { + panic("something went wrong") +} + +func simpleTask() error { + return nil +} + +// buildBundle wires a bundlev1.Registry from a closure and returns it as a +// bundlev1.Bundle (the materialised registry). +func buildBundle(t *testing.T, register func(bundlev1.Registry)) bundlev1.Bundle { + t.Helper() + reg := bundlev1.New() + register(reg) + return reg +} + +// --- Tests --- + +func TestDagParsing(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + d := r.AddDag("test_dag") + d.AddTask(simpleTask) + }) + + req := &DagFileParseRequest{ + File: "/bundles/test/main.go", + BundlePath: "/bundles/test", + } + + result := ParseDags(bundle, req) + + assert.Equal(t, "DagFileParsingResult", result["type"]) + assert.Equal(t, "/bundles/test/main.go", result["fileloc"]) + + serializedDags, ok := result["serialized_dags"].([]any) + require.True(t, ok) + require.Len(t, serializedDags, 1) + + dagEntry := serializedDags[0].(map[string]any) + data := dagEntry["data"].(map[string]any) + assert.Equal(t, 3, data["__version"]) + + dagMap := data["dag"].(map[string]any) + assert.Equal(t, "test_dag", dagMap["dag_id"]) + + tt := dagMap["timetable"].(map[string]any) + assert.Equal(t, "airflow.timetables.simple.NullTimetable", tt["__type"]) + + tasks := dagMap["tasks"].([]any) + require.Len(t, tasks, 1) + taskMap := tasks[0].(map[string]any) + assert.Equal(t, "operator", taskMap["__type"]) + taskData := taskMap["__var"].(map[string]any) + assert.Equal(t, "simpleTask", taskData["task_id"]) + assert.Equal(t, "go", taskData["language"]) +} + +func TestDagParsingMultipleDagsPreservesOrder(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + r.AddDag("dag1").AddTask(simpleTask) + r.AddDag("dag2").AddTask(failingTask) + }) + + req := &DagFileParseRequest{File: "/bundle/main.go", BundlePath: "/bundle"} + result := ParseDags(bundle, req) + + serializedDags := result["serialized_dags"].([]any) + require.Len(t, serializedDags, 2) + + dag1Data := serializedDags[0].(map[string]any)["data"].(map[string]any)["dag"].(map[string]any) + assert.Equal(t, "dag1", dag1Data["dag_id"]) + + dag2Data := serializedDags[1].(map[string]any)["data"].(map[string]any)["dag"].(map[string]any) + assert.Equal(t, "dag2", dag2Data["dag_id"]) +} + +func TestTaskRunnerSuccess(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + r.AddDag("test_dag").AddTask(simpleTask) + }) + + details := &StartupDetails{ + TI: TaskInstanceInfo{ + ID: "550e8400-e29b-41d4-a716-446655440000", + DagID: "test_dag", + TaskID: "simpleTask", + RunID: "run1", + MapIndex: -1, + }, + BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"}, + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger) + + result := RunTask(bundle, details, comm, logger) + assert.Equal(t, "SucceedTask", result["type"]) +} + +func TestTaskRunnerFailure(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + r.AddDag("test_dag").AddTask(failingTask) + }) + + details := &StartupDetails{ + TI: TaskInstanceInfo{ + ID: "550e8400-e29b-41d4-a716-446655440000", + DagID: "test_dag", + TaskID: "failingTask", + RunID: "run1", + MapIndex: -1, + }, + BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"}, + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger) + + result := RunTask(bundle, details, comm, logger) + assert.Equal(t, "TaskState", result["type"]) + assert.Equal(t, "failed", result["state"]) +} + +func TestTaskRunnerTaskNotFound(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + r.AddDag("test_dag").AddTask(simpleTask) + }) + + details := &StartupDetails{ + TI: TaskInstanceInfo{ + ID: "550e8400-e29b-41d4-a716-446655440000", + DagID: "test_dag", + TaskID: "nonexistent", + RunID: "run1", + }, + BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"}, + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger) + + result := RunTask(bundle, details, comm, logger) + assert.Equal(t, "TaskState", result["type"]) + assert.Equal(t, "removed", result["state"]) +} + +func TestTaskRunnerPanic(t *testing.T) { + bundle := buildBundle(t, func(r bundlev1.Registry) { + r.AddDag("test_dag").AddTask(panicTask) + }) + + details := &StartupDetails{ + TI: TaskInstanceInfo{ + ID: "550e8400-e29b-41d4-a716-446655440000", + DagID: "test_dag", + TaskID: "panicTask", + RunID: "run1", + MapIndex: -1, + }, + BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"}, + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger) + + result := RunTask(bundle, details, comm, logger) + assert.Equal(t, "TaskState", result["type"]) + assert.Equal(t, "failed", result["state"]) +} + +// --- End-to-end Serve test against a fake supervisor --- + +// fakeProvider implements bundlev1.BundleProvider; it lets a test inject the +// registration closure and a synthetic version. +type fakeProvider struct { + register func(bundlev1.Registry) error +} + +func (f *fakeProvider) GetBundleVersion() bundlev1.BundleInfo { + v := "1.0" + return bundlev1.BundleInfo{Name: "fake", Version: &v} +} + +func (f *fakeProvider) RegisterDags(reg bundlev1.Registry) error { + if f.register == nil { + return nil + } + return f.register(reg) +} + +func startSupervisor( + t *testing.T, +) (commAddr, logsAddr string, commCh, logsCh chan net.Conn, cleanup func()) { + t.Helper() + commLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + logsLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + commCh = make(chan net.Conn, 1) + logsCh = make(chan net.Conn, 1) + go func() { + c, err := commLn.Accept() + if err == nil { + commCh <- c + } + close(commCh) + }() + go func() { + c, err := logsLn.Accept() + if err == nil { + logsCh <- c + } + close(logsCh) + }() + cleanup = func() { + commLn.Close() + logsLn.Close() + } + return commLn.Addr().String(), logsLn.Addr().String(), commCh, logsCh, cleanup +} + +func TestServeDagFileParseEndToEnd(t *testing.T) { + commAddr, logsAddr, commCh, logsCh, cleanup := startSupervisor(t) + defer cleanup() + + provider := &fakeProvider{ + register: func(r bundlev1.Registry) error { + d := r.AddDag("simple_dag") + d.AddTask(simpleTask) + return nil + }, + } + + done := make(chan error, 1) + go func() { done <- Serve(provider, commAddr, logsAddr) }() + + commConn := <-commCh + require.NotNil(t, commConn) + defer commConn.Close() + logsConn := <-logsCh + require.NotNil(t, logsConn) + defer logsConn.Close() + + // Send DagFileParseRequest as a request frame. + payload, err := encodeRequest(0, map[string]any{ + "type": "DagFileParseRequest", + "file": "/bundle/main.go", + "bundle_path": "/bundle", + }) + require.NoError(t, err) + require.NoError(t, writeFrame(commConn, payload)) + + frame, err := readFrame(commConn) + require.NoError(t, err) + assert.Equal(t, 0, frame.ID) + require.Nil(t, frame.Err) + assert.Equal(t, "DagFileParsingResult", frame.Body["type"]) + + dags := frame.Body["serialized_dags"].([]any) + require.Len(t, dags, 1) + dag := dags[0].(map[string]any)["data"].(map[string]any)["dag"].(map[string]any) + assert.Equal(t, "simple_dag", dag["dag_id"]) + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Serve did not return after parse result") + } +} + +func TestServeStartupDetailsEndToEnd(t *testing.T) { + commAddr, logsAddr, commCh, logsCh, cleanup := startSupervisor(t) + defer cleanup() + + provider := &fakeProvider{ + register: func(r bundlev1.Registry) error { + r.AddDag("dag1").AddTask(simpleTask) + return nil + }, + } + + done := make(chan error, 1) + go func() { done <- Serve(provider, commAddr, logsAddr) }() + + commConn := <-commCh + defer commConn.Close() + logsConn := <-logsCh + defer logsConn.Close() + + payload, err := encodeRequest(0, map[string]any{ + "type": "StartupDetails", + "ti": map[string]any{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "dag_id": "dag1", + "task_id": "simpleTask", + "run_id": "run1", + "try_number": 1, + }, + "bundle_info": map[string]any{"name": "fake", "version": "1.0"}, + }) + require.NoError(t, err) + require.NoError(t, writeFrame(commConn, payload)) + + frame, err := readFrame(commConn) + require.NoError(t, err) + require.Nil(t, frame.Err) + assert.Equal(t, "SucceedTask", frame.Body["type"]) + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Serve did not return after task completion") + } +} diff --git a/go-sdk/pkg/execution/logger.go b/go-sdk/pkg/execution/logger.go new file mode 100644 index 0000000000000..172cbba161f55 --- /dev/null +++ b/go-sdk/pkg/execution/logger.go @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "strings" + "sync" + "time" +) + +// SocketLogHandler is an slog.Handler that streams structured JSON log lines +// to the logs TCP socket. Each log entry is a single JSON object followed by +// a newline, matching the Airflow log streaming format. +// +// Key mapping: +// - "event" for the log message (not "msg") +// - "level" in lowercase (not "INFO"/"ERROR") +// - "timestamp" in RFC3339Nano format (not "time") +// - Additional attributes are included as top-level fields +type SocketLogHandler struct { + shared *socketLogHandlerShared + level slog.Level + attrs []slog.Attr + groups []string +} + +// socketLogHandlerShared holds the writer and buffer that must remain shared +// across WithAttrs / WithGroup clones; otherwise the sync.Mutex would be +// copied (which the runtime detector flags as a bug). +type socketLogHandlerShared struct { + mu sync.Mutex + writer io.Writer + buf [][]byte + connected bool +} + +var _ slog.Handler = (*SocketLogHandler)(nil) + +// NewSocketLogHandler creates a new handler. If writer is nil, messages are +// buffered until Connect() is called. +func NewSocketLogHandler(writer io.Writer, level slog.Level) *SocketLogHandler { + shared := &socketLogHandlerShared{} + if writer != nil { + shared.writer = writer + shared.connected = true + } + return &SocketLogHandler{ + shared: shared, + level: level, + } +} + +// Connect sets the writer and flushes any buffered log messages. +func (h *SocketLogHandler) Connect(w io.Writer) { + h.shared.mu.Lock() + defer h.shared.mu.Unlock() + + h.shared.writer = w + h.shared.connected = true + + for _, line := range h.shared.buf { + _, _ = w.Write(line) + } + h.shared.buf = nil +} + +func (h *SocketLogHandler) Enabled(_ context.Context, level slog.Level) bool { + return level >= h.level +} + +func (h *SocketLogHandler) Handle(_ context.Context, r slog.Record) error { + entry := make(map[string]any) + + // Set standard fields. + entry["event"] = r.Message + entry["level"] = strings.ToLower(r.Level.String()) + if !r.Time.IsZero() { + entry["timestamp"] = r.Time.Format(time.RFC3339Nano) + } + + // Apply pre-configured attrs. + for _, a := range h.attrs { + key := h.prefixedKey(a.Key) + entry[key] = a.Value.Any() + } + + // Apply record attrs. + r.Attrs(func(a slog.Attr) bool { + key := h.prefixedKey(a.Key) + entry[key] = a.Value.Any() + return true + }) + + line, err := json.Marshal(entry) + if err != nil { + return err + } + line = append(line, '\n') + + h.shared.mu.Lock() + defer h.shared.mu.Unlock() + + if !h.shared.connected { + h.shared.buf = append(h.shared.buf, line) + return nil + } + + _, err = h.shared.writer.Write(line) + return err +} + +func (h *SocketLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &SocketLogHandler{ + shared: h.shared, + level: h.level, + attrs: append(append([]slog.Attr{}, h.attrs...), attrs...), + groups: h.groups, + } +} + +func (h *SocketLogHandler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + return &SocketLogHandler{ + shared: h.shared, + level: h.level, + attrs: h.attrs, + groups: append(append([]string{}, h.groups...), name), + } +} + +// prefixedKey prepends any active group names to the attribute key. +func (h *SocketLogHandler) prefixedKey(key string) string { + if len(h.groups) == 0 { + return key + } + return strings.Join(h.groups, ".") + "." + key +} diff --git a/go-sdk/pkg/execution/logger_test.go b/go-sdk/pkg/execution/logger_test.go new file mode 100644 index 0000000000000..cf925ad13744a --- /dev/null +++ b/go-sdk/pkg/execution/logger_test.go @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/json" + "log/slog" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSocketLogHandlerBasicOutput(t *testing.T) { + var buf bytes.Buffer + handler := NewSocketLogHandler(&buf, slog.LevelDebug) + logger := slog.New(handler) + + logger.Info("test message", "key1", "val1") + + // Parse the output. + output := buf.String() + assert.True(t, strings.HasSuffix(output, "\n"), "output should end with newline") + + var entry map[string]any + require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &entry)) + + assert.Equal(t, "test message", entry["event"]) + assert.Equal(t, "info", entry["level"]) + assert.Equal(t, "val1", entry["key1"]) + assert.Contains(t, entry, "timestamp") +} + +func TestSocketLogHandlerLevelFiltering(t *testing.T) { + var buf bytes.Buffer + handler := NewSocketLogHandler(&buf, slog.LevelWarn) + logger := slog.New(handler) + + logger.Debug("should be filtered") + logger.Info("also filtered") + logger.Warn("should appear") + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + assert.Len(t, lines, 1) + + var entry map[string]any + require.NoError(t, json.Unmarshal([]byte(lines[0]), &entry)) + assert.Equal(t, "should appear", entry["event"]) + assert.Equal(t, "warn", entry["level"]) +} + +func TestSocketLogHandlerBuffering(t *testing.T) { + // Create handler without a writer — messages should be buffered. + handler := NewSocketLogHandler(nil, slog.LevelDebug) + logger := slog.New(handler) + + logger.Info("buffered message 1") + logger.Info("buffered message 2") + + // Connect the writer — buffered messages should flush. + var buf bytes.Buffer + handler.Connect(&buf) + + output := buf.String() + lines := strings.Split(strings.TrimSpace(output), "\n") + assert.Len(t, lines, 2) + + // New messages should write directly. + logger.Info("direct message") + lines = strings.Split(strings.TrimSpace(buf.String()), "\n") + assert.Len(t, lines, 3) +} + +func TestSocketLogHandlerWithAttrs(t *testing.T) { + var buf bytes.Buffer + handler := NewSocketLogHandler(&buf, slog.LevelDebug) + logger := slog.New(handler).With("component", "test") + + logger.Info("with attrs") + + var entry map[string]any + require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(buf.String())), &entry)) + assert.Equal(t, "test", entry["component"]) +} + +func TestSocketLogHandlerWithGroup(t *testing.T) { + var buf bytes.Buffer + handler := NewSocketLogHandler(&buf, slog.LevelDebug) + logger := slog.New(handler).WithGroup("grp") + + logger.Info("grouped", "key", "val") + + var entry map[string]any + require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(buf.String())), &entry)) + assert.Equal(t, "val", entry["grp.key"]) +} + +func TestSocketLogHandlerKeyMapping(t *testing.T) { + var buf bytes.Buffer + handler := NewSocketLogHandler(&buf, slog.LevelDebug) + logger := slog.New(handler) + + logger.Error("an error occurred") + + var entry map[string]any + require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(buf.String())), &entry)) + + // Check key mapping: "event" not "msg", "level" lowercase, "timestamp" not "time" + assert.Equal(t, "an error occurred", entry["event"]) + assert.Equal(t, "error", entry["level"]) + _, hasTimestamp := entry["timestamp"] + assert.True(t, hasTimestamp) + _, hasMsg := entry["msg"] + assert.False(t, hasMsg) + _, hasTime := entry["time"] + assert.False(t, hasTime) +} diff --git a/go-sdk/pkg/execution/messages.go b/go-sdk/pkg/execution/messages.go new file mode 100644 index 0000000000000..ae5f6a6c7045c --- /dev/null +++ b/go-sdk/pkg/execution/messages.go @@ -0,0 +1,394 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "fmt" + "time" +) + +// --- Inbound messages (Supervisor -> Runtime) --- + +// DagFileParseRequest is sent by the supervisor to request DAG parsing. +type DagFileParseRequest struct { + File string + BundlePath string +} + +func decodeDagFileParseRequest(m map[string]any) (*DagFileParseRequest, error) { + file, err := mapString(m, "file") + if err != nil { + return nil, err + } + bundlePath := mapStringOr(m, "bundle_path", "") + return &DagFileParseRequest{File: file, BundlePath: bundlePath}, nil +} + +// TaskInstanceInfo holds task instance details from StartupDetails. +type TaskInstanceInfo struct { + ID string + TaskID string + DagID string + RunID string + TryNumber int + DagVersionID string + MapIndex int + ContextCarrier map[string]any +} + +func decodeTaskInstanceInfo(m map[string]any) (TaskInstanceInfo, error) { + if m == nil { + return TaskInstanceInfo{}, fmt.Errorf("nil task instance map") + } + id, err := mapString(m, "id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.id: %w", err) + } + taskID, err := mapString(m, "task_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.task_id: %w", err) + } + dagID, err := mapString(m, "dag_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.dag_id: %w", err) + } + runID, err := mapString(m, "run_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.run_id: %w", err) + } + tryNumber := mapIntOr(m, "try_number", 1) + dagVersionID := mapStringOr(m, "dag_version_id", "") + mapIndex := mapIntOr(m, "map_index", -1) + contextCarrier := mapMap(m, "context_carrier") + + return TaskInstanceInfo{ + ID: id, + TaskID: taskID, + DagID: dagID, + RunID: runID, + TryNumber: tryNumber, + DagVersionID: dagVersionID, + MapIndex: mapIndex, + ContextCarrier: contextCarrier, + }, nil +} + +// BundleInfoMsg holds bundle identification from StartupDetails. +type BundleInfoMsg struct { + Name string + Version string +} + +func decodeBundleInfo(m map[string]any) BundleInfoMsg { + if m == nil { + return BundleInfoMsg{} + } + return BundleInfoMsg{ + Name: mapStringOr(m, "name", ""), + Version: mapStringOr(m, "version", ""), + } +} + +// TIRunContext holds the runtime context for a task instance. +type TIRunContext struct { + LogicalDate *time.Time + DataIntervalStart *time.Time + DataIntervalEnd *time.Time +} + +func decodeTIRunContext(m map[string]any) TIRunContext { + if m == nil { + return TIRunContext{} + } + ctx := TIRunContext{} + if t, err := asTime(m["logical_date"]); err == nil { + ctx.LogicalDate = &t + } + if t, err := asTime(m["data_interval_start"]); err == nil { + ctx.DataIntervalStart = &t + } + if t, err := asTime(m["data_interval_end"]); err == nil { + ctx.DataIntervalEnd = &t + } + return ctx +} + +// StartupDetails is sent by the supervisor to initiate task execution. +type StartupDetails struct { + TI TaskInstanceInfo + DagRelPath string + BundleInfo BundleInfoMsg + StartDate time.Time + TIContext TIRunContext + SentryIntegration string +} + +func decodeStartupDetails(m map[string]any) (*StartupDetails, error) { + tiMap := mapMap(m, "ti") + ti, err := decodeTaskInstanceInfo(tiMap) + if err != nil { + return nil, fmt.Errorf("decoding ti: %w", err) + } + + dagRelPath := mapStringOr(m, "dag_rel_path", "") + bundleInfo := decodeBundleInfo(mapMap(m, "bundle_info")) + + var startDate time.Time + if t, err := asTime(m["start_date"]); err == nil { + startDate = t + } + + tiContext := decodeTIRunContext(mapMap(m, "ti_context")) + sentryIntegration := mapStringOr(m, "sentry_integration", "") + + return &StartupDetails{ + TI: ti, + DagRelPath: dagRelPath, + BundleInfo: bundleInfo, + StartDate: startDate, + TIContext: tiContext, + SentryIntegration: sentryIntegration, + }, nil +} + +// --- Response types (for runtime-initiated requests) --- + +// ConnectionResult is the response to GetConnection. +type ConnectionResult struct { + ConnID string + ConnType string + Host string + Schema string + Login string + Password string + Port int + Extra string +} + +func decodeConnectionResult(m map[string]any) (*ConnectionResult, error) { + return &ConnectionResult{ + ConnID: mapStringOr(m, "conn_id", ""), + ConnType: mapStringOr(m, "conn_type", ""), + Host: mapStringOr(m, "host", ""), + Schema: mapStringOr(m, "schema", ""), + Login: mapStringOr(m, "login", ""), + Password: mapStringOr(m, "password", ""), + Port: mapIntOr(m, "port", 0), + Extra: mapStringOr(m, "extra", ""), + }, nil +} + +// VariableResult is the response to GetVariable. +type VariableResult struct { + Key string + Value any +} + +func decodeVariableResult(m map[string]any) (*VariableResult, error) { + return &VariableResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// XComResult is the response to GetXCom. +type XComResult struct { + Key string + Value any +} + +func decodeXComResult(m map[string]any) (*XComResult, error) { + return &XComResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// ErrorResponse represents an error returned by the supervisor. +type ErrorResponse struct { + Error string + Detail any +} + +func decodeErrorResponse(m map[string]any) *ErrorResponse { + if m == nil { + return nil + } + return &ErrorResponse{ + Error: mapStringOr(m, "error", ""), + Detail: m["detail"], + } +} + +// --- Outbound messages (Runtime -> Supervisor) --- + +// GetConnectionMsg is sent to request a connection from the supervisor. +type GetConnectionMsg struct { + ConnID string +} + +func (m GetConnectionMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetConnection", + "conn_id": m.ConnID, + } +} + +// GetVariableMsg is sent to request a variable from the supervisor. +type GetVariableMsg struct { + Key string +} + +func (m GetVariableMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetVariable", + "key": m.Key, + } +} + +// GetXComMsg is sent to request an XCom value from the supervisor. +type GetXComMsg struct { + Key string + DagID string + TaskID string + RunID string + MapIndex *int + IncludePriorDates bool +} + +func (m GetXComMsg) toMap() map[string]any { + result := map[string]any{ + "type": "GetXCom", + "key": m.Key, + "dag_id": m.DagID, + "task_id": m.TaskID, + "run_id": m.RunID, + "include_prior_dates": m.IncludePriorDates, + } + if m.MapIndex != nil { + result["map_index"] = *m.MapIndex + } + return result +} + +// SetXComMsg is sent to set an XCom value. +type SetXComMsg struct { + Key string + Value any + DagID string + TaskID string + RunID string + MapIndex int + MappedLength *int +} + +func (m SetXComMsg) toMap() map[string]any { + result := map[string]any{ + "type": "SetXCom", + "key": m.Key, + "value": m.Value, + "dag_id": m.DagID, + "task_id": m.TaskID, + "run_id": m.RunID, + "map_index": m.MapIndex, + } + if m.MappedLength != nil { + result["mapped_length"] = *m.MappedLength + } + return result +} + +// SucceedTaskMsg is sent as a terminal message when a task succeeds. +type SucceedTaskMsg struct { + EndDate time.Time + TaskOutlets []any + OutletEvents []any +} + +func (m SucceedTaskMsg) toMap() map[string]any { + taskOutlets := m.TaskOutlets + if taskOutlets == nil { + taskOutlets = []any{} + } + outletEvents := m.OutletEvents + if outletEvents == nil { + outletEvents = []any{} + } + return map[string]any{ + "type": "SucceedTask", + "end_date": m.EndDate.UTC().Format(time.RFC3339), + "task_outlets": taskOutlets, + "outlet_events": outletEvents, + } +} + +// TaskStateMsg is sent as a terminal message for failed/removed/skipped tasks. +type TaskStateMsg struct { + State string // "failed", "removed", "skipped" + EndDate time.Time +} + +func (m TaskStateMsg) toMap() map[string]any { + return map[string]any{ + "type": "TaskState", + "state": m.State, + "end_date": m.EndDate.UTC().Format(time.RFC3339), + } +} + +// --- Message dispatch --- + +// decodeIncomingBody dispatches decoding of a body map based on its "type" field. +func decodeIncomingBody(m map[string]any) (any, error) { + if m == nil { + return nil, nil + } + typ, _ := m["type"].(string) + switch typ { + case "DagFileParseRequest": + return decodeDagFileParseRequest(m) + case "StartupDetails": + return decodeStartupDetails(m) + case "ConnectionResult": + return decodeConnectionResult(m) + case "VariableResult": + return decodeVariableResult(m) + case "XComResult": + return decodeXComResult(m) + case "ErrorResponse": + return decodeErrorResponse(m), nil + default: + return nil, fmt.Errorf("unknown message type %q", typ) + } +} + +// asTime parses a time value that may be a time.Time (from msgpack timestamp ext) +// or a string (ISO 8601 format). +func asTime(v any) (time.Time, error) { + if v == nil { + return time.Time{}, fmt.Errorf("nil time value") + } + switch t := v.(type) { + case time.Time: + return t, nil + case string: + return time.Parse(time.RFC3339Nano, t) + default: + return time.Time{}, fmt.Errorf("expected time, got %T", v) + } +} diff --git a/go-sdk/pkg/execution/messages_test.go b/go-sdk/pkg/execution/messages_test.go new file mode 100644 index 0000000000000..78270dc2f4fd1 --- /dev/null +++ b/go-sdk/pkg/execution/messages_test.go @@ -0,0 +1,258 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeDagFileParseRequest(t *testing.T) { + m := map[string]any{ + "type": "DagFileParseRequest", + "file": "/path/to/dags.go", + "bundle_path": "/bundles/my_bundle", + } + + req, err := decodeDagFileParseRequest(m) + require.NoError(t, err) + assert.Equal(t, "/path/to/dags.go", req.File) + assert.Equal(t, "/bundles/my_bundle", req.BundlePath) +} + +func TestDecodeStartupDetails(t *testing.T) { + m := map[string]any{ + "type": "StartupDetails", + "ti": map[string]any{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "task_id": "extract", + "dag_id": "tutorial_dag", + "run_id": "manual__2024-01-15", + "try_number": int64(1), + "dag_version_id": "abc-123", + "map_index": int64(-1), + }, + "dag_rel_path": "dags/tutorial.go", + "bundle_info": map[string]any{ + "name": "example_dags", + "version": "1.0.0", + }, + "start_date": "2024-01-15T10:30:00Z", + "sentry_integration": "", + "ti_context": map[string]any{ + "logical_date": "2024-01-15T00:00:00Z", + "data_interval_start": "2024-01-14T00:00:00Z", + "data_interval_end": "2024-01-15T00:00:00Z", + }, + } + + details, err := decodeStartupDetails(m) + require.NoError(t, err) + + assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", details.TI.ID) + assert.Equal(t, "extract", details.TI.TaskID) + assert.Equal(t, "tutorial_dag", details.TI.DagID) + assert.Equal(t, "manual__2024-01-15", details.TI.RunID) + assert.Equal(t, 1, details.TI.TryNumber) + assert.Equal(t, -1, details.TI.MapIndex) + assert.Equal(t, "dags/tutorial.go", details.DagRelPath) + assert.Equal(t, "example_dags", details.BundleInfo.Name) + assert.Equal(t, "1.0.0", details.BundleInfo.Version) + assert.NotNil(t, details.TIContext.LogicalDate) +} + +func TestDecodeConnectionResult(t *testing.T) { + m := map[string]any{ + "type": "ConnectionResult", + "conn_id": "my_db", + "conn_type": "postgres", + "host": "db.example.com", + "schema": "mydb", + "login": "user", + "password": "secret", + "port": int64(5432), + "extra": `{"sslmode":"require"}`, + } + + result, err := decodeConnectionResult(m) + require.NoError(t, err) + assert.Equal(t, "my_db", result.ConnID) + assert.Equal(t, "postgres", result.ConnType) + assert.Equal(t, "db.example.com", result.Host) + assert.Equal(t, "mydb", result.Schema) + assert.Equal(t, "user", result.Login) + assert.Equal(t, "secret", result.Password) + assert.Equal(t, 5432, result.Port) +} + +func TestDecodeVariableResult(t *testing.T) { + m := map[string]any{ + "type": "VariableResult", + "key": "my_var", + "value": "hello", + } + + result, err := decodeVariableResult(m) + require.NoError(t, err) + assert.Equal(t, "my_var", result.Key) + assert.Equal(t, "hello", result.Value) +} + +func TestDecodeXComResult(t *testing.T) { + m := map[string]any{ + "type": "XComResult", + "key": "return_value", + "value": map[string]any{"data": "processed"}, + } + + result, err := decodeXComResult(m) + require.NoError(t, err) + assert.Equal(t, "return_value", result.Key) + valMap, ok := result.Value.(map[string]any) + require.True(t, ok) + assert.Equal(t, "processed", valMap["data"]) +} + +func TestDecodeErrorResponseNil(t *testing.T) { + assert.Nil(t, decodeErrorResponse(nil)) +} + +func TestGetConnectionMsgToMap(t *testing.T) { + msg := GetConnectionMsg{ConnID: "my_db"} + m := msg.toMap() + assert.Equal(t, "GetConnection", m["type"]) + assert.Equal(t, "my_db", m["conn_id"]) +} + +func TestGetVariableMsgToMap(t *testing.T) { + msg := GetVariableMsg{Key: "my_var"} + m := msg.toMap() + assert.Equal(t, "GetVariable", m["type"]) + assert.Equal(t, "my_var", m["key"]) +} + +func TestGetXComMsgToMapWithMapIndex(t *testing.T) { + mapIdx := 3 + msg := GetXComMsg{ + Key: "result", + DagID: "dag1", + TaskID: "task1", + RunID: "run1", + MapIndex: &mapIdx, + IncludePriorDates: true, + } + m := msg.toMap() + assert.Equal(t, "GetXCom", m["type"]) + assert.Equal(t, 3, m["map_index"]) + assert.Equal(t, true, m["include_prior_dates"]) +} + +func TestGetXComMsgToMapNilMapIndex(t *testing.T) { + msg := GetXComMsg{Key: "result", DagID: "d", TaskID: "t", RunID: "r"} + m := msg.toMap() + _, hasMapIndex := m["map_index"] + assert.False(t, hasMapIndex) +} + +func TestSetXComMsgToMap(t *testing.T) { + msg := SetXComMsg{ + Key: "output", Value: 42, + DagID: "dag1", TaskID: "task1", RunID: "run1", MapIndex: -1, + } + m := msg.toMap() + assert.Equal(t, "SetXCom", m["type"]) + assert.Equal(t, 42, m["value"]) + _, hasMappedLength := m["mapped_length"] + assert.False(t, hasMappedLength) +} + +func TestSucceedTaskMsgToMap(t *testing.T) { + endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + msg := SucceedTaskMsg{EndDate: endDate} + m := msg.toMap() + assert.Equal(t, "SucceedTask", m["type"]) + assert.Equal(t, "2024-01-15T10:30:00Z", m["end_date"]) + assert.Equal(t, []any{}, m["task_outlets"]) + assert.Equal(t, []any{}, m["outlet_events"]) +} + +func TestTaskStateMsgToMap(t *testing.T) { + endDate := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + msg := TaskStateMsg{State: "failed", EndDate: endDate} + m := msg.toMap() + assert.Equal(t, "TaskState", m["type"]) + assert.Equal(t, "failed", m["state"]) +} + +func TestDecodeIncomingBodyDispatch(t *testing.T) { + t.Run("DagFileParseRequest", func(t *testing.T) { + body := map[string]any{"type": "DagFileParseRequest", "file": "x", "bundle_path": "y"} + result, err := decodeIncomingBody(body) + require.NoError(t, err) + _, ok := result.(*DagFileParseRequest) + assert.True(t, ok) + }) + + t.Run("ConnectionResult", func(t *testing.T) { + body := map[string]any{"type": "ConnectionResult", "conn_id": "x"} + result, err := decodeIncomingBody(body) + require.NoError(t, err) + _, ok := result.(*ConnectionResult) + assert.True(t, ok) + }) + + t.Run("nil", func(t *testing.T) { + result, err := decodeIncomingBody(nil) + require.NoError(t, err) + assert.Nil(t, result) + }) + + t.Run("unknown type", func(t *testing.T) { + _, err := decodeIncomingBody(map[string]any{"type": "UnknownMsg"}) + assert.Error(t, err) + }) +} + +func TestAsTime(t *testing.T) { + t.Run("from string", func(t *testing.T) { + ts, err := asTime("2024-01-15T10:30:00Z") + require.NoError(t, err) + assert.Equal(t, 2024, ts.Year()) + assert.Equal(t, time.January, ts.Month()) + }) + + t.Run("from time.Time", func(t *testing.T) { + now := time.Now() + ts, err := asTime(now) + require.NoError(t, err) + assert.Equal(t, now, ts) + }) + + t.Run("nil", func(t *testing.T) { + _, err := asTime(nil) + assert.Error(t, err) + }) + + t.Run("wrong type", func(t *testing.T) { + _, err := asTime(42) + assert.Error(t, err) + }) +} diff --git a/go-sdk/pkg/execution/serde.go b/go-sdk/pkg/execution/serde.go new file mode 100644 index 0000000000000..02ec641c89dcb --- /dev/null +++ b/go-sdk/pkg/execution/serde.go @@ -0,0 +1,279 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "fmt" + "path/filepath" + "reflect" + "sort" + "time" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" +) + +// serializeValue recursively serializes a value with Airflow's type/var encoding. +// This matches Python's BaseSerialization.serialize() output: +// - primitives (string, bool, int, float) pass through unchanged +// - time.Time -> {"__type": "datetime", "__var": epoch_seconds_float} +// - time.Duration -> {"__type": "timedelta", "__var": total_seconds_float} +// - map[string]any -> {"__type": "dict", "__var": {k: serialize(v), ...}} +// - []any -> direct array with each element serialized +func serializeValue(value any) any { + if value == nil { + return nil + } + switch v := value.(type) { + case string, bool: + return v + case int: + return v + case int8: + return int(v) + case int16: + return int(v) + case int32: + return int(v) + case int64: + return v + case float32: + return float64(v) + case float64: + return v + case time.Time: + epochSec := float64(v.Unix()) + float64(v.Nanosecond())/1e9 + return map[string]any{ + "__type": "datetime", + "__var": epochSec, + } + case time.Duration: + return map[string]any{ + "__type": "timedelta", + "__var": v.Seconds(), + } + case map[string]any: + serialized := make(map[string]any, len(v)) + for k, val := range v { + serialized[k] = serializeValue(val) + } + return map[string]any{ + "__type": "dict", + "__var": serialized, + } + case []string: + result := make([]any, len(v)) + for i, item := range v { + result[i] = serializeValue(item) + } + return result + case []any: + result := make([]any, len(v)) + for i, item := range v { + result[i] = serializeValue(item) + } + return result + default: + // Use reflection to handle typed maps and slices that don't match + // the concrete types above (e.g., map[string]map[string][]string). + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Map: + serialized := make(map[string]any, rv.Len()) + for _, key := range rv.MapKeys() { + serialized[fmt.Sprint(key.Interface())] = serializeValue(rv.MapIndex(key).Interface()) + } + return map[string]any{ + "__type": "dict", + "__var": serialized, + } + case reflect.Slice, reflect.Array: + result := make([]any, rv.Len()) + for i := range result { + result[i] = serializeValue(rv.Index(i).Interface()) + } + return result + default: + return v + } + } +} + +// unwrapTypeEncoding extracts the "__var" part from a type-encoded value. +// In Python's serialize_to_json, non-decorated fields are serialized then unwrapped. +func unwrapTypeEncoding(value any) any { + m, ok := value.(map[string]any) + if !ok { + return value + } + if _, hasType := m["__type"]; !hasType { + return value + } + if v, hasVar := m["__var"]; hasVar { + return v + } + return value +} + +// serializeTimetable converts a schedule string to the Airflow timetable format. +func serializeTimetable(schedule *string) map[string]any { + if schedule == nil { + return map[string]any{ + "__type": "airflow.timetables.simple.NullTimetable", + "__var": map[string]any{}, + } + } + switch *schedule { + case "@once": + return map[string]any{ + "__type": "airflow.timetables.simple.OnceTimetable", + "__var": map[string]any{}, + } + case "@continuous": + return map[string]any{ + "__type": "airflow.timetables.simple.ContinuousTimetable", + "__var": map[string]any{}, + } + default: + return map[string]any{ + "__type": "airflow.timetables.trigger.CronTriggerTimetable", + "__var": map[string]any{ + "expression": *schedule, + "timezone": "UTC", + "interval": 0.0, + "run_immediately": false, + }, + } + } +} + +// serializeTask converts a task to the Airflow serialization format. +func serializeTask(taskID, typeName, pkgPath string, downstream []string) map[string]any { + if typeName == "" { + typeName = taskID + } + if pkgPath == "" { + pkgPath = "main" + } + data := map[string]any{ + "task_id": taskID, + "task_type": typeName, + "_task_module": pkgPath, + "language": "go", + } + if len(downstream) > 0 { + sorted := make([]string, len(downstream)) + copy(sorted, downstream) + sort.Strings(sorted) + data["downstream_task_ids"] = sorted + } + return map[string]any{ + "__type": "operator", + "__var": data, + } +} + +// serializeTaskGroup creates a flat root task group containing all task IDs. +func serializeTaskGroup(taskIDs []string) map[string]any { + children := make(map[string]any, len(taskIDs)) + for _, id := range taskIDs { + children[id] = []any{"operator", id} + } + return map[string]any{ + "_group_id": nil, + "group_display_name": "", + "prefix_group_id": true, + "tooltip": "", + "ui_color": "CornflowerBlue", + "ui_fgcolor": "#000", + "children": children, + "upstream_group_ids": []any{}, + "downstream_group_ids": []any{}, + "upstream_task_ids": []any{}, + "downstream_task_ids": []any{}, + } +} + +// serializeParams converts DAG params to Airflow's serialization format. +func serializeParams(params map[string]any) []any { + if len(params) == 0 { + return []any{} + } + result := make([]any, 0, len(params)) + for k, v := range params { + result = append(result, []any{ + k, + map[string]any{ + "__class": "airflow.sdk.definitions.param.Param", + "default": serializeValue(v), + "description": nil, + "schema": serializeValue(map[string]any{}), + "source": nil, + }, + }) + } + return result +} + +// SerializeDag converts a bundlev1.DagInfo to Airflow DagSerialization v3 +// format. The Go SDK's bundlev1.Dag interface does not (yet) carry per-DAG +// metadata like schedule, start_date, tags, etc., so the encoding emits +// schema defaults for those fields. The optional-field handling below is +// kept (gated on nil checks) so the encoder can grow naturally as the +// bundle metadata surface expands. +func SerializeDag(info bundlev1.DagInfo, fileloc, relativeFileloc string) map[string]any { + taskIDs := make([]string, len(info.Tasks)) + tasks := make([]any, len(info.Tasks)) + for i, t := range info.Tasks { + taskIDs[i] = t.ID + tasks[i] = serializeTask(t.ID, t.TypeName, t.PkgPath, nil) + } + + return map[string]any{ + // Required fields (always present) + "dag_id": info.DagID, + "fileloc": fileloc, + "relative_fileloc": relativeFileloc, + "timezone": "UTC", + "timetable": serializeTimetable(nil), + "tasks": tasks, + "dag_dependencies": []any{}, + "task_group": serializeTaskGroup(taskIDs), + "edge_info": map[string]any{}, + "params": serializeParams(nil), + "deadline": nil, + "allowed_run_types": nil, + } +} + +// computeRelativeFileloc computes the relative file location from the bundle path. +func computeRelativeFileloc(fileloc, bundlePath string) string { + if fileloc == "" { + return "" + } + if bundlePath == "" { + return "." + } + rel, err := filepath.Rel(bundlePath, fileloc) + if err != nil { + return "." + } + if rel == "" { + return "." + } + return rel +} diff --git a/go-sdk/pkg/execution/serde_test.go b/go-sdk/pkg/execution/serde_test.go new file mode 100644 index 0000000000000..afddaf5e700b1 --- /dev/null +++ b/go-sdk/pkg/execution/serde_test.go @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" +) + +func TestSerializeValuePrimitives(t *testing.T) { + assert.Nil(t, serializeValue(nil)) + assert.Equal(t, "hello", serializeValue("hello")) + assert.Equal(t, true, serializeValue(true)) + assert.Equal(t, 42, serializeValue(42)) + assert.Equal(t, float64(3.14), serializeValue(3.14)) +} + +func TestSerializeValueDatetime(t *testing.T) { + ts := time.Date(2024, 1, 15, 10, 30, 0, 500000000, time.UTC) + result := serializeValue(ts) + m, ok := result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "datetime", m["__type"]) + epochSec := m["__var"].(float64) + expected := float64(ts.Unix()) + 0.5 + assert.InDelta(t, expected, epochSec, 0.001) +} + +func TestSerializeValueTimedelta(t *testing.T) { + dur := 90 * time.Second + result := serializeValue(dur) + m, ok := result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "timedelta", m["__type"]) + assert.Equal(t, 90.0, m["__var"]) +} + +func TestSerializeValueMap(t *testing.T) { + input := map[string]any{ + "key1": "val1", + "key2": 42, + } + result := serializeValue(input) + m, ok := result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "dict", m["__type"]) + inner := m["__var"].(map[string]any) + assert.Equal(t, "val1", inner["key1"]) + assert.Equal(t, 42, inner["key2"]) +} + +func TestSerializeValueSlice(t *testing.T) { + input := []any{"a", 1, true} + result := serializeValue(input) + arr, ok := result.([]any) + require.True(t, ok) + assert.Len(t, arr, 3) + assert.Equal(t, "a", arr[0]) +} + +func TestUnwrapTypeEncoding(t *testing.T) { + wrapped := map[string]any{ + "__type": "datetime", + "__var": 1705313400.5, + } + assert.Equal(t, 1705313400.5, unwrapTypeEncoding(wrapped)) + + assert.Equal(t, "hello", unwrapTypeEncoding("hello")) + assert.Equal(t, 42, unwrapTypeEncoding(42)) +} + +func TestSerializeTimetable(t *testing.T) { + t.Run("nil schedule", func(t *testing.T) { + result := serializeTimetable(nil) + assert.Equal(t, "airflow.timetables.simple.NullTimetable", result["__type"]) + }) + + t.Run("@once", func(t *testing.T) { + s := "@once" + result := serializeTimetable(&s) + assert.Equal(t, "airflow.timetables.simple.OnceTimetable", result["__type"]) + }) + + t.Run("@continuous", func(t *testing.T) { + s := "@continuous" + result := serializeTimetable(&s) + assert.Equal(t, "airflow.timetables.simple.ContinuousTimetable", result["__type"]) + }) + + t.Run("cron expression", func(t *testing.T) { + s := "0 12 * * *" + result := serializeTimetable(&s) + assert.Equal(t, "airflow.timetables.trigger.CronTriggerTimetable", result["__type"]) + v := result["__var"].(map[string]any) + assert.Equal(t, "0 12 * * *", v["expression"]) + assert.Equal(t, "UTC", v["timezone"]) + assert.Equal(t, 0.0, v["interval"]) + assert.Equal(t, false, v["run_immediately"]) + }) +} + +func TestSerializeTask(t *testing.T) { + result := serializeTask("extract", "extract", "main", []string{"transform"}) + assert.Equal(t, "operator", result["__type"]) + data := result["__var"].(map[string]any) + assert.Equal(t, "extract", data["task_id"]) + assert.Equal(t, "extract", data["task_type"]) + assert.Equal(t, "main", data["_task_module"]) + assert.Equal(t, "go", data["language"]) + assert.Equal(t, []string{"transform"}, data["downstream_task_ids"]) +} + +func TestSerializeTaskNoDownstream(t *testing.T) { + result := serializeTask("load", "load", "main", nil) + data := result["__var"].(map[string]any) + _, hasDownstream := data["downstream_task_ids"] + assert.False(t, hasDownstream) +} + +func TestSerializeTaskGroup(t *testing.T) { + result := serializeTaskGroup([]string{"t1", "t2"}) + assert.Nil(t, result["_group_id"]) + assert.Equal(t, true, result["prefix_group_id"]) + assert.Equal(t, "CornflowerBlue", result["ui_color"]) + + children := result["children"].(map[string]any) + assert.Equal(t, []any{"operator", "t1"}, children["t1"]) + assert.Equal(t, []any{"operator", "t2"}, children["t2"]) +} + +func TestSerializeParams(t *testing.T) { + t.Run("empty", func(t *testing.T) { + result := serializeParams(nil) + assert.Equal(t, []any{}, result) + }) + + t.Run("with values", func(t *testing.T) { + result := serializeParams(map[string]any{"key1": "default_val"}) + assert.Len(t, result, 1) + pair := result[0].([]any) + assert.Equal(t, "key1", pair[0]) + paramMap := pair[1].(map[string]any) + assert.Equal(t, "airflow.sdk.definitions.param.Param", paramMap["__class"]) + assert.Equal(t, "default_val", paramMap["default"]) + }) +} + +func TestSerializeDagMinimal(t *testing.T) { + info := bundlev1.DagInfo{DagID: "test_dag"} + result := SerializeDag(info, "/path/to/bundle", ".") + + assert.Equal(t, "test_dag", result["dag_id"]) + assert.Equal(t, "/path/to/bundle", result["fileloc"]) + assert.Equal(t, ".", result["relative_fileloc"]) + assert.Equal(t, "UTC", result["timezone"]) + + tt := result["timetable"].(map[string]any) + assert.Equal(t, "airflow.timetables.simple.NullTimetable", tt["__type"]) + + _, hasDesc := result["description"] + assert.False(t, hasDesc) + _, hasCatchup := result["catchup"] + assert.False(t, hasCatchup) +} + +func TestSerializeDagWithTasks(t *testing.T) { + info := bundlev1.DagInfo{ + DagID: "etl", + Tasks: []bundlev1.TaskInfo{ + {ID: "extract", TypeName: "extract", PkgPath: "main"}, + {ID: "load", TypeName: "load", PkgPath: "main"}, + }, + } + result := SerializeDag(info, "/bundle/main.go", "main.go") + + tasks := result["tasks"].([]any) + require.Len(t, tasks, 2) + first := tasks[0].(map[string]any) + v := first["__var"].(map[string]any) + assert.Equal(t, "extract", v["task_id"]) + assert.Equal(t, "extract", v["task_type"]) + assert.Equal(t, "main", v["_task_module"]) + assert.Equal(t, "go", v["language"]) + + tg := result["task_group"].(map[string]any) + children := tg["children"].(map[string]any) + assert.Contains(t, children, "extract") + assert.Contains(t, children, "load") +} + +func TestComputeRelativeFileloc(t *testing.T) { + tests := []struct { + fileloc string + bundlePath string + want string + }{ + {"", "", ""}, + {"/a/b/c.go", "", "."}, + {"/bundles/my/dags.go", "/bundles/my", "dags.go"}, + {"/bundles/my/sub/dags.go", "/bundles/my", "sub/dags.go"}, + } + for _, tt := range tests { + result := computeRelativeFileloc(tt.fileloc, tt.bundlePath) + assert.Equal(t, tt.want, result, "fileloc=%q bundlePath=%q", tt.fileloc, tt.bundlePath) + } +} diff --git a/go-sdk/pkg/execution/server.go b/go-sdk/pkg/execution/server.go new file mode 100644 index 0000000000000..bee35a028d9de --- /dev/null +++ b/go-sdk/pkg/execution/server.go @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package execution implements the SDK coordinator-protocol runtime +// (msgpack-over-IPC). It is the second mode of bundlev1server.Serve: when +// the bundle binary is launched with --comm/--logs by the Airflow supervisor +// (Python ExecutableCoordinator), bundlev1server.Serve dispatches here. +// +// The first inbound frame on the comm socket selects between two +// sub-protocols: +// +// - DagFileParseRequest: one-shot, returns DagFileParsingResult and exits. +// - StartupDetails: multi-round task execution. +// +// See go-sdk/adr/0003-coordinator-protocol-msgpack-ipc.md. +package execution + +import ( + "fmt" + "log/slog" + "net" + "sync" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" +) + +// Serve runs the bundle binary in coordinator mode. It dials the supervisor's +// comm and logs sockets, installs an slog handler that writes JSON-line +// records to the logs connection, and dispatches on the first frame. +// +// Serve returns nil on a clean shutdown (one-shot DAG parse or task execution +// completed); a non-nil error indicates a protocol-level failure (connection +// loss, malformed frames, unknown first message type). +func Serve(provider bundlev1.BundleProvider, commAddr, logsAddr string) error { + if commAddr == "" { + return fmt.Errorf("missing --comm=host:port argument") + } + if logsAddr == "" { + return fmt.Errorf("missing --logs=host:port argument") + } + + // Buffer log records until the logs socket is connected. Anything the + // runtime emits between Connect-time and the first frame still gets + // flushed. + logHandler := NewSocketLogHandler(nil, slog.LevelDebug) + logger := slog.New(logHandler) + slog.SetDefault(logger) + + // Connect to both sockets concurrently so the supervisor can accept them + // in either order. + var commConn, logsConn net.Conn + var commErr, logsErr error + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + commConn, commErr = net.Dial("tcp", commAddr) + }() + go func() { + defer wg.Done() + logsConn, logsErr = net.Dial("tcp", logsAddr) + }() + wg.Wait() + + if commErr != nil { + return fmt.Errorf("connecting to comm socket %s: %w", commAddr, commErr) + } + defer commConn.Close() + if logsErr != nil { + return fmt.Errorf("connecting to logs socket %s: %w", logsAddr, logsErr) + } + defer logsConn.Close() + + logHandler.Connect(logsConn) + logger.Debug("Connected", "comm", commAddr, "logs", logsAddr) + + // Materialise the bundle (RegisterDags) up front. Both protocol paths + // need the registry, and doing it once before the first frame keeps the + // dispatcher simple. + bundle, err := materialiseBundle(provider) + if err != nil { + return fmt.Errorf("registering dags: %w", err) + } + + comm := NewCoordinatorComm(commConn, commConn, logger) + + frame, err := comm.ReadMessage() + if err != nil { + return fmt.Errorf("reading initial message: %w", err) + } + + if frame.Err != nil { + errResp := decodeErrorResponse(frame.Err) + if errResp != nil { + return fmt.Errorf( + "received error from supervisor: [%s] %v", + errResp.Error, + errResp.Detail, + ) + } + } + + body, err := decodeIncomingBody(frame.Body) + if err != nil { + return fmt.Errorf("decoding initial message: %w", err) + } + + switch msg := body.(type) { + case *DagFileParseRequest: + logger.Debug("DAG parsing mode", "file", msg.File) + result := ParseDags(bundle, msg) + if err := comm.SendRequest(frame.ID, result); err != nil { + return fmt.Errorf("sending parse result: %w", err) + } + logger.Debug("DAG parsing complete") + + case *StartupDetails: + logger.Debug("Task execution mode", + "dag_id", msg.TI.DagID, + "task_id", msg.TI.TaskID, + ) + result := RunTask(bundle, msg, comm, logger) + if err := comm.SendRequest(frame.ID, result); err != nil { + return fmt.Errorf("sending task result: %w", err) + } + logger.Debug("Task execution complete") + + default: + return fmt.Errorf("unexpected initial message type: %T", body) + } + + return nil +} + +func materialiseBundle(provider bundlev1.BundleProvider) (bundlev1.Bundle, error) { + reg := bundlev1.New() + if err := provider.RegisterDags(reg); err != nil { + return nil, err + } + return reg, nil +} diff --git a/go-sdk/pkg/execution/task_runner.go b/go-sdk/pkg/execution/task_runner.go new file mode 100644 index 0000000000000..cae7613aa67d9 --- /dev/null +++ b/go-sdk/pkg/execution/task_runner.go @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "context" + "fmt" + "log/slog" + "runtime/debug" + "time" + + "github.com/google/uuid" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/pkg/sdkcontext" + "github.com/apache/airflow/go-sdk/sdk" +) + +// RunTask executes a task based on StartupDetails received from the supervisor. +// +// It looks up the task in the bundle, creates a CoordinatorClient for SDK +// calls, executes the task, and returns a terminal message body +// (SucceedTaskMsg or TaskStateMsg) ready to ship as the final response frame. +// +// The supervisor owns the Execution-API state transitions in coordinator +// mode, so we deliberately bypass worker.ExecuteTaskWorkload (which drives +// Run / UpdateState itself) and only invoke the user's task function. +func RunTask( + bundle bundlev1.Bundle, + details *StartupDetails, + comm *CoordinatorComm, + logger *slog.Logger, +) map[string]any { + task, exists := bundle.LookupTask(details.TI.DagID, details.TI.TaskID) + if !exists { + logger.Error("Task not registered", + "dag_id", details.TI.DagID, + "task_id", details.TI.TaskID, + ) + return TaskStateMsg{State: "removed", EndDate: time.Now().UTC()}.toMap() + } + + client := NewCoordinatorClient(comm, details) + + // taskFunction.sendXcom reads the workload from context to get the task + // instance ids; populate it the same shape the gRPC path uses. + tiUUID, _ := uuid.Parse(details.TI.ID) + mapIndex := details.TI.MapIndex + workload := api.ExecuteTaskWorkload{ + TI: api.TaskInstance{ + Id: tiUUID, + DagId: details.TI.DagID, + RunId: details.TI.RunID, + TaskId: details.TI.TaskID, + TryNumber: details.TI.TryNumber, + MapIndex: &mapIndex, + }, + BundleInfo: api.BundleInfo{ + Name: details.BundleInfo.Name, + Version: &details.BundleInfo.Version, + }, + } + + ctx := context.Background() + ctx = context.WithValue(ctx, sdkcontext.WorkloadContextKey, workload) + ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey, sdk.Client(client)) + + return executeTask(ctx, task, logger) +} + +// executeTask runs the task and handles success, failure, and panics. +func executeTask( + ctx context.Context, + task bundlev1.Task, + logger *slog.Logger, +) (result map[string]any) { + defer func() { + if r := recover(); r != nil { + logger.Error("Recovered panic in task", + "error", r, + "stack", string(debug.Stack()), + ) + result = TaskStateMsg{ + State: "failed", + EndDate: time.Now().UTC(), + }.toMap() + } + }() + + if err := task.Execute(ctx, logger); err != nil { + logger.Error("Task failed", "error", fmt.Sprintf("%v", err)) + return TaskStateMsg{ + State: "failed", + EndDate: time.Now().UTC(), + }.toMap() + } + + return SucceedTaskMsg{ + EndDate: time.Now().UTC(), + }.toMap() +} diff --git a/go-sdk/pkg/sdkcontext/keys.go b/go-sdk/pkg/sdkcontext/keys.go index 0dbc2c6019487..ad83dfce3bd6e 100644 --- a/go-sdk/pkg/sdkcontext/keys.go +++ b/go-sdk/pkg/sdkcontext/keys.go @@ -22,6 +22,7 @@ type ( apiClientContextKey struct{} workerContextKey struct{} runtimeTIContextKey struct{} + sdkClientContextKey struct{} ) var ( @@ -32,4 +33,11 @@ var ( RuntimeTIContextKey = runtimeTIContextKey{} ApiClientContextKey = apiClientContextKey{} WorkerContextKey = workerContextKey{} + + // SdkClientContextKey, when present, holds an sdk.Client implementation + // that should be injected into task functions instead of constructing a + // default HTTP-backed client. The coordinator-mode runtime uses this to + // route task SDK calls (GetVariable, GetConnection, ...) over the + // supervisor comm socket rather than to the Execution API. + SdkClientContextKey = sdkClientContextKey{} ) diff --git a/go-sdk/sdk/client.go b/go-sdk/sdk/client.go index d2aa1f40351f3..3e10c913d1c3d 100644 --- a/go-sdk/sdk/client.go +++ b/go-sdk/sdk/client.go @@ -84,7 +84,7 @@ func (*client) GetConnection(ctx context.Context, connID string) (Connection, er return Connection{}, err } - return connFromAPIResponse(resp) + return ConnFromAPIResponse(resp) } func (c *client) PushXCom( diff --git a/go-sdk/sdk/connection.go b/go-sdk/sdk/connection.go index 35835c2d523d4..1d0bcf30b6f89 100644 --- a/go-sdk/sdk/connection.go +++ b/go-sdk/sdk/connection.go @@ -110,7 +110,11 @@ func (c Connection) GetURI() *url.URL { return uri } -func connFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) { +// ConnFromAPIResponse converts an Execution-API ConnectionResponse into the +// SDK's Connection type. It is exported so other internal SDK packages (for +// example, the coordinator-mode runtime in bundlev1server/impl/coord) can +// reuse the same conversion. +func ConnFromAPIResponse(resp *api.ConnectionResponse) (Connection, error) { var err error conn := Connection{ ID: resp.ConnId, From 2f81ba69cf1ba8e26da595bb7f276ac07a18fe82 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 11 May 2026 17:46:30 +0800 Subject: [PATCH 2/5] Enhance task and DAG registration with optional specifications for improved configuration flexibility --- go-sdk/bundle/bundlev1/registry.go | 122 ++++++++++++++-- go-sdk/bundle/bundlev1/registry_test.go | 63 ++++++++ go-sdk/example/bundle/main.go | 12 +- go-sdk/pkg/execution/serde.go | 167 +++++++++++++++++++-- go-sdk/pkg/execution/serde_test.go | 184 +++++++++++++++++++++++- 5 files changed, 517 insertions(+), 31 deletions(-) diff --git a/go-sdk/bundle/bundlev1/registry.go b/go-sdk/bundle/bundlev1/registry.go index c7be26a47592b..2bb15b0aa5a71 100644 --- a/go-sdk/bundle/bundlev1/registry.go +++ b/go-sdk/bundle/bundlev1/registry.go @@ -23,6 +23,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/apache/airflow/go-sdk/pkg/worker" ) @@ -32,15 +33,78 @@ type ( Bundle = worker.Bundle Dag interface { - AddTask(fn any) - AddTaskWithName(taskId string, fn any) + AddTask(fn any, spec ...TaskSpec) + AddTaskWithName(taskId string, fn any, spec ...TaskSpec) } // Registry defines the interface that lets user code add dags and tasks, and extends Bundle for execution // time Registry interface { Bundle - AddDag(dagId string) Dag + AddDag(dagId string, spec ...DagSpec) Dag + } + + // TaskSpec is the optional configuration applied to a task at registration + // time. Every field is optional: a zero value means "unset" and the + // scheduler falls back to its serialization-schema default. The field + // names mirror the keys defined under "operator" in + // airflow-core/src/airflow/serialization/schema.json. + TaskSpec struct { + Queue string + Pool string + PoolSlots int + Retries int + RetryDelay time.Duration + MaxRetryDelay time.Duration + RetryExponentialBackoff float64 + PriorityWeight int + WeightRule string + TriggerRule string + Owner string + ExecutionTimeout time.Duration + Executor string + StartDate time.Time + EndDate time.Time + DependsOnPast bool + WaitForDownstream bool + // DoXComPush, EmailOnFailure, and EmailOnRetry default to true in the + // scheduler. A nil pointer means "unset" so the field is omitted from + // the serialized payload; pass Bool(false) to explicitly opt out. + DoXComPush *bool + EmailOnFailure *bool + EmailOnRetry *bool + DocMD string + MapIndexTemplate string + MaxActiveTisPerDag int + MaxActiveTisPerDagrun int + } + + // DagSpec is the optional configuration applied to a DAG at registration + // time. Every field is optional: a zero value means "unset" and the + // scheduler falls back to its serialization-schema default. The field + // names mirror the keys defined under "dag" in + // airflow-core/src/airflow/serialization/schema.json. + DagSpec struct { + // Schedule is "@once", "@continuous", a cron expression, or "" for + // NullTimetable (no schedule). + Schedule string + Description string + StartDate time.Time + EndDate time.Time + Tags []string + DagDisplayName string + DocMD string + MaxActiveTasks int + MaxActiveRuns int + MaxConsecutiveFailedDagRuns int + DagrunTimeout time.Duration + Catchup bool + FailFast bool + RenderTemplateAsNativeObj bool + DisableBundleVersioning bool + // IsPausedUponCreation has no schema default. nil means "unset"; pass + // Bool(true) or Bool(false) to set it explicitly. + IsPausedUponCreation *bool } // TaskInfo describes a registered task. Coordinator-mode DAG parsing uses @@ -53,12 +117,18 @@ type ( TypeName string // PkgPath is the Go package path (e.g. "main", "github.com/x/y"). PkgPath string + // Spec carries the optional per-task configuration supplied at + // registration. The zero value means "no overrides". + Spec TaskSpec } // DagInfo describes a registered dag together with its tasks in // registration order. DagInfo struct { DagID string + // Spec carries the optional per-dag configuration supplied at + // registration. The zero value means "no overrides". + Spec DagSpec Tasks []TaskInfo } @@ -73,6 +143,7 @@ type ( sync.RWMutex taskFuncMap map[string]map[string]Task taskInfo map[string]map[string]TaskInfo + dagSpec map[string]DagSpec dagOrder []string taskOrder map[string][]string } @@ -83,12 +154,32 @@ type dagShim struct { registry *registry } -func (d dagShim) AddTask(fn any) { - d.registry.registerTask(d.dagId, fn) +func (d dagShim) AddTask(fn any, spec ...TaskSpec) { + d.registry.registerTask(d.dagId, fn, optionalSpec(spec, "AddTask")) } -func (d dagShim) AddTaskWithName(taskId string, fn any) { - d.registry.registerTaskWithName(d.dagId, taskId, fn) +func (d dagShim) AddTaskWithName(taskId string, fn any, spec ...TaskSpec) { + d.registry.registerTaskWithName(d.dagId, taskId, fn, optionalSpec(spec, "AddTaskWithName")) +} + +// Bool returns a pointer to b. Use it for the *bool fields on TaskSpec / +// DagSpec where nil means "leave at schema default": +// +// v1.TaskSpec{DoXComPush: v1.Bool(false)} +func Bool(b bool) *bool { + return &b +} + +func optionalSpec[T any](specs []T, caller string) T { + switch len(specs) { + case 0: + var zero T + return zero + case 1: + return specs[0] + default: + panic(fmt.Errorf("%s accepts at most one spec, got %d", caller, len(specs))) + } } // Function New creates a new bundle on which Dag and Tasks can be registered @@ -96,6 +187,7 @@ func New() Registry { return ®istry{ taskFuncMap: make(map[string]map[string]Task), taskInfo: make(map[string]map[string]TaskInfo), + dagSpec: make(map[string]DagSpec), taskOrder: make(map[string][]string), } } @@ -116,7 +208,8 @@ func getFnName(fn reflect.Value) string { return name } -func (r *registry) AddDag(dagId string) Dag { +func (r *registry) AddDag(dagId string, spec ...DagSpec) Dag { + dagSpec := optionalSpec(spec, "AddDag") r.RWMutex.Lock() defer r.RWMutex.Unlock() if _, exists := r.taskFuncMap[dagId]; exists { @@ -124,11 +217,12 @@ func (r *registry) AddDag(dagId string) Dag { } r.taskFuncMap[dagId] = make(map[string]Task) r.taskInfo[dagId] = make(map[string]TaskInfo) + r.dagSpec[dagId] = dagSpec r.dagOrder = append(r.dagOrder, dagId) return dagShim{dagId, r} } -func (r *registry) registerTask(dagId string, fn any) { +func (r *registry) registerTask(dagId string, fn any, spec TaskSpec) { val := reflect.ValueOf(fn) if val.Kind() != reflect.Func { @@ -137,10 +231,10 @@ func (r *registry) registerTask(dagId string, fn any) { fnName := getFnName(val) - r.registerTaskWithName(dagId, fnName, fn) + r.registerTaskWithName(dagId, fnName, fn, spec) } -func (r *registry) registerTaskWithName(dagId, taskId string, fn any) { +func (r *registry) registerTaskWithName(dagId, taskId string, fn any, spec TaskSpec) { task, err := NewTaskFunction(fn) if err != nil { panic(fmt.Errorf("error registering task %q for DAG %q: %w", taskId, dagId, err)) @@ -150,6 +244,8 @@ func (r *registry) registerTaskWithName(dagId, taskId string, fn any) { fullName := runtime.FuncForPC(val.Pointer()).Name() typeName, pkgPath := splitFullName(fullName) + info := TaskInfo{ID: taskId, TypeName: typeName, PkgPath: pkgPath, Spec: spec} + r.RWMutex.Lock() defer r.RWMutex.Unlock() @@ -166,7 +262,7 @@ func (r *registry) registerTaskWithName(dagId, taskId string, fn any) { } dagTasks[taskId] = task - r.taskInfo[dagId][taskId] = TaskInfo{ID: taskId, TypeName: typeName, PkgPath: pkgPath} + r.taskInfo[dagId][taskId] = info r.taskOrder[dagId] = append(r.taskOrder[dagId], taskId) } @@ -196,7 +292,7 @@ func (r *registry) OrderedDags() []DagInfo { for _, tid := range taskIDs { tasks = append(tasks, r.taskInfo[dagID][tid]) } - out = append(out, DagInfo{DagID: dagID, Tasks: tasks}) + out = append(out, DagInfo{DagID: dagID, Spec: r.dagSpec[dagID], Tasks: tasks}) } return out } diff --git a/go-sdk/bundle/bundlev1/registry_test.go b/go-sdk/bundle/bundlev1/registry_test.go index 25105cd145598..d8c78d638c407 100644 --- a/go-sdk/bundle/bundlev1/registry_test.go +++ b/go-sdk/bundle/bundlev1/registry_test.go @@ -118,3 +118,66 @@ func (s *RegistrySuite) TestAddTask_ErrorReturnType() { _, exists := s.reg.LookupTask("dag1", "errorTask") s.True(exists) } + +func (s *RegistrySuite) TestAddTask_WithSpec() { + s.dag.AddTask(myTask, TaskSpec{Queue: "high_mem", Retries: 3, DoXComPush: Bool(false)}) + enum, ok := s.reg.(EnumerableBundle) + s.Require().True(ok) + dags := enum.OrderedDags() + s.Require().Len(dags, 1) + s.Require().Len(dags[0].Tasks, 1) + got := dags[0].Tasks[0] + s.Equal("myTask", got.ID) + s.Equal("high_mem", got.Spec.Queue) + s.Equal(3, got.Spec.Retries) + s.Require().NotNil(got.Spec.DoXComPush) + s.False(*got.Spec.DoXComPush) +} + +func (s *RegistrySuite) TestAddTaskWithName_WithSpec() { + s.dag.AddTaskWithName("special", myTask, TaskSpec{Queue: "gpu", Pool: "gpu_pool"}) + enum, ok := s.reg.(EnumerableBundle) + s.Require().True(ok) + dags := enum.OrderedDags() + s.Require().Len(dags, 1) + s.Require().Len(dags[0].Tasks, 1) + got := dags[0].Tasks[0] + s.Equal("special", got.ID) + s.Equal("gpu", got.Spec.Queue) + s.Equal("gpu_pool", got.Spec.Pool) +} + +func (s *RegistrySuite) TestAddTask_TooManySpecsPanics() { + s.PanicsWithError("AddTask accepts at most one spec, got 2", func() { + s.dag.AddTask(myTask, TaskSpec{}, TaskSpec{}) + }) +} + +func (s *RegistrySuite) TestAddDag_WithSpec() { + dag2 := s.reg.AddDag( + "dag2", + DagSpec{Schedule: "@daily", Tags: []string{"etl"}, MaxActiveRuns: 4}, + ) + s.NotNil(dag2) + enum, ok := s.reg.(EnumerableBundle) + s.Require().True(ok) + dags := enum.OrderedDags() + s.Require().Len(dags, 2) + var got DagInfo + for _, d := range dags { + if d.DagID == "dag2" { + got = d + break + } + } + s.Equal("dag2", got.DagID) + s.Equal("@daily", got.Spec.Schedule) + s.Equal([]string{"etl"}, got.Spec.Tags) + s.Equal(4, got.Spec.MaxActiveRuns) +} + +func (s *RegistrySuite) TestAddDag_TooManySpecsPanics() { + s.PanicsWithError("AddDag accepts at most one spec, got 2", func() { + s.reg.AddDag("dag3", DagSpec{}, DagSpec{}) + }) +} diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go index 5e970da54951b..2692b9b384d44 100644 --- a/go-sdk/example/bundle/main.go +++ b/go-sdk/example/bundle/main.go @@ -45,10 +45,14 @@ func (m *myBundle) GetBundleVersion() v1.BundleInfo { } func (m *myBundle) RegisterDags(dagbag v1.Registry) error { - simpleDag := dagbag.AddDag("simple_dag") - simpleDag.AddTask(extract) - simpleDag.AddTask(transform) - simpleDag.AddTask(load) + simpleDag := dagbag.AddDag("simple_dag", v1.DagSpec{ + Schedule: "@daily", + Description: "Example Go-authored Dag", + Tags: []string{"example", "go-sdk"}, + }) + simpleDag.AddTask(extract, v1.TaskSpec{Queue: "go-task", Retries: 2}) + simpleDag.AddTask(transform, v1.TaskSpec{Queue: "go-task"}) + simpleDag.AddTask(load, v1.TaskSpec{Queue: "go-task"}) return nil } diff --git a/go-sdk/pkg/execution/serde.go b/go-sdk/pkg/execution/serde.go index 02ec641c89dcb..5b42b9b7a352c 100644 --- a/go-sdk/pkg/execution/serde.go +++ b/go-sdk/pkg/execution/serde.go @@ -162,19 +162,22 @@ func serializeTimetable(schedule *string) map[string]any { } // serializeTask converts a task to the Airflow serialization format. -func serializeTask(taskID, typeName, pkgPath string, downstream []string) map[string]any { +func serializeTask(info bundlev1.TaskInfo, downstream []string) map[string]any { + typeName := info.TypeName if typeName == "" { - typeName = taskID + typeName = info.ID } + pkgPath := info.PkgPath if pkgPath == "" { pkgPath = "main" } data := map[string]any{ - "task_id": taskID, + "task_id": info.ID, "task_type": typeName, "_task_module": pkgPath, "language": "go", } + applyTaskSpec(data, info.Spec) if len(downstream) > 0 { sorted := make([]string, len(downstream)) copy(sorted, downstream) @@ -187,6 +190,142 @@ func serializeTask(taskID, typeName, pkgPath string, downstream []string) map[st } } +// applyTaskSpec mirrors Python BaseSerialization's "omit hard-coded default" +// behavior: each TaskSpec field is written into data only when it differs +// from the schema default declared in +// airflow-core/src/airflow/serialization/schema.json. A zero-valued field is +// always considered "unset" and is skipped. +func applyTaskSpec(data map[string]any, s bundlev1.TaskSpec) { + if s.Queue != "" && s.Queue != "default" { + data["queue"] = s.Queue + } + if s.Pool != "" && s.Pool != "default_pool" { + data["pool"] = s.Pool + } + if s.PoolSlots != 0 && s.PoolSlots != 1 { + data["pool_slots"] = s.PoolSlots + } + if s.Retries != 0 { + data["retries"] = s.Retries + } + if s.RetryDelay != 0 && s.RetryDelay != 300*time.Second { + data["retry_delay"] = unwrapTypeEncoding(serializeValue(s.RetryDelay)) + } + if s.MaxRetryDelay != 0 { + data["max_retry_delay"] = unwrapTypeEncoding(serializeValue(s.MaxRetryDelay)) + } + if s.RetryExponentialBackoff != 0 { + data["retry_exponential_backoff"] = s.RetryExponentialBackoff + } + if s.PriorityWeight != 0 && s.PriorityWeight != 1 { + data["priority_weight"] = s.PriorityWeight + } + if s.WeightRule != "" && s.WeightRule != "downstream" { + data["weight_rule"] = s.WeightRule + } + if s.TriggerRule != "" && s.TriggerRule != "all_success" { + data["trigger_rule"] = s.TriggerRule + } + if s.Owner != "" && s.Owner != "airflow" { + data["owner"] = s.Owner + } + if s.ExecutionTimeout != 0 { + data["execution_timeout"] = unwrapTypeEncoding(serializeValue(s.ExecutionTimeout)) + } + if s.Executor != "" { + data["executor"] = s.Executor + } + if !s.StartDate.IsZero() { + data["start_date"] = unwrapTypeEncoding(serializeValue(s.StartDate)) + } + if !s.EndDate.IsZero() { + data["end_date"] = unwrapTypeEncoding(serializeValue(s.EndDate)) + } + if s.DependsOnPast { + data["depends_on_past"] = true + } + if s.WaitForDownstream { + data["wait_for_downstream"] = true + } + // do_xcom_push / email_on_failure / email_on_retry default to true; only + // emit when an explicit false overrides the default. + if s.DoXComPush != nil && !*s.DoXComPush { + data["do_xcom_push"] = false + } + if s.EmailOnFailure != nil && !*s.EmailOnFailure { + data["email_on_failure"] = false + } + if s.EmailOnRetry != nil && !*s.EmailOnRetry { + data["email_on_retry"] = false + } + if s.DocMD != "" { + data["doc_md"] = s.DocMD + } + if s.MapIndexTemplate != "" { + data["map_index_template"] = s.MapIndexTemplate + } + if s.MaxActiveTisPerDag != 0 { + data["max_active_tis_per_dag"] = s.MaxActiveTisPerDag + } + if s.MaxActiveTisPerDagrun != 0 { + data["max_active_tis_per_dagrun"] = s.MaxActiveTisPerDagrun + } +} + +// applyDagSpec writes optional DAG-level fields onto data, omitting any +// field equal to its schema default. See applyTaskSpec for the convention. +func applyDagSpec(data map[string]any, s bundlev1.DagSpec) { + if s.Description != "" { + data["description"] = s.Description + } + if !s.StartDate.IsZero() { + data["start_date"] = unwrapTypeEncoding(serializeValue(s.StartDate)) + } + if !s.EndDate.IsZero() { + data["end_date"] = unwrapTypeEncoding(serializeValue(s.EndDate)) + } + if len(s.Tags) > 0 { + tags := make([]any, len(s.Tags)) + for i, t := range s.Tags { + tags[i] = t + } + data["tags"] = tags + } + if s.DagDisplayName != "" { + data["dag_display_name"] = s.DagDisplayName + } + if s.DocMD != "" { + data["doc_md"] = s.DocMD + } + if s.MaxActiveTasks != 0 && s.MaxActiveTasks != 16 { + data["max_active_tasks"] = s.MaxActiveTasks + } + if s.MaxActiveRuns != 0 && s.MaxActiveRuns != 16 { + data["max_active_runs"] = s.MaxActiveRuns + } + if s.MaxConsecutiveFailedDagRuns != 0 { + data["max_consecutive_failed_dag_runs"] = s.MaxConsecutiveFailedDagRuns + } + if s.DagrunTimeout != 0 { + data["dagrun_timeout"] = unwrapTypeEncoding(serializeValue(s.DagrunTimeout)) + } + if s.Catchup { + data["catchup"] = true + } + if s.FailFast { + data["fail_fast"] = true + } + if s.RenderTemplateAsNativeObj { + data["render_template_as_native_obj"] = true + } + if s.DisableBundleVersioning { + data["disable_bundle_versioning"] = true + } + if s.IsPausedUponCreation != nil { + data["is_paused_upon_creation"] = *s.IsPausedUponCreation + } +} + // serializeTaskGroup creates a flat root task group containing all task IDs. func serializeTaskGroup(taskIDs []string) map[string]any { children := make(map[string]any, len(taskIDs)) @@ -230,26 +369,30 @@ func serializeParams(params map[string]any) []any { } // SerializeDag converts a bundlev1.DagInfo to Airflow DagSerialization v3 -// format. The Go SDK's bundlev1.Dag interface does not (yet) carry per-DAG -// metadata like schedule, start_date, tags, etc., so the encoding emits -// schema defaults for those fields. The optional-field handling below is -// kept (gated on nil checks) so the encoder can grow naturally as the -// bundle metadata surface expands. +// format. Required fields are always present; optional fields from +// info.Spec are emitted only when they differ from their schema default +// (see applyDagSpec). func SerializeDag(info bundlev1.DagInfo, fileloc, relativeFileloc string) map[string]any { taskIDs := make([]string, len(info.Tasks)) tasks := make([]any, len(info.Tasks)) for i, t := range info.Tasks { taskIDs[i] = t.ID - tasks[i] = serializeTask(t.ID, t.TypeName, t.PkgPath, nil) + tasks[i] = serializeTask(t, nil) } - return map[string]any{ + var schedule *string + if info.Spec.Schedule != "" { + s := info.Spec.Schedule + schedule = &s + } + + result := map[string]any{ // Required fields (always present) "dag_id": info.DagID, "fileloc": fileloc, "relative_fileloc": relativeFileloc, "timezone": "UTC", - "timetable": serializeTimetable(nil), + "timetable": serializeTimetable(schedule), "tasks": tasks, "dag_dependencies": []any{}, "task_group": serializeTaskGroup(taskIDs), @@ -258,6 +401,8 @@ func SerializeDag(info bundlev1.DagInfo, fileloc, relativeFileloc string) map[st "deadline": nil, "allowed_run_types": nil, } + applyDagSpec(result, info.Spec) + return result } // computeRelativeFileloc computes the relative file location from the bundle path. diff --git a/go-sdk/pkg/execution/serde_test.go b/go-sdk/pkg/execution/serde_test.go index afddaf5e700b1..6121d24e285e8 100644 --- a/go-sdk/pkg/execution/serde_test.go +++ b/go-sdk/pkg/execution/serde_test.go @@ -120,7 +120,10 @@ func TestSerializeTimetable(t *testing.T) { } func TestSerializeTask(t *testing.T) { - result := serializeTask("extract", "extract", "main", []string{"transform"}) + result := serializeTask( + bundlev1.TaskInfo{ID: "extract", TypeName: "extract", PkgPath: "main"}, + []string{"transform"}, + ) assert.Equal(t, "operator", result["__type"]) data := result["__var"].(map[string]any) assert.Equal(t, "extract", data["task_id"]) @@ -128,15 +131,124 @@ func TestSerializeTask(t *testing.T) { assert.Equal(t, "main", data["_task_module"]) assert.Equal(t, "go", data["language"]) assert.Equal(t, []string{"transform"}, data["downstream_task_ids"]) + _, hasQueue := data["queue"] + assert.False(t, hasQueue, "queue should be omitted when unset") } func TestSerializeTaskNoDownstream(t *testing.T) { - result := serializeTask("load", "load", "main", nil) + result := serializeTask( + bundlev1.TaskInfo{ID: "load", TypeName: "load", PkgPath: "main"}, + nil, + ) data := result["__var"].(map[string]any) _, hasDownstream := data["downstream_task_ids"] assert.False(t, hasDownstream) } +func TestSerializeTaskCustomQueue(t *testing.T) { + result := serializeTask( + bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Spec: bundlev1.TaskSpec{Queue: "high_mem"}, + }, + nil, + ) + data := result["__var"].(map[string]any) + assert.Equal(t, "high_mem", data["queue"]) +} + +func TestSerializeTaskDefaultQueueOmitted(t *testing.T) { + result := serializeTask( + bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Spec: bundlev1.TaskSpec{Queue: "default"}, + }, + nil, + ) + data := result["__var"].(map[string]any) + _, hasQueue := data["queue"] + assert.False(t, hasQueue, "queue=\"default\" matches the schema default and should be omitted") +} + +func TestApplyTaskSpec_EmitsAndOmits(t *testing.T) { + spec := bundlev1.TaskSpec{ + Queue: "gpu", + Pool: "gpu_pool", + PoolSlots: 4, + Retries: 3, + RetryDelay: 60 * time.Second, + MaxRetryDelay: 10 * time.Minute, + RetryExponentialBackoff: 2.0, + PriorityWeight: 5, + WeightRule: "upstream", + TriggerRule: "all_done", + Owner: "data-eng", + ExecutionTimeout: 45 * time.Second, + Executor: "KubernetesExecutor", + DependsOnPast: true, + WaitForDownstream: true, + DoXComPush: bundlev1.Bool(false), + EmailOnFailure: bundlev1.Bool(false), + EmailOnRetry: bundlev1.Bool(false), + DocMD: "## task", + MapIndexTemplate: "{{ task.task_id }}", + MaxActiveTisPerDag: 2, + MaxActiveTisPerDagrun: 1, + } + data := map[string]any{} + applyTaskSpec(data, spec) + + assert.Equal(t, "gpu", data["queue"]) + assert.Equal(t, "gpu_pool", data["pool"]) + assert.Equal(t, 4, data["pool_slots"]) + assert.Equal(t, 3, data["retries"]) + assert.Equal(t, 60.0, data["retry_delay"]) + assert.Equal(t, 600.0, data["max_retry_delay"]) + assert.Equal(t, 2.0, data["retry_exponential_backoff"]) + assert.Equal(t, 5, data["priority_weight"]) + assert.Equal(t, "upstream", data["weight_rule"]) + assert.Equal(t, "all_done", data["trigger_rule"]) + assert.Equal(t, "data-eng", data["owner"]) + assert.Equal(t, 45.0, data["execution_timeout"]) + assert.Equal(t, "KubernetesExecutor", data["executor"]) + assert.Equal(t, true, data["depends_on_past"]) + assert.Equal(t, true, data["wait_for_downstream"]) + assert.Equal(t, false, data["do_xcom_push"]) + assert.Equal(t, false, data["email_on_failure"]) + assert.Equal(t, false, data["email_on_retry"]) + assert.Equal(t, "## task", data["doc_md"]) + assert.Equal(t, "{{ task.task_id }}", data["map_index_template"]) + assert.Equal(t, 2, data["max_active_tis_per_dag"]) + assert.Equal(t, 1, data["max_active_tis_per_dagrun"]) +} + +func TestApplyTaskSpec_OmitsSchemaDefaults(t *testing.T) { + // Values equal to schema defaults must be dropped. + spec := bundlev1.TaskSpec{ + Queue: "default", + Pool: "default_pool", + PoolSlots: 1, + Retries: 0, + RetryDelay: 300 * time.Second, + PriorityWeight: 1, + WeightRule: "downstream", + TriggerRule: "all_success", + Owner: "airflow", + DoXComPush: bundlev1.Bool(true), + EmailOnFailure: bundlev1.Bool(true), + EmailOnRetry: bundlev1.Bool(true), + } + data := map[string]any{} + applyTaskSpec(data, spec) + assert.Empty(t, data, "all fields equal schema defaults; nothing should be emitted") +} + +func TestApplyTaskSpec_EmptySpecNoOp(t *testing.T) { + data := map[string]any{} + applyTaskSpec(data, bundlev1.TaskSpec{}) + assert.Empty(t, data) +} + func TestSerializeTaskGroup(t *testing.T) { result := serializeTaskGroup([]string{"t1", "t2"}) assert.Nil(t, result["_group_id"]) @@ -188,7 +300,10 @@ func TestSerializeDagWithTasks(t *testing.T) { DagID: "etl", Tasks: []bundlev1.TaskInfo{ {ID: "extract", TypeName: "extract", PkgPath: "main"}, - {ID: "load", TypeName: "load", PkgPath: "main"}, + { + ID: "load", TypeName: "load", PkgPath: "main", + Spec: bundlev1.TaskSpec{Queue: "high_mem"}, + }, }, } result := SerializeDag(info, "/bundle/main.go", "main.go") @@ -201,6 +316,11 @@ func TestSerializeDagWithTasks(t *testing.T) { assert.Equal(t, "extract", v["task_type"]) assert.Equal(t, "main", v["_task_module"]) assert.Equal(t, "go", v["language"]) + _, hasQueue := v["queue"] + assert.False(t, hasQueue, "extract has no queue set; field should be omitted") + + second := tasks[1].(map[string]any)["__var"].(map[string]any) + assert.Equal(t, "high_mem", second["queue"]) tg := result["task_group"].(map[string]any) children := tg["children"].(map[string]any) @@ -208,6 +328,64 @@ func TestSerializeDagWithTasks(t *testing.T) { assert.Contains(t, children, "load") } +func TestSerializeDagWithSpec(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + info := bundlev1.DagInfo{ + DagID: "etl", + Spec: bundlev1.DagSpec{ + Schedule: "@daily", + Description: "Extract, transform, load", + StartDate: start, + Tags: []string{"prod", "etl"}, + DagDisplayName: "ETL Pipeline", + DocMD: "## ETL", + MaxActiveTasks: 32, + MaxActiveRuns: 4, + MaxConsecutiveFailedDagRuns: 3, + DagrunTimeout: 2 * time.Hour, + Catchup: true, + FailFast: true, + RenderTemplateAsNativeObj: true, + DisableBundleVersioning: true, + IsPausedUponCreation: bundlev1.Bool(true), + }, + } + result := SerializeDag(info, "/bundle/main.go", "main.go") + + tt := result["timetable"].(map[string]any) + assert.Equal(t, "airflow.timetables.trigger.CronTriggerTimetable", tt["__type"]) + v := tt["__var"].(map[string]any) + assert.Equal(t, "@daily", v["expression"]) + + assert.Equal(t, "Extract, transform, load", result["description"]) + assert.Equal(t, []any{"prod", "etl"}, result["tags"]) + assert.Equal(t, "ETL Pipeline", result["dag_display_name"]) + assert.Equal(t, "## ETL", result["doc_md"]) + assert.Equal(t, 32, result["max_active_tasks"]) + assert.Equal(t, 4, result["max_active_runs"]) + assert.Equal(t, 3, result["max_consecutive_failed_dag_runs"]) + assert.Equal(t, (2 * time.Hour).Seconds(), result["dagrun_timeout"]) + assert.Equal(t, true, result["catchup"]) + assert.Equal(t, true, result["fail_fast"]) + assert.Equal(t, true, result["render_template_as_native_obj"]) + assert.Equal(t, true, result["disable_bundle_versioning"]) + assert.Equal(t, true, result["is_paused_upon_creation"]) + + // start_date is a raw epoch number, not the type-wrapped form. + startDate := result["start_date"].(float64) + assert.InDelta(t, float64(start.Unix()), startDate, 0.001) +} + +func TestApplyDagSpec_OmitsSchemaDefaults(t *testing.T) { + spec := bundlev1.DagSpec{ + MaxActiveTasks: 16, + MaxActiveRuns: 16, + } + data := map[string]any{} + applyDagSpec(data, spec) + assert.Empty(t, data, "values equal to schema defaults must be omitted") +} + func TestComputeRelativeFileloc(t *testing.T) { tests := []struct { fileloc string From c3a2584de17bb9700021cb953cb33eb4bd73a470 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 12 May 2026 10:38:10 +0800 Subject: [PATCH 3/5] Support setting downstream at AddTask method --- go-sdk/bundle/bundlev1/registry.go | 56 +++++++++++++--- go-sdk/bundle/bundlev1/registry_test.go | 81 ++++++++++++++++++++---- go-sdk/example/bundle/main.go | 6 +- go-sdk/pkg/execution/integration_test.go | 18 +++--- go-sdk/pkg/execution/serde.go | 14 ++-- go-sdk/pkg/execution/serde_test.go | 52 ++++++++------- go-sdk/pkg/worker/runner_test.go | 4 +- 7 files changed, 166 insertions(+), 65 deletions(-) diff --git a/go-sdk/bundle/bundlev1/registry.go b/go-sdk/bundle/bundlev1/registry.go index 2bb15b0aa5a71..9e302c679332b 100644 --- a/go-sdk/bundle/bundlev1/registry.go +++ b/go-sdk/bundle/bundlev1/registry.go @@ -33,8 +33,14 @@ type ( Bundle = worker.Bundle Dag interface { - AddTask(fn any, spec ...TaskSpec) - AddTaskWithName(taskId string, fn any, spec ...TaskSpec) + // AddTask registers fn as a task in this Dag using fn's Go name as + // the task id. spec carries optional per-task configuration (pass + // TaskSpec{} for defaults). depends lists task ids in the same Dag + // that must run before this one; each must already be registered. + // Pass nil for no dependencies. + AddTask(fn any, spec TaskSpec, depends []string) + // AddTaskWithName is AddTask with an explicit task id. + AddTaskWithName(taskId string, fn any, spec TaskSpec, depends []string) } // Registry defines the interface that lets user code add dags and tasks, and extends Bundle for execution @@ -120,6 +126,10 @@ type ( // Spec carries the optional per-task configuration supplied at // registration. The zero value means "no overrides". Spec TaskSpec + // Downstream lists task ids that depend on this task, populated as + // later tasks declare this id in their AddTask `depends` argument. + // Order is registration order; the serializer sorts before emit. + Downstream []string } // DagInfo describes a registered dag together with its tasks in @@ -154,12 +164,12 @@ type dagShim struct { registry *registry } -func (d dagShim) AddTask(fn any, spec ...TaskSpec) { - d.registry.registerTask(d.dagId, fn, optionalSpec(spec, "AddTask")) +func (d dagShim) AddTask(fn any, spec TaskSpec, depends []string) { + d.registry.registerTask(d.dagId, fn, spec, depends) } -func (d dagShim) AddTaskWithName(taskId string, fn any, spec ...TaskSpec) { - d.registry.registerTaskWithName(d.dagId, taskId, fn, optionalSpec(spec, "AddTaskWithName")) +func (d dagShim) AddTaskWithName(taskId string, fn any, spec TaskSpec, depends []string) { + d.registry.registerTaskWithName(d.dagId, taskId, fn, spec, depends) } // Bool returns a pointer to b. Use it for the *bool fields on TaskSpec / @@ -222,7 +232,7 @@ func (r *registry) AddDag(dagId string, spec ...DagSpec) Dag { return dagShim{dagId, r} } -func (r *registry) registerTask(dagId string, fn any, spec TaskSpec) { +func (r *registry) registerTask(dagId string, fn any, spec TaskSpec, depends []string) { val := reflect.ValueOf(fn) if val.Kind() != reflect.Func { @@ -231,10 +241,15 @@ func (r *registry) registerTask(dagId string, fn any, spec TaskSpec) { fnName := getFnName(val) - r.registerTaskWithName(dagId, fnName, fn, spec) + r.registerTaskWithName(dagId, fnName, fn, spec, depends) } -func (r *registry) registerTaskWithName(dagId, taskId string, fn any, spec TaskSpec) { +func (r *registry) registerTaskWithName( + dagId, taskId string, + fn any, + spec TaskSpec, + depends []string, +) { task, err := NewTaskFunction(fn) if err != nil { panic(fmt.Errorf("error registering task %q for DAG %q: %w", taskId, dagId, err)) @@ -261,6 +276,29 @@ func (r *registry) registerTaskWithName(dagId, taskId string, fn any, spec TaskS panic(fmt.Errorf("taskId %q is already registered for DAG %q", taskId, dagId)) } + // Resolve depends to upstream TaskInfo entries, validating each exists. + // We dedupe so a repeated id in `depends` only records one downstream + // edge on the parent. + seen := make(map[string]bool, len(depends)) + for _, dep := range depends { + if dep == taskId { + panic(fmt.Errorf("task %q cannot depend on itself in DAG %q", taskId, dagId)) + } + if seen[dep] { + continue + } + seen[dep] = true + parent, ok := r.taskInfo[dagId][dep] + if !ok { + panic(fmt.Errorf( + "task %q depends on unknown task %q in DAG %q; register upstream tasks first", + taskId, dep, dagId, + )) + } + parent.Downstream = append(parent.Downstream, taskId) + r.taskInfo[dagId][dep] = parent + } + dagTasks[taskId] = task r.taskInfo[dagId][taskId] = info r.taskOrder[dagId] = append(r.taskOrder[dagId], taskId) diff --git a/go-sdk/bundle/bundlev1/registry_test.go b/go-sdk/bundle/bundlev1/registry_test.go index d8c78d638c407..31f4ce2b7326e 100644 --- a/go-sdk/bundle/bundlev1/registry_test.go +++ b/go-sdk/bundle/bundlev1/registry_test.go @@ -67,14 +67,14 @@ func (s *RegistrySuite) TestAddDag_DuplicatePanics() { } func (s *RegistrySuite) TestAddTask_RegistersAndFindsTask() { - s.dag.AddTask(myTask) + s.dag.AddTask(myTask, TaskSpec{}, nil) task, exists := s.reg.LookupTask("dag1", "myTask") s.True(exists) s.NotNil(task) } func (s *RegistrySuite) TestAddTaskWithName_RegistersAndFindsTask() { - s.dag.AddTaskWithName("special", myTask) + s.dag.AddTaskWithName("special", myTask, TaskSpec{}, nil) task, exists := s.reg.LookupTask("dag1", "special") s.True(exists) s.NotNil(task) @@ -85,20 +85,20 @@ func (s *RegistrySuite) TestAddTaskWithName_RegistersAndFindsTask() { } func (s *RegistrySuite) TestRegisterTaskWithName_DuplicatePanics() { - s.dag.AddTaskWithName("special", myTask) + s.dag.AddTaskWithName("special", myTask, TaskSpec{}, nil) s.PanicsWithError("taskId \"special\" is already registered for DAG \"dag1\"", func() { - s.dag.AddTaskWithName("special", myTask) + s.dag.AddTaskWithName("special", myTask, TaskSpec{}, nil) }) } func (s *RegistrySuite) TestAddTask_NonFuncPanics() { s.PanicsWithError("task fn was a string, not a func", func() { - s.dag.AddTask("not a func") + s.dag.AddTask("not a func", TaskSpec{}, nil) }) } func (s *RegistrySuite) TestAddTaskWithArgs_BindsCorrectArgs() { - s.dag.AddTask(myTaskWithArgs) + s.dag.AddTask(myTaskWithArgs, TaskSpec{}, nil) task, exists := s.reg.LookupTask("dag1", "myTaskWithArgs") s.True(exists) s.NotNil(task) @@ -108,19 +108,19 @@ func (s *RegistrySuite) TestAddTask_InvalidReturnType() { s.PanicsWithError( "error registering task \"NotErrorRet\" for DAG \"dag1\": expected task function github.com/apache/airflow/go-sdk/bundle/bundlev1.NotErrorRet last return value to return error but found int", func() { - s.dag.AddTask(NotErrorRet) + s.dag.AddTask(NotErrorRet, TaskSpec{}, nil) }, ) } func (s *RegistrySuite) TestAddTask_ErrorReturnType() { - s.dag.AddTask(errorTask) + s.dag.AddTask(errorTask, TaskSpec{}, nil) _, exists := s.reg.LookupTask("dag1", "errorTask") s.True(exists) } func (s *RegistrySuite) TestAddTask_WithSpec() { - s.dag.AddTask(myTask, TaskSpec{Queue: "high_mem", Retries: 3, DoXComPush: Bool(false)}) + s.dag.AddTask(myTask, TaskSpec{Queue: "high_mem", Retries: 3, DoXComPush: Bool(false)}, nil) enum, ok := s.reg.(EnumerableBundle) s.Require().True(ok) dags := enum.OrderedDags() @@ -135,7 +135,7 @@ func (s *RegistrySuite) TestAddTask_WithSpec() { } func (s *RegistrySuite) TestAddTaskWithName_WithSpec() { - s.dag.AddTaskWithName("special", myTask, TaskSpec{Queue: "gpu", Pool: "gpu_pool"}) + s.dag.AddTaskWithName("special", myTask, TaskSpec{Queue: "gpu", Pool: "gpu_pool"}, nil) enum, ok := s.reg.(EnumerableBundle) s.Require().True(ok) dags := enum.OrderedDags() @@ -147,9 +147,64 @@ func (s *RegistrySuite) TestAddTaskWithName_WithSpec() { s.Equal("gpu_pool", got.Spec.Pool) } -func (s *RegistrySuite) TestAddTask_TooManySpecsPanics() { - s.PanicsWithError("AddTask accepts at most one spec, got 2", func() { - s.dag.AddTask(myTask, TaskSpec{}, TaskSpec{}) +func (s *RegistrySuite) TestAddTask_DependsRecordsDownstream() { + s.dag.AddTaskWithName("extract", myTask, TaskSpec{}, nil) + s.dag.AddTaskWithName("transform", myTask, TaskSpec{}, []string{"extract"}) + s.dag.AddTaskWithName("load", myTask, TaskSpec{}, []string{"transform"}) + + enum := s.reg.(EnumerableBundle) + tasks := enum.OrderedDags()[0].Tasks + byID := make(map[string]TaskInfo, len(tasks)) + for _, t := range tasks { + byID[t.ID] = t + } + s.Equal([]string{"transform"}, byID["extract"].Downstream) + s.Equal([]string{"load"}, byID["transform"].Downstream) + s.Nil(byID["load"].Downstream) +} + +func (s *RegistrySuite) TestAddTask_FanOutFanIn() { + s.dag.AddTaskWithName("extract", myTask, TaskSpec{}, nil) + s.dag.AddTaskWithName("transform_a", myTask, TaskSpec{}, []string{"extract"}) + s.dag.AddTaskWithName("transform_b", myTask, TaskSpec{}, []string{"extract"}) + s.dag.AddTaskWithName("load", myTask, TaskSpec{}, []string{"transform_a", "transform_b"}) + + enum := s.reg.(EnumerableBundle) + tasks := enum.OrderedDags()[0].Tasks + byID := make(map[string]TaskInfo, len(tasks)) + for _, t := range tasks { + byID[t.ID] = t + } + s.ElementsMatch([]string{"transform_a", "transform_b"}, byID["extract"].Downstream) + s.Equal([]string{"load"}, byID["transform_a"].Downstream) + s.Equal([]string{"load"}, byID["transform_b"].Downstream) +} + +func (s *RegistrySuite) TestAddTask_DependsDuplicatesIgnored() { + s.dag.AddTaskWithName("extract", myTask, TaskSpec{}, nil) + s.dag.AddTaskWithName("load", myTask, TaskSpec{}, []string{"extract", "extract"}) + + enum := s.reg.(EnumerableBundle) + tasks := enum.OrderedDags()[0].Tasks + byID := make(map[string]TaskInfo, len(tasks)) + for _, t := range tasks { + byID[t.ID] = t + } + s.Equal([]string{"load"}, byID["extract"].Downstream) +} + +func (s *RegistrySuite) TestAddTask_DependsUnknownPanics() { + s.PanicsWithError( + `task "load" depends on unknown task "extract" in DAG "dag1"; register upstream tasks first`, + func() { + s.dag.AddTaskWithName("load", myTask, TaskSpec{}, []string{"extract"}) + }, + ) +} + +func (s *RegistrySuite) TestAddTask_DependsOnSelfPanics() { + s.PanicsWithError(`task "self" cannot depend on itself in DAG "dag1"`, func() { + s.dag.AddTaskWithName("self", myTask, TaskSpec{}, []string{"self"}) }) } diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go index 2692b9b384d44..fc1f625051971 100644 --- a/go-sdk/example/bundle/main.go +++ b/go-sdk/example/bundle/main.go @@ -50,9 +50,9 @@ func (m *myBundle) RegisterDags(dagbag v1.Registry) error { Description: "Example Go-authored Dag", Tags: []string{"example", "go-sdk"}, }) - simpleDag.AddTask(extract, v1.TaskSpec{Queue: "go-task", Retries: 2}) - simpleDag.AddTask(transform, v1.TaskSpec{Queue: "go-task"}) - simpleDag.AddTask(load, v1.TaskSpec{Queue: "go-task"}) + simpleDag.AddTask(extract, v1.TaskSpec{Queue: "go-task", Retries: 2}, nil) + simpleDag.AddTask(transform, v1.TaskSpec{Queue: "go-task"}, []string{"extract"}) + simpleDag.AddTask(load, v1.TaskSpec{Queue: "go-task"}, []string{"transform"}) return nil } diff --git a/go-sdk/pkg/execution/integration_test.go b/go-sdk/pkg/execution/integration_test.go index a7b18256a8c1e..4bfc0e06fb181 100644 --- a/go-sdk/pkg/execution/integration_test.go +++ b/go-sdk/pkg/execution/integration_test.go @@ -60,7 +60,7 @@ func buildBundle(t *testing.T, register func(bundlev1.Registry)) bundlev1.Bundle func TestDagParsing(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { d := r.AddDag("test_dag") - d.AddTask(simpleTask) + d.AddTask(simpleTask, bundlev1.TaskSpec{}, nil) }) req := &DagFileParseRequest{ @@ -98,8 +98,8 @@ func TestDagParsing(t *testing.T) { func TestDagParsingMultipleDagsPreservesOrder(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { - r.AddDag("dag1").AddTask(simpleTask) - r.AddDag("dag2").AddTask(failingTask) + r.AddDag("dag1").AddTask(simpleTask, bundlev1.TaskSpec{}, nil) + r.AddDag("dag2").AddTask(failingTask, bundlev1.TaskSpec{}, nil) }) req := &DagFileParseRequest{File: "/bundle/main.go", BundlePath: "/bundle"} @@ -117,7 +117,7 @@ func TestDagParsingMultipleDagsPreservesOrder(t *testing.T) { func TestTaskRunnerSuccess(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { - r.AddDag("test_dag").AddTask(simpleTask) + r.AddDag("test_dag").AddTask(simpleTask, bundlev1.TaskSpec{}, nil) }) details := &StartupDetails{ @@ -140,7 +140,7 @@ func TestTaskRunnerSuccess(t *testing.T) { func TestTaskRunnerFailure(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { - r.AddDag("test_dag").AddTask(failingTask) + r.AddDag("test_dag").AddTask(failingTask, bundlev1.TaskSpec{}, nil) }) details := &StartupDetails{ @@ -164,7 +164,7 @@ func TestTaskRunnerFailure(t *testing.T) { func TestTaskRunnerTaskNotFound(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { - r.AddDag("test_dag").AddTask(simpleTask) + r.AddDag("test_dag").AddTask(simpleTask, bundlev1.TaskSpec{}, nil) }) details := &StartupDetails{ @@ -187,7 +187,7 @@ func TestTaskRunnerTaskNotFound(t *testing.T) { func TestTaskRunnerPanic(t *testing.T) { bundle := buildBundle(t, func(r bundlev1.Registry) { - r.AddDag("test_dag").AddTask(panicTask) + r.AddDag("test_dag").AddTask(panicTask, bundlev1.TaskSpec{}, nil) }) details := &StartupDetails{ @@ -268,7 +268,7 @@ func TestServeDagFileParseEndToEnd(t *testing.T) { provider := &fakeProvider{ register: func(r bundlev1.Registry) error { d := r.AddDag("simple_dag") - d.AddTask(simpleTask) + d.AddTask(simpleTask, bundlev1.TaskSpec{}, nil) return nil }, } @@ -317,7 +317,7 @@ func TestServeStartupDetailsEndToEnd(t *testing.T) { provider := &fakeProvider{ register: func(r bundlev1.Registry) error { - r.AddDag("dag1").AddTask(simpleTask) + r.AddDag("dag1").AddTask(simpleTask, bundlev1.TaskSpec{}, nil) return nil }, } diff --git a/go-sdk/pkg/execution/serde.go b/go-sdk/pkg/execution/serde.go index 5b42b9b7a352c..0d9618c3990eb 100644 --- a/go-sdk/pkg/execution/serde.go +++ b/go-sdk/pkg/execution/serde.go @@ -161,8 +161,10 @@ func serializeTimetable(schedule *string) map[string]any { } } -// serializeTask converts a task to the Airflow serialization format. -func serializeTask(info bundlev1.TaskInfo, downstream []string) map[string]any { +// serializeTask converts a task to the Airflow serialization format. The +// downstream_task_ids slice is read from info.Downstream (populated by the +// registry from each task's `depends` argument) and sorted for stable JSON. +func serializeTask(info bundlev1.TaskInfo) map[string]any { typeName := info.TypeName if typeName == "" { typeName = info.ID @@ -178,9 +180,9 @@ func serializeTask(info bundlev1.TaskInfo, downstream []string) map[string]any { "language": "go", } applyTaskSpec(data, info.Spec) - if len(downstream) > 0 { - sorted := make([]string, len(downstream)) - copy(sorted, downstream) + if len(info.Downstream) > 0 { + sorted := make([]string, len(info.Downstream)) + copy(sorted, info.Downstream) sort.Strings(sorted) data["downstream_task_ids"] = sorted } @@ -377,7 +379,7 @@ func SerializeDag(info bundlev1.DagInfo, fileloc, relativeFileloc string) map[st tasks := make([]any, len(info.Tasks)) for i, t := range info.Tasks { taskIDs[i] = t.ID - tasks[i] = serializeTask(t, nil) + tasks[i] = serializeTask(t) } var schedule *string diff --git a/go-sdk/pkg/execution/serde_test.go b/go-sdk/pkg/execution/serde_test.go index 6121d24e285e8..6ad434e539b10 100644 --- a/go-sdk/pkg/execution/serde_test.go +++ b/go-sdk/pkg/execution/serde_test.go @@ -120,10 +120,10 @@ func TestSerializeTimetable(t *testing.T) { } func TestSerializeTask(t *testing.T) { - result := serializeTask( - bundlev1.TaskInfo{ID: "extract", TypeName: "extract", PkgPath: "main"}, - []string{"transform"}, - ) + result := serializeTask(bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Downstream: []string{"transform"}, + }) assert.Equal(t, "operator", result["__type"]) data := result["__var"].(map[string]any) assert.Equal(t, "extract", data["task_id"]) @@ -135,36 +135,36 @@ func TestSerializeTask(t *testing.T) { assert.False(t, hasQueue, "queue should be omitted when unset") } +func TestSerializeTaskDownstreamSorted(t *testing.T) { + result := serializeTask(bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Downstream: []string{"transform", "audit", "load"}, + }) + data := result["__var"].(map[string]any) + assert.Equal(t, []string{"audit", "load", "transform"}, data["downstream_task_ids"]) +} + func TestSerializeTaskNoDownstream(t *testing.T) { - result := serializeTask( - bundlev1.TaskInfo{ID: "load", TypeName: "load", PkgPath: "main"}, - nil, - ) + result := serializeTask(bundlev1.TaskInfo{ID: "load", TypeName: "load", PkgPath: "main"}) data := result["__var"].(map[string]any) _, hasDownstream := data["downstream_task_ids"] assert.False(t, hasDownstream) } func TestSerializeTaskCustomQueue(t *testing.T) { - result := serializeTask( - bundlev1.TaskInfo{ - ID: "extract", TypeName: "extract", PkgPath: "main", - Spec: bundlev1.TaskSpec{Queue: "high_mem"}, - }, - nil, - ) + result := serializeTask(bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Spec: bundlev1.TaskSpec{Queue: "high_mem"}, + }) data := result["__var"].(map[string]any) assert.Equal(t, "high_mem", data["queue"]) } func TestSerializeTaskDefaultQueueOmitted(t *testing.T) { - result := serializeTask( - bundlev1.TaskInfo{ - ID: "extract", TypeName: "extract", PkgPath: "main", - Spec: bundlev1.TaskSpec{Queue: "default"}, - }, - nil, - ) + result := serializeTask(bundlev1.TaskInfo{ + ID: "extract", TypeName: "extract", PkgPath: "main", + Spec: bundlev1.TaskSpec{Queue: "default"}, + }) data := result["__var"].(map[string]any) _, hasQueue := data["queue"] assert.False(t, hasQueue, "queue=\"default\" matches the schema default and should be omitted") @@ -299,7 +299,10 @@ func TestSerializeDagWithTasks(t *testing.T) { info := bundlev1.DagInfo{ DagID: "etl", Tasks: []bundlev1.TaskInfo{ - {ID: "extract", TypeName: "extract", PkgPath: "main"}, + { + ID: "extract", TypeName: "extract", PkgPath: "main", + Downstream: []string{"load"}, + }, { ID: "load", TypeName: "load", PkgPath: "main", Spec: bundlev1.TaskSpec{Queue: "high_mem"}, @@ -318,9 +321,12 @@ func TestSerializeDagWithTasks(t *testing.T) { assert.Equal(t, "go", v["language"]) _, hasQueue := v["queue"] assert.False(t, hasQueue, "extract has no queue set; field should be omitted") + assert.Equal(t, []string{"load"}, v["downstream_task_ids"]) second := tasks[1].(map[string]any)["__var"].(map[string]any) assert.Equal(t, "high_mem", second["queue"]) + _, hasDownstream := second["downstream_task_ids"] + assert.False(t, hasDownstream, "leaf task has no downstream") tg := result["task_group"].(map[string]any) children := tg["children"].(map[string]any) diff --git a/go-sdk/pkg/worker/runner_test.go b/go-sdk/pkg/worker/runner_test.go index 72673c00df376..f5026e84a2017 100644 --- a/go-sdk/pkg/worker/runner_test.go +++ b/go-sdk/pkg/worker/runner_test.go @@ -161,7 +161,7 @@ func (s *WorkerSuite) TestStartContextErrorTaskDoesntStart() { s.registry.AddDag(testWorkload.TI.DagId).AddTaskWithName(testWorkload.TI.TaskId, func() error { wasCalled = true return nil - }) + }, bundlev1.TaskSpec{}, nil) // Setup the mock s.ti.EXPECT(). @@ -194,7 +194,7 @@ func (s *WorkerSuite) TestTaskHeartbeatsWhileRunning() { s.registry.AddDag(testWorkload.TI.DagId).AddTaskWithName(testWorkload.TI.TaskId, func() error { time.Sleep(time.Second) return nil - }) + }, bundlev1.TaskSpec{}, nil) s.ExpectTaskRun(id) s.ExpectTaskState(id, api.TerminalTIStateSuccess) From 09e99d76d74d94dfb0be01592b87fb43770672ab Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 5 May 2026 14:12:01 +0800 Subject: [PATCH 4/5] Add airflow-go-pack for building self-contained Airflow bundles - Added new command `airflow-go-pack` to build a self-contained Airflow bundle from a Go package. - Introduced `inspect` command to print the embedded manifest and optionally the source from a bundle. - Implemented `dump-bundle-spec` functionality to output the bundle specification in JSON format. - Created `bundlefooter` package to manage appending source and metadata to the binary with a defined trailer. - Added tests for bundle footer operations and manifest rendering to ensure correctness. - Updated Justfile for building and packing example DAG bundles. --- .../bundle/bundlev1/bundlev1server/server.go | 100 ++++- go-sdk/cmd/airflow-go-pack/inspect.go | 57 +++ go-sdk/cmd/airflow-go-pack/main.go | 102 +++++ go-sdk/cmd/airflow-go-pack/pack.go | 407 ++++++++++++++++++ go-sdk/cmd/airflow-go-pack/pack_test.go | 70 +++ go-sdk/example/bundle/Justfile | 25 +- go-sdk/go.mod | 4 +- go-sdk/internal/bundlefooter/footer.go | 222 ++++++++++ go-sdk/internal/bundlefooter/footer_test.go | 139 ++++++ 9 files changed, 1121 insertions(+), 5 deletions(-) create mode 100644 go-sdk/cmd/airflow-go-pack/inspect.go create mode 100644 go-sdk/cmd/airflow-go-pack/main.go create mode 100644 go-sdk/cmd/airflow-go-pack/pack.go create mode 100644 go-sdk/cmd/airflow-go-pack/pack_test.go create mode 100644 go-sdk/internal/bundlefooter/footer.go create mode 100644 go-sdk/internal/bundlefooter/footer_test.go diff --git a/go-sdk/bundle/bundlev1/bundlev1server/server.go b/go-sdk/bundle/bundlev1/bundlev1server/server.go index e2e795dc1dd60..8e3612c2b602b 100644 --- a/go-sdk/bundle/bundlev1/bundlev1server/server.go +++ b/go-sdk/bundle/bundlev1/bundlev1server/server.go @@ -22,6 +22,7 @@ import ( "fmt" "log/slog" "os" + "runtime/debug" "github.com/evanphx/go-hclog-slog/hclogslog" "github.com/hashicorp/go-hclog" @@ -35,12 +36,21 @@ import ( "github.com/apache/airflow/go-sdk/pkg/execution" ) +// sdkModulePath is the import path of the SDK module. Used to identify the +// SDK version from the bundle binary's build info dependencies. +const sdkModulePath = "github.com/apache/airflow/go-sdk" + // Flags. The bundle-metadata flag is the existing ADR 0001 introspection // hook; --comm and --logs select the coordinator-mode protocol added by // ADR 0003. All three are read by Serve to choose a server mode below. var ( versionInfo = flag.Bool("bundle-metadata", false, "show the embedded bundle info") - commAddr = flag.String( + dumpSpec = flag.Bool( + "dump-bundle-spec", + false, + "print the bundle spec JSON (sdk + dags) used by airflow-go-pack and exit", + ) + commAddr = flag.String( "comm", "", "host:port of the supervisor's coordinator comm channel (selects coordinator mode)", @@ -74,6 +84,7 @@ type serveMode int const ( modePlugin serveMode = iota // go-plugin gRPC (existing Edge Worker path) modeMetadataDump // --bundle-metadata: print BundleInfo JSON + modeSpecDump // --dump-bundle-spec: print bundle spec JSON (ADR 0002) modeCoordinator // --comm/--logs: msgpack-over-IPC (ADR 0003) modeUsageError // misuse: print usage and exit non-zero ) @@ -102,6 +113,8 @@ func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error { switch decideMode() { case modeMetadataDump: return dumpBundleMetadata(bundle) + case modeSpecDump: + return dumpBundleSpec(bundle) case modeCoordinator: // In coordinator mode the supervisor reads the logs channel for // structured records, so configuring the hclog/stderr default @@ -124,6 +137,9 @@ func decideMode() serveMode { if *versionInfo { return modeMetadataDump } + if *dumpSpec { + return modeSpecDump + } commSet := *commAddr != "" logsSet := *logsAddr != "" if commSet && logsSet { @@ -149,6 +165,88 @@ func dumpBundleMetadata(bundle bundlev1.BundleProvider) error { return nil } +// bundleSpec is the wire shape printed by --dump-bundle-spec. The schema is +// stable per ADR 0002 and consumed by airflow-go-pack to populate the +// bundle's airflow-metadata.yaml at build time. +type bundleSpec struct { + FormatVersion string `json:"format_version"` + SDK bundleSpecSDK `json:"sdk"` + Dags map[string]bundleSpecDag `json:"dags"` +} + +type bundleSpecSDK struct { + Language string `json:"language"` + Version string `json:"version"` +} + +type bundleSpecDag struct { + Tasks []string `json:"tasks"` +} + +// dumpBundleSpec runs the bundle's RegisterDags against an in-memory recorder +// and writes the bundle spec JSON to stdout. It must not start the gRPC +// server or contact any external services; the recorder is the only side +// effect. +func dumpBundleSpec(bundle bundlev1.BundleProvider) error { + reg := bundlev1.New() + if err := bundle.RegisterDags(reg); err != nil { + return fmt.Errorf("registering dags: %w", err) + } + + enum, ok := reg.(bundlev1.EnumerableBundle) + if !ok { + return fmt.Errorf("registry does not implement EnumerableBundle") + } + + spec := bundleSpec{ + FormatVersion: "1.0", + SDK: bundleSpecSDK{ + Language: "go", + Version: sdkVersion(), + }, + Dags: make(map[string]bundleSpecDag), + } + for _, dag := range enum.OrderedDags() { + taskIDs := make([]string, 0, len(dag.Tasks)) + for _, t := range dag.Tasks { + taskIDs = append(taskIDs, t.ID) + } + spec.Dags[dag.DagID] = bundleSpecDag{Tasks: taskIDs} + } + + data, err := json.MarshalIndent(spec, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil +} + +// sdkVersion returns the version of the SDK module linked into this binary, +// derived from runtime/debug.ReadBuildInfo. Falls back to "(devel)" when +// build info is unavailable (e.g. tests, bundle binaries built from a local +// replace directive). +func sdkVersion() string { + info, ok := debug.ReadBuildInfo() + if !ok { + return "(devel)" + } + if info.Main.Path == sdkModulePath && info.Main.Version != "" { + return info.Main.Version + } + for _, dep := range info.Deps { + if dep.Path == sdkModulePath { + if dep.Replace != nil && dep.Replace.Version != "" { + return dep.Replace.Version + } + if dep.Version != "" { + return dep.Version + } + } + } + return "(devel)" +} + func installPluginLogger() { hcLogger := hclog.New(&hclog.LoggerOptions{ Level: hclog.Trace, diff --git a/go-sdk/cmd/airflow-go-pack/inspect.go b/go-sdk/cmd/airflow-go-pack/inspect.go new file mode 100644 index 0000000000000..f9cca6d488b22 --- /dev/null +++ b/go-sdk/cmd/airflow-go-pack/inspect.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/apache/airflow/go-sdk/internal/bundlefooter" +) + +func newInspectCmd() *cobra.Command { + var showSource bool + cmd := &cobra.Command{ + Use: "inspect ", + Short: "Print the manifest (and optionally source) embedded in a bundle", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + source, manifest, err := bundlefooter.Read(args[0]) + if err != nil { + return err + } + out := cmd.OutOrStdout() + if showSource { + fmt.Fprintln(out, "# --- source ---") + out.Write(source) + if len(source) > 0 && source[len(source)-1] != '\n' { + fmt.Fprintln(out) + } + fmt.Fprintln(out, "# --- manifest ---") + } + out.Write(manifest) + if len(manifest) > 0 && manifest[len(manifest)-1] != '\n' { + fmt.Fprintln(out) + } + return nil + }, + } + cmd.Flags().BoolVar(&showSource, "source", false, "also print the embedded source file") + return cmd +} diff --git a/go-sdk/cmd/airflow-go-pack/main.go b/go-sdk/cmd/airflow-go-pack/main.go new file mode 100644 index 0000000000000..307c0616377ca --- /dev/null +++ b/go-sdk/cmd/airflow-go-pack/main.go @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Command airflow-go-pack builds a self-contained Airflow bundle from a Go +// package. It runs `go build`, exec's the freshly built binary with +// `--dump-bundle-spec` to obtain the manifest, and appends the source plus +// manifest plus AFBNDL01 trailer to the executable as specified by ADR 0004. +// +// Usage: +// +// go tool airflow-go-pack [./path/to/pkg] [-- ...] +// go tool airflow-go-pack --executable ./build/example --source main.go +// go tool airflow-go-pack inspect ./mybundle +// +// See go-sdk/adr/0002-use-go-tool-directive-for-bundle-packer.md and +// go-sdk/adr/0004-self-contained-executable-bundle.md. +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +func main() { + if err := newRootCmd().Execute(); err != nil { + fmt.Fprintln(os.Stderr, "error:", err) + os.Exit(1) + } +} + +func newRootCmd() *cobra.Command { + opts := &packOptions{} + + root := &cobra.Command{ + Use: "airflow-go-pack [package]", + Short: "Build a self-contained Airflow bundle from a Go package", + Long: `airflow-go-pack builds a Go bundle binary, queries it for its DAG/task +identity via --dump-bundle-spec, and appends the source plus an +airflow-metadata.yaml manifest plus an AFBNDL01 trailer to the +executable. The result is a single self-contained file that drops into +[executable] bundles_folder. + +By default the packer builds the package in the current directory. Pass +a different package as the positional argument; pass extra go build +flags after a "--" separator. + +Examples: + go tool airflow-go-pack + go tool airflow-go-pack ./cmd/my-bundle -- -trimpath -tags=prod + go tool airflow-go-pack --executable ./build/example --source main.go +`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + // Anything after "--" is forwarded to the internal `go build` + // invocation. ArgsLenAtDash() returns the count of args before + // the dash, or -1 if the dash isn't present. + dashAt := cmd.ArgsLenAtDash() + var pkgArgs, buildArgs []string + if dashAt < 0 { + pkgArgs = args + } else { + pkgArgs = args[:dashAt] + buildArgs = args[dashAt:] + } + opts.pkg = "." + if len(pkgArgs) == 1 { + opts.pkg = pkgArgs[0] + } + opts.buildArgs = buildArgs + return runPack(cmd.OutOrStdout(), cmd.ErrOrStderr(), opts) + }, + } + + root.Flags().StringVar(&opts.source, "source", + "", + "path to the DAG source file (defaults to the file in the target package containing func main)") + root.Flags().StringVar(&opts.executable, "executable", + "", + "pack a pre-built executable instead of running go build") + root.Flags().StringVar(&opts.output, "output", + "", + "output bundle path (defaults to ./)") + + root.AddCommand(newInspectCmd()) + return root +} diff --git a/go-sdk/cmd/airflow-go-pack/pack.go b/go-sdk/cmd/airflow-go-pack/pack.go new file mode 100644 index 0000000000000..16640faa899d3 --- /dev/null +++ b/go-sdk/cmd/airflow-go-pack/pack.go @@ -0,0 +1,407 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "sort" + + "gopkg.in/yaml.v3" + + "github.com/apache/airflow/go-sdk/internal/bundlefooter" +) + +// packOptions are the flags accepted by the root pack command. +type packOptions struct { + pkg string // target package (default ".") + source string // override the auto-detected DAG source file + executable string // pack a pre-built binary instead of building + output string // override the default output path + buildArgs []string // forwarded verbatim to `go build` (already includes the leading "--") +} + +// bundleSpec mirrors the JSON printed by --dump-bundle-spec. +type bundleSpec struct { + FormatVersion string `json:"format_version"` + SDK bundleSpecSDK `json:"sdk"` + Dags map[string]bundleSpecDag `json:"dags"` +} + +type bundleSpecSDK struct { + Language string `json:"language"` + Version string `json:"version"` +} + +type bundleSpecDag struct { + Tasks []string `json:"tasks"` +} + +// bundleMetadata mirrors --bundle-metadata's BundleInfo JSON. +type bundleMetadata struct { + Name string `json:"Name"` + Version *string `json:"Version,omitempty"` +} + +func runPack(stdout, stderr io.Writer, opts *packOptions) error { + if opts.executable != "" && len(opts.buildArgs) > 0 { + return fmt.Errorf("--executable is mutually exclusive with go build flags after \"--\"") + } + + sourcePath := opts.source + var execPath string + cleanupExec := func() {} + defer func() { cleanupExec() }() + + if opts.executable != "" { + execPath = opts.executable + if sourcePath == "" { + return fmt.Errorf( + "--executable requires --source: cannot infer the DAG source for a pre-built binary", + ) + } + } else { + discovered, err := discoverMainSource(opts.pkg) + if err != nil { + return fmt.Errorf("locating DAG source file: %w", err) + } + if sourcePath == "" { + sourcePath = discovered + } + + tmp, cleanup, err := buildPackage(stderr, opts.pkg, opts.buildArgs) + if err != nil { + return err + } + execPath = tmp + cleanupExec = cleanup + } + + if _, err := os.Stat(execPath); err != nil { + return fmt.Errorf("executable %s: %w", execPath, err) + } + if _, err := os.Stat(sourcePath); err != nil { + return fmt.Errorf("source file %s: %w", sourcePath, err) + } + + meta, err := readBundleMetadata(execPath) + if err != nil { + return fmt.Errorf("--bundle-metadata: %w", err) + } + if meta.Name == "" { + return fmt.Errorf( + "bundle binary returned an empty BundleInfo.Name; set bundleName at build time", + ) + } + + spec, err := readBundleSpec(execPath) + if err != nil { + return fmt.Errorf("--dump-bundle-spec: %w", err) + } + if len(spec.Dags) == 0 { + return fmt.Errorf("bundle exposes no dags: nothing to pack") + } + for dagID, dag := range spec.Dags { + if len(dag.Tasks) == 0 { + fmt.Fprintf(stderr, "warning: dag %q has no tasks\n", dagID) + } + } + + manifest, err := renderManifest(spec, filepath.Base(sourcePath)) + if err != nil { + return fmt.Errorf("rendering manifest: %w", err) + } + sourceBytes, err := os.ReadFile(sourcePath) + if err != nil { + return fmt.Errorf("reading source file: %w", err) + } + + output := opts.output + if output == "" { + output = defaultOutputPath(meta.Name) + } + + // Copy the executable to the output path before appending so we never + // mutate the build artefact in the temp dir or the user-supplied + // --executable file. + if err := copyFile(execPath, output, 0o755); err != nil { + return fmt.Errorf("writing %s: %w", output, err) + } + if err := bundlefooter.Append(output, sourceBytes, manifest); err != nil { + return err + } + + fmt.Fprintf(stdout, "Wrote bundle %s (sdk=%s/%s, dags=%d)\n", + output, spec.SDK.Language, spec.SDK.Version, len(spec.Dags)) + return nil +} + +func defaultOutputPath(bundleName string) string { + if runtime.GOOS == "windows" { + return bundleName + ".exe" + } + return bundleName +} + +// discoverMainSource locates the file in the given package whose AST contains +// a top-level `func main()`. Returns an error if the package has zero or +// more than one such file, mirroring ADR 0002's discovery contract. +func discoverMainSource(pkg string) (string, error) { + cmd := exec.Command("go", "list", "-f", "{{.Dir}}\n{{range .GoFiles}}{{.}}\n{{end}}", pkg) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("go list %s: %w: %s", pkg, err, stderr.String()) + } + + lines := splitNonEmpty(stdout.String()) + if len(lines) < 2 { + return "", fmt.Errorf("package %s has no Go source files", pkg) + } + dir := lines[0] + files := lines[1:] + + fset := token.NewFileSet() + var matches []string + for _, name := range files { + full := filepath.Join(dir, name) + f, err := parser.ParseFile(fset, full, nil, parser.SkipObjectResolution) + if err != nil { + return "", fmt.Errorf("parsing %s: %w", full, err) + } + if hasMainFunc(f) { + matches = append(matches, full) + } + } + switch len(matches) { + case 0: + return "", fmt.Errorf("no file in package %s defines func main()", pkg) + case 1: + return matches[0], nil + default: + return "", fmt.Errorf( + "multiple files in package %s define func main(): %v; use --source to disambiguate", + pkg, + matches, + ) + } +} + +func hasMainFunc(f *ast.File) bool { + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if fn.Recv != nil { + continue + } + if fn.Name.Name != "main" { + continue + } + if fn.Type.Params != nil && len(fn.Type.Params.List) != 0 { + continue + } + return true + } + return false +} + +func splitNonEmpty(s string) []string { + var out []string + for _, line := range bytes.Split([]byte(s), []byte("\n")) { + t := bytes.TrimSpace(line) + if len(t) > 0 { + out = append(out, string(t)) + } + } + return out +} + +// buildPackage runs `go build [extraArgs...] -o /bundle ` and +// returns the path to the freshly built executable plus a cleanup function. +// extraArgs is the slice that comes after the "--" separator on the +// airflow-go-pack command line; we drop the leading "--" before forwarding. +func buildPackage(stderr io.Writer, pkg string, extraArgs []string) (string, func(), error) { + tmpDir, err := os.MkdirTemp("", "airflow-go-pack-*") + if err != nil { + return "", nil, fmt.Errorf("creating temp dir: %w", err) + } + cleanup := func() { _ = os.RemoveAll(tmpDir) } + + binName := "bundle" + if runtime.GOOS == "windows" { + binName += ".exe" + } + outPath := filepath.Join(tmpDir, binName) + + args := []string{"build"} + for _, a := range extraArgs { + if a == "--" { + continue + } + args = append(args, a) + } + args = append(args, "-o", outPath, pkg) + + cmd := exec.Command("go", args...) + cmd.Stdout = stderr + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + cleanup() + return "", nil, fmt.Errorf("go build failed: %w", err) + } + return outPath, cleanup, nil +} + +func readBundleMetadata(execPath string) (bundleMetadata, error) { + out, err := runIntrospect(execPath, "--bundle-metadata") + if err != nil { + return bundleMetadata{}, err + } + var meta bundleMetadata + if err := json.Unmarshal(out, &meta); err != nil { + return bundleMetadata{}, fmt.Errorf("decoding --bundle-metadata JSON: %w", err) + } + return meta, nil +} + +func readBundleSpec(execPath string) (bundleSpec, error) { + out, err := runIntrospect(execPath, "--dump-bundle-spec") + if err != nil { + return bundleSpec{}, err + } + var spec bundleSpec + if err := json.Unmarshal(out, &spec); err != nil { + return bundleSpec{}, fmt.Errorf("decoding --dump-bundle-spec JSON: %w", err) + } + return spec, nil +} + +func runIntrospect(execPath string, flag string) ([]byte, error) { + cmd := exec.Command(execPath, flag) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("%s %s: %w: %s", execPath, flag, err, stderr.String()) + } + return stdout.Bytes(), nil +} + +// renderManifest serialises the bundle spec as deterministic, sorted-key +// YAML matching the schema in providers/sdk/executable/docs/bundle-spec.rst. +func renderManifest(spec bundleSpec, sourceName string) ([]byte, error) { + if spec.FormatVersion == "" { + spec.FormatVersion = "1.0" + } + + dagIDs := make([]string, 0, len(spec.Dags)) + for id := range spec.Dags { + dagIDs = append(dagIDs, id) + } + sort.Strings(dagIDs) + + dagsNode := &yaml.Node{Kind: yaml.MappingNode} + for _, id := range dagIDs { + tasks := spec.Dags[id].Tasks + taskItems := make([]*yaml.Node, 0, len(tasks)) + for _, t := range tasks { + taskItems = append(taskItems, scalar(t)) + } + dagsNode.Content = append(dagsNode.Content, + scalar(id), + &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + scalar("tasks"), + {Kind: yaml.SequenceNode, Content: taskItems}, + }, + }, + ) + } + + root := &yaml.Node{Kind: yaml.DocumentNode} + manifest := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + scalar("format_version"), quotedScalar(spec.FormatVersion), + scalar("sdk"), + { + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + scalar("language"), scalar(spec.SDK.Language), + scalar("version"), quotedScalar(spec.SDK.Version), + }, + }, + scalar("source"), scalar(sourceName), + scalar("dags"), dagsNode, + }, + } + root.Content = []*yaml.Node{manifest} + + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + enc.SetIndent(2) + if err := enc.Encode(root); err != nil { + return nil, err + } + if err := enc.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func scalar(value string) *yaml.Node { + return &yaml.Node{Kind: yaml.ScalarNode, Value: value} +} + +func quotedScalar(value string) *yaml.Node { + return &yaml.Node{Kind: yaml.ScalarNode, Value: value, Style: yaml.DoubleQuotedStyle} +} + +// copyFile copies src to dst, truncating dst if it already exists. +func copyFile(src, dst string, mode os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + if _, err := io.Copy(out, in); err != nil { + out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + return os.Chmod(dst, mode) +} diff --git a/go-sdk/cmd/airflow-go-pack/pack_test.go b/go-sdk/cmd/airflow-go-pack/pack_test.go new file mode 100644 index 0000000000000..2b229331903a9 --- /dev/null +++ b/go-sdk/cmd/airflow-go-pack/pack_test.go @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderManifest_DeterministicDagOrdering(t *testing.T) { + spec := bundleSpec{ + FormatVersion: "1.0", + SDK: bundleSpecSDK{Language: "go", Version: "0.1.0"}, + Dags: map[string]bundleSpecDag{ + "zeta_dag": {Tasks: []string{"a", "b"}}, + "alpha_dag": {Tasks: []string{"x"}}, + }, + } + + got1, err := renderManifest(spec, "main.go") + require.NoError(t, err) + got2, err := renderManifest(spec, "main.go") + require.NoError(t, err) + + assert.Equal(t, got1, got2, "manifest should be byte-identical for identical input") + + expected := `format_version: "1.0" +sdk: + language: go + version: "0.1.0" +source: main.go +dags: + alpha_dag: + tasks: + - x + zeta_dag: + tasks: + - a + - b +` + assert.Equal(t, expected, string(got1)) +} + +func TestRenderManifest_EmptyDags(t *testing.T) { + spec := bundleSpec{ + FormatVersion: "1.0", + SDK: bundleSpecSDK{Language: "go", Version: "0.1.0"}, + Dags: map[string]bundleSpecDag{}, + } + got, err := renderManifest(spec, "main.go") + require.NoError(t, err) + assert.Contains(t, string(got), "dags: {}") +} diff --git a/go-sdk/example/bundle/Justfile b/go-sdk/example/bundle/Justfile index ca211b304b104..5ed6ef5d34a1a 100644 --- a/go-sdk/example/bundle/Justfile +++ b/go-sdk/example/bundle/Justfile @@ -21,13 +21,32 @@ default: @just --list -# Build the example bundle +# Build the example bundle (raw go build, no footer; for Edge Worker testing) build: @echo "Building example DAG bundle..." go build -o ../../bin/example-dag-bundle . - -# Build with specific name and version +# Build with specific name and version (raw go build, no footer) build-with name="data_processing_example" version="1.0.0": @echo "Building example DAG bundle with name={{name}} version={{version}}..." go build -ldflags="-X 'main.bundleName={{name}}' -X 'main.bundleVersion={{version}}'" -o ../../bin/{{name}}-{{version}} . + +# One-step build + pack. The single `go tool airflow-go-pack` +# invocation runs `go build` internally, queries the binary for its +# DAG/task identity via --dump-bundle-spec, and appends the source plus +# airflow-metadata.yaml plus AFBNDL01 trailer. The output is a single +# self-contained executable bundle, named after BundleInfo.Name and +# written to the current directory. Drop it into [executable] +# bundles_folder to deploy. +pack: + @echo "Packing example DAG bundle..." + go tool airflow-go-pack --output ../../bin/example_dags + +# Pack with extra go build flags forwarded after "--". +pack-release: + @echo "Packing example DAG bundle (release flags)..." + go tool airflow-go-pack --output ../../bin/example_dags -- -trimpath -ldflags="-s -w" + +# Inspect a packed bundle's embedded manifest. +inspect bundle="../../bin/example_dags": + go tool airflow-go-pack inspect {{bundle}} diff --git a/go-sdk/go.mod b/go-sdk/go.mod index f3bfcd4b0f600..cc738e4c3f862 100644 --- a/go-sdk/go.mod +++ b/go-sdk/go.mod @@ -4,6 +4,8 @@ go 1.24.0 toolchain go1.24.6 +tool github.com/apache/airflow/go-sdk/cmd/airflow-go-pack + require ( github.com/cappuccinotm/slogx v1.4.2 github.com/golang-jwt/jwt/v5 v5.3.0 @@ -60,5 +62,5 @@ require ( github.com/samber/slog-http v1.8.2 golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go-sdk/internal/bundlefooter/footer.go b/go-sdk/internal/bundlefooter/footer.go new file mode 100644 index 0000000000000..75d518e027d1f --- /dev/null +++ b/go-sdk/internal/bundlefooter/footer.go @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package bundlefooter implements the AFBNDL01 trailer described in +// ADR 0004 (and providers/sdk/executable/docs/bundle-spec.rst). A bundle +// file is the compiled executable with three appended regions: the source +// bytes, the manifest bytes, and a fixed 32-byte trailer that locates them. +// +// The trailer layout (all little-endian) is: +// +// bytes 0..3 source_len uint32 +// bytes 4..7 metadata_len uint32 +// bytes 8..11 footer_ver uint32 (= 1) +// bytes 12..23 reserved 12 bytes, zero +// bytes 24..31 magic 8 bytes ASCII "AFBNDL01" +package bundlefooter + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "os" +) + +const ( + // TrailerSize is the fixed length of the trailer, in bytes. + TrailerSize = 32 + + // FooterVersion is the currently defined trailer-format version. + FooterVersion = 1 + + // MaxRegionSize is the largest source or metadata region this footer + // format can address (uint32 length field). + MaxRegionSize = math.MaxUint32 +) + +// Magic is the 8-byte ASCII tag that identifies a file as a bundle. +var Magic = [8]byte{'A', 'F', 'B', 'N', 'D', 'L', '0', '1'} + +// ErrNotBundle is returned by Read when the file does not end with the +// AFBNDL01 magic. +var ErrNotBundle = errors.New("bundlefooter: not a bundle (magic mismatch)") + +// ErrUnknownVersion is returned by Read when the trailer's footer_ver field +// is something other than FooterVersion. +var ErrUnknownVersion = errors.New("bundlefooter: unknown footer version") + +// Trailer carries the parsed contents of a bundle's 32-byte trailer. +type Trailer struct { + SourceLen uint32 + MetadataLen uint32 + FooterVersion uint32 +} + +// Append writes the source bytes, metadata bytes, and trailer to the end of +// the file at execPath. The file's existing contents (the executable) are +// left intact and its mode bits are preserved. source MAY be nil/empty. +func Append(execPath string, source, metadata []byte) error { + if int64(len(source)) > MaxRegionSize { + return fmt.Errorf( + "bundlefooter: source region too large (%d bytes, max %d)", + len(source), + MaxRegionSize, + ) + } + if int64(len(metadata)) > MaxRegionSize { + return fmt.Errorf( + "bundlefooter: metadata region too large (%d bytes, max %d)", + len(metadata), + MaxRegionSize, + ) + } + + f, err := os.OpenFile(execPath, os.O_RDWR|os.O_APPEND, 0) + if err != nil { + return fmt.Errorf("bundlefooter: opening %s: %w", execPath, err) + } + defer f.Close() + + if len(source) > 0 { + if _, err := f.Write(source); err != nil { + return fmt.Errorf("bundlefooter: writing source region: %w", err) + } + } + if len(metadata) > 0 { + if _, err := f.Write(metadata); err != nil { + return fmt.Errorf("bundlefooter: writing metadata region: %w", err) + } + } + + trailer := encodeTrailer(uint32(len(source)), uint32(len(metadata))) + if _, err := f.Write(trailer[:]); err != nil { + return fmt.Errorf("bundlefooter: writing trailer: %w", err) + } + return nil +} + +// Read parses the trailer of the file at path and returns the embedded +// source and metadata regions. Returns ErrNotBundle if the magic does not +// match (so callers may silently ignore non-bundle files). +func Read(path string) (source, metadata []byte, err error) { + f, err := os.Open(path) + if err != nil { + return nil, nil, fmt.Errorf("bundlefooter: opening %s: %w", path, err) + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return nil, nil, fmt.Errorf("bundlefooter: stat %s: %w", path, err) + } + size := stat.Size() + if size < TrailerSize { + return nil, nil, ErrNotBundle + } + + var trailer [TrailerSize]byte + if _, err := f.ReadAt(trailer[:], size-TrailerSize); err != nil { + return nil, nil, fmt.Errorf("bundlefooter: reading trailer: %w", err) + } + + t, err := decodeTrailer(trailer) + if err != nil { + return nil, nil, err + } + + metadataStart := size - TrailerSize - int64(t.MetadataLen) + sourceStart := metadataStart - int64(t.SourceLen) + if sourceStart < 0 { + return nil, nil, fmt.Errorf( + "bundlefooter: trailer reports regions larger than file (source_len=%d metadata_len=%d size=%d)", + t.SourceLen, + t.MetadataLen, + size, + ) + } + if sourceStart == 0 { + return nil, nil, fmt.Errorf("bundlefooter: empty binary region") + } + + if t.SourceLen > 0 { + source = make([]byte, t.SourceLen) + if _, err := f.ReadAt(source, sourceStart); err != nil && !errors.Is(err, io.EOF) { + return nil, nil, fmt.Errorf("bundlefooter: reading source region: %w", err) + } + } + if t.MetadataLen > 0 { + metadata = make([]byte, t.MetadataLen) + if _, err := f.ReadAt(metadata, metadataStart); err != nil && !errors.Is(err, io.EOF) { + return nil, nil, fmt.Errorf("bundlefooter: reading metadata region: %w", err) + } + } + return source, metadata, nil +} + +// IsBundle reports whether the file at path ends with the AFBNDL01 magic. +// It does not validate the trailer beyond the magic check, so a file with a +// matching magic but a corrupt trailer body still returns true. +func IsBundle(path string) (bool, error) { + f, err := os.Open(path) + if err != nil { + return false, err + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return false, err + } + if stat.Size() < TrailerSize { + return false, nil + } + + var tail [8]byte + if _, err := f.ReadAt(tail[:], stat.Size()-int64(len(tail))); err != nil { + return false, err + } + return tail == Magic, nil +} + +func encodeTrailer(sourceLen, metadataLen uint32) [TrailerSize]byte { + var t [TrailerSize]byte + binary.LittleEndian.PutUint32(t[0:4], sourceLen) + binary.LittleEndian.PutUint32(t[4:8], metadataLen) + binary.LittleEndian.PutUint32(t[8:12], FooterVersion) + // bytes 12..23 are reserved, zero + copy(t[24:32], Magic[:]) + return t +} + +func decodeTrailer(b [TrailerSize]byte) (Trailer, error) { + var magic [8]byte + copy(magic[:], b[24:32]) + if magic != Magic { + return Trailer{}, ErrNotBundle + } + t := Trailer{ + SourceLen: binary.LittleEndian.Uint32(b[0:4]), + MetadataLen: binary.LittleEndian.Uint32(b[4:8]), + FooterVersion: binary.LittleEndian.Uint32(b[8:12]), + } + if t.FooterVersion != FooterVersion { + return Trailer{}, fmt.Errorf("%w: %d", ErrUnknownVersion, t.FooterVersion) + } + return t, nil +} diff --git a/go-sdk/internal/bundlefooter/footer_test.go b/go-sdk/internal/bundlefooter/footer_test.go new file mode 100644 index 0000000000000..33cbcafffd8ef --- /dev/null +++ b/go-sdk/internal/bundlefooter/footer_test.go @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bundlefooter + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTempBinary(t *testing.T, contents []byte) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "fake-binary") + require.NoError(t, os.WriteFile(path, contents, 0o755)) + return path +} + +func TestAppendAndRead_RoundTrip(t *testing.T) { + binary := []byte("\x7FELFnot-really-an-elf-but-good-enough") + source := []byte("package main\n\nfunc main() {}\n") + metadata := []byte("format_version: \"1.0\"\nsdk:\n language: go\n") + + path := writeTempBinary(t, binary) + require.NoError(t, Append(path, source, metadata)) + + got := mustRead(t, path) + assert.Equal(t, len(binary)+len(source)+len(metadata)+TrailerSize, got.size) + + gotSource, gotMetadata, err := Read(path) + require.NoError(t, err) + assert.Equal(t, source, gotSource) + assert.Equal(t, metadata, gotMetadata) + + ok, err := IsBundle(path) + require.NoError(t, err) + assert.True(t, ok) +} + +func TestAppend_ZeroLengthSource(t *testing.T) { + binary := []byte("\x7FELFstub") + metadata := []byte("manifest") + + path := writeTempBinary(t, binary) + require.NoError(t, Append(path, nil, metadata)) + + source, gotMetadata, err := Read(path) + require.NoError(t, err) + assert.Empty(t, source) + assert.Equal(t, metadata, gotMetadata) +} + +func TestAppend_DeterministicOutput(t *testing.T) { + binary := []byte("\x7FELFstub-binary-bytes") + source := []byte("source") + metadata := []byte("manifest") + + pathA := writeTempBinary(t, binary) + pathB := writeTempBinary(t, binary) + require.NoError(t, Append(pathA, source, metadata)) + require.NoError(t, Append(pathB, source, metadata)) + + a, err := os.ReadFile(pathA) + require.NoError(t, err) + b, err := os.ReadFile(pathB) + require.NoError(t, err) + assert.Equal(t, a, b) +} + +func TestRead_NotBundle(t *testing.T) { + path := writeTempBinary(t, []byte("just a regular file with no footer")) + + _, _, err := Read(path) + require.ErrorIs(t, err, ErrNotBundle) + + ok, err := IsBundle(path) + require.NoError(t, err) + assert.False(t, ok) +} + +func TestRead_TooShort(t *testing.T) { + path := writeTempBinary(t, []byte("hi")) + + _, _, err := Read(path) + require.ErrorIs(t, err, ErrNotBundle) +} + +func TestRead_UnknownVersion(t *testing.T) { + binary := []byte("\x7FELFstub") + source := []byte("src") + metadata := []byte("md") + path := writeTempBinary(t, binary) + require.NoError(t, Append(path, source, metadata)) + + // Mutate the version byte in the trailer. + f, err := os.OpenFile(path, os.O_RDWR, 0) + require.NoError(t, err) + stat, err := f.Stat() + require.NoError(t, err) + // footer_ver lives at bytes 8..11 of the trailer. + versionOffset := stat.Size() - TrailerSize + 8 + _, err = f.WriteAt([]byte{99, 0, 0, 0}, versionOffset) + require.NoError(t, err) + require.NoError(t, f.Close()) + + _, _, err = Read(path) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrUnknownVersion)) +} + +type bundleStat struct { + size int +} + +func mustRead(t *testing.T, path string) bundleStat { + t.Helper() + stat, err := os.Stat(path) + require.NoError(t, err) + return bundleStat{size: int(stat.Size())} +} From 8c5d62d0cdfb4f381176afff6a91cd8f2a213608 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 12 May 2026 14:27:09 +0800 Subject: [PATCH 5/5] Enhance argument validation in airflow-go-pack to allow build flags after "--" and update example bundle to support dynamic DAG naming --- go-sdk/cmd/airflow-go-pack/main.go | 12 +++++++++- go-sdk/cmd/airflow-go-pack/pack_test.go | 29 +++++++++++++++++++++++++ go-sdk/example/bundle/main.go | 14 +++++++----- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/go-sdk/cmd/airflow-go-pack/main.go b/go-sdk/cmd/airflow-go-pack/main.go index 307c0616377ca..3be934f18b2f0 100644 --- a/go-sdk/cmd/airflow-go-pack/main.go +++ b/go-sdk/cmd/airflow-go-pack/main.go @@ -65,7 +65,17 @@ Examples: go tool airflow-go-pack ./cmd/my-bundle -- -trimpath -tags=prod go tool airflow-go-pack --executable ./build/example --source main.go `, - Args: cobra.MaximumNArgs(1), + // Only count args BEFORE "--" toward the positional limit; args + // after "--" are forwarded verbatim to `go build` and must not + // inflate the count (e.g. `-- -ldflags "-X main.foo=bar"`). + Args: func(cmd *cobra.Command, args []string) error { + dashAt := cmd.ArgsLenAtDash() + pkgArgs := args + if dashAt >= 0 { + pkgArgs = args[:dashAt] + } + return cobra.MaximumNArgs(1)(cmd, pkgArgs) + }, RunE: func(cmd *cobra.Command, args []string) error { // Anything after "--" is forwarded to the internal `go build` // invocation. ArgsLenAtDash() returns the count of args before diff --git a/go-sdk/cmd/airflow-go-pack/pack_test.go b/go-sdk/cmd/airflow-go-pack/pack_test.go index 2b229331903a9..b6dc9f9280e69 100644 --- a/go-sdk/cmd/airflow-go-pack/pack_test.go +++ b/go-sdk/cmd/airflow-go-pack/pack_test.go @@ -20,6 +20,7 @@ package main import ( "testing" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -68,3 +69,31 @@ func TestRenderManifest_EmptyDags(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(got), "dags: {}") } + +// TestRootArgs_AllowsBuildFlagsAfterDoubleDash regression-tests the +// positional-arg validator: forwarded `go build` flags after "--" must not +// be counted against MaximumNArgs(1). The runtime call would otherwise fail +// with `accepts at most 1 arg(s), received N`. +func TestRootArgs_AllowsBuildFlagsAfterDoubleDash(t *testing.T) { + cases := [][]string{ + {"--", "-ldflags", "-X main.dagId=foo"}, + {"./pkg", "--", "-ldflags", "-X main.dagId=foo"}, + {"--", "-trimpath", "-tags=prod"}, + } + for _, argv := range cases { + cmd := newRootCmd() + // Stop the command from actually running; we only want arg validation. + cmd.RunE = func(*cobra.Command, []string) error { return nil } + cmd.SetArgs(argv) + assert.NoError(t, cmd.Execute(), "args=%v should validate", argv) + } +} + +func TestRootArgs_RejectsExtraPositionalBeforeDash(t *testing.T) { + cmd := newRootCmd() + cmd.RunE = func(*cobra.Command, []string) error { return nil } + cmd.SetArgs([]string{"./pkg1", "./pkg2", "--", "-ldflags", "-X main.dagId=foo"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts at most 1 arg") +} diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go index fc1f625051971..c3cc777a6b511 100644 --- a/go-sdk/example/bundle/main.go +++ b/go-sdk/example/bundle/main.go @@ -29,10 +29,14 @@ import ( "github.com/apache/airflow/go-sdk/sdk" ) -// Set by `-ldflags` at build time +// Set by `-ldflags` at build time. Override dagId to produce the same +// example bundle under a different DAG name — e.g. build once with +// `-X main.dagId=simple_dag` for the pure-Go scenario, and again with +// `-X main.dagId=go_multi_lang` for the Python-stub scenario. var ( bundleName = "example_dags" bundleVersion = "0.0" + dagId = "simple_dag" ) type myBundle struct{} @@ -45,14 +49,14 @@ func (m *myBundle) GetBundleVersion() v1.BundleInfo { } func (m *myBundle) RegisterDags(dagbag v1.Registry) error { - simpleDag := dagbag.AddDag("simple_dag", v1.DagSpec{ + dag := dagbag.AddDag(dagId, v1.DagSpec{ Schedule: "@daily", Description: "Example Go-authored Dag", Tags: []string{"example", "go-sdk"}, }) - simpleDag.AddTask(extract, v1.TaskSpec{Queue: "go-task", Retries: 2}, nil) - simpleDag.AddTask(transform, v1.TaskSpec{Queue: "go-task"}, []string{"extract"}) - simpleDag.AddTask(load, v1.TaskSpec{Queue: "go-task"}, []string{"transform"}) + dag.AddTask(extract, v1.TaskSpec{Queue: "go-task", Retries: 2}, nil) + dag.AddTask(transform, v1.TaskSpec{Queue: "go-task"}, []string{"extract"}) + dag.AddTask(load, v1.TaskSpec{Queue: "go-task"}, []string{"transform"}) return nil }