Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 81 additions & 53 deletions internal/dms/biz/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,20 +568,11 @@ type UserList struct {
} `json:"listUsers"`
}

func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbeaverUserId string, dmsUser *User) error {
cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
if err != nil {
return err
}
var reservedCloudbeaverUserId = map[string]struct{}{"admin": {}, "user": {}}

fingerprint := cu.userUsecase.GetUserFingerprint(dmsUser)
if exist && cloudbeaverUser.DMSFingerprint == fingerprint {
return nil
}

reservedCloudbeaverUserId := map[string]struct{}{"admin": {}, "user": {}}
func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbeaverUserId string, dmsUser *User) error {
if _, ok := reservedCloudbeaverUserId[cloudbeaverUserId]; ok {
return fmt.Errorf("this username cannot be used")
return fmt.Errorf("username %s is reserved, cann't be used", cloudbeaverUserId)
}

// 使用管理员身份登录
Expand All @@ -591,7 +582,6 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
}

checkExistReq := cloudbeaver.NewRequest(cu.graphQl.IsUserExistQuery(cloudbeaverUserId))

cloudbeaverUserList := UserList{}
err = graphQLClient.Run(ctx, checkExistReq, &cloudbeaverUserList)
if err != nil {
Expand Down Expand Up @@ -619,6 +609,15 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
if err != nil {
return fmt.Errorf("grant cloudbeaver user failed: %v", err)
}
} else {
cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
if err != nil {
return err
}

if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) {
return nil
}
}

// 设置CloudBeaver用户密码
Expand All @@ -633,9 +632,9 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
return fmt.Errorf("update cloudbeaver user failed: %v", err)
}

cloudbeaverUser = &CloudbeaverUser{
cloudbeaverUser := &CloudbeaverUser{
DMSUserID: dmsUser.UID,
DMSFingerprint: fingerprint,
DMSFingerprint: cu.userUsecase.GetUserFingerprint(dmsUser),
CloudbeaverUserID: cloudbeaverUserId,
}

Expand Down Expand Up @@ -701,17 +700,18 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
activeDBServices = lastActiveDBServices
}

if err = cu.operateConnection(ctx, activeDBServices, dmsUser.UID); err != nil {
return err
}

cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
if err != nil {
return err
}
if !exist {
return fmt.Errorf("cloudbeaver user: %s not eixst", cloudbeaverUserId)
}

if err = cu.operateConnection(ctx, cloudbeaverUser, dmsUser, activeDBServices); err != nil {
return err
}

if err = cu.grantAccessConnection(ctx, cloudbeaverUser, dmsUser, activeDBServices); err != nil {
return err
}
Expand All @@ -725,11 +725,43 @@ func getDBPrimaryKey(dbUid, purpose, userUid string) string {
return fmt.Sprint(dbUid, ":", purpose, ":", userUid)
}

