Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion internal/sql_workbench/service/sql_workbench_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,12 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildDatasourceBaseInfo(ctx cont
return nil, err
}

return sqlWorkbenchService.fillDatasourceBaseInfo(datasourceName, dbService, environmentID)
}

// fillDatasourceBaseInfo 根据 dbService 字段填充 datasourceBaseInfo,
// 不包含外部 IO(不查 project / 不连 DB),便于单元测试覆盖 DBType 分支逻辑。
func (sqlWorkbenchService *SqlWorkbenchService) fillDatasourceBaseInfo(datasourceName string, dbService *biz.DBService, environmentID int64) (*datasourceBaseInfo, error) {
baseInfo := &datasourceBaseInfo{
Name: datasourceName,
Type: sqlWorkbenchService.convertDBType(dbService.DBType),
Expand All @@ -892,6 +898,16 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildDatasourceBaseInfo(ctx cont
baseInfo.DefaultSchema, baseInfo.Properties, baseInfo.JDBCParams = buildMongoDatasourceOptions(dbService)
}

// DB2 特殊处理:从 AdditionalParams.database_name 取默认 schema 透传到 ODC
if dbService.DBType == "DB2" {
databaseNameParam := dbService.AdditionalParams.GetParam("database_name")
if databaseNameParam == nil || databaseNameParam.Value == "" {
return nil, fmt.Errorf("DB2 数据源 %s 缺少 AdditionalParam database_name,请在数据源 AdditionalParams 中补充", dbService.Name)
}
databaseName := databaseNameParam.Value
baseInfo.DefaultSchema = &databaseName
}

return baseInfo, nil
}

Expand Down Expand Up @@ -946,7 +962,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildUpdateDatasourceRequest(ctx
func (sqlWorkbenchService *SqlWorkbenchService) convertDBType(dmsDBType string) string {
// 这里需要根据实际的数据库类型映射关系进行转换
// ODC ConnectType 枚举值: OB_MYSQL, OB_ORACLE, ORACLE, MYSQL, ODP_SHARDING_OB_MYSQL,
// DORIS, POSTGRESQL, HIVE, DM, TIDB, SQL_SERVER, MONGODB, GAUSSDB 等
// DORIS, POSTGRESQL, HIVE, DM, TIDB, SQL_SERVER, MONGODB, GAUSSDB, DB2 等
// 其余调用创建数据源接口会直接失败
switch dmsDBType {
case "MySQL":
return "MYSQL"
Expand Down Expand Up @@ -976,6 +993,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) convertDBType(dmsDBType string)
return "MYSQL"
case "MongoDB":
return "MONGODB"
case "DB2":
return "DB2"
default:
return dmsDBType
}
Expand Down
105 changes: 105 additions & 0 deletions internal/sql_workbench/service/sql_workbench_service_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sql_workbench

import (
"strings"
"testing"

"github.com/actiontech/dms/internal/dms/biz"
Expand All @@ -27,6 +28,7 @@ func Test_convertDBType(t *testing.T) {
"PolarDB For MySQL": {input: "PolarDB For MySQL", expected: "MYSQL"},
"GaussDB": {input: "GaussDB", expected: "GAUSSDB"},
"MongoDB": {input: "MongoDB", expected: "MONGODB"},
"DB2": {input: "DB2", expected: "DB2"},
"Unknown passthrough": {input: "UnknownDB", expected: "UnknownDB"},
}
for name, tc := range cases {
Expand Down Expand Up @@ -58,6 +60,7 @@ func Test_SupportDBType(t *testing.T) {
"PolarDB For MySQL supported": {input: pkgConst.DBTypePolarDBForMySQL, expected: true},
"GaussDB supported": {input: pkgConst.DBTypeGaussDB, expected: true},
"GaussDBForMySQL unsupported": {input: pkgConst.DBTypeGaussDBForMySQL, expected: false},
"DB2 unsupported": {input: pkgConst.DBTypeDB2, expected: false},
"empty string unsupported": {input: pkgConst.DBType(""), expected: false},
"unknown type unsupported": {input: pkgConst.DBType("UnknownDBType"), expected: false},
}
Expand Down Expand Up @@ -137,3 +140,105 @@ func Test_buildMongoDatasourceOptions_tlsOnly(t *testing.T) {
}
}

// Test_buildDatasourceBaseInfo_DB2 覆盖 buildDatasourceBaseInfo 中 DB2 / 回归 4 组 case:
//
// (a) DB2 正例:AdditionalParam database_name=testdb → baseInfo.DefaultSchema=="testdb"
// (b) DB2 负例:缺 database_name → 返回 err 且 err 含 "database_name"
// (c) MySQL 回归:DefaultSchema == nil 且无 err
// (d) Oracle 回归:ServiceName != nil 且无 err
//
// 通过 fillDatasourceBaseInfo(无 IO helper)进行 mock-only 单测,避免触达 projectUsecase / DB。
func Test_buildDatasourceBaseInfo_DB2(t *testing.T) {
svc := &SqlWorkbenchService{}
const envID = int64(1)
const datasourceName = "proj:ds"

cases := map[string]struct {
dbService *biz.DBService
expectErr bool
expectErrSubstr string
expectDefaultSchema *string
expectServiceName *string
}{
"DB2 happy path": {
dbService: &biz.DBService{
Name: "db2-1",
DBType: "DB2",
AdditionalParams: pkgParams.Params{
{Key: "database_name", Value: "testdb"},
},
},
expectErr: false,
expectDefaultSchema: strPtr("testdb"),
expectServiceName: nil,
},
"DB2 missing database_name": {
dbService: &biz.DBService{
Name: "db2-2",
DBType: "DB2",
AdditionalParams: pkgParams.Params{},
},
expectErr: true,
expectErrSubstr: "database_name",
},
"MySQL regression": {
dbService: &biz.DBService{
Name: "mysql-1",
DBType: "MySQL",
AdditionalParams: pkgParams.Params{},
},
expectErr: false,
expectDefaultSchema: nil,
expectServiceName: nil,
},
"Oracle regression": {
dbService: &biz.DBService{
Name: "oracle-1",
DBType: "Oracle",
AdditionalParams: pkgParams.Params{
{Key: "service_name", Value: "ORCL"},
},
},
expectErr: false,
expectDefaultSchema: nil,
expectServiceName: strPtr("ORCL"),
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
got, err := svc.fillDatasourceBaseInfo(datasourceName, tc.dbService, envID)
if tc.expectErr {
if err == nil {
t.Fatalf("expected error, got nil; baseInfo=%+v", got)
}
if tc.expectErrSubstr != "" && !strings.Contains(err.Error(), tc.expectErrSubstr) {
t.Errorf("error %q does not contain %q", err.Error(), tc.expectErrSubstr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got == nil {
t.Fatalf("expected non-nil baseInfo")
}
// DefaultSchema 对比
if (got.DefaultSchema == nil) != (tc.expectDefaultSchema == nil) {
t.Errorf("DefaultSchema nil mismatch: got=%v, want=%v", got.DefaultSchema, tc.expectDefaultSchema)
} else if got.DefaultSchema != nil && tc.expectDefaultSchema != nil && *got.DefaultSchema != *tc.expectDefaultSchema {
t.Errorf("DefaultSchema = %q, want %q", *got.DefaultSchema, *tc.expectDefaultSchema)
}
// ServiceName 对比
if (got.ServiceName == nil) != (tc.expectServiceName == nil) {
t.Errorf("ServiceName nil mismatch: got=%v, want=%v", got.ServiceName, tc.expectServiceName)
} else if got.ServiceName != nil && tc.expectServiceName != nil && *got.ServiceName != *tc.expectServiceName {
t.Errorf("ServiceName = %q, want %q", *got.ServiceName, *tc.expectServiceName)
}
})
}
}

func strPtr(s string) *string {
return &s
}