Skip to content

Commit

Permalink
Orm: Support TiDB
Browse files Browse the repository at this point in the history
  • Loading branch information
ngaut committed Sep 17, 2015
1 parent c644872 commit c73e039
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 3 deletions.
3 changes: 3 additions & 0 deletions orm/db_alias.go
Expand Up @@ -32,6 +32,7 @@ const (
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
DRTiDB // TiDB
)

// database driver string.
Expand All @@ -57,12 +58,14 @@ var (
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
"tidb": DRTiDB,
}
dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
DROracle: newdbBaseMysql(),
DRPostgres: newdbBasePostgres(),
DRTiDB: newdbBaseTidb(),
}
)

Expand Down
63 changes: 63 additions & 0 deletions orm/db_tidb.go
@@ -0,0 +1,63 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package orm

import (
"fmt"
)

// mysql dbBaser implementation.
type dbBaseTidb struct {
dbBase
}

var _ dbBaser = new(dbBaseTidb)

// get mysql operator.
func (d *dbBaseTidb) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}

// get mysql table field types.
func (d *dbBaseTidb) DbTypes() map[string]string {
return mysqlTypes
}

// show table sql for mysql.
func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}

// show columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}

// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}

// create new mysql dbBaser.
func newdbBaseTidb() dbBaser {
b := new(dbBaseTidb)
b.ins = b
return b
}
10 changes: 10 additions & 0 deletions orm/models_test.go
Expand Up @@ -25,6 +25,7 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/pingcap/tidb"
)

// A slice string field.
Expand Down Expand Up @@ -345,6 +346,7 @@ var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
IsTidb = DBARGS.Driver == "tidb"
)

var (
Expand All @@ -364,13 +366,15 @@ Default DB Drivers.
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
Expand All @@ -390,6 +394,12 @@ psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test'
go test -v github.com/astaxie/beego/orm
`)
os.Exit(2)
}
Expand Down
10 changes: 7 additions & 3 deletions orm/orm_test.go
Expand Up @@ -702,7 +702,7 @@ func TestOperators(t *testing.T) {

var shouldNum int

if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
Expand Down Expand Up @@ -740,7 +740,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))

if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 1
} else {
shouldNum = 0
Expand All @@ -758,7 +758,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 2))

if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
Expand Down Expand Up @@ -986,6 +986,10 @@ func TestValuesFlat(t *testing.T) {
}

func TestRelatedSel(t *testing.T) {
if IsTidb {
// Skip it. TiDB does not support relation now.
return
}
qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err)
Expand Down
2 changes: 2 additions & 0 deletions orm/qb.go
Expand Up @@ -48,6 +48,8 @@ type QueryBuilder interface {
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
} else if driver == "mysql" {
qb = new(MySQLQueryBuilder)
} else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
Expand Down
151 changes: 151 additions & 0 deletions orm/qb_tidb.go
@@ -0,0 +1,151 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package orm

import (
"fmt"
"strconv"
"strings"
)

type TiDBQueryBuilder struct {
Tokens []string
}

func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}

func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}

func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}

func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}

func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}

func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}

func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}

func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}

func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}

func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}

func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}

func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}

func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}

func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}

func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}

func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}

func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}

func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}

func (qb *TiDBQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

0 comments on commit c73e039

Please sign in to comment.