Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions sqle/api/controller/v1/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package v1
import (
"context"
"fmt"
v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
"github.com/actiontech/sqle/sqle/errors"
"net/http"
"strconv"

"github.com/actiontech/sqle/sqle/errors"

"github.com/actiontech/sqle/sqle/api/controller"
"github.com/actiontech/sqle/sqle/dms"
"github.com/actiontech/sqle/sqle/server/pipeline"
Expand Down Expand Up @@ -234,19 +234,10 @@ func GetPipelines(c echo.Context) error {
if err != nil {
return errors.New(errors.ConnectStorageError, fmt.Errorf("check get pipelines failed: %v", err))
}
userId := ""
if !userPermission.CanViewProject() {
userId = user.GetIDStr()
}
rangeDatasourceIds := make([]string, 0)
viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline)
if viewPipelinePermission != nil {
userId = ""
rangeDatasourceIds = viewPipelinePermission.RangeUids
}
// 4. 获取存储对象并查询流水线列表

// 3. 获取存储对象并查询流水线列表
var pipelineSvc pipeline.PipelineSvc
count, pipelineList, err := pipelineSvc.GetPipelineList(limit, offset, req.FuzzySearchNameDesc, projectUid, userId, rangeDatasourceIds)
count, pipelineList, err := pipelineSvc.GetPipelineListWithPermission(limit, offset, req.FuzzySearchNameDesc, projectUid, userPermission, user.GetIDStr())
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
Expand Down
65 changes: 57 additions & 8 deletions sqle/model/pipline.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,50 @@ func isValidAuditMethod(a string) bool {
return false
}

func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string) ([]*Pipeline, uint64, error) {
func (s *Storage) GetPipelineList(projectID ProjectUID, fuzzySearchContent string, limit, offset uint32, userId string, rangeDatasourceIds []string, canViewAll bool) ([]*Pipeline, uint64, error) {
var count int64
var pipelines []*Pipeline
query := s.db.Model(&Pipeline{}).Where("project_uid = ?", projectID)
if userId != "" {
query = query.Where("create_user_id = ? OR create_user_id IS NULL", userId)
}

// 1. 模糊搜索
if fuzzySearchContent != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+fuzzySearchContent+"%", "%"+fuzzySearchContent+"%")
}
if len(rangeDatasourceIds) > 0 {
query = query.Joins("JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id").
Where("pipeline_nodes.instance_id IN (?)", rangeDatasourceIds).
Group("pipelines.id")

// 2. 权限过滤
if !canViewAll {
if len(rangeDatasourceIds) > 0 {
// 有数据源权限的用户可以看到:
// 1. 包含权限范围内数据源的流水线(通过LEFT JOIN匹配)
// 2. 自己创建的所有流水线
// 3. 所有节点都是离线节点的流水线(通过NOT EXISTS检查)
query = query.
Joins("LEFT JOIN pipeline_nodes ON pipelines.id = pipeline_nodes.pipeline_id").
Where(`
pipeline_nodes.instance_id IN (?) OR
pipelines.create_user_id = ? OR
NOT EXISTS (
SELECT 1 FROM pipeline_nodes pn2
WHERE pn2.pipeline_id = pipelines.id
AND pn2.instance_id != 0
)`, rangeDatasourceIds, userId).
Group("pipelines.id") // 去重,因为LEFT JOIN可能产生重复记录
} else if userId != "" {
// 普通用户只能看到:
// 1. 自己创建的流水线
// 2. 所有节点都是离线节点的流水线
query = query.Where(`
create_user_id = ? OR
NOT EXISTS (
SELECT 1 FROM pipeline_nodes pn
WHERE pn.pipeline_id = pipelines.id
AND pn.instance_id != 0
)`, userId)
}
}
// canViewAll = true 时不添加任何过滤条件

// 3. 统计和分页查询
err := query.Count(&count).Error
if err != nil {
return pipelines, uint64(count), errors.New(errors.ConnectStorageError, err)
Expand Down Expand Up @@ -169,6 +197,27 @@ func (s *Storage) GetPipelineNodesByInstanceId(instanceID uint64) ([]*PipelineNo
return nodes, nil
}

// GetPipelineNodesInBatch 批量获取多个流水线的节点
func (s *Storage) GetPipelineNodesInBatch(pipelineIDs []uint) (map[uint][]*PipelineNode, error) {
if len(pipelineIDs) == 0 {
return make(map[uint][]*PipelineNode), nil
}

var nodes []*PipelineNode
err := s.db.Model(PipelineNode{}).Where("pipeline_id IN (?)", pipelineIDs).Find(&nodes).Error
if err != nil {
return nil, errors.New(errors.ConnectStorageError, err)
}

// 按pipeline_id分组
nodeMap := make(map[uint][]*PipelineNode)
for _, node := range nodes {
nodeMap[node.PipelineID] = append(nodeMap[node.PipelineID], node)
}

return nodeMap, nil
}

func (s *Storage) CreatePipeline(pipeline *Pipeline, nodes []*PipelineNode) error {
return s.Tx(func(txDB *gorm.DB) error {
// 4.1 保存 Pipeline 到数据库
Expand Down
55 changes: 48 additions & 7 deletions sqle/server/pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/actiontech/sqle/sqle/errors"

v1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
dmsCommonJwt "github.com/actiontech/dms/pkg/dms-common/api/jwt"
"github.com/actiontech/sqle/sqle/api/controller"
scannerCmd "github.com/actiontech/sqle/sqle/cmd/scannerd/command"
Expand Down Expand Up @@ -235,19 +236,59 @@ func (svc PipelineSvc) GetPipeline(projectUID string, pipelineID uint) (*Pipelin
return svc.toPipeline(modelPipeline, modelPiplineNodes), nil
}

func (svc PipelineSvc) GetPipelineList(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userId string, rangeDatasourceIds []string) (count uint64, pipelines []*Pipeline, err error) {
// GetPipelineListWithPermission 根据用户权限获取流水线列表
func (svc PipelineSvc) GetPipelineListWithPermission(limit, offset uint32, fuzzySearchNameDesc string, projectUID string, userPermission *dms.UserPermission, userId string) (count uint64, pipelines []*Pipeline, err error) {
s := model.GetStorage()
modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, userId, rangeDatasourceIds)

// 根据用户权限确定查询参数
var queryUserId string
var rangeDatasourceIds []string
var canViewAll bool

// 权限判断逻辑
if userPermission.IsAdmin() || userPermission.IsProjectAdmin() {
// 超级管理员或项目管理员:可以查看所有流水线
canViewAll = true
} else if viewPipelinePermission := userPermission.GetOnePermission(v1.OpPermissionViewPipeline); viewPipelinePermission != nil {
// 拥有"查看流水线"权限的普通用户:可以查看指定数据源相关的流水线 + 自己创建的所有流水线
queryUserId = userId
rangeDatasourceIds = viewPipelinePermission.RangeUids
canViewAll = false
} else {
// 普通用户:只能查看自己创建的流水线
queryUserId = userId
rangeDatasourceIds = nil
canViewAll = false
}

// 执行数据库查询
modelPipelines, count, err := s.GetPipelineList(model.ProjectUID(projectUID), fuzzySearchNameDesc, limit, offset, queryUserId, rangeDatasourceIds, canViewAll)
if err != nil {
return 0, nil, err
}

// 转换为服务层对象
pipelines = make([]*Pipeline, 0, len(modelPipelines))
if len(modelPipelines) == 0 {
return count, pipelines, nil
}

// 收集所有pipeline ID
pipelineIDs := make([]uint, 0, len(modelPipelines))
for _, mp := range modelPipelines {
pipelineIDs = append(pipelineIDs, mp.ID)
}

// 批量获取所有节点
nodesMap, err := s.GetPipelineNodesInBatch(pipelineIDs)
if err != nil {
return 0, nil, err
}

// 组装结果
for _, modelPipeline := range modelPipelines {
modelPiplineNodes, err := s.GetPipelineNodes(modelPipeline.ID)
if err != nil {
return 0, nil, err
}
pipelines = append(pipelines, svc.toPipeline(modelPipeline, modelPiplineNodes))
nodes := nodesMap[modelPipeline.ID]
pipelines = append(pipelines, svc.toPipeline(modelPipeline, nodes))
}
return count, pipelines, nil
}
Expand Down
Loading