diff --git a/go/tasks/plugins/webapi/agent/integration_test.go b/go/tasks/plugins/webapi/agent/integration_test.go index 0aeed67f5..f66c5f733 100644 --- a/go/tasks/plugins/webapi/agent/integration_test.go +++ b/go/tasks/plugins/webapi/agent/integration_test.go @@ -48,12 +48,15 @@ func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.Crea return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil } -func (m *MockClient) GetTask(_ context.Context, _ *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), - }, - }}}, nil +func (m *MockClient) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { + if req.GetTaskType() == "bigquery_query_job_task" { + return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }}}, nil + } + return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil } func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { @@ -113,6 +116,11 @@ func TestEndToEnd(t *testing.T) { phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) + + template.Type = "spark_job" + phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) t.Run("failed to create a job", func(t *testing.T) { @@ -251,7 +259,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAgentPlugin() webapi.PluginEntry { return webapi.PluginEntry{ ID: "agent-service", - SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"}, + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job"}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return &MockPlugin{ Plugin{ diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index d97d0c186..0153d9dc1 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flytestdlib/config" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -19,8 +18,11 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" "google.golang.org/grpc" ) @@ -176,17 +178,38 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase case admin.State_RETRYABLE_FAILURE: return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil case admin.State_SUCCEEDED: - if resource.Outputs != nil { - err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)) - if err != nil { - return core.PhaseInfoUndefined, err - } + err = writeOutput(ctx, taskCtx, resource) + if err != nil { + logger.Errorf(ctx, "Failed to write output with err %s", err.Error()) + return core.PhaseInfoUndefined, err } return core.PhaseInfoSuccess(taskInfo), nil } return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State) } +func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource *ResourceWrapper) error { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil { + logger.Debugf(ctx, "The task declares no outputs. Skipping writing the outputs.") + return nil + } + + var opReader io.OutputReader + if resource.Outputs != nil { + logger.Debugf(ctx, "Agent returned an output") + opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil) + } else { + logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.") + opReader = ioutils.NewRemoteFileOutputReader(ctx, taskCtx.DataStore(), taskCtx.OutputWriter(), taskCtx.MaxDatasetSizeBytes()) + } + return taskCtx.OutputWriter().Put(ctx, opReader) +} + func getFinalAgent(taskType string, cfg *Config) (*Agent, error) { if id, exists := cfg.AgentForTaskTypes[taskType]; exists { if agent, exists := cfg.Agents[id]; exists { diff --git a/go/tasks/plugins/webapi/agent/plugin_test.go b/go/tasks/plugins/webapi/agent/plugin_test.go index 180a0d6e6..bf9e25e20 100644 --- a/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/go/tasks/plugins/webapi/agent/plugin_test.go @@ -42,7 +42,7 @@ func TestPlugin(t *testing.T) { assert.Equal(t, plugin.cfg.ResourceConstraints, constraints) }) - t.Run("tet newAgentPlugin", func(t *testing.T) { + t.Run("test newAgentPlugin", func(t *testing.T) { p := newAgentPlugin() assert.NotNil(t, p) assert.Equal(t, "agent-service", p.ID) diff --git a/tests/end_to_end.go b/tests/end_to_end.go index dac473dfd..341bbe1c6 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -92,6 +92,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb") outputWriter.OnGetCheckpointPrefix().Return("/checkpoint") outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev") + outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil) outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { or := args.Get(1).(io.OutputReader)