func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBServices []*DBService, userId string) error {
type UserConnectionsResp struct {
Connections []*struct {
Id string `json:"id"`
Template bool `json:"template"`
} `json:"connections"`
}

// 获取用户当前数据库连接ID
func (cu *CloudbeaverUsecase) getUserConnectionIds(ctx context.Context, cloudbeaverUser *CloudbeaverUser, dmsUser *User) ([]string, error) {
client, err := cu.getGraphQLClient(cloudbeaverUser.CloudbeaverUserID, dmsUser.Password)
if err != nil {
return nil, err
}

var userConnectionsResp UserConnectionsResp

variables := map[string]interface{}{"projectId": cloudbeaverProjectId}
err = client.Run(ctx, cloudbeaver.NewRequest(cu.graphQl.GetUserConnectionsQuery(), variables), &userConnectionsResp)
if err != nil {
return nil, err
}

ret := make([]string, 0, len(userConnectionsResp.Connections))
for _, connection := range userConnectionsResp.Connections {
if !connection.Template {
ret = append(ret, connection.Id)
}
}

return ret, nil
}

func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, cloudbeaverUser *CloudbeaverUser, dmsUser *User, activeDBServices []*DBService) error {
dbServiceMap := map[string]*DBService{}
projectMap := map[string]string{}
for _, service := range activeDBServices {
dbServiceMap[getDBPrimaryKey(service.UID, service.AccountPurpose, userId)] = service
dbServiceMap[getDBPrimaryKey(service.UID, service.AccountPurpose, dmsUser.UID)] = service

project, err := cu.dbServiceUsecase.projectUsecase.GetProject(ctx, service.ProjectUID)
if err != nil {
Expand All @@ -741,33 +773,41 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer
}

//获取当前用户所有已创建的连接
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserId(ctx, userId)
localCloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserId(ctx, dmsUser.UID)
if err != nil {
return err
}

// cloudbeaver连接数为空则重置缓存
if userConnectionIds, err := cu.getUserConnectionIds(ctx, cloudbeaverUser, dmsUser); err != nil {
return err
} else if len(userConnectionIds) == 0 {
localCloudbeaverConnections = []*CloudbeaverConnection{}
}

var deleteConnections []*CloudbeaverConnection

cloudbeaverConnectionMap := map[string]*CloudbeaverConnection{}
for _, connection := range cloudbeaverConnections {
for _, connection := range localCloudbeaverConnections {
// 删除用户关联的连接
if connection.DMSUserId == userId {
if connection.DMSUserId == dmsUser.UID {
cloudbeaverConnectionMap[connection.PrimaryKey()] = connection
if _, ok := dbServiceMap[connection.PrimaryKey()]; !ok {
deleteConnections = append(deleteConnections, connection)
}
}
}

createConnections, updateConnections := []*CloudbeaverConnection{}, []*CloudbeaverConnection{}
var createConnections []*CloudbeaverConnection
var updateConnections []*CloudbeaverConnection

for _, dbService := range dbServiceMap {
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[getDBPrimaryKey(dbService.UID, dbService.AccountPurpose, userId)]; ok {
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[getDBPrimaryKey(dbService.UID, dbService.AccountPurpose, dmsUser.UID)]; ok {
if cloudbeaverConnection.DMSDBServiceFingerprint != cu.dbServiceUsecase.GetDBServiceFingerprint(dbService) {
updateConnections = append(updateConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: userId})
updateConnections = append(updateConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: dmsUser.UID})
}
} else {
createConnections = append(createConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: userId})
createConnections = append(createConnections, &CloudbeaverConnection{DMSDBServiceID: dbService.UID, Purpose: dbService.AccountPurpose, DMSUserId: dmsUser.UID})
}
}

Expand All @@ -783,20 +823,20 @@ func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBSer

// 同步实例连接信息
for _, createConnection := range createConnections {
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[getDBPrimaryKey(createConnection.DMSDBServiceID, createConnection.Purpose, userId)],
projectMap[createConnection.DMSDBServiceID], userId); err != nil {
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[getDBPrimaryKey(createConnection.DMSDBServiceID, createConnection.Purpose, dmsUser.UID)],
projectMap[createConnection.DMSDBServiceID], dmsUser.UID); err != nil {
cu.log.Errorf("create connection %v failed: %v", createConnection, err)
}
}

for _, updateConnection := range updateConnections {
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, updateConnection.CloudbeaverConnectionID, dbServiceMap[getDBPrimaryKey(updateConnection.DMSDBServiceID, updateConnection.Purpose, userId)], projectMap[updateConnection.DMSDBServiceID], userId); err != nil {
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, updateConnection.CloudbeaverConnectionID, dbServiceMap[getDBPrimaryKey(updateConnection.DMSDBServiceID, updateConnection.Purpose, dmsUser.UID)], projectMap[updateConnection.DMSDBServiceID], dmsUser.UID); err != nil {
cu.log.Errorf("update dnServerId %s to connection failed: %v", updateConnection, err)
}
}

