diff --git a/sqle/api/controller/v1/workflow.go b/sqle/api/controller/v1/workflow.go index 6b7715b1b..071e083ba 100644 --- a/sqle/api/controller/v1/workflow.go +++ b/sqle/api/controller/v1/workflow.go @@ -1367,9 +1367,101 @@ func ReExecuteTaskOnWorkflowV1(c echo.Context) error { if err := controller.BindAndValidateReq(c, req); err != nil { return controller.JSONBaseErrorReq(c, err) } + projectUid, err := dms.GetProjectUIDByName(context.TODO(), c.Param("project_name"), true) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + workflowId := c.Param("workflow_id") + taskId := c.Param("task_id") + reExecSqlIds := req.ExecSqlIds + + s := model.GetStorage() + workflow, err := dms.GetWorkflowDetailByWorkflowId(projectUid, workflowId, s.GetWorkflowDetailWithoutInstancesByWorkflowID) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + task, exist, err := s.GetTaskDetailById(taskId) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + if !exist { + return controller.JSONBaseErrorReq(c, fmt.Errorf("task is not exist")) + } + + user, err := controller.GetCurrentUser(c, dms.GetUser) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + if err := PrepareForTaskReExecution(c, projectUid, workflow, user, task, reExecSqlIds); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + err = server.ReExecuteTaskSQLs(workflow, task, reExecSqlIds, user) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + return c.JSON(http.StatusOK, controller.NewBaseReq(nil)) } +func PrepareForTaskReExecution(c echo.Context, projectID string, workflow *model.Workflow, user *model.User, task *model.Task, reExecSqlIds []uint) error { + // 只有上线失败的工单可以重新上线sql + if workflow.Record.Status != model.WorkflowStatusExecFailed { + return errors.New(errors.DataInvalid, e.New("workflow status is not exec failed")) + } + + if task.Status != model.TaskStatusExecuteFailed { + return errors.New(errors.DataInvalid, e.New("task status is not execute failed")) + } + + err := CheckCurrentUserCanOperateTasks(c, projectID, workflow, []dmsV1.OpPermissionType{dmsV1.OpPermissionTypeExecuteWorkflow}, []uint{task.ID}) + if err != nil { + return err + } + + for _, record := range workflow.Record.InstanceRecords { + if record.TaskId != task.ID { + continue + } + + for _, u := range strings.Split(record.ExecutionAssignees, ",") { + if u == user.GetIDStr() { + goto CheckReExecSqlIds + } + } + } + + return e.New("you are not allow to execute the task") + +CheckReExecSqlIds: + // 校验reExecSqlIds对应的SQL状态是否都为SQLExecuteStatusFailed + if len(reExecSqlIds) == 0 { + return errors.New(errors.DataInvalid, e.New("re-execute sql ids cannot be empty")) + } + + // 创建一个map用于快速查找ExecuteSQLs中的SQL + execSqlMap := make(map[uint]*model.ExecuteSQL) + for _, execSql := range task.ExecuteSQLs { + execSqlMap[execSql.ID] = execSql + } + + // 检查每个reExecSqlId + for _, sqlId := range reExecSqlIds { + execSql, exists := execSqlMap[sqlId] + if !exists { + return errors.New(errors.DataInvalid, fmt.Errorf("execute sql id %d not found in task", sqlId)) + } + + if execSql.ExecStatus != model.SQLExecuteStatusFailed && execSql.ExecStatus != model.SQLExecuteStatusInitialized { + return errors.New(errors.DataInvalid, fmt.Errorf("execute sql id %d status is %s, only failed or initialized sql can be re-executed", sqlId, execSql.ExecStatus)) + } + } + + return nil +} + type GetWorkflowResV1 struct { controller.BaseRes Data *WorkflowResV1 `json:"data"` diff --git a/sqle/model/task.go b/sqle/model/task.go index 266fb8c25..39930418a 100644 --- a/sqle/model/task.go +++ b/sqle/model/task.go @@ -411,7 +411,7 @@ func (t *Task) HasDoingAudit() bool { func (t *Task) HasDoingExecute() bool { if t.ExecuteSQLs != nil { for _, commitSQL := range t.ExecuteSQLs { - if commitSQL.ExecStatus != SQLExecuteStatusInitialized { + if commitSQL.ExecStatus != SQLExecuteStatusInitialized && commitSQL.ExecStatus != SQLExecuteStatusFailed { return true } } @@ -490,6 +490,38 @@ func (s *Storage) GetTaskDetailById(taskId string) (*Task, bool, error) { return task, true, errors.New(errors.ConnectStorageError, err) } +func (s *Storage) GetTaskDetailByIdWithExecSqlIds(taskId string, execSqlIds []uint) (*Task, bool, error) { + task := &Task{} + + db := s.db.Where("id = ?", taskId). + Preload("RuleTemplate"). + Preload("RollbackSQLs") + + if len(execSqlIds) > 0 { + // 重新执行上线,获取指定需要执行的sql + db = db.Preload("ExecuteSQLs", "id IN (?)", execSqlIds) + } else { + // 未指定则加载所有待执行sql + db = db.Preload("ExecuteSQLs") + } + + err := db.First(task).Error + + if err == gorm.ErrRecordNotFound { + return nil, false, nil + } + return task, true, errors.New(errors.ConnectStorageError, err) +} + +func (s *Storage) GetExecSqlsByTaskIdAndStatus(taskId uint, status []string) ([]*ExecuteSQL, error) { + executeSQLs := []*ExecuteSQL{} + err := s.db.Where("task_id = ? and exec_status IN (?)", taskId, status).Find(&executeSQLs).Error + if err != nil { + return nil, errors.New(errors.ConnectStorageError, err) + } + return executeSQLs, nil +} + func (s *Storage) GetTaskExecuteSQLContent(taskId string) ([]byte, error) { rows, err := s.db.Model(&ExecuteSQL{}).Select("content"). Where("task_id = ?", taskId).Rows() diff --git a/sqle/model/workflow.go b/sqle/model/workflow.go index 634284f27..cd5536031 100644 --- a/sqle/model/workflow.go +++ b/sqle/model/workflow.go @@ -773,6 +773,16 @@ func (s *Storage) UpdateWorkflowExecInstanceRecord(w *Workflow, operateStep *Wor }) } +// UpdateWorkflowExecInstanceRecordForReExecute, 用于重新执行SQL时更新上线状态和执行人 +func (s *Storage) UpdateWorkflowExecInstanceRecordForReExecute(w *Workflow, needExecInstanceRecords []*WorkflowInstanceRecord) error { + return s.Tx(func(tx *gorm.DB) error { + if err := updateWorkflowStatus(tx, w); err != nil { + return err + } + return updateWorkflowInstanceRecordForReExecute(tx, needExecInstanceRecords) + }) +} + func updateWorkflowStatus(tx *gorm.DB, w *Workflow) error { db := tx.Exec("UPDATE workflow_records SET status = ?, current_workflow_step_id = ? WHERE id = ?", w.Record.Status, w.Record.CurrentWorkflowStepId, w.Record.ID) @@ -810,6 +820,17 @@ func updateWorkflowInstanceRecord(tx *gorm.DB, needExecInstanceRecords []*Workfl return nil } +func updateWorkflowInstanceRecordForReExecute(tx *gorm.DB, needExecInstanceRecords []*WorkflowInstanceRecord) error { + for _, inst := range needExecInstanceRecords { + db := tx.Exec("UPDATE workflow_instance_records SET is_sql_executed = ?, execution_user_id = ? WHERE id = ? AND is_sql_executed = 0 AND execution_user_id = 0", + inst.IsSQLExecuted, inst.ExecutionUserId, inst.ID) + if db.Error != nil { + return db.Error + } + } + return nil +} + func updateWorkflowInstanceRecordById(tx *gorm.DB, needExecInstanceRecords []*WorkflowInstanceRecord) error { for _, inst := range needExecInstanceRecords { db := tx.Exec("UPDATE workflow_instance_records SET is_sql_executed = ?, execution_user_id = ? WHERE id = ?", diff --git a/sqle/server/sqled.go b/sqle/server/sqled.go index c63237a6a..897f7c1fd 100644 --- a/sqle/server/sqled.go +++ b/sqle/server/sqled.go @@ -62,7 +62,7 @@ func (s *Sqled) HasTask(taskId string) bool { // addTask receive taskId and action type, using taskId and typ to create an action; // action will be validated, and sent to Sqled.queue. -func (s *Sqled) addTask(projectId string, taskId string, typ int) (*action, error) { +func (s *Sqled) addTask(projectId string, taskId string, typ int, execSqlIds []uint) (*action, error) { var err error var p driver.Plugin var rules []*model.Rule @@ -87,7 +87,7 @@ func (s *Sqled) addTask(projectId string, taskId string, typ int) (*action, erro return action, errors.New(errors.TaskRunning, fmt.Errorf("task is running")) } - task, exist, err := st.GetTaskDetailById(taskId) + task, exist, err := st.GetTaskDetailByIdWithExecSqlIds(taskId, execSqlIds) if err != nil { goto Error } @@ -140,12 +140,21 @@ Error: } func (s *Sqled) AddTask(projectId string, taskId string, typ int) error { - _, err := s.addTask(projectId, taskId, typ) + _, err := s.addTask(projectId, taskId, typ, nil) return err } func (s *Sqled) AddTaskWaitResult(projectId string, taskId string, typ int) (*model.Task, error) { - action, err := s.addTask(projectId, taskId, typ) + action, err := s.addTask(projectId, taskId, typ, nil) + if err != nil { + return nil, err + } + <-action.done + return action.task, action.err +} + +func (s *Sqled) AddTaskWaitResultWithSQLIds(projectId string, taskId string, execSqlIds []uint, typ int) (*model.Task, error) { + action, err := s.addTask(projectId, taskId, typ, execSqlIds) if err != nil { return nil, err } @@ -390,12 +399,13 @@ func (a *action) execute() (err error) { taskStatus = model.TaskStatusExecuteSucceeded } // update task status by sql - for _, sql := range task.ExecuteSQLs { - if sql.ExecStatus == model.SQLExecuteStatusFailed || - sql.ExecStatus == model.SQLExecuteStatusTerminateSucc { - taskStatus = model.TaskStatusExecuteFailed - break - } + // 验证task下所有的sql是否全部成功(工单中允许重新上线部分sql,所以需要验证全部sql是否成功) + failedSqls, queryErr := st.GetExecSqlsByTaskIdAndStatus(task.ID, []string{model.SQLExecuteStatusFailed, model.SQLExecuteStatusTerminateSucc}) + if queryErr != nil { + return queryErr + } + if len(failedSqls) > 0 { + taskStatus = model.TaskStatusExecuteFailed } case terminationErr := <-terminateErrChan: diff --git a/sqle/server/sqled_test.go b/sqle/server/sqled_test.go index 7a1dd3be1..1c28fa83d 100644 --- a/sqle/server/sqled_test.go +++ b/sqle/server/sqled_test.go @@ -331,6 +331,8 @@ func Test_action_execute(t *testing.T) { mockDB, mock, err := sqlmock.New() assert.NoError(t, err) mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow("5.7")) + mock.ExpectQuery("SELECT \\* FROM `execute_sql_detail`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "task_id", "exec_status"})) model.InitMockStorage(mockDB) a := getAction(tt.sqls, ActionTypeExecute, d) if err := a.execute(); (err != nil) != tt.wantErr { diff --git a/sqle/server/workflow.go b/sqle/server/workflow.go new file mode 100644 index 000000000..fa0e99374 --- /dev/null +++ b/sqle/server/workflow.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "fmt" + "strconv" + "sync" + + "github.com/actiontech/sqle/sqle/common" + "github.com/actiontech/sqle/sqle/dms" + "github.com/actiontech/sqle/sqle/errors" + "github.com/actiontech/sqle/sqle/log" + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/notification" +) + +func ReExecuteTaskSQLs(workflow *model.Workflow, task *model.Task, execSqlIds []uint, user *model.User) error { + s := model.GetStorage() + l := log.NewEntry() + + instance, exist, err := dms.GetInstancesById(context.Background(), fmt.Sprintf("%d", task.InstanceId)) + if err != nil { + return err + } + if !exist { + return errors.New(errors.DataNotExist, fmt.Errorf("instance is not exist. instanceId=%v", task.InstanceId)) + } + task.Instance = instance + if task.Instance == nil { + return errors.New(errors.DataNotExist, fmt.Errorf("instance is not exist")) + } + + // if instance is not connectable, exec sql must be failed; + // commit action unable to retry, so don't to exec it. + if err = common.CheckInstanceIsConnectable(task.Instance); err != nil { + return errors.New(errors.ConnectRemoteDatabaseError, err) + } + + needExecTaskRecords := make([]*model.WorkflowInstanceRecord, 0, len(workflow.Record.InstanceRecords)) + // update workflow + for _, inst := range workflow.Record.InstanceRecords { + if inst.TaskId != task.ID { + continue + } + inst.IsSQLExecuted = true + inst.ExecutionUserId = user.GetIDStr() + needExecTaskRecords = append(needExecTaskRecords, inst) + } + + workflow.Record.Status = model.WorkflowStatusExecuting + workflow.Record.CurrentWorkflowStepId = 0 + + err = s.UpdateWorkflowExecInstanceRecordForReExecute(workflow, needExecTaskRecords) + if err != nil { + return err + } + var lock sync.Mutex + go func() { + sqledServer := GetSqled() + task, err := sqledServer.AddTaskWaitResultWithSQLIds(string(workflow.ProjectId), strconv.Itoa(int(task.ID)), execSqlIds, ActionTypeExecute) + { + lock.Lock() + updateStatus(s, workflow, l, nil) + lock.Unlock() + } + if err != nil || task.Status == model.TaskStatusExecuteFailed { + go notification.NotifyWorkflow(string(workflow.ProjectId), workflow.WorkflowId, notification.WorkflowNotifyTypeExecuteFail) + } else { + go notification.NotifyWorkflow(string(workflow.ProjectId), workflow.WorkflowId, notification.WorkflowNotifyTypeExecuteSuccess) + } + + }() + + return nil +}