for _, deleteConnection := range deleteConnections {
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, deleteConnection.CloudbeaverConnectionID, deleteConnection.DMSDBServiceID, userId, deleteConnection.Purpose); err != nil {
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, deleteConnection.CloudbeaverConnectionID, deleteConnection.DMSDBServiceID, dmsUser.UID, deleteConnection.Purpose); err != nil {
cu.log.Errorf("delete connection %v failed: %v", deleteConnection, err)
}
}
Expand Down Expand Up @@ -842,41 +882,29 @@ func (cu *CloudbeaverUsecase) grantAccessConnection(ctx context.Context, cloudbe
for _, dbService := range activeDBServices {
dbServiceIds = append(dbServiceIds, dbService.UID)
}
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx, dmsUser.UID, dbServiceIds)
localCloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx, dmsUser.UID, dbServiceIds)
if err != nil {
return err
}

// 从缓存中获取需要同步的CloudBeaver实例
cloudbeaverConnectionMap := map[string]*CloudbeaverConnection{}
for _, cloudbeaverConnection := range cloudbeaverConnections {
cloudbeaverConnectionMap[cloudbeaverConnection.CloudbeaverConnectionID] = cloudbeaverConnection
}

// 获取用户当前实例列表
connResp := &struct {
Connections []*struct {
Id string `json:"id"`
} `json:"connections"`
}{}

client, err := cu.getGraphQLClient(cloudbeaverUser.CloudbeaverUserID, dmsUser.Password)
if err != nil {
return err
for _, connection := range localCloudbeaverConnections {
cloudbeaverConnectionMap[connection.CloudbeaverConnectionID] = connection
}

err = client.Run(ctx, cloudbeaver.NewRequest(cu.graphQl.GetUserConnectionsQuery(), nil), connResp)
cloudbeaverConnectionIds, err := cu.getUserConnectionIds(ctx, cloudbeaverUser, dmsUser)
if err != nil {
return err
}

if len(connResp.Connections) != len(cloudbeaverConnections) {
return cu.bindUserAccessConnection(ctx, cloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
if len(cloudbeaverConnectionIds) != len(localCloudbeaverConnections) {
return cu.bindUserAccessConnection(ctx, localCloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
}

for _, connection := range connResp.Connections {
if _, ok := cloudbeaverConnectionMap[connection.Id]; !ok {
return cu.bindUserAccessConnection(ctx, cloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
for _, connectionId := range cloudbeaverConnectionIds {
if _, ok := cloudbeaverConnectionMap[connectionId]; !ok {
return cu.bindUserAccessConnection(ctx, localCloudbeaverConnections, cloudbeaverUser.CloudbeaverUserID)
}
}

Expand Down
5 changes: 4 additions & 1 deletion internal/dms/storage/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/actiontech/dms/internal/dms/biz"
"github.com/actiontech/dms/internal/dms/storage/model"
"gorm.io/gorm/clause"

utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"

Expand Down Expand Up @@ -118,7 +119,9 @@ func (cr *CloudbeaverRepo) GetCloudbeaverConnectionsByUserId(ctx context.Context

func (cr *CloudbeaverRepo) UpdateCloudbeaverConnectionCache(ctx context.Context, u *biz.CloudbeaverConnection) error {
return transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
if err := tx.WithContext(ctx).Save(convertBizCloudbeaverConnection(u)).Error; err != nil {
if err := tx.WithContext(ctx).Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(convertBizCloudbeaverConnection(u)).Error; err != nil {
return fmt.Errorf("failed to update cloudbeaver db Service: %v", err)
}
return nil
Expand Down
1 change: 1 addition & 0 deletions internal/pkg/cloudbeaver/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ query getUserConnections (
}
fragment DatabaseConnection on ConnectionInfo {
id
template
}
`
}
Expand Down