diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml
index 506fe2484..3062edddb 100644
--- a/.github/workflows/github-actions.yml
+++ b/.github/workflows/github-actions.yml
@@ -38,7 +38,7 @@ jobs:
# If you want to matrix build , you can append the following list.
matrix:
go_version:
- - 1.16
+ - 1.18
os:
- ubuntu-latest
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b71067d33..091ac3268 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -20,7 +20,7 @@
name: Release
on:
release:
- types: [created]
+ types: [ created ]
jobs:
releases-matrix:
@@ -29,9 +29,9 @@ jobs:
strategy:
matrix:
# build and publish in parallel: linux/386, linux/amd64, windows/386, windows/amd64, darwin/amd64
- goos: [linux, windows]
- goarch: ["386", amd64, arm]
- exclude:
+ goos: [ linux, windows ]
+ goarch: [ "386", amd64, arm ]
+ exclude:
- goarch: "arm"
goos: windows
@@ -42,6 +42,6 @@ jobs:
github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
- goversion: "https://golang.org/dl/go1.16.15.linux-amd64.tar.gz"
+ goversion: "https://go.dev/dl/go1.18.3.linux-amd64.tar.gz"
project_path: "./cmd/arana"
binary_name: "arana"
diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml
index 39ac5fbff..6a7425b32 100644
--- a/.github/workflows/reviewdog.yml
+++ b/.github/workflows/reviewdog.yml
@@ -30,4 +30,5 @@ jobs:
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2
with:
- go_version: "1.16"
+ go_version: "1.18"
+ golangci_lint_version: "v1.46.2" # use latest version by default
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2845a2862..503044c5a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,6 +19,6 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: http://github.com/golangci/golangci-lint
- rev: v1.42.1
+ rev: v1.46.2
hooks:
- id: golangci-lint
diff --git a/Dockerfile b/Dockerfile
index 2b20788aa..f4dc11141 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
# builder layer
-FROM golang:1.16-alpine AS builder
+FROM golang:1.18-alpine AS builder
RUN apk add --no-cache upx
diff --git a/Makefile b/Makefile
index 2cd0ab391..c67702973 100644
--- a/Makefile
+++ b/Makefile
@@ -38,6 +38,9 @@ docker-build:
integration-test:
@go clean -testcache
go test -tags integration -v ./test/...
+ go test -tags integration-db_tbl -v ./integration_test/scene/db_tbl/...
+ go test -tags integration-db -v ./integration_test/scene/db/...
+ go test -tags integration-tbl -v ./integration_test/scene/tbl/...
clean:
@rm -rf coverage.txt
@@ -50,4 +53,4 @@ prepareLic:
.PHONY: license
license: prepareLic
- $(GO_LICENSE_CHECKER) -v -a -r -i vendor $(LICENSE_DIR)/license.txt . go && [[ -z `git status -s` ]]
\ No newline at end of file
+ $(GO_LICENSE_CHECKER) -v -a -r -i vendor $(LICENSE_DIR)/license.txt . go && [[ -z `git status -s` ]]
diff --git a/README.md b/README.md
index a91e6d5cd..e8898d9ca 100644
--- a/README.md
+++ b/README.md
@@ -1,48 +1,87 @@
-# arana
-[![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE)
+# Arana
+
+
+
![](https://raw.githubusercontent.com/arana-db/arana/master/docs/pics/arana-main.png)
+
+
+`Arana` is a Cloud Native Database Proxy. It can be deployed as a Database mesh sidecar. It provides transparent data access capabilities,
+when using `arana`, user doesn't need to care about the `sharding` details of database, they can use it just like a single `MySQL` database.
+
+## Overview
+
+[![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE)
[![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana)
+[![Go Report Card](https://goreportcard.com/badge/github.com/arana-db/arana)](https://goreportcard.com/report/github.com/arana-db/arana)
+[![Release](https://img.shields.io/github/v/release/arana-db/arana)](https://img.shields.io/github/v/release/arana-db/arana)
+[![Docker Pulls](https://img.shields.io/docker/pulls/aranadb/arana)](https://img.shields.io/docker/pulls/aranadb/arana)
+| **Stargazers Over Time** | **Contributors Over Time** |
+|:-----------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
+| [![Stargazers over time](https://starchart.cc/arana-db/arana.svg)](https://starchart.cc/arana-db/arana) | [![Contributor over time](https://contributor-graph-api.apiseven.com/contributors-svg?chart=contributorOverTime&repo=arana-db/arana)](https://contributor-graph-api.apiseven.com/contributors-svg?chart=contributorOverTime&repo=arana-db/arana) |
-![](./docs/pics/arana-logo.png)
## Introduction | [中文](https://github.com/arana-db/arana/blob/master/README_CN.md)
-Arana is a db proxy. It can be deployed as a sidecar.
+First, `Arana` is a Cloud Native Database Proxy. It provides transparent data access capabilities, when using `arana`,
+user doesn't need to care about the `sharding` details of database, they can use it just like a single `MySQL` database.
+`Arana` also provide abilities of `Multi Tenant`, `Distributed transaction`, `Shadow database`, `SQL Audit`, `Data encrypt / decrypt`
+and so on. Through simple config, user can use these abilities provided by `arana` directly.
+
+Second, `Arana` can also be deployed as a Database mesh sidecar. As a Database mesh sidecar, arana switches data access from
+client mode to proxy mode, which greatly optimizes the startup speed of applications. It provides the ability to manage database
+traffic, it takes up very little container resources, doesn't affect the performance of application services in the container, but
+provides all the capabilities of proxy.
## Architecture
+
+
## Features
-| feature | complete |
-| -- | -- |
-| single db proxy | √ |
-| read write splitting | × |
-| tracing | × |
-| metrics | × |
-| sql audit | × |
-| sharding | × |
-| multi tenant | × |
+| **Feature** | **Complete** |
+|:-----------------------:|:------------:|
+| Single DB Proxy | √ |
+| Read Write Splitting | √ |
+| Sharding | √ |
+| Multi Tenant | √ |
+| Distributed Primary Key | WIP |
+| Distributed Transaction | WIP |
+| Shadow Table | WIP |
+| Database Mesh | WIP |
+| Tracing / Metrics | WIP |
+| SQL Audit | Roadmap |
+| Data encrypt / decrypt | Roadmap |
## Getting started
+Please reference this link [Getting Started](https://github.com/arana-db/arana/discussions/172)
+
```
arana start -c ${configFilePath}
```
### Prerequisites
-+ MySQL server 5.7+
++ Go 1.18+
++ MySQL Server 5.7+
## Design and implementation
## Roadmap
## Built With
-- [tidb](https://github.com/pingcap/tidb) - The sql parser used
+
+- [TiDB](https://github.com/pingcap/tidb) - The SQL parser used
## Contact
+Arana Chinese Community Meeting Time: **Every Saturday At 9:00PM GMT+8**
+
+
+
## Contributing
+Thanks for your help improving the project! We are so happy to have you! We have a contributing guide to help you get involved in the Arana project.
+
## License
Arana software is licenced under the Apache License Version 2.0. See the [LICENSE](https://github.com/arana-db/arana/blob/master/LICENSE) file for details.
diff --git a/README_CN.md b/README_CN.md
index ca05c93ba..a957e4604 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -2,7 +2,7 @@
[![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE)
[![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana)
-![](./docs/pics/arana-logo.png)
+![](./docs/pics/arana-main.png)
## 简介 | [English](https://github.com/arana-db/arana/blob/master/README.md)
diff --git a/README_ZH.md b/README_ZH.md
index 96d9c4684..ca91e1874 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -2,7 +2,7 @@
[![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE)
[![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana)
-![](./docs/pics/arana-logo.png)
+![](./docs/pics/arana-main.png)
## 简介
diff --git a/cmd/main.go b/cmd/main.go
index e0da5f65c..ad35125d5 100644
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -27,9 +27,7 @@ import (
_ "github.com/arana-db/arana/cmd/tools"
)
-var (
- Version = "0.1.0"
-)
+var Version = "0.1.0"
func main() {
rootCommand := &cobra.Command{
diff --git a/cmd/start/start.go b/cmd/start/start.go
index f6ee9de5f..28d232147 100644
--- a/cmd/start/start.go
+++ b/cmd/start/start.go
@@ -49,7 +49,7 @@ const slogan = `
/ _ | / _ \/ _ | / |/ / _ |
/ __ |/ , _/ __ |/ / __ |
/_/ |_/_/|_/_/ |_/_/|_/_/ |_|
-High performance, powerful DB Mesh.
+Arana, A High performance & Powerful DB Mesh sidecar.
_____________________________________________
`
diff --git a/conf/config.yaml b/conf/config.yaml
index 21721426d..5caeab8b1 100644
--- a/conf/config.yaml
+++ b/conf/config.yaml
@@ -40,40 +40,76 @@ data:
type: mysql
sql_max_limit: -1
tenant: arana
- conn_props:
- capacity: 10
- max_capacity: 20
- idle_timeout: 60
groups:
- name: employees_0000
nodes:
- - name: arana-node-1
+ - name: node0
host: arana-mysql
port: 3306
username: root
password: "123456"
- database: employees
+ database: employees_0000
weight: r10w10
- labels:
- zone: shanghai
- conn_props:
- readTimeout: "1s"
- writeTimeout: "1s"
- parseTime: true
- loc: Local
- charset: utf8mb4,utf8
-
+ parameters:
+ maxAllowedPacket: 256M
+ - name: node0_r_0
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0000_r
+ weight: r0w0
+ parameters:
+ maxAllowedPacket: 256M
+ - name: employees_0001
+ nodes:
+ - name: node1
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0001
+ weight: r10w10
+ parameters:
+ maxAllowedPacket: 256M
+ - name: employees_0002
+ nodes:
+ - name: node2
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0002
+ weight: r10w10
+ parameters:
+ maxAllowedPacket: 256M
+ - name: employees_0003
+ nodes:
+ - name: node3
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0003
+ weight: r10w10
+ parameters:
+ maxAllowedPacket: 256M
sharding_rule:
tables:
- name: employees.student
allow_full_scan: true
db_rules:
+ - column: uid
+ type: scriptExpr
+ expr: parseInt($value % 32 / 8)
tbl_rules:
- column: uid
+ type: scriptExpr
expr: $value % 32
+ step: 32
topology:
- db_pattern: employees_0000
- tbl_pattern: student_${0000...0031}
+ db_pattern: employees_${0000..0003}
+ tbl_pattern: student_${0000..0031}
attributes:
sqlMaxLimit: -1
diff --git a/docker-compose.yaml b/docker-compose.yaml
index c66cf692d..7a1548fc1 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -18,7 +18,7 @@
version: "3"
services:
mysql:
- image: mysql:8.0
+ image: mysql:5.7
container_name: arana-mysql
networks:
- local
@@ -41,7 +41,7 @@ services:
arana:
build: .
container_name: arana
- image: aranadb/arana:latest
+ image: aranadb/arana:master
networks:
- local
ports:
diff --git a/docs/pics/arana-architecture.png b/docs/pics/arana-architecture.png
new file mode 100644
index 000000000..775fa8dcb
Binary files /dev/null and b/docs/pics/arana-architecture.png differ
diff --git a/docs/pics/arana-blue.png b/docs/pics/arana-blue.png
new file mode 100644
index 000000000..90d31b8c4
Binary files /dev/null and b/docs/pics/arana-blue.png differ
diff --git a/docs/pics/arana-db-blue.png b/docs/pics/arana-db-blue.png
new file mode 100644
index 000000000..93f3d4014
Binary files /dev/null and b/docs/pics/arana-db-blue.png differ
diff --git a/docs/pics/arana-db-v0.2.sketch b/docs/pics/arana-db-v0.2.sketch
new file mode 100644
index 000000000..a97994827
Binary files /dev/null and b/docs/pics/arana-db-v0.2.sketch differ
diff --git a/docs/pics/arana-main.png b/docs/pics/arana-main.png
new file mode 100644
index 000000000..559c90dd4
Binary files /dev/null and b/docs/pics/arana-main.png differ
diff --git a/docs/pics/dingtalk-group.jpeg b/docs/pics/dingtalk-group.jpeg
new file mode 100644
index 000000000..638128143
Binary files /dev/null and b/docs/pics/dingtalk-group.jpeg differ
diff --git a/go.mod b/go.mod
index 7d19bd53a..23d592959 100644
--- a/go.mod
+++ b/go.mod
@@ -1,16 +1,15 @@
module github.com/arana-db/arana
-go 1.16
+go 1.18
require (
- github.com/arana-db/parser v0.2.1
+ github.com/arana-db/parser v0.2.3
github.com/bwmarrin/snowflake v0.3.0
github.com/cespare/xxhash/v2 v2.1.2
github.com/dop251/goja v0.0.0-20220422102209-3faab1d8f20e
github.com/dubbogo/gost v1.12.3
github.com/go-playground/validator/v10 v10.10.1
github.com/go-sql-driver/mysql v1.6.0
- github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/mock v1.5.0
github.com/hashicorp/golang-lru v0.5.4
github.com/lestrrat-go/strftime v1.0.5
@@ -19,21 +18,105 @@ require (
github.com/olekukonko/tablewriter v0.0.5
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.11.0
- github.com/prometheus/common v0.28.0 // indirect
github.com/spf13/cobra v1.2.1
- github.com/stretchr/testify v1.7.0
+ github.com/stretchr/testify v1.7.1
github.com/testcontainers/testcontainers-go v0.12.0
github.com/tidwall/gjson v1.14.0
go.etcd.io/etcd/api/v3 v3.5.1
go.etcd.io/etcd/client/v3 v3.5.0
go.etcd.io/etcd/server/v3 v3.5.0-alpha.0
+ go.opentelemetry.io/otel v1.7.0
+ go.opentelemetry.io/otel/trace v1.7.0
go.uber.org/atomic v1.9.0
- go.uber.org/multierr v1.7.0
go.uber.org/zap v1.19.1
- golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect
+ golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
- golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32 // indirect
+ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
+)
+
+require (
+ github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect
+ github.com/Microsoft/go-winio v0.4.17-0.20210211115548-6eac466e5fa3 // indirect
+ github.com/Microsoft/hcsshim v0.8.16 // indirect
+ github.com/aliyun/alibaba-cloud-sdk-go v1.61.18 // indirect
+ github.com/beorn7/perks v1.0.1 // indirect
+ github.com/buger/jsonparser v1.1.1 // indirect
+ github.com/cenkalti/backoff v2.2.1+incompatible // indirect
+ github.com/containerd/cgroups v0.0.0-20210114181951-8a68de567b68 // indirect
+ github.com/containerd/containerd v1.5.0-beta.4 // indirect
+ github.com/coreos/go-semver v0.3.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.3.2 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91 // indirect
+ github.com/docker/distribution v2.7.1+incompatible // indirect
+ github.com/docker/docker v20.10.11+incompatible // indirect
+ github.com/docker/go-connections v0.4.0 // indirect
+ github.com/docker/go-units v0.4.0 // indirect
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect
+ github.com/go-errors/errors v1.0.1 // indirect
+ github.com/go-logr/logr v1.2.3 // indirect
+ github.com/go-logr/stdr v1.2.2 // indirect
+ github.com/go-playground/locales v0.14.0 // indirect
+ github.com/go-playground/universal-translator v0.18.0 // indirect
+ github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
+ github.com/gogo/protobuf v1.3.2 // indirect
+ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
+ github.com/golang/protobuf v1.5.2 // indirect
+ github.com/google/btree v1.0.0 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99 // indirect
+ github.com/gorilla/websocket v1.4.2 // indirect
+ github.com/grpc-ecosystem/go-grpc-middleware v1.2.2 // indirect
+ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
+ github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
+ github.com/inconshreveable/mousetrap v1.0.0 // indirect
+ github.com/jmespath/go-jmespath v0.4.0 // indirect
+ github.com/jonboulle/clockwork v0.2.2 // indirect
+ github.com/json-iterator/go v1.1.11 // indirect
+ github.com/leodido/go-urn v1.2.1 // indirect
+ github.com/magiconair/properties v1.8.5 // indirect
+ github.com/mattn/go-runewidth v0.0.9 // indirect
+ github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
+ github.com/moby/sys/mount v0.2.0 // indirect
+ github.com/moby/sys/mountinfo v0.5.0 // indirect
+ github.com/moby/term v0.0.0-20201216013528-df9cb8a40635 // indirect
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
+ github.com/modern-go/reflect2 v1.0.1 // indirect
+ github.com/morikuni/aec v0.0.0-20170113033406-39771216ff4c // indirect
+ github.com/opencontainers/go-digest v1.0.0 // indirect
+ github.com/opencontainers/image-spec v1.0.1 // indirect
+ github.com/opencontainers/runc v1.0.2 // indirect
+ github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 // indirect
+ github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/prometheus/client_model v0.2.0 // indirect
+ github.com/prometheus/common v0.28.0 // indirect
+ github.com/prometheus/procfs v0.6.0 // indirect
+ github.com/sirupsen/logrus v1.8.1 // indirect
+ github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5 // indirect
+ github.com/spf13/pflag v1.0.5 // indirect
+ github.com/tidwall/match v1.1.1 // indirect
+ github.com/tidwall/pretty v1.2.0 // indirect
+ github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
+ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
+ go.etcd.io/bbolt v1.3.5 // indirect
+ go.etcd.io/etcd/client/pkg/v3 v3.5.0 // indirect
+ go.etcd.io/etcd/client/v2 v2.305.0 // indirect
+ go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0 // indirect
+ go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0 // indirect
+ go.opencensus.io v0.23.0 // indirect
+ go.uber.org/multierr v1.6.0 // indirect
+ golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect
+ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 // indirect
+ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 // indirect
+ golang.org/x/text v0.3.7 // indirect
+ golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect
google.golang.org/genproto v0.0.0-20211104193956-4c6863e31247 // indirect
google.golang.org/grpc v1.42.0 // indirect
- gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
+ google.golang.org/protobuf v1.27.1 // indirect
+ gopkg.in/ini.v1 v1.62.0 // indirect
+ gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ sigs.k8s.io/yaml v1.2.0 // indirect
)
diff --git a/go.sum b/go.sum
index a2d502ace..c3cc775e1 100644
--- a/go.sum
+++ b/go.sum
@@ -94,8 +94,8 @@ github.com/aliyun/alibaba-cloud-sdk-go v1.61.18/go.mod h1:v8ESoHo4SyHmuB4b1tJqDH
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
-github.com/arana-db/parser v0.2.1 h1:a885V+OIABmqYHYbLP2QWZbn+/TE0mZJd8dafWY7F6Y=
-github.com/arana-db/parser v0.2.1/go.mod h1:y4hxIPieC5T26aoNd44XiWXNunC03kUQW0CI3NKaYTk=
+github.com/arana-db/parser v0.2.3 h1:zLZcx0/oidlHnw/GZYE78NuvwQkHUv2Xtrm2IwyZasA=
+github.com/arana-db/parser v0.2.3/go.mod h1:/XA29bplweWSEAjgoM557ZCzhBilSawUlHcZFjOeDAc=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
@@ -136,7 +136,6 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40 h1:xvUo53O5MRZhVMJAxWCJcS5HHrqAiAG9SJ1LpMu6aAI=
github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA=
-github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
@@ -244,7 +243,6 @@ github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmf
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20161114122254-48702e0da86b/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
-github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9/ZjdUKyjop4mf3Qdd+1TvvltAvM3m8=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
@@ -352,6 +350,11 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU=
+github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
+github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
+github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
+github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
+github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -444,8 +447,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
+github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
@@ -471,8 +474,9 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg=
-github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
+github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99 h1:twflg0XRTjwKpxb/jFExr4HGq6on2dEOmnL6FV+fgPw=
+github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@@ -536,8 +540,11 @@ github.com/j-keck/arping v0.0.0-20160618110441-2cf9dc699c56/go.mod h1:ymszkNOg6t
github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869/go.mod h1:cJ6Cj7dQo+O6GJNiMx+Pa94qKj+TG8ONdKHgMNIyyag=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.0.0-20160803190731-bd40a432e4c7/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
-github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ=
github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8=
@@ -647,7 +654,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
-github.com/nacos-group/nacos-sdk-go v1.0.8 h1:8pEm05Cdav9sQgJSv5kyvlgfz0SzFUUGI3pWX6SiSnM=
github.com/nacos-group/nacos-sdk-go v1.0.8/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA=
github.com/nacos-group/nacos-sdk-go/v2 v2.0.1 h1:jEZjqdCDSt6ZFtl628UUwON21GxwJ+lEN/PDamQOzgU=
github.com/nacos-group/nacos-sdk-go/v2 v2.0.1/go.mod h1:SlhyCAv961LcZ198XpKfPEQqlJWt2HkL1fDLas0uy/w=
@@ -860,8 +866,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
@@ -915,7 +922,6 @@ go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg=
-go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489 h1:1JFLBqwIgdyHN1ZtgjTBwO+blA6gVOmZurpiMEsETKo=
go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489/go.mod h1:yVHk9ub3CSBatqGNg7GRmsnfLWtoW60w4eDYfh7vHDg=
go.etcd.io/etcd/api/v3 v3.5.0-alpha.0/go.mod h1:mPcW6aZJukV6Aa81LSKpBjQXTWlXB5r74ymPoSWa3Sw=
go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
@@ -946,6 +952,10 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
+go.opentelemetry.io/otel v1.7.0 h1:Z2lA3Tdch0iDcrhJXDIlC94XE+bxok1F9B+4Lz/lGsM=
+go.opentelemetry.io/otel v1.7.0/go.mod h1:5BdUoMIz5WEs0vt0CUEMtSSaTSHBBVwrhnz7+nrD5xk=
+go.opentelemetry.io/otel/trace v1.7.0 h1:O37Iogk1lEkMRXewVtZ1BBTVn5JEp8GrJvP92bJqC6o=
+go.opentelemetry.io/otel/trace v1.7.0/go.mod h1:fzLSB9nqR2eXzxPXb2JW9IKE+ScyXA48yyE4TNvoHqU=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
@@ -960,9 +970,8 @@ go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpK
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
+go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
-go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
-go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
@@ -987,9 +996,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc=
-golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -1002,6 +1010,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
+golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d h1:vtUKgx8dahOomfFzLREU8nSv25YHnTgLBn4rDnWZdU0=
+golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -1199,9 +1209,8 @@ golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211109184856-51b60fd695b3/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 h1:XDXtA5hveEEV8JB2l7nhMTp3t3cHp9ZpwcdjqyEWLlo=
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32 h1:Js08h5hqB5xyWR789+QqueR6sDE8mk+YvpETZ+F6X9Y=
-golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1285,7 +1294,6 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
diff --git a/integration_test/config/db/config.yaml b/integration_test/config/db/config.yaml
new file mode 100644
index 000000000..d9eaf3d96
--- /dev/null
+++ b/integration_test/config/db/config.yaml
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: ConfigMap
+apiVersion: "1.0"
+metadata:
+ name: arana-config
+data:
+ listeners:
+ - protocol_type: mysql
+ server_version: 5.7.0
+ socket_address:
+ address: 0.0.0.0
+ port: 13306
+
+ tenants:
+ - name: arana
+ users:
+ - username: arana
+ password: "123456"
+ - username: dksl
+ password: "123456"
+
+ clusters:
+ - name: employees
+ type: mysql
+ sql_max_limit: -1
+ tenant: arana
+ groups:
+ - name: employees_0000
+ nodes:
+ - name: node0
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0000
+ weight: r10w10
+ - name: employees_0001
+ nodes:
+ - name: node1
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0001
+ weight: r10w10
+ - name: employees_0002
+ nodes:
+ - name: node2
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0002
+ weight: r10w10
+ - name: employees_0003
+ nodes:
+ - name: node3
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0003
+ weight: r10w10
+
+ sharding_rule:
+ tables:
+ - name: employees.student
+ allow_full_scan: true
+ db_rules:
+ - column: uid
+ type: scriptExpr
+ expr: parseInt($value % 32 / 8)
+ step: 32
+ tbl_rules:
+ - column: uid
+ type: scriptExpr
+ expr: parseInt(0)
+ topology:
+ db_pattern: employees_${0000...0003}
+ tbl_pattern: student_0000
+ attributes:
+ sqlMaxLimit: -1
diff --git a/integration_test/config/db/data.yaml b/integration_test/config/db/data.yaml
new file mode 100644
index 000000000..112cb3c18
--- /dev/null
+++ b/integration_test/config/db/data.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: DataSet
+metadata:
+ tables:
+ - name: "order"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "order"
+ value:
+ - ["test", "test1"]
diff --git a/integration_test/config/db/expected.yaml b/integration_test/config/db/expected.yaml
new file mode 100644
index 000000000..e16b99a92
--- /dev/null
+++ b/integration_test/config/db/expected.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: excepted
+metadata:
+ tables:
+ - name: "order"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "order"
+ value:
+ - ["test", "test1"]
diff --git a/integration_test/config/db_tbl/config.yaml b/integration_test/config/db_tbl/config.yaml
new file mode 100644
index 000000000..4b421d229
--- /dev/null
+++ b/integration_test/config/db_tbl/config.yaml
@@ -0,0 +1,104 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: ConfigMap
+apiVersion: "1.0"
+metadata:
+ name: arana-config
+data:
+ listeners:
+ - protocol_type: mysql
+ server_version: 5.7.0
+ socket_address:
+ address: 0.0.0.0
+ port: 13306
+
+ tenants:
+ - name: arana
+ users:
+ - username: arana
+ password: "123456"
+ - username: dksl
+ password: "123456"
+
+ clusters:
+ - name: employees
+ type: mysql
+ sql_max_limit: -1
+ tenant: arana
+ groups:
+ - name: employees_0000
+ nodes:
+ - name: node0
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0000
+ weight: r10w10
+ - name: node0_r_0
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0000_r
+ weight: r0w0
+ - name: employees_0001
+ nodes:
+ - name: node1
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0001
+ weight: r10w10
+ - name: employees_0002
+ nodes:
+ - name: node2
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0002
+ weight: r10w10
+ - name: employees_0003
+ nodes:
+ - name: node3
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0003
+ weight: r10w10
+
+ sharding_rule:
+ tables:
+ - name: employees.student
+ allow_full_scan: true
+ db_rules:
+ - column: uid
+ type: scriptExpr
+ expr: parseInt($value % 32 / 8)
+ tbl_rules:
+ - column: uid
+ type: scriptExpr
+ expr: $value % 32
+ topology:
+ db_pattern: employees_${0000..0003}
+ tbl_pattern: student_${0000..0031}
+ attributes:
+ sqlMaxLimit: -1
diff --git a/integration_test/config/db_tbl/data.yaml b/integration_test/config/db_tbl/data.yaml
new file mode 100644
index 000000000..112cb3c18
--- /dev/null
+++ b/integration_test/config/db_tbl/data.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: DataSet
+metadata:
+ tables:
+ - name: "order"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "order"
+ value:
+ - ["test", "test1"]
diff --git a/integration_test/config/db_tbl/expected.yaml b/integration_test/config/db_tbl/expected.yaml
new file mode 100644
index 000000000..542d37822
--- /dev/null
+++ b/integration_test/config/db_tbl/expected.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: excepted
+metadata:
+ tables:
+ - name: "sequence"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "sequence"
+ value:
+ - ["1", "2"]
diff --git a/integration_test/config/tbl/config.yaml b/integration_test/config/tbl/config.yaml
new file mode 100644
index 000000000..3eb3564bb
--- /dev/null
+++ b/integration_test/config/tbl/config.yaml
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: ConfigMap
+apiVersion: "1.0"
+metadata:
+ name: arana-config
+data:
+ listeners:
+ - protocol_type: mysql
+ server_version: 5.7.0
+ socket_address:
+ address: 0.0.0.0
+ port: 13306
+
+ tenants:
+ - name: arana
+ users:
+ - username: arana
+ password: "123456"
+ - username: dksl
+ password: "123456"
+
+ clusters:
+ - name: employees
+ type: mysql
+ sql_max_limit: -1
+ tenant: arana
+ groups:
+ - name: employees_0000
+ nodes:
+ - name: node0
+ host: arana-mysql
+ port: 3306
+ username: root
+ password: "123456"
+ database: employees_0000
+ weight: r10w10
+
+ sharding_rule:
+ tables:
+ - name: employees.student
+ allow_full_scan: true
+ db_rules:
+ - column: uid
+ type: scriptExpr
+ expr: parseInt(0)
+ tbl_rules:
+ - column: uid
+ type: scriptExpr
+ expr: $value % 32
+ topology:
+ db_pattern: employees_0000
+ tbl_pattern: student_${0000..0031}
+ attributes:
+ sqlMaxLimit: -1
diff --git a/integration_test/config/tbl/data.yaml b/integration_test/config/tbl/data.yaml
new file mode 100644
index 000000000..112cb3c18
--- /dev/null
+++ b/integration_test/config/tbl/data.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: DataSet
+metadata:
+ tables:
+ - name: "order"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "order"
+ value:
+ - ["test", "test1"]
diff --git a/integration_test/config/tbl/expected.yaml b/integration_test/config/tbl/expected.yaml
new file mode 100644
index 000000000..e16b99a92
--- /dev/null
+++ b/integration_test/config/tbl/expected.yaml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+kind: excepted
+metadata:
+ tables:
+ - name: "order"
+ columns:
+ - name: "name"
+ type: "string"
+ - name: "value"
+ type: "string"
+data:
+ - name: "order"
+ value:
+ - ["test", "test1"]
diff --git a/integration_test/scene/db/integration_test.go b/integration_test/scene/db/integration_test.go
new file mode 100644
index 000000000..c7188fcba
--- /dev/null
+++ b/integration_test/scene/db/integration_test.go
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test
+
+import (
+ "strings"
+ "testing"
+)
+
+import (
+ _ "github.com/go-sql-driver/mysql"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/suite"
+)
+
+import (
+ "github.com/arana-db/arana/test"
+)
+
+type IntegrationSuite struct {
+ *test.MySuite
+}
+
+func TestSuite(t *testing.T) {
+ su := test.NewMySuite(
+ test.WithMySQLServerAuth("root", "123456"),
+ test.WithMySQLDatabase("employees"),
+ test.WithConfig("../integration_test/config/db/config.yaml"),
+ test.WithScriptPath("../integration_test/scripts/db"),
+ test.WithTestCasePath("../../testcase/casetest.yaml"),
+ // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!!
+ )
+ suite.Run(t, &IntegrationSuite{su})
+}
+
+func (s *IntegrationSuite) TestDBScene() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+ tx, err := db.Begin()
+ assert.NoError(t, err, "should begin a new tx")
+
+ cases := s.TestCases()
+ for _, sqlCase := range cases.ExecCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.TrimSpace(sense) == "db" {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ // Execute sql
+ result, err := tx.Exec(sqlCase.SQL, args...)
+ assert.NoError(t, err, "exec not right")
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+
+ for _, sqlCase := range cases.QueryRowCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.TrimSpace(sense) == "db" {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ result := tx.QueryRow(sqlCase.SQL, args...)
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+
+ for _, sqlCase := range cases.QueryRowsCases {
+ s.LoadExpectedDataSetPath(sqlCase.ExpectedResult.Value)
+ for _, sense := range sqlCase.Sense {
+ if strings.TrimSpace(sense) == "db" {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ result, err := db.Query(sqlCase.SQL, args...)
+ assert.NoError(t, err, err)
+ err = sqlCase.ExpectedResult.CompareRows(result, s.ExpectedDataset())
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+}
diff --git a/integration_test/scene/db_tbl/integration_test.go b/integration_test/scene/db_tbl/integration_test.go
new file mode 100644
index 000000000..e7878729a
--- /dev/null
+++ b/integration_test/scene/db_tbl/integration_test.go
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test
+
+import (
+ "strings"
+ "testing"
+)
+
+import (
+ _ "github.com/go-sql-driver/mysql" // register mysql
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/suite"
+)
+
+import (
+ "github.com/arana-db/arana/test"
+)
+
+type IntegrationSuite struct {
+ *test.MySuite
+}
+
+func TestSuite(t *testing.T) {
+ su := test.NewMySuite(
+ test.WithMySQLServerAuth("root", "123456"),
+ test.WithMySQLDatabase("employees"),
+ test.WithConfig("../integration_test/config/db_tbl/config.yaml"),
+ test.WithScriptPath("../integration_test/scripts/db_tbl"),
+ test.WithTestCasePath("../../testcase/casetest.yaml"),
+ // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!!
+ )
+ suite.Run(t, &IntegrationSuite{su})
+}
+
+func (s *IntegrationSuite) TestDBTBLScene() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+ tx, err := db.Begin()
+ assert.NoError(t, err, "should begin a new tx")
+
+ cases := s.TestCases()
+ for _, sqlCase := range cases.ExecCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.Compare(strings.TrimSpace(sense), "db_tbl") == 1 {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ // Execute sql
+ result, err := tx.Exec(sqlCase.SQL, args...)
+ assert.NoError(t, err, "exec not right")
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+
+ for _, sqlCase := range cases.QueryRowCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.Compare(strings.TrimSpace(sense), "db_tbl") == 1 {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ result := tx.QueryRow(sqlCase.SQL, args...)
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+}
diff --git a/integration_test/scene/tbl/integration_test.go b/integration_test/scene/tbl/integration_test.go
new file mode 100644
index 000000000..cc95a5853
--- /dev/null
+++ b/integration_test/scene/tbl/integration_test.go
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test
+
+import (
+ "strings"
+ "testing"
+)
+
+import (
+ _ "github.com/go-sql-driver/mysql" // register mysql
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/suite"
+)
+
+import (
+ "github.com/arana-db/arana/test"
+)
+
+type IntegrationSuite struct {
+ *test.MySuite
+}
+
+func TestSuite(t *testing.T) {
+ su := test.NewMySuite(
+ test.WithMySQLServerAuth("root", "123456"),
+ test.WithMySQLDatabase("employees"),
+ test.WithConfig("../integration_test/config/tbl/config.yaml"),
+ test.WithScriptPath("../integration_test/scripts/tbl"),
+ test.WithTestCasePath("../../testcase/casetest.yaml"),
+ // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!!
+ )
+ suite.Run(t, &IntegrationSuite{su})
+}
+
+func (s *IntegrationSuite) TestDBTBLScene() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+ tx, err := db.Begin()
+ assert.NoError(t, err, "should begin a new tx")
+
+ cases := s.TestCases()
+ for _, sqlCase := range cases.ExecCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.Compare(strings.TrimSpace(sense), "tbl") == 1 {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ // Execute sql
+ result, err := tx.Exec(sqlCase.SQL, args...)
+ assert.NoError(t, err, "exec not right")
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+
+ for _, sqlCase := range cases.QueryRowCases {
+ for _, sense := range sqlCase.Sense {
+ if strings.Compare(strings.TrimSpace(sense), "tbl") == 1 {
+ params := strings.Split(sqlCase.Parameters, ",")
+ args := make([]interface{}, 0, len(params))
+ for _, param := range params {
+ k, _ := test.GetValueByType(param)
+ args = append(args, k)
+ }
+
+ result := tx.QueryRow(sqlCase.SQL, args...)
+ err = sqlCase.ExpectedResult.CompareRow(result)
+ assert.NoError(t, err, err)
+ }
+ }
+ }
+
+}
diff --git a/integration_test/scripts/db/init.sql b/integration_test/scripts/db/init.sql
new file mode 100644
index 000000000..744c67674
--- /dev/null
+++ b/integration_test/scripts/db/init.sql
@@ -0,0 +1,113 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+-- Sample employee database
+-- See changelog table for details
+-- Copyright (C) 2007,2008, MySQL AB
+--
+-- Original data created by Fusheng Wang and Carlo Zaniolo
+-- http://www.cs.aau.dk/TimeCenter/software.htm
+-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip
+--
+-- Current schema by Giuseppe Maxia
+-- Data conversion from XML to relational by Patrick Crews
+--
+-- This work is licensed under the
+-- Creative Commons Attribution-Share Alike 3.0 Unported License.
+-- To view a copy of this license, visit
+-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to
+-- Creative Commons, 171 Second Street, Suite 300, San Francisco,
+-- California, 94105, USA.
+--
+-- DISCLAIMER
+-- To the best of our knowledge, this data is fabricated, and
+-- it does not correspond to real people.
+-- Any similarity to existing people is purely coincidental.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+USE employees_0000;
+
+SELECT 'CREATING DATABASE STRUCTURE' as 'INFO';
+
+DROP TABLE IF EXISTS dept_emp,
+ dept_manager,
+ titles,
+ salaries,
+ employees,
+ departments;
+
+/*!50503 set default_storage_engine = InnoDB */;
+/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */;
+
+CREATE TABLE employees (
+ emp_no INT NOT NULL,
+ birth_date DATE NOT NULL,
+ first_name VARCHAR(14) NOT NULL,
+ last_name VARCHAR(16) NOT NULL,
+ gender ENUM ('M','F') NOT NULL,
+ hire_date DATE NOT NULL,
+ PRIMARY KEY (emp_no)
+);
+
+CREATE TABLE departments (
+ dept_no CHAR(4) NOT NULL,
+ dept_name VARCHAR(40) NOT NULL,
+ PRIMARY KEY (dept_no),
+ UNIQUE KEY (dept_name)
+);
+
+CREATE TABLE dept_manager (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE dept_emp (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE titles (
+ emp_no INT NOT NULL,
+ title VARCHAR(50) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,title, from_date)
+)
+;
+
+CREATE TABLE salaries (
+ emp_no INT NOT NULL,
+ salary INT NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no, from_date)
+)
+;
diff --git a/integration_test/scripts/db/sequence.sql b/integration_test/scripts/db/sequence.sql
new file mode 100644
index 000000000..0df0add1c
--- /dev/null
+++ b/integration_test/scripts/db/sequence.sql
@@ -0,0 +1,31 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence`
+(
+ `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
+ `name` VARCHAR(64) NOT NULL,
+ `value` BIGINT NOT NULL,
+ `step` INT NOT NULL DEFAULT 10000,
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_name` (`name`)
+) ENGINE = InnoDB
+ DEFAULT CHARSET = utf8mb4;
diff --git a/integration_test/scripts/db/sharding.sql b/integration_test/scripts/db/sharding.sql
new file mode 100644
index 000000000..1f22d4036
--- /dev/null
+++ b/integration_test/scripts/db/sharding.sql
@@ -0,0 +1,83 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+INSERT INTO employees_0000.student_0000 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW());
diff --git a/integration_test/scripts/db_tbl/init.sql b/integration_test/scripts/db_tbl/init.sql
new file mode 100644
index 000000000..744c67674
--- /dev/null
+++ b/integration_test/scripts/db_tbl/init.sql
@@ -0,0 +1,113 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+-- Sample employee database
+-- See changelog table for details
+-- Copyright (C) 2007,2008, MySQL AB
+--
+-- Original data created by Fusheng Wang and Carlo Zaniolo
+-- http://www.cs.aau.dk/TimeCenter/software.htm
+-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip
+--
+-- Current schema by Giuseppe Maxia
+-- Data conversion from XML to relational by Patrick Crews
+--
+-- This work is licensed under the
+-- Creative Commons Attribution-Share Alike 3.0 Unported License.
+-- To view a copy of this license, visit
+-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to
+-- Creative Commons, 171 Second Street, Suite 300, San Francisco,
+-- California, 94105, USA.
+--
+-- DISCLAIMER
+-- To the best of our knowledge, this data is fabricated, and
+-- it does not correspond to real people.
+-- Any similarity to existing people is purely coincidental.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+USE employees_0000;
+
+SELECT 'CREATING DATABASE STRUCTURE' as 'INFO';
+
+DROP TABLE IF EXISTS dept_emp,
+ dept_manager,
+ titles,
+ salaries,
+ employees,
+ departments;
+
+/*!50503 set default_storage_engine = InnoDB */;
+/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */;
+
+CREATE TABLE employees (
+ emp_no INT NOT NULL,
+ birth_date DATE NOT NULL,
+ first_name VARCHAR(14) NOT NULL,
+ last_name VARCHAR(16) NOT NULL,
+ gender ENUM ('M','F') NOT NULL,
+ hire_date DATE NOT NULL,
+ PRIMARY KEY (emp_no)
+);
+
+CREATE TABLE departments (
+ dept_no CHAR(4) NOT NULL,
+ dept_name VARCHAR(40) NOT NULL,
+ PRIMARY KEY (dept_no),
+ UNIQUE KEY (dept_name)
+);
+
+CREATE TABLE dept_manager (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE dept_emp (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE titles (
+ emp_no INT NOT NULL,
+ title VARCHAR(50) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,title, from_date)
+)
+;
+
+CREATE TABLE salaries (
+ emp_no INT NOT NULL,
+ salary INT NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no, from_date)
+)
+;
diff --git a/integration_test/scripts/db_tbl/sequence.sql b/integration_test/scripts/db_tbl/sequence.sql
new file mode 100644
index 000000000..0df0add1c
--- /dev/null
+++ b/integration_test/scripts/db_tbl/sequence.sql
@@ -0,0 +1,31 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence`
+(
+ `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
+ `name` VARCHAR(64) NOT NULL,
+ `value` BIGINT NOT NULL,
+ `step` INT NOT NULL DEFAULT 10000,
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_name` (`name`)
+) ENGINE = InnoDB
+ DEFAULT CHARSET = utf8mb4;
diff --git a/integration_test/scripts/db_tbl/sharding.sql b/integration_test/scripts/db_tbl/sharding.sql
new file mode 100644
index 000000000..e98faa26f
--- /dev/null
+++ b/integration_test/scripts/db_tbl/sharding.sql
@@ -0,0 +1,633 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE DATABASE IF NOT EXISTS employees_0000_r CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0008`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0009`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0010`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0011`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0012`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0013`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0014`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0015`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0016`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0017`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0018`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0019`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0020`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0021`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0022`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0023`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0024`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0025`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0026`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0027`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0028`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0029`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0030`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0031`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0001`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0002`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0003`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0004`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0005`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0006`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0007`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW());
diff --git a/integration_test/scripts/tbl/init.sql b/integration_test/scripts/tbl/init.sql
new file mode 100644
index 000000000..744c67674
--- /dev/null
+++ b/integration_test/scripts/tbl/init.sql
@@ -0,0 +1,113 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+-- Sample employee database
+-- See changelog table for details
+-- Copyright (C) 2007,2008, MySQL AB
+--
+-- Original data created by Fusheng Wang and Carlo Zaniolo
+-- http://www.cs.aau.dk/TimeCenter/software.htm
+-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip
+--
+-- Current schema by Giuseppe Maxia
+-- Data conversion from XML to relational by Patrick Crews
+--
+-- This work is licensed under the
+-- Creative Commons Attribution-Share Alike 3.0 Unported License.
+-- To view a copy of this license, visit
+-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to
+-- Creative Commons, 171 Second Street, Suite 300, San Francisco,
+-- California, 94105, USA.
+--
+-- DISCLAIMER
+-- To the best of our knowledge, this data is fabricated, and
+-- it does not correspond to real people.
+-- Any similarity to existing people is purely coincidental.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+USE employees_0000;
+
+SELECT 'CREATING DATABASE STRUCTURE' as 'INFO';
+
+DROP TABLE IF EXISTS dept_emp,
+ dept_manager,
+ titles,
+ salaries,
+ employees,
+ departments;
+
+/*!50503 set default_storage_engine = InnoDB */;
+/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */;
+
+CREATE TABLE employees (
+ emp_no INT NOT NULL,
+ birth_date DATE NOT NULL,
+ first_name VARCHAR(14) NOT NULL,
+ last_name VARCHAR(16) NOT NULL,
+ gender ENUM ('M','F') NOT NULL,
+ hire_date DATE NOT NULL,
+ PRIMARY KEY (emp_no)
+);
+
+CREATE TABLE departments (
+ dept_no CHAR(4) NOT NULL,
+ dept_name VARCHAR(40) NOT NULL,
+ PRIMARY KEY (dept_no),
+ UNIQUE KEY (dept_name)
+);
+
+CREATE TABLE dept_manager (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE dept_emp (
+ emp_no INT NOT NULL,
+ dept_no CHAR(4) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,dept_no)
+);
+
+CREATE TABLE titles (
+ emp_no INT NOT NULL,
+ title VARCHAR(50) NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no,title, from_date)
+)
+;
+
+CREATE TABLE salaries (
+ emp_no INT NOT NULL,
+ salary INT NOT NULL,
+ from_date DATE NOT NULL,
+ to_date DATE NOT NULL,
+ FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
+ PRIMARY KEY (emp_no, from_date)
+)
+;
diff --git a/integration_test/scripts/tbl/sequence.sql b/integration_test/scripts/tbl/sequence.sql
new file mode 100644
index 000000000..0df0add1c
--- /dev/null
+++ b/integration_test/scripts/tbl/sequence.sql
@@ -0,0 +1,31 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence`
+(
+ `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
+ `name` VARCHAR(64) NOT NULL,
+ `value` BIGINT NOT NULL,
+ `step` INT NOT NULL DEFAULT 10000,
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_name` (`name`)
+) ENGINE = InnoDB
+ DEFAULT CHARSET = utf8mb4;
diff --git a/integration_test/scripts/tbl/sharding.sql b/integration_test/scripts/tbl/sharding.sql
new file mode 100644
index 000000000..8a00efcb7
--- /dev/null
+++ b/integration_test/scripts/tbl/sharding.sql
@@ -0,0 +1,500 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0008`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0009`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0010`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0011`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0012`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0013`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0014`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0015`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0016`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0017`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0018`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0019`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0020`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0021`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0022`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0023`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0024`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0025`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0026`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0027`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0028`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0029`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0030`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0031`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW());
diff --git a/integration_test/testcase/casetest.yaml b/integration_test/testcase/casetest.yaml
new file mode 100644
index 000000000..0324d827a
--- /dev/null
+++ b/integration_test/testcase/casetest.yaml
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+kind: dml
+query_rows_cases:
+ - sql: "SELECT name, value FROM sequence WHERE name= ?"
+ parameters: "1:string"
+ sense:
+ - db
+ expected:
+ type: "file"
+ value: "../../config/db/expected.yaml"
+query_row_cases:
+ - sql: "SELECT COUNT(1) FROM sequence WHERE name=?"
+ parameters: "1:string"
+ sense:
+ - db
+ - tbl
+ expected:
+ type: "valueInt"
+ value: "1"
+exec_cases:
+ - sql: "INSERT INTO sequence(name,value,modified_at) VALUES(?,?,NOW())"
+ parameters: "1:string, 2:string"
+ sense:
+ - db
+ - tbl
+ expected:
+ type: "rowAffect"
+ value: 1
diff --git a/justfile b/justfile
index 0dfff1acf..eb4f510f9 100644
--- a/justfile
+++ b/justfile
@@ -7,7 +7,7 @@ default:
@just --list
run:
- @go run ./cmd/... start -c ./conf/bootstrap.yaml
+ @go run ./example/local_server
cli:
@mycli -h127.0.0.1 -P13306 -udksl employees -p123456
@@ -16,4 +16,4 @@ cli-raw:
fix:
@imports-formatter .
- @license-eye header fix
\ No newline at end of file
+ @license-eye header fix
diff --git a/pkg/boot/boot.go b/pkg/boot/boot.go
index 56126b4de..0b30863d5 100644
--- a/pkg/boot/boot.go
+++ b/pkg/boot/boot.go
@@ -30,7 +30,7 @@ import (
"github.com/arana-db/arana/pkg/proto/rule"
"github.com/arana-db/arana/pkg/runtime"
"github.com/arana-db/arana/pkg/runtime/namespace"
- "github.com/arana-db/arana/pkg/runtime/optimize"
+ _ "github.com/arana-db/arana/pkg/schema"
"github.com/arana-db/arana/pkg/security"
"github.com/arana-db/arana/pkg/util/log"
)
@@ -129,5 +129,5 @@ func buildNamespace(ctx context.Context, provider Discovery, cluster string) (*n
}
initCmds = append(initCmds, namespace.UpdateRule(&ru))
- return namespace.New(cluster, optimize.GetOptimizer(), initCmds...), nil
+ return namespace.New(cluster, initCmds...), nil
}
diff --git a/pkg/boot/discovery.go b/pkg/boot/discovery.go
index d76331c5e..669dfad1b 100644
--- a/pkg/boot/discovery.go
+++ b/pkg/boot/discovery.go
@@ -176,7 +176,6 @@ func (fp *discovery) GetCluster(ctx context.Context, cluster string) (*Cluster,
}
func (fp *discovery) ListTenants(ctx context.Context) ([]string, error) {
-
cfg, err := fp.c.Load()
if err != nil {
return nil, err
@@ -269,7 +268,7 @@ func (fp *discovery) ListTables(ctx context.Context, cluster string) ([]string,
}
var tables []string
- for tb, _ := range fp.loadTables(cfg, cluster) {
+ for tb := range fp.loadTables(cfg, cluster) {
tables = append(tables, tb)
}
sort.Strings(tables)
@@ -325,6 +324,7 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (*
var (
keys map[string]struct{}
dbSharder, tbSharder map[string]rule.ShardComputer
+ dbSteps, tbSteps map[string]int
)
for _, it := range table.DbRules {
var shd rule.ShardComputer
@@ -337,8 +337,12 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (*
if keys == nil {
keys = make(map[string]struct{})
}
+ if dbSteps == nil {
+ dbSteps = make(map[string]int)
+ }
dbSharder[it.Column] = shd
keys[it.Column] = struct{}{}
+ dbSteps[it.Column] = it.Step
}
for _, it := range table.TblRules {
@@ -352,8 +356,12 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (*
if keys == nil {
keys = make(map[string]struct{})
}
+ if tbSteps == nil {
+ tbSteps = make(map[string]int)
+ }
tbSharder[it.Column] = shd
keys[it.Column] = struct{}{}
+ tbSteps[it.Column] = it.Step
}
for k := range keys {
@@ -366,7 +374,9 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (*
Computer: shd,
Stepper: rule.DefaultNumberStepper,
}
- if dbBegin >= 0 && dbEnd >= 0 {
+ if s, ok := dbSteps[k]; ok && s > 0 {
+ dbMetadata.Steps = s
+ } else if dbBegin >= 0 && dbEnd >= 0 {
dbMetadata.Steps = 1 + dbEnd - dbBegin
}
}
@@ -375,14 +385,20 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (*
Computer: shd,
Stepper: rule.DefaultNumberStepper,
}
- if tbBegin >= 0 && tbEnd >= 0 {
+ if s, ok := tbSteps[k]; ok && s > 0 {
+ tbMetadata.Steps = s
+ } else if tbBegin >= 0 && tbEnd >= 0 {
tbMetadata.Steps = 1 + tbEnd - tbBegin
}
}
vt.SetShardMetadata(k, dbMetadata, tbMetadata)
tpRes := make(map[int][]int)
- rng, _ := tbMetadata.Stepper.Ascend(0, tbMetadata.Steps)
+ step := tbMetadata.Steps
+ if dbMetadata.Steps > step {
+ step = dbMetadata.Steps
+ }
+ rng, _ := tbMetadata.Stepper.Ascend(0, step)
for rng.HasNext() {
var (
seed = rng.Next()
@@ -472,7 +488,7 @@ var (
func getTopologyRegexp() *regexp.Regexp {
_regexpTopologyOnce.Do(func() {
- _regexpTopology = regexp.MustCompile(`\${(?P[0-9]+)\.\.\.(?P[0-9]+)}`)
+ _regexpTopology = regexp.MustCompile(`\${(?P\d+)\.{2,}(?P\d+)}`)
})
return _regexpTopology
}
@@ -492,9 +508,7 @@ func parseTopology(input string) (format string, begin, end int, err error) {
return
}
- var (
- beginStr, endStr string
- )
+ var beginStr, endStr string
for i := 1; i < len(mats[0]); i++ {
switch getTopologyRegexp().SubexpNames()[i] {
case "begin":
@@ -518,27 +532,29 @@ func parseTopology(input string) (format string, begin, end int, err error) {
func toSharder(input *config.Rule) (rule.ShardComputer, error) {
var (
computer rule.ShardComputer
- method string
mod int
err error
)
if mat := getRuleExprRegexp().FindStringSubmatch(input.Expr); len(mat) == 3 {
- method = mat[1]
mod, _ = strconv.Atoi(mat[2])
}
- switch method {
- case string(rrule.ModShard):
+ switch rrule.ShardType(input.Type) {
+ case rrule.ModShard:
computer = rrule.NewModShard(mod)
- case string(rrule.HashMd5Shard):
+ case rrule.HashMd5Shard:
computer = rrule.NewHashMd5Shard(mod)
- case string(rrule.HashBKDRShard):
+ case rrule.HashBKDRShard:
computer = rrule.NewHashBKDRShard(mod)
- case string(rrule.HashCrc32Shard):
+ case rrule.HashCrc32Shard:
computer = rrule.NewHashCrc32Shard(mod)
- default:
+ case rrule.FunctionExpr:
+ computer, err = rrule.NewExprShardComputer(input.Expr, input.Column)
+ case rrule.ScriptExpr:
computer, err = rrule.NewJavascriptShardComputer(input.Expr)
+ default:
+ panic(fmt.Errorf("error config, unsupport shard type: %s", input.Type))
}
return computer, err
}
diff --git a/pkg/config/api.go b/pkg/config/api.go
index 7bd8bb591..08105ac0f 100644
--- a/pkg/config/api.go
+++ b/pkg/config/api.go
@@ -53,7 +53,7 @@ const (
)
var (
- slots map[string]StoreOperate = make(map[string]StoreOperate)
+ slots = make(map[string]StoreOperate)
storeOperate StoreOperate
)
@@ -66,7 +66,6 @@ func GetStoreOperate() (StoreOperate, error) {
}
func Init(name string, options map[string]interface{}) error {
-
s, exist := slots[name]
if !exist {
return fmt.Errorf("StoreOperate solt=[%s] not exist", name)
@@ -77,7 +76,7 @@ func Init(name string, options map[string]interface{}) error {
return storeOperate.Init(options)
}
-//Register register store plugin
+// Register register store plugin
func Register(s StoreOperate) {
if _, ok := slots[s.Name()]; ok {
panic(fmt.Errorf("StoreOperate=[%s] already exist", s.Name()))
@@ -86,22 +85,22 @@ func Register(s StoreOperate) {
slots[s.Name()] = s
}
-//StoreOperate config storage related plugins
+// StoreOperate config storage related plugins
type StoreOperate interface {
io.Closer
- //Init plugin initialization
+ // Init plugin initialization
Init(options map[string]interface{}) error
- //Save save a configuration data
+ // Save save a configuration data
Save(key PathKey, val []byte) error
- //Get get a configuration
+ // Get get a configuration
Get(key PathKey) ([]byte, error)
- //Watch Monitor changes of the key
+ // Watch Monitor changes of the key
Watch(key PathKey) (<-chan []byte, error)
- //Name plugin name
+ // Name plugin name
Name() string
}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 5531b5531..e1343975b 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -36,6 +36,7 @@ import (
)
import (
+ "github.com/arana-db/arana/pkg/util/env"
"github.com/arana-db/arana/pkg/util/log"
)
@@ -136,7 +137,9 @@ func (c *Center) LoadContext(ctx context.Context) (*Configuration, error) {
c.confHolder.Store(cfg)
out, _ := yaml.Marshal(cfg)
- log.Debugf("load configuration:\n%s", string(out))
+ if env.IsDevelopEnvironment() {
+ log.Infof("load configuration:\n%s", string(out))
+ }
}
val = c.confHolder.Load()
@@ -255,7 +258,6 @@ func (c *Center) PersistContext(ctx context.Context) error {
}
for k, v := range ConfigKeyMapping {
-
if err := c.storeOperate.Save(k, []byte(gjson.GetBytes(configJson, v).String())); err != nil {
return err
}
diff --git a/pkg/config/etcd/etcd.go b/pkg/config/etcd/etcd.go
index caf1fda31..5c8bbacef 100644
--- a/pkg/config/etcd/etcd.go
+++ b/pkg/config/etcd/etcd.go
@@ -35,6 +35,7 @@ import (
import (
"github.com/arana-db/arana/pkg/config"
+ "github.com/arana-db/arana/pkg/util/env"
"github.com/arana-db/arana/pkg/util/log"
)
@@ -79,6 +80,10 @@ func (c *storeOperate) Get(key config.PathKey) ([]byte, error) {
return nil, err
}
+ if env.IsDevelopEnvironment() {
+ log.Infof("[ConfigCenter][etcd] load config content : %#v", v)
+ }
+
return []byte(v), nil
}
diff --git a/pkg/config/file/file.go b/pkg/config/file/file.go
index aa2bc276e..a8cac1abb 100644
--- a/pkg/config/file/file.go
+++ b/pkg/config/file/file.go
@@ -27,18 +27,19 @@ import (
import (
"github.com/pkg/errors"
-
"github.com/tidwall/gjson"
-
"gopkg.in/yaml.v3"
)
import (
"github.com/arana-db/arana/pkg/config"
"github.com/arana-db/arana/pkg/constants"
+ "github.com/arana-db/arana/pkg/util/env"
"github.com/arana-db/arana/pkg/util/log"
)
+var configFilenameList = []string{"config.yaml", "config.yml"}
+
func init() {
config.Register(&storeOperate{})
}
@@ -96,7 +97,9 @@ func (s *storeOperate) initCfgJsonMap(val string) {
s.cfgJson[k] = gjson.Get(val, v).String()
}
- log.Debugf("[ConfigCenter][File] load config content : %#v", s.cfgJson)
+ if env.IsDevelopEnvironment() {
+ log.Infof("[ConfigCenter][File] load config content : %#v", s.cfgJson)
+ }
}
func (s *storeOperate) Save(key config.PathKey, val []byte) error {
@@ -108,7 +111,7 @@ func (s *storeOperate) Get(key config.PathKey) ([]byte, error) {
return val, nil
}
-//Watch TODO change notification through file inotify mechanism
+// Watch TODO change notification through file inotify mechanism
func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) {
defer s.lock.Unlock()
@@ -164,13 +167,11 @@ func (s *storeOperate) readFromFile(path string, cfg *config.Configuration) erro
func (s *storeOperate) searchDefaultConfigFile() (string, bool) {
var p string
for _, it := range constants.GetConfigSearchPathList() {
- p = filepath.Join(it, "config.yaml")
- if _, err := os.Stat(p); err == nil {
- return p, true
- }
- p = filepath.Join(it, "config.yml")
- if _, err := os.Stat(p); err == nil {
- return p, true
+ for _, filename := range configFilenameList {
+ p = filepath.Join(it, filename)
+ if _, err := os.Stat(p); err == nil {
+ return p, true
+ }
}
}
return "", false
diff --git a/pkg/config/model.go b/pkg/config/model.go
index d1e77869e..d160ce285 100644
--- a/pkg/config/model.go
+++ b/pkg/config/model.go
@@ -20,10 +20,13 @@ package config
import (
"bytes"
"encoding/json"
+ "fmt"
"io"
"os"
"regexp"
"strconv"
+ "strings"
+ "time"
)
import (
@@ -78,31 +81,25 @@ type (
Type DataSourceType `yaml:"type" json:"type"`
SqlMaxLimit int `default:"-1" yaml:"sql_max_limit" json:"sql_max_limit,omitempty"`
Tenant string `yaml:"tenant" json:"tenant"`
- ConnProps *ConnProp `yaml:"conn_props" json:"conn_props,omitempty"`
Groups []*Group `yaml:"groups" json:"groups"`
}
- ConnProp struct {
- Capacity int `yaml:"capacity" json:"capacity,omitempty"` // connection pool capacity
- MaxCapacity int `yaml:"max_capacity" json:"max_capacity,omitempty"` // max connection pool capacity
- IdleTimeout int `yaml:"idle_timeout" json:"idle_timeout,omitempty"` // close backend direct connection after idle_timeout
- }
-
Group struct {
Name string `yaml:"name" json:"name"`
Nodes []*Node `yaml:"nodes" json:"nodes"`
}
Node struct {
- Name string `validate:"required" yaml:"name" json:"name"`
- Host string `validate:"required" yaml:"host" json:"host"`
- Port int `validate:"required" yaml:"port" json:"port"`
- Username string `validate:"required" yaml:"username" json:"username"`
- Password string `validate:"required" yaml:"password" json:"password"`
- Database string `validate:"required" yaml:"database" json:"database"`
- ConnProps map[string]string `yaml:"conn_props" json:"conn_props,omitempty"`
- Weight string `default:"r10w10" yaml:"weight" json:"weight"`
- Labels map[string]string `yaml:"labels" json:"labels,omitempty"`
+ Name string `validate:"required" yaml:"name" json:"name"`
+ Host string `validate:"required" yaml:"host" json:"host"`
+ Port int `validate:"required" yaml:"port" json:"port"`
+ Username string `validate:"required" yaml:"username" json:"username"`
+ Password string `validate:"required" yaml:"password" json:"password"`
+ Database string `validate:"required" yaml:"database" json:"database"`
+ Parameters ParametersMap `yaml:"parameters" json:"parameters"`
+ ConnProps map[string]interface{} `yaml:"conn_props" json:"conn_props,omitempty"`
+ Weight string `default:"r10w10" yaml:"weight" json:"weight"`
+ Labels map[string]string `yaml:"labels" json:"labels,omitempty"`
}
ShardingRule struct {
@@ -163,7 +160,9 @@ type (
Rule struct {
Column string `validate:"required" yaml:"column" json:"column"`
+ Type string `validate:"required" yaml:"type" json:"type"`
Expr string `validate:"required" yaml:"expr" json:"expr"`
+ Step int `yaml:"step" json:"step"`
}
Topology struct {
@@ -172,6 +171,19 @@ type (
}
)
+type ParametersMap map[string]string
+
+func (pm *ParametersMap) String() string {
+ sBuff := strings.Builder{}
+ for k, v := range *pm {
+ sBuff.WriteString(k)
+ sBuff.WriteString("=")
+ sBuff.WriteString(v)
+ sBuff.WriteString("&")
+ }
+ return strings.TrimRight(sBuff.String(), "&")
+}
+
// Decoder decodes configuration.
type Decoder struct {
reader io.Reader
@@ -262,3 +274,62 @@ func Validate(cfg *Configuration) error {
v := validator.New()
return v.Struct(cfg)
}
+
+// GetConnPropCapacity parses the capacity of backend connection pool, return default value if failed.
+func GetConnPropCapacity(connProps map[string]interface{}, defaultValue int) int {
+ capacity, ok := connProps["capacity"]
+ if !ok {
+ return defaultValue
+ }
+ n, _ := strconv.Atoi(fmt.Sprint(capacity))
+ if n < 1 {
+ return defaultValue
+ }
+ return n
+}
+
+// GetConnPropMaxCapacity parses the max capacity of backend connection pool, return default value if failed.
+func GetConnPropMaxCapacity(connProps map[string]interface{}, defaultValue int) int {
+ var (
+ maxCapacity interface{}
+ ok bool
+ )
+
+ if maxCapacity, ok = connProps["max_capacity"]; !ok {
+ if maxCapacity, ok = connProps["maxCapacity"]; !ok {
+ return defaultValue
+ }
+ }
+ n, _ := strconv.Atoi(fmt.Sprint(maxCapacity))
+ if n < 1 {
+ return defaultValue
+ }
+ return n
+}
+
+// GetConnPropIdleTime parses the idle time of backend connection pool, return default value if failed.
+func GetConnPropIdleTime(connProps map[string]interface{}, defaultValue time.Duration) time.Duration {
+ var (
+ idleTime interface{}
+ ok bool
+ )
+
+ if idleTime, ok = connProps["idle_time"]; !ok {
+ if idleTime, ok = connProps["idleTime"]; !ok {
+ return defaultValue
+ }
+ }
+
+ s := fmt.Sprint(idleTime)
+ d, _ := time.ParseDuration(s)
+ if d > 0 {
+ return d
+ }
+
+ n, _ := strconv.Atoi(s)
+ if n < 1 {
+ return defaultValue
+ }
+
+ return time.Duration(n) * time.Second
+}
diff --git a/pkg/config/model_test.go b/pkg/config/model_test.go
index baba8a461..188d72931 100644
--- a/pkg/config/model_test.go
+++ b/pkg/config/model_test.go
@@ -64,10 +64,6 @@ func TestDataSourceClustersConf(t *testing.T) {
assert.Equal(t, DBMySQL, dataSourceCluster.Type)
assert.Equal(t, -1, dataSourceCluster.SqlMaxLimit)
assert.Equal(t, "arana", dataSourceCluster.Tenant)
- assert.NotNil(t, dataSourceCluster.ConnProps)
- assert.Equal(t, 10, dataSourceCluster.ConnProps.Capacity)
- assert.Equal(t, 20, dataSourceCluster.ConnProps.MaxCapacity)
- assert.Equal(t, 60, dataSourceCluster.ConnProps.IdleTimeout)
assert.Equal(t, 1, len(dataSourceCluster.Groups))
group := dataSourceCluster.Groups[0]
@@ -112,22 +108,22 @@ func TestShardingRuleConf(t *testing.T) {
func TestUnmarshalTextForProtocolTypeNil(t *testing.T) {
var protocolType ProtocolType
- var text = []byte("http")
+ text := []byte("http")
err := protocolType.UnmarshalText(text)
assert.Nil(t, err)
assert.Equal(t, Http, protocolType)
}
func TestUnmarshalTextForUnrecognizedProtocolType(t *testing.T) {
- var protocolType = Http
- var text = []byte("PostgreSQL")
+ protocolType := Http
+ text := []byte("PostgreSQL")
err := protocolType.UnmarshalText(text)
assert.Error(t, err)
}
func TestUnmarshalText(t *testing.T) {
- var protocolType = Http
- var text = []byte("mysql")
+ protocolType := Http
+ text := []byte("mysql")
err := protocolType.UnmarshalText(text)
assert.Nil(t, err)
assert.Equal(t, MySQL, protocolType)
diff --git a/pkg/config/nacos/nacos.go b/pkg/config/nacos/nacos.go
index b4aeb18ea..b9a354005 100644
--- a/pkg/config/nacos/nacos.go
+++ b/pkg/config/nacos/nacos.go
@@ -34,6 +34,8 @@ import (
import (
"github.com/arana-db/arana/pkg/config"
+ "github.com/arana-db/arana/pkg/util/env"
+ "github.com/arana-db/arana/pkg/util/log"
)
const (
@@ -52,7 +54,7 @@ func init() {
config.Register(&storeOperate{})
}
-//StoreOperate config storage related plugins
+// StoreOperate config storage related plugins
type storeOperate struct {
groupName string
client config_client.IConfigClient
@@ -63,7 +65,7 @@ type storeOperate struct {
cancelList []context.CancelFunc
}
-//Init plugin initialization
+// Init plugin initialization
func (s *storeOperate) Init(options map[string]interface{}) error {
s.lock = &sync.RWMutex{}
s.cfgLock = &sync.RWMutex{}
@@ -96,7 +98,6 @@ func (s *storeOperate) initNacosClient(options map[string]interface{}) error {
ClientConfig: &clientConfig,
},
)
-
if err != nil {
return err
}
@@ -170,9 +171,8 @@ func (s *storeOperate) loadDataFromServer() error {
return nil
}
-//Save save a configuration data
+// Save save a configuration data
func (s *storeOperate) Save(key config.PathKey, val []byte) error {
-
_, err := s.client.PublishConfig(vo.ConfigParam{
Group: s.groupName,
DataId: string(key),
@@ -182,16 +182,20 @@ func (s *storeOperate) Save(key config.PathKey, val []byte) error {
return err
}
-//Get get a configuration
+// Get get a configuration
func (s *storeOperate) Get(key config.PathKey) ([]byte, error) {
defer s.cfgLock.RUnlock()
s.cfgLock.RLock()
val := []byte(s.confMap[key])
+
+ if env.IsDevelopEnvironment() {
+ log.Infof("[ConfigCenter][nacos] load config content : %#v", string(val))
+ }
return val, nil
}
-//Watch Monitor changes of the key
+// Watch Monitor changes of the key
func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) {
defer s.lock.Unlock()
s.lock.Lock()
@@ -217,12 +221,12 @@ func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) {
return rec, nil
}
-//Name plugin name
+// Name plugin name
func (s *storeOperate) Name() string {
return "nacos"
}
-//Close do close storeOperate
+// Close do close storeOperate
func (s *storeOperate) Close() error {
return nil
}
@@ -250,7 +254,6 @@ func (s *storeOperate) newWatcher(key config.PathKey, client config_client.IConf
s.receivers[config.PathKey(dataId)].ch <- []byte(content)
},
})
-
if err != nil {
return nil, err
}
diff --git a/pkg/config/nacos/nacos_test.go b/pkg/config/nacos/nacos_test.go
index e693c0a1c..0e4a5442b 100644
--- a/pkg/config/nacos/nacos_test.go
+++ b/pkg/config/nacos/nacos_test.go
@@ -72,7 +72,6 @@ type TestNacosClient struct {
}
func newNacosClient() *TestNacosClient {
-
client := &TestNacosClient{
listeners: make(map[string][]vo.Listener),
ch: make(chan vo.ConfigParam, 16),
@@ -82,7 +81,6 @@ func newNacosClient() *TestNacosClient {
go client.doLongPoll()
return client
-
}
func (client *TestNacosClient) doLongPoll() {
@@ -151,7 +149,7 @@ func (client *TestNacosClient) ListenConfig(params vo.ConfigParam) (err error) {
return nil
}
-//CancelListenConfig use to cancel listen config change
+// CancelListenConfig use to cancel listen config change
// dataId require
// group require
// tenant ==>nacos.namespace optional
diff --git a/pkg/constants/env.go b/pkg/constants/env.go
index 6830f0fac..60bd3a876 100644
--- a/pkg/constants/env.go
+++ b/pkg/constants/env.go
@@ -24,8 +24,9 @@ import (
// Environments
const (
- EnvBootstrapPath = "ARANA_BOOTSTRAP_PATH" // bootstrap file path, eg: /etc/arana/bootstrap.yaml
- EnvConfigPath = "ARANA_CONFIG_PATH" // config file path, eg: /etc/arana/config.yaml
+ EnvBootstrapPath = "ARANA_BOOTSTRAP_PATH" // bootstrap file path, eg: /etc/arana/bootstrap.yaml
+ EnvConfigPath = "ARANA_CONFIG_PATH" // config file path, eg: /etc/arana/config.yaml
+ EnvDevelopEnvironment = "ARANA_DEV" // config dev environment
)
// GetConfigSearchPathList returns the default search path list of configuration.
diff --git a/pkg/constants/mysql/type.go b/pkg/constants/mysql/type.go
index 343379fce..b729cca38 100644
--- a/pkg/constants/mysql/type.go
+++ b/pkg/constants/mysql/type.go
@@ -60,14 +60,6 @@ const (
FieldTypeBit
)
-const (
- FieldTypeUint8 FieldType = iota + 0x85
- FieldTypeUint16
- FieldTypeUint24
- FieldTypeUint32
- FieldTypeUint64
-)
-
const (
FieldTypeJSON FieldType = iota + 0xf5
FieldTypeNewDecimal
@@ -118,48 +110,14 @@ var mysqlToType = map[int64]FieldType{
255: FieldTypeGeometry,
}
-// modifyType modifies the vitess type based on the
-// mysql flag. The function checks specific flags based
-// on the type. This allows us to ignore stray flags
-// that MySQL occasionally sets.
-func modifyType(typ FieldType, flags int64) FieldType {
- switch typ {
- case FieldTypeTiny:
- if uint(flags)&UnsignedFlag != 0 {
- return FieldTypeUint8
- }
- return FieldTypeTiny
- case FieldTypeShort:
- if uint(flags)&UnsignedFlag != 0 {
- return FieldTypeUint16
- }
- return FieldTypeShort
- case FieldTypeLong:
- if uint(flags)&UnsignedFlag != 0 {
- return FieldTypeUint32
- }
- return FieldTypeLong
- case FieldTypeLongLong:
- if uint(flags)&UnsignedFlag != 0 {
- return FieldTypeUint64
- }
- return FieldTypeLongLong
- case FieldTypeInt24:
- if uint(flags)&UnsignedFlag != 0 {
- return FieldTypeUint24
- }
- return FieldTypeInt24
- }
- return typ
-}
-
// MySQLToType computes the vitess type from mysql type and flags.
-func MySQLToType(mysqlType, flags int64) (typ FieldType, err error) {
+func MySQLToType(mysqlType, flags int64) (FieldType, error) {
+ _ = flags
result, ok := mysqlToType[mysqlType]
if !ok {
return 0, fmt.Errorf("unsupported type: %d", mysqlType)
}
- return modifyType(result, flags), nil
+ return result, nil
}
// typeToMySQL is the reverse of MysqlToType.
@@ -168,19 +126,14 @@ var typeToMySQL = map[FieldType]struct {
flags int64
}{
FieldTypeTiny: {typ: 1},
- FieldTypeUint8: {typ: 1, flags: int64(UnsignedFlag)},
FieldTypeShort: {typ: 2},
- FieldTypeUint16: {typ: 2, flags: int64(UnsignedFlag)},
FieldTypeLong: {typ: 3},
- FieldTypeUint32: {typ: 3, flags: int64(UnsignedFlag)},
FieldTypeFloat: {typ: 4},
FieldTypeDouble: {typ: 5},
FieldTypeNULL: {typ: 6, flags: int64(BinaryFlag)},
FieldTypeTimestamp: {typ: 7},
FieldTypeLongLong: {typ: 8},
- FieldTypeUint64: {typ: 8, flags: int64(UnsignedFlag)},
FieldTypeInt24: {typ: 9},
- FieldTypeUint24: {typ: 9, flags: int64(UnsignedFlag)},
FieldTypeDate: {typ: 10, flags: int64(BinaryFlag)},
FieldTypeTime: {typ: 11, flags: int64(BinaryFlag)},
FieldTypeDateTime: {typ: 12, flags: int64(BinaryFlag)},
diff --git a/pkg/constants/mysql/type_test.go b/pkg/constants/mysql/type_test.go
index fa46ba001..2ed3dbdea 100644
--- a/pkg/constants/mysql/type_test.go
+++ b/pkg/constants/mysql/type_test.go
@@ -34,11 +34,8 @@ func TestMySQLToType(t *testing.T) {
{0, 0, FieldTypeDecimal},
{0, 32, FieldTypeDecimal},
{1, 0, FieldTypeTiny},
- {1, 32, FieldTypeUint8},
{2, 0, FieldTypeShort},
- {2, 32, FieldTypeUint16},
{3, 0, FieldTypeLong},
- {3, 32, FieldTypeUint32},
{4, 0, FieldTypeFloat},
{4, 32, FieldTypeFloat},
{5, 0, FieldTypeDouble},
@@ -48,9 +45,7 @@ func TestMySQLToType(t *testing.T) {
{7, 0, FieldTypeTimestamp},
{7, 32, FieldTypeTimestamp},
{8, 0, FieldTypeLongLong},
- {8, 32, FieldTypeUint64},
{9, 0, FieldTypeInt24},
- {9, 32, FieldTypeUint24},
{10, 0, FieldTypeDate},
{10, 32, FieldTypeDate},
{11, 0, FieldTypeTime},
@@ -108,18 +103,14 @@ func TestTypeToMySQL(t *testing.T) {
expectedFlags int64
}{
{FieldTypeTiny, int64(1), int64(0)},
- {FieldTypeUint8, int64(1), int64(UnsignedFlag)},
{FieldTypeShort, int64(2), int64(0)},
{FieldTypeLong, int64(3), int64(0)},
- {FieldTypeUint32, int64(3), int64(UnsignedFlag)},
{FieldTypeFloat, int64(4), int64(0)},
{FieldTypeDouble, int64(5), int64(0)},
{FieldTypeNULL, int64(6), int64(BinaryFlag)},
{FieldTypeTimestamp, int64(7), int64(0)},
{FieldTypeLongLong, int64(8), int64(0)},
- {FieldTypeUint64, int64(8), int64(UnsignedFlag)},
{FieldTypeInt24, int64(9), int64(0)},
- {FieldTypeUint24, int64(9), int64(UnsignedFlag)},
{FieldTypeDate, int64(10), int64(BinaryFlag)},
{FieldTypeTime, int64(11), int64(BinaryFlag)},
{FieldTypeDateTime, int64(12), int64(BinaryFlag)},
diff --git a/pkg/dataset/chain.go b/pkg/dataset/chain.go
new file mode 100644
index 000000000..54fb6b1fd
--- /dev/null
+++ b/pkg/dataset/chain.go
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+type pipeOption []func(proto.Dataset) proto.Dataset
+
+func Filter(predicate PredicateFunc) Option {
+ return func(option *pipeOption) {
+ *option = append(*option, func(prev proto.Dataset) proto.Dataset {
+ return FilterDataset{
+ Dataset: prev,
+ Predicate: predicate,
+ }
+ })
+ }
+}
+
+func Map(generateFields FieldsFunc, transform TransformFunc) Option {
+ return func(option *pipeOption) {
+ *option = append(*option, func(dataset proto.Dataset) proto.Dataset {
+ return &TransformDataset{
+ Dataset: dataset,
+ FieldsGetter: generateFields,
+ Transform: transform,
+ }
+ })
+ }
+}
+
+func GroupReduce(groups []OrderByItem, generateFields FieldsFunc, reducer func() Reducer) Option {
+ return func(option *pipeOption) {
+ *option = append(*option, func(dataset proto.Dataset) proto.Dataset {
+ return &GroupDataset{
+ Dataset: dataset,
+ keys: groups,
+ reducer: reducer,
+ fieldFunc: generateFields,
+ }
+ })
+ }
+}
+
+type Option func(*pipeOption)
+
+func Pipe(root proto.Dataset, options ...Option) proto.Dataset {
+ var o pipeOption
+ for _, it := range options {
+ it(&o)
+ }
+
+ next := root
+ for _, it := range o {
+ next = it(next)
+ }
+
+ return next
+}
diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go
new file mode 100644
index 000000000..e93b2ce7d
--- /dev/null
+++ b/pkg/dataset/dataset.go
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+// PeekableDataset represents a peekable dataset.
+type PeekableDataset interface {
+ proto.Dataset
+ // Peek peeks the next row, but will not consume it.
+ Peek() (proto.Row, error)
+}
+
+type RandomAccessDataset interface {
+ PeekableDataset
+ // Len returns the length of sub-datasets.
+ Len() int
+ // PeekN peeks the next row with specified index.
+ PeekN(index int) (proto.Row, error)
+ // SetNextN force sets the next index of row.
+ SetNextN(index int) error
+}
diff --git a/pkg/dataset/filter.go b/pkg/dataset/filter.go
new file mode 100644
index 000000000..4393209ba
--- /dev/null
+++ b/pkg/dataset/filter.go
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+var _ proto.Dataset = (*FilterDataset)(nil)
+
+type PredicateFunc func(proto.Row) bool
+
+type FilterDataset struct {
+ proto.Dataset
+ Predicate PredicateFunc
+}
+
+func (f FilterDataset) Next() (proto.Row, error) {
+ if f.Predicate == nil {
+ return f.Dataset.Next()
+ }
+
+ row, err := f.Dataset.Next()
+ if err != nil {
+ return nil, err
+ }
+
+ if !f.Predicate(row) {
+ return f.Next()
+ }
+
+ return row, nil
+}
diff --git a/pkg/dataset/filter_test.go b/pkg/dataset/filter_test.go
new file mode 100644
index 000000000..6fac08c37
--- /dev/null
+++ b/pkg/dataset/filter_test.go
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "database/sql"
+ "fmt"
+ "io"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+func TestFilter(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("name", consts.FieldTypeVarChar),
+ mysql.NewField("gender", consts.FieldTypeLong),
+ }
+ root := &VirtualDataset{
+ Columns: fields,
+ }
+
+ for i := 0; i < 10; i++ {
+ root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(i),
+ fmt.Sprintf("fake-name-%d", i),
+ int64(i & 1), // 0=female,1=male
+ }))
+ }
+
+ filtered := Pipe(root, Filter(func(row proto.Row) bool {
+ dest := make([]proto.Value, len(fields))
+ _ = row.Scan(dest)
+ var gender sql.NullInt64
+ _ = gender.Scan(dest[2])
+ assert.True(t, gender.Valid)
+ return gender.Int64 == 1
+ }))
+
+ for {
+ next, err := filtered.Next()
+ if err == io.EOF {
+ break
+ }
+ assert.NoError(t, err)
+
+ dest := make([]proto.Value, len(fields))
+ _ = next.Scan(dest)
+ assert.Equal(t, "1", fmt.Sprint(dest[2]))
+
+ t.Logf("id=%v, name=%v, gender=%v\n", dest[0], dest[1], dest[2])
+ }
+}
diff --git a/pkg/dataset/fuse.go b/pkg/dataset/fuse.go
new file mode 100644
index 000000000..820f92d54
--- /dev/null
+++ b/pkg/dataset/fuse.go
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "io"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+var _ proto.Dataset = (*FuseableDataset)(nil)
+
+type GenerateFunc func() (proto.Dataset, error)
+
+type FuseableDataset struct {
+ fields []proto.Field
+ current proto.Dataset
+ generators []GenerateFunc
+}
+
+func (fu *FuseableDataset) Close() error {
+ if fu.current == nil {
+ return nil
+ }
+ if err := fu.current.Close(); err != nil {
+ return errors.WithStack(err)
+ }
+ return nil
+}
+
+func (fu *FuseableDataset) Fields() ([]proto.Field, error) {
+ return fu.fields, nil
+}
+
+func (fu *FuseableDataset) Next() (proto.Row, error) {
+ if fu.current == nil {
+ return nil, io.EOF
+ }
+
+ var (
+ next proto.Row
+ err error
+ )
+
+ if next, err = fu.current.Next(); errors.Is(err, io.EOF) {
+ if err = fu.nextDataset(); err != nil {
+ return nil, err
+ }
+ return fu.Next()
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return next, nil
+}
+
+func (fu *FuseableDataset) nextDataset() error {
+ var err error
+ if err = fu.current.Close(); err != nil {
+ return errors.Wrap(err, "failed to close previous fused dataset")
+ }
+ fu.current = nil
+
+ if len(fu.generators) < 1 {
+ return io.EOF
+ }
+
+ gen := fu.generators[0]
+ fu.generators[0] = nil
+ fu.generators = fu.generators[1:]
+
+ if fu.current, err = gen(); err != nil {
+ return errors.Wrap(err, "failed to close previous fused dataset")
+ }
+
+ return nil
+}
+
+func (fu *FuseableDataset) ToParallel() RandomAccessDataset {
+ generators := make([]GenerateFunc, len(fu.generators)+1)
+ copy(generators[1:], fu.generators)
+ streams := make([]*peekableDataset, len(fu.generators)+1)
+ streams[0] = &peekableDataset{Dataset: fu.current}
+ result := ¶llelDataset{
+ fields: fu.fields,
+ generators: generators,
+ streams: streams,
+ }
+ return result
+}
+
+func Fuse(first GenerateFunc, others ...GenerateFunc) (proto.Dataset, error) {
+ current, err := first()
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to fuse datasets")
+ }
+
+ fields, err := current.Fields()
+ if err != nil {
+ defer func() {
+ _ = current.Close()
+ }()
+ return nil, errors.WithStack(err)
+ }
+
+ return &FuseableDataset{
+ fields: fields,
+ current: current,
+ generators: others,
+ }, nil
+}
diff --git a/pkg/dataset/fuse_test.go b/pkg/dataset/fuse_test.go
new file mode 100644
index 000000000..91cf862f1
--- /dev/null
+++ b/pkg/dataset/fuse_test.go
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "fmt"
+ "io"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+func TestFuse(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("name", consts.FieldTypeVarChar),
+ }
+
+ generate := func(offset, length int) proto.Dataset {
+ d := &VirtualDataset{
+ Columns: fields,
+ }
+
+ for i := offset; i < offset+length; i++ {
+ d.Rows = append(d.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(i),
+ fmt.Sprintf("fake-name-%d", i),
+ }))
+ }
+
+ return d
+ }
+
+ fuse, err := Fuse(
+ func() (proto.Dataset, error) {
+ return generate(0, 3), nil
+ },
+ func() (proto.Dataset, error) {
+ return generate(3, 4), nil
+ },
+ func() (proto.Dataset, error) {
+ return generate(7, 3), nil
+ },
+ )
+
+ assert.NoError(t, err)
+
+ var seq int
+ for {
+ next, err := fuse.Next()
+ if err == io.EOF {
+ break
+ }
+ assert.NoError(t, err)
+ values := make([]proto.Value, len(fields))
+ _ = next.Scan(values)
+ assert.Equal(t, fmt.Sprint(seq), fmt.Sprint(values[0]))
+
+ t.Logf("next: id=%v, name=%v\n", values[0], values[1])
+
+ seq++
+ }
+}
diff --git a/pkg/dataset/group_reduce.go b/pkg/dataset/group_reduce.go
new file mode 100644
index 000000000..cc1a86962
--- /dev/null
+++ b/pkg/dataset/group_reduce.go
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+ "sync"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/merge"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+var _ proto.Dataset = (*GroupDataset)(nil)
+
+// Reducer represents the way to reduce rows.
+type Reducer interface {
+ // Reduce reduces next row.
+ Reduce(next proto.Row) error
+ // Row returns the result row.
+ Row() proto.Row
+}
+
+type AggregateItem struct {
+ agg merge.Aggregator
+ idx int
+}
+
+type AggregateReducer struct {
+ AggItems map[int]merge.Aggregator
+ currentRow proto.Row
+ Fields []proto.Field
+}
+
+func NewGroupReducer(aggFuncMap map[int]func() merge.Aggregator, fields []proto.Field) *AggregateReducer {
+ aggItems := make(map[int]merge.Aggregator)
+ for idx, f := range aggFuncMap {
+ aggItems[idx] = f()
+ }
+ return &AggregateReducer{
+ AggItems: aggItems,
+ currentRow: nil,
+ Fields: fields,
+ }
+}
+
+func (gr *AggregateReducer) Reduce(next proto.Row) error {
+ var (
+ values = make([]proto.Value, len(gr.Fields))
+ result = make([]proto.Value, len(gr.Fields))
+ )
+ err := next.Scan(values)
+ if err != nil {
+ return err
+ }
+
+ for idx, aggregator := range gr.AggItems {
+ aggregator.Aggregate([]interface{}{values[idx]})
+ }
+
+ for i := 0; i < len(values); i++ {
+ if gr.AggItems[i] == nil {
+ result[i] = values[i]
+ } else {
+ aggResult, ok := gr.AggItems[i].GetResult()
+ if !ok {
+ return errors.New("can not aggregate value")
+ }
+ result[i] = aggResult
+ }
+ }
+
+ if next.IsBinary() {
+ gr.currentRow = rows.NewBinaryVirtualRow(gr.Fields, result)
+ } else {
+ gr.currentRow = rows.NewTextVirtualRow(gr.Fields, result)
+ }
+ return nil
+}
+
+func (gr *AggregateReducer) Row() proto.Row {
+ return gr.currentRow
+}
+
+type GroupDataset struct {
+ // Should be an orderedDataset
+ proto.Dataset
+ keys []OrderByItem
+
+ fieldFunc FieldsFunc
+ actualFieldsOnce sync.Once
+ actualFields []proto.Field
+ actualFieldsFailure error
+
+ keyIndexes []int
+ keyIndexesOnce sync.Once
+ keyIndexesFailure error
+
+ reducer func() Reducer
+
+ buf proto.Row
+ eof bool
+}
+
+func (gd *GroupDataset) Close() error {
+ return gd.Dataset.Close()
+}
+
+func (gd *GroupDataset) Fields() ([]proto.Field, error) {
+ gd.actualFieldsOnce.Do(func() {
+ if gd.fieldFunc == nil {
+ gd.actualFields, gd.actualFieldsFailure = gd.Dataset.Fields()
+ return
+ }
+
+ defer func() {
+ gd.fieldFunc = nil
+ }()
+
+ fields, err := gd.Dataset.Fields()
+ if err != nil {
+ gd.actualFieldsFailure = err
+ return
+ }
+ gd.actualFields = gd.fieldFunc(fields)
+ })
+
+ return gd.actualFields, gd.actualFieldsFailure
+}
+
+func (gd *GroupDataset) Next() (proto.Row, error) {
+ if gd.eof {
+ return nil, io.EOF
+ }
+
+ indexes, err := gd.getKeyIndexes()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ rowsChan = make(chan proto.Row, 1)
+ errChan = make(chan error, 1)
+ )
+
+ go func() {
+ defer close(rowsChan)
+ gd.consumeUntilDifferent(indexes, rowsChan, errChan)
+ }()
+
+ reducer := gd.reducer()
+
+L:
+ for {
+ select {
+ case next, ok := <-rowsChan:
+ if !ok {
+ break L
+ }
+ if err = reducer.Reduce(next); err != nil {
+ break L
+ }
+ case err = <-errChan:
+ break L
+ }
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return reducer.Row(), nil
+}
+
+func (gd *GroupDataset) consumeUntilDifferent(indexes []int, rowsChan chan<- proto.Row, errChan chan<- error) {
+ var (
+ next proto.Row
+ err error
+ )
+
+ for {
+ next, err = gd.Dataset.Next()
+ if errors.Is(err, io.EOF) {
+ gd.eof = true
+
+ if buf, ok := gd.popBuf(); ok {
+ rowsChan <- buf
+ }
+ break
+ }
+
+ if err != nil {
+ errChan <- err
+ break
+ }
+
+ prev, ok := gd.popBuf()
+ gd.buf = next
+ if !ok {
+ log.Debugf("begin next group: %s", gd.toDebugStr(next))
+ continue
+ }
+
+ ok, err = gd.isSameGroup(indexes, prev, next)
+
+ if err != nil {
+ errChan <- err
+ break
+ }
+
+ rowsChan <- prev
+
+ if !ok {
+ log.Debugf("begin next group: %s", gd.toDebugStr(next))
+ break
+ }
+ }
+}
+
+func (gd *GroupDataset) toDebugStr(next proto.Row) string {
+ var (
+ display []string
+ fields, _ = gd.Dataset.Fields()
+ indexes, _ = gd.getKeyIndexes()
+ dest = make([]proto.Value, len(fields))
+ )
+
+ _ = next.Scan(dest)
+ for _, it := range indexes {
+ display = append(display, fmt.Sprintf("%s:%v", fields[it].Name(), dest[it]))
+ }
+
+ return fmt.Sprintf("[%s]", strings.Join(display, ","))
+}
+
+func (gd *GroupDataset) isSameGroup(indexes []int, prev, next proto.Row) (bool, error) {
+ var (
+ fields, _ = gd.Dataset.Fields()
+ err error
+ )
+
+ // TODO: reduce scan times, maybe cache it.
+ var (
+ dest0 = make([]proto.Value, len(fields))
+ dest1 = make([]proto.Value, len(fields))
+ )
+ if err = prev.Scan(dest0); err != nil {
+ return false, errors.WithStack(err)
+ }
+ if err = next.Scan(dest1); err != nil {
+ return false, errors.WithStack(err)
+ }
+
+ equal := true
+ for _, index := range indexes {
+ // TODO: how to compare equality more effectively?
+ if !reflect.DeepEqual(dest0[index], dest1[index]) {
+ equal = false
+ break
+ }
+ }
+
+ return equal, nil
+}
+
+func (gd *GroupDataset) popBuf() (ret proto.Row, ok bool) {
+ if gd.buf != nil {
+ ret, gd.buf, ok = gd.buf, nil, true
+ }
+ return
+}
+
+// getKeyIndexes computes and holds the indexes of group keys.
+func (gd *GroupDataset) getKeyIndexes() ([]int, error) {
+ gd.keyIndexesOnce.Do(func() {
+ var (
+ fields []proto.Field
+ err error
+ )
+
+ if fields, err = gd.Dataset.Fields(); err != nil {
+ gd.keyIndexesFailure = err
+ return
+ }
+ gd.keyIndexes = make([]int, 0, len(gd.keys))
+ for _, key := range gd.keys {
+ idx := -1
+ for i := 0; i < len(fields); i++ {
+ if fields[i].Name() == key.Column {
+ idx = i
+ break
+ }
+ }
+ if idx == -1 {
+ gd.keyIndexesFailure = fmt.Errorf("cannot find group field '%+v'", key)
+ return
+ }
+ gd.keyIndexes = append(gd.keyIndexes, idx)
+ }
+ })
+
+ if gd.keyIndexesFailure != nil {
+ return nil, gd.keyIndexesFailure
+ }
+
+ return gd.keyIndexes, nil
+}
diff --git a/pkg/dataset/group_reduce_test.go b/pkg/dataset/group_reduce_test.go
new file mode 100644
index 000000000..4ae082f9c
--- /dev/null
+++ b/pkg/dataset/group_reduce_test.go
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "database/sql"
+ "fmt"
+ "sort"
+ "testing"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ vrows "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/rand2"
+)
+
+type sortBy struct {
+ rows [][]proto.Value
+ less func(a, b []proto.Value) bool
+}
+
+func (s sortBy) Len() int {
+ return len(s.rows)
+}
+
+func (s sortBy) Less(i, j int) bool {
+ return s.less(s.rows[i], s.rows[j])
+}
+
+func (s sortBy) Swap(i, j int) {
+ s.rows[i], s.rows[j] = s.rows[j], s.rows[i]
+}
+
+type fakeReducer struct {
+ fields []proto.Field
+ gender sql.NullInt64
+ cnt int64
+}
+
+func (fa *fakeReducer) Reduce(next proto.Row) error {
+ if !fa.gender.Valid {
+ gender, _ := next.(proto.KeyedRow).Get("gender")
+ fa.gender.Int64, fa.gender.Valid = gender.(int64), true
+ }
+ fa.cnt++
+ return nil
+}
+
+func (fa *fakeReducer) Row() proto.Row {
+ return vrows.NewTextVirtualRow(fa.fields, []proto.Value{
+ fa.gender,
+ fa.cnt,
+ })
+}
+
+func TestGroupReduce(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("name", consts.FieldTypeVarChar),
+ mysql.NewField("gender", consts.FieldTypeLong),
+ }
+
+ var origin VirtualDataset
+ origin.Columns = fields
+
+ var rows [][]proto.Value
+ for i := 0; i < 1000; i++ {
+ rows = append(rows, []proto.Value{
+ int64(i),
+ fmt.Sprintf("Fake %d", i),
+ rand2.Int63n(2),
+ })
+ }
+
+ s := sortBy{
+ rows: rows,
+ less: func(a, b []proto.Value) bool {
+ return a[2].(int64) < b[2].(int64)
+ },
+ }
+ sort.Sort(s)
+
+ for _, it := range rows {
+ origin.Rows = append(origin.Rows, vrows.NewTextVirtualRow(fields, it))
+ }
+
+ actualFields := []proto.Field{
+ fields[2],
+ mysql.NewField("amount", consts.FieldTypeLong),
+ }
+
+ // Simulate: SELECT gender,COUNT(*) AS amount FROM xxx WHERE ... GROUP BY gender
+ groups := []OrderByItem{{"gender", true}}
+ p := Pipe(&origin,
+ GroupReduce(
+ groups,
+ func(fields []proto.Field) []proto.Field {
+ return actualFields
+ },
+ func() Reducer {
+ return &fakeReducer{
+ fields: actualFields,
+ }
+ },
+ ),
+ )
+
+ for {
+ next, err := p.Next()
+ if err != nil {
+ break
+ }
+ v := make([]proto.Value, len(actualFields))
+ _ = next.Scan(v)
+ t.Logf("next: gender=%v, amount=%v\n", v[0], v[1])
+ }
+}
diff --git a/pkg/dataset/ordered.go b/pkg/dataset/ordered.go
new file mode 100644
index 000000000..262155808
--- /dev/null
+++ b/pkg/dataset/ordered.go
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "container/heap"
+ "io"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+type orderedDataset struct {
+ dataset RandomAccessDataset
+ queue *PriorityQueue
+ firstRow bool
+}
+
+func NewOrderedDataset(dataset RandomAccessDataset, items []OrderByItem) proto.Dataset {
+ return &orderedDataset{
+ dataset: dataset,
+ queue: NewPriorityQueue(make([]*RowItem, 0), items),
+ firstRow: true,
+ }
+}
+
+func (or *orderedDataset) Close() error {
+ return or.dataset.Close()
+}
+
+func (or *orderedDataset) Fields() ([]proto.Field, error) {
+ return or.dataset.Fields()
+}
+
+func (or *orderedDataset) Next() (proto.Row, error) {
+ if or.firstRow {
+ n := or.dataset.Len()
+ for i := 0; i < n; i++ {
+ or.dataset.SetNextN(i)
+ row, err := or.dataset.Next()
+ if err == io.EOF {
+ continue
+ } else if err != nil {
+ return nil, err
+ }
+ or.queue.Push(&RowItem{
+ row: row.(proto.KeyedRow),
+ streamIdx: i,
+ })
+ }
+ or.firstRow = false
+ }
+ if or.queue.Len() == 0 {
+ return nil, io.EOF
+ }
+ data := heap.Pop(or.queue)
+
+ item := data.(*RowItem)
+ or.dataset.SetNextN(item.streamIdx)
+ nextRow, err := or.dataset.Next()
+ if err == io.EOF {
+ return item.row, nil
+ } else if err != nil {
+ return nil, err
+ }
+ heap.Push(or.queue, &RowItem{
+ row: nextRow.(proto.KeyedRow),
+ streamIdx: item.streamIdx,
+ })
+
+ return item.row, nil
+}
diff --git a/pkg/dataset/ordered_test.go b/pkg/dataset/ordered_test.go
new file mode 100644
index 000000000..bd25ee0fe
--- /dev/null
+++ b/pkg/dataset/ordered_test.go
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "testing"
+)
+
+import (
+ "github.com/golang/mock/gomock"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestOrderedDataset(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ pd := generateFakeParallelDataset(ctrl, 0, 2, 2, 2, 1, 1)
+ items := []OrderByItem{
+ {
+ Column: "id",
+ Desc: false,
+ },
+ {
+ Column: "gender",
+ Desc: true,
+ },
+ }
+
+ od := NewOrderedDataset(pd, items)
+ var pojo fakePojo
+
+ row, err := od.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ t.Logf("next: %#v\n", pojo)
+ assert.Equal(t, int64(0), pojo.ID)
+
+ row, err = od.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ t.Logf("next: %#v\n", pojo)
+ assert.Equal(t, int64(1), pojo.ID)
+
+ row, err = od.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ t.Logf("next: %#v\n", pojo)
+ assert.Equal(t, int64(1), pojo.ID)
+
+ row, err = od.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ t.Logf("next: %#v\n", pojo)
+ assert.Equal(t, int64(2), pojo.ID)
+
+ row, err = od.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ t.Logf("next: %#v\n", pojo)
+ assert.Equal(t, int64(3), pojo.ID)
+}
diff --git a/pkg/dataset/parallel.go b/pkg/dataset/parallel.go
new file mode 100644
index 000000000..c3202c191
--- /dev/null
+++ b/pkg/dataset/parallel.go
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "io"
+ "sync"
+)
+
+import (
+ "github.com/pkg/errors"
+
+ uatomic "go.uber.org/atomic"
+
+ "golang.org/x/sync/errgroup"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+var (
+ _ PeekableDataset = (*peekableDataset)(nil)
+ _ RandomAccessDataset = (*parallelDataset)(nil)
+)
+
+type peekableDataset struct {
+ proto.Dataset
+ mu sync.Mutex
+ next proto.Row
+ err error
+}
+
+func (pe *peekableDataset) Next() (proto.Row, error) {
+ var (
+ next proto.Row
+ err error
+ )
+
+ pe.mu.Lock()
+ defer pe.mu.Unlock()
+
+ if next, err, pe.next, pe.err = pe.next, pe.err, nil, nil; next != nil || err != nil {
+ return next, err
+ }
+
+ return pe.Dataset.Next()
+}
+
+func (pe *peekableDataset) Peek() (proto.Row, error) {
+ var (
+ next proto.Row
+ err error
+ )
+
+ pe.mu.Lock()
+ defer pe.mu.Unlock()
+
+ if next, err = pe.next, pe.err; next != nil || err != nil {
+ return next, err
+ }
+
+ pe.next, pe.err = pe.Dataset.Next()
+
+ return pe.next, pe.err
+}
+
+type parallelDataset struct {
+ mu sync.RWMutex
+ fields []proto.Field
+ generators []GenerateFunc
+ streams []*peekableDataset
+ seq uatomic.Uint32
+}
+
+func (pa *parallelDataset) getStream(i int) (*peekableDataset, error) {
+ pa.mu.RLock()
+ if stream := pa.streams[i]; stream != nil {
+ pa.mu.RUnlock()
+ return stream, nil
+ }
+
+ pa.mu.RUnlock()
+
+ pa.mu.Lock()
+ defer pa.mu.Unlock()
+
+ if stream := pa.streams[i]; stream != nil {
+ return stream, nil
+ }
+
+ d, err := pa.generators[i]()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ pa.streams[i] = &peekableDataset{Dataset: d}
+ pa.generators[i] = nil
+ return pa.streams[i], nil
+}
+
+func (pa *parallelDataset) Peek() (proto.Row, error) {
+ i := pa.seq.Load() % uint32(pa.Len())
+ s, err := pa.getStream(int(i))
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ return s.Peek()
+}
+
+func (pa *parallelDataset) PeekN(index int) (proto.Row, error) {
+ if index < 0 || index >= pa.Len() {
+ return nil, errors.Errorf("index out of range: index=%d, length=%d", index, pa.Len())
+ }
+ s, err := pa.getStream(index)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ return s.Peek()
+}
+
+func (pa *parallelDataset) Close() error {
+ var g errgroup.Group
+ for i := 0; i < len(pa.streams); i++ {
+ i := i
+ g.Go(func() (err error) {
+ if pa.streams[i] != nil {
+ err = pa.streams[i].Close()
+ }
+
+ if err != nil {
+ log.Errorf("failed to close dataset#%d: %v", i, err)
+ }
+ return
+ })
+ }
+ return g.Wait()
+}
+
+func (pa *parallelDataset) Fields() ([]proto.Field, error) {
+ return pa.fields, nil
+}
+
+func (pa *parallelDataset) Next() (proto.Row, error) {
+ var (
+ s *peekableDataset
+ next proto.Row
+ err error
+ )
+ for j := 0; j < pa.Len(); j++ {
+ i := (pa.seq.Inc() - 1) % uint32(pa.Len())
+
+ if s, err = pa.getStream(int(i)); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ next, err = s.Next()
+ if err == nil {
+ break
+ }
+ if errors.Is(err, io.EOF) {
+ err = io.EOF
+ continue
+ }
+
+ return nil, errors.WithStack(err)
+ }
+
+ return next, err
+}
+
+func (pa *parallelDataset) Len() int {
+ return len(pa.streams)
+}
+
+func (pa *parallelDataset) SetNextN(index int) error {
+ if index < 0 || index >= pa.Len() {
+ return errors.Errorf("index out of range: index=%d, length=%d", index, pa.Len())
+ }
+ pa.seq.Store(uint32(index))
+ return nil
+}
+
+// Parallel creates a thread-safe dataset, which can be random-accessed in parallel.
+func Parallel(first GenerateFunc, others ...GenerateFunc) (RandomAccessDataset, error) {
+ current, err := first()
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create parallel datasets")
+ }
+
+ fields, err := current.Fields()
+ if err != nil {
+ defer func() {
+ _ = current.Close()
+ }()
+ return nil, errors.WithStack(err)
+ }
+
+ generators := make([]GenerateFunc, len(others)+1)
+ for i := 0; i < len(others); i++ {
+ if others[i] == nil {
+ return nil, errors.Errorf("nil dataset detected, index is %d", i+1)
+ }
+ generators[i+1] = others[i]
+ }
+
+ streams := make([]*peekableDataset, len(others)+1)
+ streams[0] = &peekableDataset{Dataset: current}
+
+ return ¶llelDataset{
+ fields: fields,
+ generators: generators,
+ streams: streams,
+ }, nil
+}
+
+// Peekable converts a dataset to a peekable one.
+func Peekable(origin proto.Dataset) PeekableDataset {
+ return &peekableDataset{
+ Dataset: origin,
+ }
+}
diff --git a/pkg/dataset/parallel_test.go b/pkg/dataset/parallel_test.go
new file mode 100644
index 000000000..a8bc50f21
--- /dev/null
+++ b/pkg/dataset/parallel_test.go
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "container/list"
+ "database/sql"
+ "fmt"
+ "io"
+ "math"
+ "sort"
+ "sync"
+ "testing"
+)
+
+import (
+ "github.com/golang/mock/gomock"
+
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/rand2"
+ "github.com/arana-db/arana/testdata"
+)
+
+type fakePojo struct {
+ ID int64
+ Name string
+ Gender int64
+}
+
+func scanPojo(row proto.Row, dest *fakePojo) error {
+ s := make([]proto.Value, 3)
+ if err := row.Scan(s); err != nil {
+ return err
+ }
+
+ var (
+ id sql.NullInt64
+ name sql.NullString
+ gender sql.NullInt64
+ )
+ _, _, _ = id.Scan(s[0]), name.Scan(s[1]), gender.Scan(s[2])
+
+ dest.ID = id.Int64
+ dest.Name = name.String
+ dest.Gender = gender.Int64
+
+ return nil
+}
+
+func generateFakeParallelDataset(ctrl *gomock.Controller, pairs ...int) RandomAccessDataset {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("name", consts.FieldTypeVarChar),
+ mysql.NewField("gender", consts.FieldTypeLong),
+ }
+
+ getFakeDataset := func(offset, n int) proto.Dataset {
+ l := list.New()
+ for i := 0; i < n; i++ {
+ r := rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(offset + i),
+ fmt.Sprintf("Fake %d", offset+i),
+ rand2.Int63n(2),
+ })
+ l.PushBack(r)
+ }
+
+ ds := testdata.NewMockDataset(ctrl)
+
+ ds.EXPECT().Close().Return(nil).AnyTimes()
+ ds.EXPECT().Fields().Return(fields, nil).AnyTimes()
+ ds.EXPECT().Next().
+ DoAndReturn(func() (proto.Row, error) {
+ head := l.Front()
+ if head == nil {
+ return nil, io.EOF
+ }
+ l.Remove(head)
+ return head.Value.(proto.Row), nil
+ }).
+ AnyTimes()
+
+ return ds
+ }
+
+ var gens []GenerateFunc
+
+ for i := 1; i < len(pairs); i += 2 {
+ offset := pairs[i-1]
+ n := pairs[i]
+ gens = append(gens, func() (proto.Dataset, error) {
+ return getFakeDataset(offset, n), nil
+ })
+ }
+
+ ret, _ := Parallel(gens[0], gens[1:]...)
+ return ret
+}
+
+func TestParallelDataset(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ pd := generateFakeParallelDataset(ctrl, 0, 3, 3, 3, 6, 3)
+
+ var pojo fakePojo
+
+ row, err := pd.PeekN(0)
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ assert.Equal(t, int64(0), pojo.ID)
+
+ row, err = pd.PeekN(1)
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ assert.Equal(t, int64(3), pojo.ID)
+
+ row, err = pd.PeekN(2)
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ assert.Equal(t, int64(6), pojo.ID)
+
+ err = pd.SetNextN(1)
+ assert.NoError(t, err)
+
+ row, err = pd.Peek()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ assert.Equal(t, int64(3), pojo.ID)
+
+ row, err = pd.Next()
+ assert.NoError(t, err)
+ assert.NoError(t, scanPojo(row, &pojo))
+ assert.Equal(t, int64(3), pojo.ID)
+
+ var cnt int
+
+ _ = scanPojo(row, &pojo)
+ t.Logf("first: %#v\n", pojo)
+ cnt++
+
+ for {
+ row, err = pd.Next()
+ if err != nil {
+ break
+ }
+
+ _ = scanPojo(row, &pojo)
+ t.Logf("next: %#v\n", pojo)
+ cnt++
+ }
+
+ assert.Equal(t, 9, cnt)
+}
+
+func TestParallelDataset_SortBy(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ pairs := []int{
+ 5, 8, // offset, size
+ 0, 3,
+ 2, 4,
+ 6, 2,
+ }
+
+ var (
+ d = generateFakeParallelDataset(ctrl, pairs...)
+
+ ids sort.IntSlice
+ )
+
+ for {
+ bingo := -1
+
+ var (
+ wg sync.WaitGroup
+ lock sync.Mutex
+ min int64 = math.MaxInt64
+ )
+
+ wg.Add(d.Len())
+
+ // search the next minimum value in parallel
+ for i := 0; i < d.Len(); i++ {
+ i := i
+ go func() {
+ defer wg.Done()
+
+ r, err := d.PeekN(i)
+ if err == io.EOF {
+ return
+ }
+
+ var pojo fakePojo
+ err = scanPojo(r, &pojo)
+ assert.NoError(t, err)
+
+ lock.Lock()
+ defer lock.Unlock()
+ if pojo.ID < min {
+ min = pojo.ID
+ bingo = i
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ if bingo == -1 {
+ break
+ }
+
+ err := d.SetNextN(bingo)
+ assert.NoError(t, err)
+
+ var pojo fakePojo
+
+ r, err := d.Next()
+ assert.NoError(t, err)
+ err = scanPojo(r, &pojo)
+ assert.NoError(t, err)
+
+ t.Logf("next: %#v\n", pojo)
+ ids = append(ids, int(pojo.ID))
+ }
+
+ assert.True(t, sort.IsSorted(ids), "values should be sorted")
+}
diff --git a/pkg/dataset/priority_queue.go b/pkg/dataset/priority_queue.go
new file mode 100644
index 000000000..d3483e1c2
--- /dev/null
+++ b/pkg/dataset/priority_queue.go
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "container/heap"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+import (
+ "golang.org/x/exp/constraints"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+type OrderByValue struct {
+ OrderValues map[string]interface{}
+}
+
+type RowItem struct {
+ row proto.KeyedRow
+ streamIdx int
+}
+
+type OrderByItem struct {
+ Column string
+ Desc bool
+}
+
+type PriorityQueue struct {
+ rows []*RowItem
+ orderByItems []OrderByItem
+}
+
+func NewPriorityQueue(rows []*RowItem, orderByItems []OrderByItem) *PriorityQueue {
+ pq := &PriorityQueue{
+ rows: rows,
+ orderByItems: orderByItems,
+ }
+ heap.Init(pq)
+ return pq
+}
+
+func (pq *PriorityQueue) Len() int {
+ return len(pq.rows)
+}
+
+func (pq *PriorityQueue) Less(i, j int) bool {
+ orderValues1 := &OrderByValue{
+ OrderValues: make(map[string]interface{}),
+ }
+ orderValues2 := &OrderByValue{
+ OrderValues: make(map[string]interface{}),
+ }
+ if i >= len(pq.rows) || j >= len(pq.rows) {
+ return false
+ }
+ row1 := pq.rows[i]
+ row2 := pq.rows[j]
+ for _, item := range pq.orderByItems {
+ val1, _ := row1.row.Get(item.Column)
+ val2, _ := row2.row.Get(item.Column)
+ orderValues1.OrderValues[item.Column] = val1
+ orderValues2.OrderValues[item.Column] = val2
+ }
+ return compare(orderValues1, orderValues2, pq.orderByItems) < 0
+}
+
+func (pq *PriorityQueue) Swap(i, j int) {
+ pq.rows[i], pq.rows[j] = pq.rows[j], pq.rows[i]
+}
+
+func (pq *PriorityQueue) Push(x interface{}) {
+ item := x.(*RowItem)
+ pq.rows = append(pq.rows, item)
+ pq.update()
+}
+
+func (pq *PriorityQueue) Pop() interface{} {
+ old := *pq
+ n := len(old.rows)
+ if n == 0 {
+ return nil
+ }
+ item := old.rows[n-1]
+ pq.rows = old.rows[0 : n-1]
+ return item
+}
+
+func (pq *PriorityQueue) update() {
+ heap.Fix(pq, pq.Len()-1)
+}
+
+func compare(a *OrderByValue, b *OrderByValue, orderByItems []OrderByItem) int {
+ for _, item := range orderByItems {
+ compare := compareTo(a.OrderValues[item.Column], b.OrderValues[item.Column], item.Desc)
+ if compare == 0 {
+ continue
+ }
+ return compare
+ }
+ return 0
+}
+
+func compareTo(a, b interface{}, desc bool) int {
+ if a == nil && b == nil {
+ return 0
+ }
+ if a == nil {
+ return -1
+ }
+ if b == nil {
+ return 1
+ }
+ // TODO Deal with case sensitive.
+ var (
+ result = 0
+ )
+ switch a.(type) {
+ case string:
+ result = compareValue(fmt.Sprintf("%v", a), fmt.Sprintf("%v", b))
+ case int8, int16, int32, int64:
+ a, _ := strconv.ParseInt(fmt.Sprintf("%v", a), 10, 64)
+ b, _ := strconv.ParseInt(fmt.Sprintf("%v", b), 10, 64)
+ result = compareValue(a, b)
+ case uint8, uint16, uint32, uint64:
+ a, _ := strconv.ParseUint(fmt.Sprintf("%v", a), 10, 64)
+ b, _ := strconv.ParseUint(fmt.Sprintf("%v", b), 10, 64)
+ result = compareValue(a, b)
+ case float32, float64:
+ a, _ := strconv.ParseFloat(fmt.Sprintf("%v", a), 64)
+ b, _ := strconv.ParseFloat(fmt.Sprintf("%v", b), 64)
+ result = compareValue(a, b)
+ case time.Time:
+ result = compareTime(a.(time.Time), b.(time.Time))
+ }
+ if desc {
+ return -1 * result
+ }
+ return result
+}
+
+func compareValue[T constraints.Ordered](a, b T) int {
+ switch {
+ case a > b:
+ return 1
+ case a < b:
+ return -1
+ default:
+ return 0
+ }
+}
+
+func compareTime(a, b time.Time) int {
+ switch {
+ case a.After(b):
+ return 1
+ case a.Before(b):
+ return -1
+ default:
+ return 0
+ }
+}
diff --git a/pkg/dataset/priority_queue_test.go b/pkg/dataset/priority_queue_test.go
new file mode 100644
index 000000000..e41d3b77c
--- /dev/null
+++ b/pkg/dataset/priority_queue_test.go
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "container/heap"
+ "database/sql"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+func TestPriorityQueue(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("score", consts.FieldTypeLong),
+ }
+ items := []OrderByItem{
+ {"id", false},
+ {"score", true},
+ }
+
+ r1 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(1),
+ int64(80),
+ }), 1}
+ r2 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(2),
+ int64(75),
+ }), 1}
+ r3 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(1),
+ int64(90),
+ }), 1}
+ r4 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(3),
+ int64(85),
+ }), 1}
+ pq := NewPriorityQueue([]*RowItem{
+ r1, r2, r3, r4,
+ }, items)
+
+ assertScorePojoEquals(t, fakeScorePojo{
+ id: int64(1),
+ score: int64(90),
+ }, heap.Pop(pq).(*RowItem).row)
+
+ assertScorePojoEquals(t, fakeScorePojo{
+ id: int64(1),
+ score: int64(80),
+ }, heap.Pop(pq).(*RowItem).row)
+
+ assertScorePojoEquals(t, fakeScorePojo{
+ id: int64(2),
+ score: int64(75),
+ }, heap.Pop(pq).(*RowItem).row)
+
+ assertScorePojoEquals(t, fakeScorePojo{
+ id: int64(3),
+ score: int64(85),
+ }, heap.Pop(pq).(*RowItem).row)
+}
+
+func assertScorePojoEquals(t *testing.T, expected fakeScorePojo, actual proto.Row) {
+ var pojo fakeScorePojo
+ err := scanScorePojo(actual, &pojo)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, pojo)
+}
+
+type fakeScorePojo struct {
+ id int64
+ score int64
+}
+
+func scanScorePojo(row proto.Row, dest *fakeScorePojo) error {
+ s := make([]proto.Value, 2)
+ if err := row.Scan(s); err != nil {
+ return err
+ }
+
+ var (
+ id sql.NullInt64
+ score sql.NullInt64
+ )
+ _, _ = id.Scan(s[0]), score.Scan(s[1])
+
+ dest.id = id.Int64
+ dest.score = score.Int64
+
+ return nil
+}
diff --git a/pkg/dataset/transform.go b/pkg/dataset/transform.go
new file mode 100644
index 000000000..43d4a2276
--- /dev/null
+++ b/pkg/dataset/transform.go
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "sync"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+var _ proto.Dataset = (*TransformDataset)(nil)
+
+type (
+ FieldsFunc func([]proto.Field) []proto.Field
+ TransformFunc func(proto.Row) (proto.Row, error)
+)
+
+type TransformDataset struct {
+ proto.Dataset
+ FieldsGetter FieldsFunc
+ Transform TransformFunc
+
+ actualFieldsOnce sync.Once
+ actualFields []proto.Field
+ actualFieldsFailure error
+}
+
+func (td *TransformDataset) Fields() ([]proto.Field, error) {
+ td.actualFieldsOnce.Do(func() {
+ origin, err := td.Dataset.Fields()
+ if err != nil {
+ td.actualFieldsFailure = err
+ return
+ }
+ if td.FieldsGetter == nil {
+ td.actualFields = origin
+ return
+ }
+
+ td.actualFields = td.FieldsGetter(origin)
+ })
+
+ return td.actualFields, td.actualFieldsFailure
+}
+
+func (td *TransformDataset) Next() (proto.Row, error) {
+ if td.Transform == nil {
+ return td.Dataset.Next()
+ }
+
+ var (
+ row proto.Row
+ err error
+ )
+
+ if row, err = td.Dataset.Next(); err != nil {
+ return nil, err
+ }
+
+ if row, err = td.Transform(row); err != nil {
+ return nil, errors.Wrap(err, "failed to transform dataset")
+ }
+
+ return row, nil
+}
diff --git a/pkg/dataset/transform_test.go b/pkg/dataset/transform_test.go
new file mode 100644
index 000000000..e5a714813
--- /dev/null
+++ b/pkg/dataset/transform_test.go
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dataset
+
+import (
+ "fmt"
+ "io"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/rand2"
+)
+
+func TestTransform(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("id", consts.FieldTypeLong),
+ mysql.NewField("name", consts.FieldTypeVarChar),
+ mysql.NewField("level", consts.FieldTypeLong),
+ }
+
+ root := &VirtualDataset{
+ Columns: fields,
+ }
+
+ for i := 0; i < 10; i++ {
+ root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
+ int64(i),
+ fmt.Sprintf("fake-name-%d", i),
+ rand2.Int63n(10),
+ }))
+ }
+
+ transformed := Pipe(root, Map(func(fields []proto.Field) []proto.Field {
+ return fields
+ }, func(row proto.Row) (proto.Row, error) {
+ dest := make([]proto.Value, len(fields))
+ _ = row.Scan(dest)
+ dest[2] = int64(100)
+ return rows.NewBinaryVirtualRow(fields, dest), nil
+ }))
+
+ for {
+ next, err := transformed.Next()
+ if err == io.EOF {
+ break
+ }
+
+ assert.NoError(t, err)
+
+ dest := make([]proto.Value, len(fields))
+ _ = next.Scan(dest)
+
+ assert.Equal(t, "100", fmt.Sprint(dest[2]))
+ }
+}
diff --git a/pkg/mysql/result_test.go b/pkg/dataset/virtual.go
similarity index 59%
rename from pkg/mysql/result_test.go
rename to pkg/dataset/virtual.go
index ff3a287da..de73474da 100644
--- a/pkg/mysql/result_test.go
+++ b/pkg/dataset/virtual.go
@@ -15,36 +15,40 @@
* limitations under the License.
*/
-package mysql
+package dataset
import (
- "testing"
+ "io"
)
import (
- "github.com/stretchr/testify/assert"
+ "github.com/arana-db/arana/pkg/proto"
)
-func TestLastInsertId(t *testing.T) {
- result := createResult()
- insertId, err := result.LastInsertId()
- assert.Equal(t, uint64(2000), insertId)
- assert.Nil(t, err)
+var _ proto.Dataset = (*VirtualDataset)(nil)
+
+type VirtualDataset struct {
+ Columns []proto.Field
+ Rows []proto.Row
+}
+
+func (cu *VirtualDataset) Close() error {
+ return nil
}
-func TestRowsAffected(t *testing.T) {
- result := createResult()
- affectedRows, err := result.RowsAffected()
- assert.Equal(t, uint64(10), affectedRows)
- assert.Nil(t, err)
+func (cu *VirtualDataset) Fields() ([]proto.Field, error) {
+ return cu.Columns, nil
}
-func createResult() *Result {
- result := &Result{
- Fields: nil,
- AffectedRows: uint64(10),
- InsertId: uint64(2000),
- Rows: nil,
+func (cu *VirtualDataset) Next() (proto.Row, error) {
+ if len(cu.Rows) < 1 {
+ return nil, io.EOF
}
- return result
+
+ next := cu.Rows[0]
+
+ cu.Rows[0] = nil
+ cu.Rows = cu.Rows[1:]
+
+ return next, nil
}
diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go
index 74fc62d35..3758c9f66 100644
--- a/pkg/executor/redirect.go
+++ b/pkg/executor/redirect.go
@@ -29,14 +29,18 @@ import (
"github.com/arana-db/parser/ast"
"github.com/pkg/errors"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/trace"
)
import (
mConstants "github.com/arana-db/arana/pkg/constants/mysql"
"github.com/arana-db/arana/pkg/metrics"
- "github.com/arana-db/arana/pkg/mysql"
mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/hint"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
"github.com/arana-db/arana/pkg/security"
@@ -44,6 +48,8 @@ import (
)
var (
+ Tracer = otel.Tracer("Executor")
+
errMissingTx = stdErrors.New("no transaction found")
errNoDatabaseSelected = mysqlErrors.NewSQLError(mConstants.ERNoDb, mConstants.SSNoDatabaseSelected, "No database selected")
)
@@ -121,6 +127,12 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto.
return nil, errors.WithStack(err)
}
+ if vt, ok := rt.Namespace().Rule().VTable(table); ok {
+ if _, atomTable, exist := vt.Topology().Render(0, 0); exist {
+ table = atomTable
+ }
+ }
+
db := rt.Namespace().DB0(ctx.Context)
if db == nil {
return nil, errors.New("cannot get physical backend connection")
@@ -130,6 +142,10 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto.
}
func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Result, uint16, error) {
+ var span trace.Span
+ ctx.Context, span = Tracer.Start(ctx.Context, "ExecutorComQuery")
+ defer span.End()
+
var (
schemaless bool // true if schema is not specified
err error
@@ -138,10 +154,20 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
p := parser.New()
query := ctx.GetQuery()
start := time.Now()
- act, err := p.ParseOneStmt(query, "", "")
+ act, hts, err := p.ParseOneStmtHints(query, "", "")
if err != nil {
- return nil, 0, err
+ return nil, 0, errors.WithStack(err)
}
+
+ var hints []*hint.Hint
+ for _, next := range hts {
+ var h *hint.Hint
+ if h, err = hint.Parse(next); err != nil {
+ return nil, 0, err
+ }
+ hints = append(hints, h)
+ }
+
metrics.ParserDuration.Observe(time.Since(start).Seconds())
log.Debugf("ComQuery: %s", query)
@@ -157,6 +183,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
}
ctx.Stmt = &proto.Stmt{
+ Hints: hints,
StmtNode: act,
}
@@ -181,7 +208,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
var tx proto.Tx
if tx, err = rt.Begin(ctx); err == nil {
executor.putTx(ctx, tx)
- res = &mysql.Result{}
+ res = resultx.New()
}
}
case *ast.CommitStmt:
@@ -231,14 +258,12 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
}
case *ast.ShowStmt:
allowSchemaless := func(stmt *ast.ShowStmt) bool {
- if stmt.Tp == ast.ShowDatabases {
- return true
- }
- if stmt.Tp == ast.ShowVariables {
+ switch stmt.Tp {
+ case ast.ShowDatabases, ast.ShowVariables, ast.ShowTopology:
return true
+ default:
+ return false
}
-
- return false
}
if !schemaless || allowSchemaless(stmt) { // only SHOW DATABASES is allowed in schemaless mode
@@ -246,24 +271,10 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
} else {
err = errNoDatabaseSelected
}
- case *ast.TruncateTableStmt:
- if schemaless {
- err = errNoDatabaseSelected
- } else {
- res, warn, err = rt.Execute(ctx)
- }
- case *ast.DropTableStmt:
- if schemaless {
- err = errNoDatabaseSelected
- } else {
- res, warn, err = rt.Execute(ctx)
- }
- case *ast.ExplainStmt:
- if schemaless {
- err = errNoDatabaseSelected
- } else {
- res, warn, err = rt.Execute(ctx)
- }
+ case *ast.TruncateTableStmt, *ast.DropTableStmt, *ast.ExplainStmt, *ast.DropIndexStmt, *ast.CreateIndexStmt:
+ res, warn, err = executeStmt(ctx, schemaless, rt)
+ case *ast.DropTriggerStmt:
+ res, warn, err = rt.Execute(ctx)
default:
if schemaless {
err = errNoDatabaseSelected
@@ -276,7 +287,6 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
res, warn, err = rt.Execute(ctx)
}
}
-
}
executor.doPostFilter(ctx, res)
@@ -284,6 +294,13 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
return res, warn, err
}
+func executeStmt(ctx *proto.Context, schemaless bool, rt runtime.Runtime) (proto.Result, uint16, error) {
+ if schemaless {
+ return nil, 0, errNoDatabaseSelected
+ }
+ return rt.Execute(ctx)
+}
+
func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (proto.Result, uint16, error) {
var (
executable proto.Executable
diff --git a/pkg/merge/aggregator.go b/pkg/merge/aggregator.go
index 984cda1be..169ab5836 100644
--- a/pkg/merge/aggregator.go
+++ b/pkg/merge/aggregator.go
@@ -23,6 +23,5 @@ import (
type Aggregator interface {
Aggregate(values []interface{})
-
GetResult() (*gxbig.Decimal, bool)
}
diff --git a/pkg/merge/aggregator/init.go b/pkg/merge/aggregator/init.go
new file mode 100644
index 000000000..e68452a3c
--- /dev/null
+++ b/pkg/merge/aggregator/init.go
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aggregator
+
+import (
+ "fmt"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/merge"
+)
+
+var aggregatorMap = make(map[string]func() merge.Aggregator)
+
+func init() {
+ aggregatorMap["MAX"] = func() merge.Aggregator { return &MaxAggregator{} }
+ aggregatorMap["MIN"] = func() merge.Aggregator { return &MinAggregator{} }
+ aggregatorMap["COUNT"] = func() merge.Aggregator { return &AddAggregator{} }
+ aggregatorMap["SUM"] = func() merge.Aggregator { return &AddAggregator{} }
+}
+
+func GetAggFromName(name string) func() merge.Aggregator {
+ if agg, ok := aggregatorMap[name]; ok {
+ return agg
+ }
+ panic(fmt.Errorf("aggregator %s not support yet", name))
+}
diff --git a/pkg/merge/aggregator/load_agg.go b/pkg/merge/aggregator/load_agg.go
new file mode 100644
index 000000000..4c416044e
--- /dev/null
+++ b/pkg/merge/aggregator/load_agg.go
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aggregator
+
+import (
+ "github.com/arana-db/arana/pkg/merge"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+)
+
+func LoadAggs(fields []ast.SelectElement) map[int]func() merge.Aggregator {
+ var aggMap = make(map[int]func() merge.Aggregator)
+ enter := func(i int, n *ast.AggrFunction) {
+ if n == nil {
+ return
+ }
+ aggMap[i] = GetAggFromName(n.Name())
+ }
+
+ for i, field := range fields {
+ if field == nil {
+ continue
+ }
+ if f, ok := field.(*ast.SelectElementFunction); ok {
+ enter(i, f.Function().(*ast.AggrFunction))
+ }
+ }
+
+ return aggMap
+}
diff --git a/pkg/merge/aggregator/max_aggregator.go b/pkg/merge/aggregator/max_aggregator.go
index 2ae8c5de3..e37e5e6fb 100644
--- a/pkg/merge/aggregator/max_aggregator.go
+++ b/pkg/merge/aggregator/max_aggregator.go
@@ -22,7 +22,7 @@ import (
)
type MaxAggregator struct {
- //max decimal.Decimal
+ // max decimal.Decimal
max *gxbig.Decimal
init bool
}
diff --git a/pkg/merge/impl/group_by/group_by_stream_merge_rows.go b/pkg/merge/impl/group/group_stream_merge_rows.go
similarity index 92%
rename from pkg/merge/impl/group_by/group_by_stream_merge_rows.go
rename to pkg/merge/impl/group/group_stream_merge_rows.go
index 8e52c530c..48ce59cfd 100644
--- a/pkg/merge/impl/group_by/group_by_stream_merge_rows.go
+++ b/pkg/merge/impl/group/group_stream_merge_rows.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package group_by
+package group
import (
"container/heap"
@@ -118,16 +118,16 @@ func (s *GroupByStreamMergeRows) merge() proto.Row {
}
}
- row := testdata.NewMockRow(gomock.NewController(nil))
+ row := testdata.NewMockKeyedRow(gomock.NewController(nil))
for _, sel := range s.stmt.Selects {
if _, ok := aggrMap[sel.Column]; ok {
res, _ := aggrMap[sel.Column].GetResult()
// TODO use row encode() to build a new row result
val, _ := res.ToInt()
- row.EXPECT().GetColumnValue(sel.Column).Return(val, nil).AnyTimes()
+ row.EXPECT().Get(sel.Column).Return(val, nil).AnyTimes()
} else {
- res, _ := currentRow.GetColumnValue(sel.Column)
- row.EXPECT().GetColumnValue(sel.Column).Return(res, nil).AnyTimes()
+ res, _ := currentRow.(proto.KeyedRow).Get(sel.Column)
+ row.EXPECT().Get(sel.Column).Return(res, nil).AnyTimes()
}
}
return row
@@ -136,7 +136,7 @@ func (s *GroupByStreamMergeRows) merge() proto.Row {
// todo not support Avg method yet
func (s *GroupByStreamMergeRows) aggregate(aggrMap map[string]merge.Aggregator, row proto.Row) {
for k, v := range aggrMap {
- val, err := row.GetColumnValue(k)
+ val, err := row.(proto.KeyedRow).Get(k)
if err != nil {
panic(err.Error())
}
@@ -145,7 +145,7 @@ func (s *GroupByStreamMergeRows) aggregate(aggrMap map[string]merge.Aggregator,
}
func (s *GroupByStreamMergeRows) hasNext() bool {
- //s.currentMergedRow = nil
+ // s.currentMergedRow = nil
s.currentRow = nil
if s.queue.Len() == 0 {
return false
diff --git a/pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go b/pkg/merge/impl/group/group_stream_merge_rows_test.go
similarity index 87%
rename from pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go
rename to pkg/merge/impl/group/group_stream_merge_rows_test.go
index 2ff735488..e589fda32 100644
--- a/pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go
+++ b/pkg/merge/impl/group/group_stream_merge_rows_test.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package group_by
+package group
import (
"testing"
@@ -44,7 +44,6 @@ type (
)
func TestGroupByStreamMergeRows(t *testing.T) {
-
stmt := MergeRowStatement{
OrderBys: []merge.OrderByItem{
{
@@ -77,13 +76,16 @@ func TestGroupByStreamMergeRows(t *testing.T) {
if row == nil {
break
}
- v1, _ := row.GetColumnValue(countScore)
- v2, _ := row.GetColumnValue(age)
+ v1, _ := row.(proto.KeyedRow).Get(countScore)
+ v2, _ := row.(proto.KeyedRow).Get(age)
res = append(res, student{countScore: v1.(int64), age: v2.(int64)})
}
assert.Equal(t, []student{
- {countScore: 175, age: 81}, {countScore: 160, age: 70}, {countScore: 75, age: 68},
- {countScore: 143, age: 60}, {countScore: 70, age: 40},
+ {countScore: 175, age: 81},
+ {countScore: 160, age: 70},
+ {countScore: 75, age: 68},
+ {countScore: 143, age: 60},
+ {countScore: 70, age: 40},
}, res)
}
@@ -98,9 +100,9 @@ func buildMergeRows(t *testing.T, vals [][]student) []*merge.MergeRows {
func buildMergeRow(t *testing.T, vals []student) *merge.MergeRows {
rows := make([]proto.Row, 0)
for _, val := range vals {
- row := testdata.NewMockRow(gomock.NewController(t))
+ row := testdata.NewMockKeyedRow(gomock.NewController(t))
for k, v := range val {
- row.EXPECT().GetColumnValue(k).Return(v, nil).AnyTimes()
+ row.EXPECT().Get(k).Return(v, nil).AnyTimes()
}
rows = append(rows, row)
}
diff --git a/pkg/merge/impl/group_by/group_by_value.go b/pkg/merge/impl/group/group_value.go
similarity index 96%
rename from pkg/merge/impl/group_by/group_by_value.go
rename to pkg/merge/impl/group/group_value.go
index 747583e5b..991c61422 100644
--- a/pkg/merge/impl/group_by/group_by_value.go
+++ b/pkg/merge/impl/group/group_value.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package group_by
+package group
import (
"fmt"
@@ -41,7 +41,7 @@ func buildGroupValues(groupByColumns []string, row proto.Row) []interface{} {
values := make([]interface{}, 0)
for _, column := range groupByColumns {
- value, err := row.GetColumnValue(column)
+ value, err := row.(proto.KeyedRow).Get(column)
if err != nil {
panic("get column value error:" + err.Error())
}
diff --git a/pkg/merge/merge_rows.go b/pkg/merge/merge_rows.go
index 53bf60ed5..80c353ca7 100644
--- a/pkg/merge/merge_rows.go
+++ b/pkg/merge/merge_rows.go
@@ -47,6 +47,10 @@ func (s *MergeRows) Next() proto.Row {
return result
}
+func (s *MergeRows) HasNext() bool {
+ return s.currentRowIndex < len(s.rows)-1
+}
+
func (s *MergeRows) GetCurrentRow() proto.Row {
if len(s.rows) == 0 || s.currentRowIndex > len(s.rows) {
return nil
diff --git a/pkg/merge/merge_rows_test.go b/pkg/merge/merge_rows_test.go
index 0bf705926..3ed14c374 100644
--- a/pkg/merge/merge_rows_test.go
+++ b/pkg/merge/merge_rows_test.go
@@ -51,8 +51,8 @@ func TestGetCurrentRow(t *testing.T) {
if row == nil {
break
}
- v1, _ := rows.GetCurrentRow().GetColumnValue(score)
- v2, _ := rows.GetCurrentRow().GetColumnValue(age)
+ v1, _ := rows.GetCurrentRow().(proto.KeyedRow).Get(score)
+ v2, _ := rows.GetCurrentRow().(proto.KeyedRow).Get(age)
res = append(res, student{score: v1.(int64), age: v2.(int64)})
}
@@ -62,9 +62,9 @@ func TestGetCurrentRow(t *testing.T) {
func buildMergeRow(t *testing.T, vals []student) *MergeRows {
rows := make([]proto.Row, 0)
for _, val := range vals {
- row := testdata.NewMockRow(gomock.NewController(t))
+ row := testdata.NewMockKeyedRow(gomock.NewController(t))
for k, v := range val {
- row.EXPECT().GetColumnValue(k).Return(v, nil).AnyTimes()
+ row.EXPECT().Get(k).Return(v, nil).AnyTimes()
}
rows = append(rows, row)
}
diff --git a/pkg/merge/priority_queue.go b/pkg/merge/priority_queue.go
index a6b75ae0d..014cc5a9c 100644
--- a/pkg/merge/priority_queue.go
+++ b/pkg/merge/priority_queue.go
@@ -22,6 +22,10 @@ import (
"fmt"
)
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
type PriorityQueue struct {
results []*MergeRows
orderByItems []OrderByItem
@@ -47,8 +51,11 @@ func (pq *PriorityQueue) Len() int {
func (pq *PriorityQueue) Less(i, j int) bool {
for _, item := range pq.orderByItems {
- val1, _ := pq.results[i].GetCurrentRow().GetColumnValue(item.Column)
- val2, _ := pq.results[j].GetCurrentRow().GetColumnValue(item.Column)
+ rowi := pq.results[i].GetCurrentRow().(proto.KeyedRow)
+ rowj := pq.results[j].GetCurrentRow().(proto.KeyedRow)
+
+ val1, _ := rowi.Get(item.Column)
+ val2, _ := rowj.Get(item.Column)
if val1 != val2 {
if item.Desc {
return fmt.Sprintf("%v", val1) > fmt.Sprintf("%v", val2)
diff --git a/pkg/merge/priority_queue_test.go b/pkg/merge/priority_queue_test.go
index caadf636c..9e4840f2e 100644
--- a/pkg/merge/priority_queue_test.go
+++ b/pkg/merge/priority_queue_test.go
@@ -26,6 +26,10 @@ import (
"github.com/stretchr/testify/assert"
)
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
func TestPriorityQueue(t *testing.T) {
rows := buildMergeRows(t, [][]student{
{{score: 85, age: 72}, {score: 75, age: 70}, {score: 65, age: 50}},
@@ -43,18 +47,25 @@ func TestPriorityQueue(t *testing.T) {
break
}
row := heap.Pop(&queue).(*MergeRows)
- v1, _ := row.GetCurrentRow().GetColumnValue(score)
- v2, _ := row.GetCurrentRow().GetColumnValue(age)
+ v1, _ := row.GetCurrentRow().(proto.KeyedRow).Get(score)
+ v2, _ := row.GetCurrentRow().(proto.KeyedRow).Get(age)
res = append(res, student{score: v1.(int64), age: v2.(int64)})
if row.Next() != nil {
queue.Push(row)
}
}
assert.Equal(t, []student{
- {score: 90, age: 72}, {score: 85, age: 72}, {score: 85, age: 70},
- {score: 78, age: 60}, {score: 75, age: 80}, {score: 75, age: 70},
- {score: 75, age: 68}, {score: 70, age: 40}, {score: 65, age: 80},
- {score: 65, age: 50}, {score: 60, age: 40},
+ {score: 90, age: 72},
+ {score: 85, age: 72},
+ {score: 85, age: 70},
+ {score: 78, age: 60},
+ {score: 75, age: 80},
+ {score: 75, age: 70},
+ {score: 75, age: 68},
+ {score: 70, age: 40},
+ {score: 65, age: 80},
+ {score: 65, age: 50},
+ {score: 60, age: 40},
}, res)
}
@@ -75,18 +86,26 @@ func TestPriorityQueue2(t *testing.T) {
break
}
row := heap.Pop(&queue).(*MergeRows)
- v1, _ := row.GetCurrentRow().GetColumnValue(score)
- v2, _ := row.GetCurrentRow().GetColumnValue(age)
+
+ v1, _ := row.GetCurrentRow().(proto.KeyedRow).Get(score)
+ v2, _ := row.GetCurrentRow().(proto.KeyedRow).Get(age)
res = append(res, student{score: v1.(int64), age: v2.(int64)})
if row.Next() != nil {
queue.Push(row)
}
}
assert.Equal(t, []student{
- {score: 60, age: 40}, {score: 65, age: 50}, {score: 65, age: 80},
- {score: 70, age: 40}, {score: 75, age: 68}, {score: 75, age: 70},
- {score: 75, age: 80}, {score: 78, age: 60}, {score: 85, age: 70},
- {score: 85, age: 72}, {score: 90, age: 72},
+ {score: 60, age: 40},
+ {score: 65, age: 50},
+ {score: 65, age: 80},
+ {score: 70, age: 40},
+ {score: 75, age: 68},
+ {score: 75, age: 70},
+ {score: 75, age: 80},
+ {score: 78, age: 60},
+ {score: 85, age: 70},
+ {score: 85, age: 72},
+ {score: 90, age: 72},
}, res)
}
diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go
index 691898ad4..97a1a8463 100644
--- a/pkg/metrics/metrics.go
+++ b/pkg/metrics/metrics.go
@@ -27,7 +27,7 @@ var (
Subsystem: "parser",
Name: "duration_seconds",
Help: "histogram of processing time (s) in parse SQL.",
- Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), //40us ~ 11min
+ Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), // 40us ~ 11min
})
OptimizeDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
@@ -35,7 +35,7 @@ var (
Subsystem: "optimizer",
Name: "duration_seconds",
Help: "histogram of processing time (s) in optimizer.",
- Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), //40us ~ 11min
+ Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), // 40us ~ 11min
})
ExecuteDuration = prometheus.NewHistogram(prometheus.HistogramOpts{
@@ -43,7 +43,7 @@ var (
Subsystem: "executor",
Name: "duration_seconds",
Help: "histogram of processing time (s) in execute.",
- Buckets: prometheus.ExponentialBuckets(0.0001, 2, 30), //100us ~ 15h,
+ Buckets: prometheus.ExponentialBuckets(0.0001, 2, 30), // 100us ~ 15h,
})
)
diff --git a/pkg/mysql/client.go b/pkg/mysql/client.go
index cb5679360..307df3839 100644
--- a/pkg/mysql/client.go
+++ b/pkg/mysql/client.go
@@ -28,7 +28,6 @@ import (
"math/big"
"net"
"net/url"
- "strconv"
"strings"
"time"
)
@@ -37,6 +36,7 @@ import (
"github.com/arana-db/arana/pkg/constants/mysql"
err2 "github.com/arana-db/arana/pkg/mysql/errors"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/bytefmt"
"github.com/arana-db/arana/pkg/util/log"
"github.com/arana-db/arana/third_party/pools"
)
@@ -414,10 +414,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return
}
case "maxAllowedPacket":
- cfg.MaxAllowedPacket, err = strconv.Atoi(value)
+ byteSize, err := bytefmt.ToBytes(value)
if err != nil {
- return
+ return err
}
+ cfg.MaxAllowedPacket = int(byteSize)
default:
// lazy init
if cfg.Params == nil {
@@ -642,7 +643,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32,
if !ok {
return 0, nil, "", err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (lower 2 bytes)")
}
- var capabilities = uint32(capLower)
+ capabilities := uint32(capLower)
// The packet can end here.
if pos == len(data) {
@@ -727,7 +728,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32,
// Returns a SQLError.
func (conn *BackendConnection) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, plugin string) error {
// Build our flags.
- var flags = mysql.CapabilityClientLongPassword |
+ flags := mysql.CapabilityClientLongPassword |
mysql.CapabilityClientLongFlag |
mysql.CapabilityClientProtocol41 |
mysql.CapabilityClientTransactions |
@@ -747,16 +748,15 @@ func (conn *BackendConnection) writeHandshakeResponse41(capabilities uint32, scr
// FIXME(alainjobart) add multi statement.
- length :=
- 4 + // Client capability flags.
- 4 + // Max-packet size.
- 1 + // Character set.
- 23 + // Reserved.
- lenNullString(conn.conf.User) +
- // length of scrambled password is handled below.
- len(scrambledPassword) +
- 21 + // "mysql_native_password" string.
- 1 // terminating zero.
+ length := 4 + // Client capability flags.
+ 4 + // Max-packet size.
+ 1 + // Character set.
+ 23 + // Reserved.
+ lenNullString(conn.conf.User) +
+ // length of scrambled password is handled below.
+ len(scrambledPassword) +
+ 21 + // "mysql_native_password" string.
+ 1 // terminating zero.
// Add the DB name if the server supports it.
if conn.conf.DBName != "" && (capabilities&mysql.CapabilityClientConnectWithDB != 0) {
@@ -858,6 +858,7 @@ func (conn *BackendConnection) WriteComSetOption(operation uint16) error {
return nil
}
+// WriteComFieldList https://dev.mysql.com/doc/internals/en/com-field-list.html
func (conn *BackendConnection) WriteComFieldList(table string, wildcard string) error {
conn.c.sequence = 0
length := lenNullString(table) + lenNullString(wildcard)
@@ -871,10 +872,9 @@ func (conn *BackendConnection) WriteComFieldList(table string, wildcard string)
pos = writeByte(data, 0, mysql.ComFieldList)
if len(wildcard) > 0 {
pos = writeNullString(data, pos, table)
- writeNullString(data, pos, wildcard)
- } else {
- pos = writeEOFString(data, pos, table)
writeEOFString(data, pos, wildcard)
+ } else {
+ writeNullString(data, pos, table)
}
if err := conn.c.writeEphemeralPacket(); err != nil {
@@ -886,7 +886,7 @@ func (conn *BackendConnection) WriteComFieldList(table string, wildcard string)
func (conn *BackendConnection) readResultSetHeaderPacket() (affectedRows, lastInsertID uint64, colNumber int, more bool, warnings uint16, err error) {
// Get the result.
- affectedRows, lastInsertID, colNumber, more, warning, err := conn.ReadComQueryResponse()
+ affectedRows, lastInsertID, colNumber, more, warning, err := conn.readComQueryResponse()
if err != nil {
return affectedRows, lastInsertID, colNumber, more, warning, err
}
@@ -908,174 +908,18 @@ func (conn *BackendConnection) readResultSetColumnsPacket(colNumber int) (column
}
// ReadQueryRow returns iterator, and the line reads the results set
-func (conn *BackendConnection) ReadQueryRow() (iter *IterRow, affectedRows, lastInsertID uint64, more bool, warnings uint16, err error) {
- iterRow := &IterRow{BackendConnection: conn,
- Row: &Row{
- Content: []byte{}, ResultSet: &ResultSet{
- Columns: make([]proto.Field, 0),
- }},
- hasNext: true,
- }
- affectedRows, lastInsertID, colNumber, more, warning, err := conn.readResultSetHeaderPacket()
- if err != nil {
- return iterRow, affectedRows, lastInsertID, more, warning, err
- }
- if colNumber == 0 {
- // OK packet, means no results. Just use the numbers.
- return iterRow, affectedRows, lastInsertID, more, warning, nil
- }
-
- // Read column headers. One packet per column.
- // Build the fields.
- columns := make([]proto.Field, colNumber)
- for i := 0; i < colNumber; i++ {
- field := &Field{}
- columns[i] = field
- if err = conn.ReadColumnDefinition(field, i); err != nil {
- return
- }
- }
-
- iterRow.Row.ResultSet.Columns = columns
-
- if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
- // EOF is only present here if it's not deprecated.
- data, err := conn.c.readEphemeralPacket()
- if err != nil {
- return nil, affectedRows, lastInsertID, more, warning, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err)
- }
- if isEOFPacket(data) {
-
- // This is what we expect.
- // Warnings and status flags are ignored.
- conn.c.recycleReadPacket()
- // goto: read row loop
-
- } else if isErrorPacket(data) {
- defer conn.c.recycleReadPacket()
- return nil, affectedRows, lastInsertID, more, warning, ParseErrorPacket(data)
- } else {
- defer conn.c.recycleReadPacket()
- return nil, affectedRows, lastInsertID, more, warning, fmt.Errorf("unexpected packet after fields: %v", data)
- }
- }
-
- return iterRow, affectedRows, lastInsertID, more, warnings, err
+func (conn *BackendConnection) ReadQueryRow() *RawResult {
+ return newResult(conn)
}
// ReadQueryResult gets the result from the last written query.
-func (conn *BackendConnection) ReadQueryResult(wantFields bool) (result *Result, more bool, warnings uint16, err error) {
- // Get the result.
- affectedRows, lastInsertID, colNumber, more, warnings, err := conn.ReadComQueryResponse()
- if err != nil {
- return nil, false, 0, err
- }
-
- if colNumber == 0 {
- // OK packet, means no results. Just use the numbers.
- return &Result{
- AffectedRows: affectedRows,
- InsertId: lastInsertID,
- }, more, warnings, nil
- }
-
- result = &Result{
- Fields: make([]proto.Field, colNumber),
- }
-
- // Read column headers. One packet per column.
- // Build the fields.
- for i := 0; i < colNumber; i++ {
- field := &Field{}
- result.Fields[i] = field
-
- if wantFields {
- if err := conn.ReadColumnDefinition(field, i); err != nil {
- return nil, false, 0, err
- }
- } else {
- if err := conn.ReadColumnDefinitionType(field, i); err != nil {
- return nil, false, 0, err
- }
- }
- }
-
- if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
- // EOF is only present here if it's not deprecated.
- data, err := conn.c.readEphemeralPacket()
- if err != nil {
- return nil, false, 0, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err)
- }
- if isEOFPacket(data) {
-
- // This is what we expect.
- // Warnings and status flags are ignored.
- conn.c.recycleReadPacket()
- // goto: read row loop
-
- } else if isErrorPacket(data) {
- defer conn.c.recycleReadPacket()
- return nil, false, 0, ParseErrorPacket(data)
- } else {
- defer conn.c.recycleReadPacket()
- return nil, false, 0, fmt.Errorf("unexpected packet after fields: %v", data)
- }
- }
-
- // read each row until EOF or OK packet.
- for {
- data, err := conn.c.ReadPacket()
- if err != nil {
- return nil, false, 0, err
- }
-
- if isEOFPacket(data) {
- // Strip the partial Fields before returning.
- if !wantFields {
- result.Fields = nil
- }
- result.AffectedRows = uint64(len(result.Rows))
-
- // The deprecated EOF packets change means that this is either an
- // EOF packet or an OK packet with the EOF type code.
- if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
- warnings, more, err = parseEOFPacket(data)
- if err != nil {
- return nil, false, 0, err
- }
- } else {
- var statusFlags uint16
- _, _, statusFlags, warnings, err = parseOKPacket(data)
- if err != nil {
- return nil, false, 0, err
- }
- more = (statusFlags & mysql.ServerMoreResultsExists) != 0
- }
- return result, more, warnings, nil
-
- } else if isErrorPacket(data) {
- // Error packet.
- return nil, false, 0, ParseErrorPacket(data)
- }
-
- //// Check we're not over the limit before we add more.
- //if len(result.Rows) == maxrows {
- // if err := conn.DrainResults(); err != nil {
- // return nil, false, 0, err
- // }
- // return nil, false, 0, err2.NewSQLError(mysql.ERVitessMaxRowsExceeded, mysql.SSUnknownSQLState, "Row count exceeded %d")
- //}
-
- // Regular row.
- row, err := conn.parseRow(data, result.Fields)
- if err != nil {
- return nil, false, 0, err
- }
- result.Rows = append(result.Rows, row)
- }
+func (conn *BackendConnection) ReadQueryResult(wantFields bool) proto.Result {
+ ret := conn.ReadQueryRow()
+ ret.setWantFields(wantFields)
+ return ret
}
-func (conn *BackendConnection) ReadComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) {
+func (conn *BackendConnection) readComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) {
data, err := conn.c.readEphemeralPacket()
if err != nil {
return 0, 0, 0, false, 0, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err)
@@ -1107,7 +951,7 @@ func (conn *BackendConnection) ReadComQueryResponse() (affectedRows uint64, last
}
// ReadColumnDefinition reads the next Column Definition packet.
-// Returns a SQLError.
+// Returns a SQLError. https://dev.mysql.com/doc/internals/en/com-query-response.html#column-definition
func (conn *BackendConnection) ReadColumnDefinition(field *Field, index int) error {
colDef, err := conn.c.readEphemeralPacket()
if err != nil {
@@ -1188,20 +1032,20 @@ func (conn *BackendConnection) ReadColumnDefinition(field *Field, index int) err
}
field.decimals = decimals
- //if more Content, command was field list
- if len(colDef) > pos+8 {
- //length of default value lenenc-int
- field.defaultValueLength, pos, ok = readUint64(colDef, pos)
- if !ok {
- return err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "extracting col %v default value failed", index)
- }
+ // filter [0x00][0x00]
+ pos += 2
+
+ // if more Content, command was field list
+ if len(colDef) > pos {
+ // length of default value lenenc-int
+ field.defaultValueLength, pos = readComFieldListDefaultValueLength(colDef, pos)
if pos+int(field.defaultValueLength) > len(colDef) {
return err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "extracting col %v default value failed", index)
}
- //default value string[$len]
- field.defaultValue = colDef[pos:(pos + int(field.defaultValueLength))]
+ // default value string[$len]
+ field.defaultValue = append(field.defaultValue, colDef[pos:(pos+int(field.defaultValueLength))]...)
}
return nil
}
@@ -1281,18 +1125,6 @@ func (conn *BackendConnection) ReadColumnDefinitionType(field *Field, index int)
return nil
}
-// parseRow parses an individual row.
-// Returns a SQLError.
-func (conn *BackendConnection) parseRow(data []byte, fields []proto.Field) (proto.Row, error) {
- row := &Row{
- Content: data,
- ResultSet: &ResultSet{
- Columns: fields,
- },
- }
- return row, nil
-}
-
// DrainResults will read all packets for a result set and ignore them.
func (conn *BackendConnection) DrainResults() error {
for {
@@ -1303,9 +1135,12 @@ func (conn *BackendConnection) DrainResults() error {
if isEOFPacket(data) {
conn.c.recycleReadPacket()
return nil
- } else if isErrorPacket(data) {
- defer conn.c.recycleReadPacket()
- return ParseErrorPacket(data)
+ }
+
+ if isErrorPacket(data) {
+ err = ParseErrorPacket(data)
+ conn.c.recycleReadPacket()
+ return err
}
conn.c.recycleReadPacket()
}
@@ -1317,7 +1152,7 @@ func (conn *BackendConnection) ReadColumnDefinitions() ([]proto.Field, error) {
for {
field := &Field{}
err := conn.ReadColumnDefinition(field, i)
- if err == io.EOF {
+ if errors.Is(err, io.EOF) {
return result, nil
}
if err != nil {
@@ -1328,59 +1163,9 @@ func (conn *BackendConnection) ReadColumnDefinitions() ([]proto.Field, error) {
}
}
-// Execute executes a query and returns the result.
-// Returns a SQLError. Depending on the transport used, the error
-// returned might be different for the same condition:
-//
-// 1. if the server closes the connection when no command is in flight:
-//
-// 1.1 unix: WriteComQuery will fail with a 'broken pipe', and we'll
-// return CRServerGone(2006).
-//
-// 1.2 tcp: WriteComQuery will most likely work, but ReadComQueryResponse
-// will fail, and we'll return CRServerLost(2013).
-//
-// This is because closing a TCP socket on the server side sends
-// a FIN to the client (telling the client the server is done
-// writing), but on most platforms doesn't send a RST. So the
-// client has no idea it can't write. So it succeeds writing Content, which
-// *then* triggers the server to send a RST back, received a bit
-// later. By then, the client has already started waiting for
-// the response, and will just return a CRServerLost(2013).
-// So CRServerGone(2006) will almost never be seen with TCP.
-//
-// 2. if the server closes the connection when a command is in flight,
-// ReadComQueryResponse will fail, and we'll return CRServerLost(2013).
-func (conn *BackendConnection) Execute(query string, wantFields bool) (result *Result, err error) {
- result, _, err = conn.ExecuteMulti(query, wantFields)
- return
-}
-
-// ExecuteMulti is for fetching multiple results from a multi-statement result.
-// It returns an additional 'more' flag. If it is set, you must fetch the additional
-// results using ReadQueryResult.
-func (conn *BackendConnection) ExecuteMulti(query string, wantFields bool) (result *Result, more bool, err error) {
- defer func() {
- if err != nil {
- if sqlerr, ok := err.(*err2.SQLError); ok {
- sqlerr.Query = query
- }
- }
- }()
-
- // Send the query as a COM_QUERY packet.
- if err = conn.WriteComQuery(query); err != nil {
- return nil, false, err
- }
-
- result, more, _, err = conn.ReadQueryResult(wantFields)
- return
-}
-
// ExecuteWithWarningCountIterRow is for fetching results and a warning count
-// Note: In a future iteration this should be abolished and merged into the
-// Execute API.
-func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (res *Result, warnings uint16, err error) {
+// Note: In a future iteration this should be abolished and merged into the Execute API.
+func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (result proto.Result, err error) {
defer func() {
if err != nil {
if sqlErr, ok := err.(*err2.SQLError); ok {
@@ -1391,26 +1176,22 @@ func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (res
// Send the query as a COM_QUERY packet.
if err = conn.WriteComQuery(query); err != nil {
- return nil, 0, err
+ return
}
- iterTextRow, affectedRows, lastInsertID, _, warnings, err := conn.ReadQueryRow()
- iterRow := &TextIterRow{iterTextRow}
+ res := conn.ReadQueryRow()
+ res.setTextProtocol()
+ res.setWantFields(true)
+
+ result = res
- res = &Result{
- AffectedRows: affectedRows,
- InsertId: lastInsertID,
- Fields: iterRow.Fields(),
- Rows: []proto.Row{iterRow},
- DataChan: make(chan proto.Row, 1),
- }
return
}
// ExecuteWithWarningCount is for fetching results and a warning count
// Note: In a future iteration this should be abolished and merged into the
// Execute API.
-func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields bool) (result *Result, warnings uint16, err error) {
+func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields bool) (result proto.Result, err error) {
defer func() {
if err != nil {
if sqlErr, ok := err.(*err2.SQLError); ok {
@@ -1421,49 +1202,42 @@ func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields
// Send the query as a COM_QUERY packet.
if err = conn.WriteComQuery(query); err != nil {
- return nil, 0, err
+ return
}
- result, _, warnings, err = conn.ReadQueryResult(wantFields)
+ result = conn.ReadQueryResult(wantFields)
+
return
}
-func (conn *BackendConnection) PrepareExecuteArgs(query string, args []interface{}) (result *Result, warnings uint16, err error) {
+func (conn *BackendConnection) PrepareExecuteArgs(query string, args []interface{}) (proto.Result, error) {
stmt, err := conn.prepare(query)
if err != nil {
- return nil, 0, err
+ return nil, err
}
return stmt.execArgs(args)
}
-func (conn *BackendConnection) PrepareQueryArgsIterRow(query string, data []interface{}) (result *Result, warnings uint16, err error) {
- stmt, err := conn.prepare(query)
- if err != nil {
- return nil, 0, err
- }
- return stmt.queryArgsIterRow(data)
-}
-
-func (conn *BackendConnection) PrepareQueryArgs(query string, data []interface{}) (result *Result, warnings uint16, err error) {
+func (conn *BackendConnection) PrepareQueryArgs(query string, data []interface{}) (proto.Result, error) {
stmt, err := conn.prepare(query)
if err != nil {
- return nil, 0, err
+ return nil, err
}
return stmt.queryArgs(data)
}
-func (conn *BackendConnection) PrepareExecute(query string, data []byte) (result *Result, warnings uint16, err error) {
+func (conn *BackendConnection) PrepareExecute(query string, data []byte) (proto.Result, error) {
stmt, err := conn.prepare(query)
if err != nil {
- return nil, 0, err
+ return nil, err
}
return stmt.exec(data)
}
-func (conn *BackendConnection) PrepareQuery(query string, data []byte) (Result *Result, warnings uint16, err error) {
+func (conn *BackendConnection) PrepareQuery(query string, data []byte) (proto.Result, error) {
stmt, err := conn.prepare(query)
if err != nil {
- return nil, 0, err
+ return nil, err
}
return stmt.query(data)
}
diff --git a/pkg/mysql/client_test.go b/pkg/mysql/client_test.go
index aeb11e9cd..4eea2c377 100644
--- a/pkg/mysql/client_test.go
+++ b/pkg/mysql/client_test.go
@@ -304,35 +304,38 @@ func TestPrepare(t *testing.T) {
assert.Equal(t, 0, stmt.paramCount)
}
-func TestReadComQueryResponse(t *testing.T) {
- dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true"
- cfg, _ := ParseDSN(dsn)
- conn := &BackendConnection{conf: cfg}
- conn.c = newConn(new(mockConn))
- buf := make([]byte, 13)
- buf[0] = 9
- buf[4] = mysql.OKPacket
- buf[5] = 1
- buf[6] = 1
- conn.c.conn.(*mockConn).data = buf
- affectedRows, lastInsertID, _, _, _, err := conn.ReadComQueryResponse()
- assert.NoError(t, err)
- assert.Equal(t, uint64(0x1), affectedRows)
- assert.Equal(t, uint64(0x1), lastInsertID)
-}
+//func TestReadComQueryResponse(t *testing.T) {
+// dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true"
+// cfg, _ := ParseDSN(dsn)
+// conn := &BackendConnection{conf: cfg}
+// conn.c = newConn(new(mockConn))
+// buf := make([]byte, 13)
+// buf[0] = 9
+// buf[4] = mysql.OKPacket
+// buf[5] = 1
+// buf[6] = 1
+// conn.c.conn.(*mockConn).data = buf
+// affectedRows, lastInsertID, _, _, _, err := conn.ReadComQueryResponse()
+// assert.NoError(t, err)
+// assert.Equal(t, uint64(0x1), affectedRows)
+// assert.Equal(t, uint64(0x1), lastInsertID)
+//}
func TestReadColumnDefinition(t *testing.T) {
dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true"
cfg, _ := ParseDSN(dsn)
conn := &BackendConnection{conf: cfg}
conn.c = newConn(new(mockConn))
- buf := make([]byte, 100)
- buf[0] = 96
- buf[4] = 3
+ buf := make([]byte, 80)
+ buf[0] = 0x4d
+ buf[1] = 0x00
+ buf[2] = 0x00
+ buf[3] = 0x00
+ buf[4] = 0x03 // catalog
buf[5] = 'd'
buf[6] = 'e'
buf[7] = 'f'
- buf[8] = 8
+ buf[8] = 0x08 // schema
buf[9] = 't'
buf[10] = 'e'
buf[11] = 's'
@@ -341,7 +344,7 @@ func TestReadColumnDefinition(t *testing.T) {
buf[14] = 'a'
buf[15] = 's'
buf[16] = 'e'
- buf[17] = 9
+ buf[17] = 0x09 // table
buf[18] = 't'
buf[19] = 'e'
buf[20] = 's'
@@ -351,18 +354,60 @@ func TestReadColumnDefinition(t *testing.T) {
buf[24] = 'b'
buf[25] = 'l'
buf[26] = 'e'
- buf[28] = 4
- buf[29] = 'n'
- buf[30] = 'a'
- buf[31] = 'm'
- buf[32] = 'e'
- buf[37] = 255
- buf[41] = 15
- buf[45] = 4
- buf[53] = 'u'
- buf[54] = 's'
- buf[55] = 'e'
- buf[56] = 'r'
+ buf[27] = 0x09 // org table
+ buf[28] = 't'
+ buf[29] = 'e'
+ buf[30] = 's'
+ buf[31] = 't'
+ buf[32] = 't'
+ buf[33] = 'a'
+ buf[34] = 'b'
+ buf[35] = 'l'
+ buf[36] = 'e'
+ buf[37] = 0x04 // name
+ buf[38] = 'n'
+ buf[39] = 'a'
+ buf[40] = 'm'
+ buf[41] = 'e'
+ buf[42] = 0x04 // org name
+ buf[43] = 'n'
+ buf[44] = 'a'
+ buf[45] = 'm'
+ buf[46] = 'e'
+ buf[47] = 0x0c
+ buf[48] = 0x3f // character set
+ buf[49] = 0x00
+ buf[50] = 0x13 // column length
+ buf[51] = 0x00
+ buf[52] = 0x00
+ buf[53] = 0x00
+ buf[54] = 0x0c // field type
+ buf[55] = 0x81 // flags
+ buf[56] = 0x00
+ buf[57] = 0x00 // decimals
+ buf[58] = 0x00 // decimals
+ buf[59] = 0x00 // decimals
+ buf[60] = 0x13 // default
+ buf[61] = '0'
+ buf[62] = '0'
+ buf[63] = '0'
+ buf[64] = '0'
+ buf[65] = '-'
+ buf[66] = '0'
+ buf[67] = '0'
+ buf[68] = '-'
+ buf[69] = '0'
+ buf[70] = '0'
+ buf[71] = ' '
+ buf[72] = '0'
+ buf[73] = '0'
+ buf[74] = ':'
+ buf[75] = '0'
+ buf[76] = '0'
+ buf[77] = ':'
+ buf[78] = '0'
+ buf[79] = '0'
+
conn.c.conn.(*mockConn).data = buf
field := &Field{}
err := conn.ReadColumnDefinition(field, 0)
@@ -370,10 +415,10 @@ func TestReadColumnDefinition(t *testing.T) {
assert.Equal(t, "testtable", field.table)
assert.Equal(t, "testbase", field.database)
assert.Equal(t, "name", field.name)
- assert.Equal(t, mysql.FieldTypeVarChar, field.fieldType)
- assert.Equal(t, uint64(0x4), field.defaultValueLength)
- assert.Equal(t, "user", string(field.defaultValue))
- assert.Equal(t, uint32(255), field.columnLength)
+ assert.Equal(t, mysql.FieldTypeDateTime, field.fieldType)
+ assert.Equal(t, uint64(0x13), field.defaultValueLength)
+ assert.Equal(t, "0000-00-00 00:00:00", string(field.defaultValue))
+ assert.Equal(t, uint32(19), field.columnLength)
}
func TestReadColumnDefinitionType(t *testing.T) {
diff --git a/pkg/mysql/conn.go b/pkg/mysql/conn.go
index 5f91fc7f7..e5a9c015f 100644
--- a/pkg/mysql/conn.go
+++ b/pkg/mysql/conn.go
@@ -48,6 +48,9 @@ const (
// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024
+
+ // packetHeaderSize is the first four bytes of a packet
+ packetHeaderSize int = 4
)
// Constants for how ephemeral buffers were used for reading / writing.
@@ -426,6 +429,53 @@ func (c *Conn) ReadPacket() ([]byte, error) {
return result, err
}
+func (c *Conn) writePacketForFieldList(data []byte) error {
+ index := 0
+ dataLength := len(data) - packetHeaderSize
+
+ w, unget := c.getWriter()
+ defer unget()
+
+ var header [packetHeaderSize]byte
+ for {
+ // toBeSent is capped to MaxPacketSize.
+ toBeSent := dataLength
+ if toBeSent > mysql.MaxPacketSize {
+ toBeSent = mysql.MaxPacketSize
+ }
+
+ // Write the body.
+ if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
+ return errors.Wrapf(err, "Write(header) failed")
+ } else if n != (toBeSent + packetHeaderSize) {
+ return errors.Wrapf(err, "Write(packet) returned a short write: %v < %v", n, toBeSent+packetHeaderSize)
+ }
+
+ // Update our state.
+ c.sequence++
+ dataLength -= toBeSent
+ if dataLength == 0 {
+ if toBeSent == mysql.MaxPacketSize {
+ // The packet we just sent had exactly
+ // MaxPacketSize size, we need to
+ // send a zero-size packet too.
+ header[0] = 0
+ header[1] = 0
+ header[2] = 0
+ header[3] = c.sequence
+ if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
+ return errors.Wrapf(err, "Write(header) failed")
+ } else if n != (toBeSent + packetHeaderSize) {
+ return errors.Wrapf(err, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
+ }
+ c.sequence++
+ }
+ return nil
+ }
+ index += toBeSent
+ }
+}
+
// writePacket writes a packet, possibly cutting it into multiple
// chunks. Note this is not very efficient, as the client probably
// has to build the []byte and that makes a memory copy.
@@ -475,7 +525,7 @@ func (c *Conn) writePacket(data []byte) error {
if packetLength == mysql.MaxPacketSize {
// The packet we just sent had exactly
// MaxPacketSize size, we need to
- // sent a zero-size packet too.
+ // send a zero-size packet too.
header[0] = 0
header[1] = 0
header[2] = 0
@@ -512,7 +562,7 @@ func (c *Conn) writeEphemeralPacket() error {
switch c.currentEphemeralPolicy {
case ephemeralWrite:
if err := c.writePacket(*c.currentEphemeralBuffer); err != nil {
- return errors.Wrapf(err, "conn %v", c.ID())
+ return errors.WithStack(errors.Wrapf(err, "conn %v", c.ID()))
}
case ephemeralUnused, ephemeralRead:
// Programming error.
@@ -522,7 +572,7 @@ func (c *Conn) writeEphemeralPacket() error {
return nil
}
-// recycleWritePacket recycles the write packet. It needs to be called
+// recycleWritePacket recycles write packet. It needs to be called
// after writeEphemeralPacket was called.
func (c *Conn) recycleWritePacket() {
if c.currentEphemeralPolicy != ephemeralWrite {
@@ -660,6 +710,18 @@ func (c *Conn) writeErrorPacketFromError(err error) error {
return c.writeErrorPacket(mysql.ERUnknownError, mysql.SSUnknownSQLState, "unknown error: %v", err)
}
+func (c *Conn) buildEOFPacket(flags uint16, warnings uint16) []byte {
+ data := make([]byte, 9)
+ pos := 0
+ data[pos] = 0x05
+ pos += 3
+ pos = writeLenEncInt(data, pos, uint64(c.sequence))
+ pos = writeByte(data, pos, mysql.EOFPacket)
+ pos = writeUint16(data, pos, flags)
+ _ = writeUint16(data, pos, warnings)
+ return data
+}
+
// writeEOFPacket writes an EOF packet, through the buffer, and
// doesn't flush (as it is used as part of a query result).
func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
@@ -677,7 +739,7 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
// Packet parsing methods, for generic packets.
//
-// isEOFPacket determines whether or not a Content packet is a "true" EOF. DO NOT blindly compare the
+// isEOFPacket determines whether a Content packet is a "true" EOF. DO NOT blindly compare the
// first byte of a packet to EOFPacket as you might do for other packet types, as 0xfe is overloaded
// as a first byte.
//
diff --git a/pkg/mysql/conn_test.go b/pkg/mysql/conn_test.go
index 0b49c1eca..214688126 100644
--- a/pkg/mysql/conn_test.go
+++ b/pkg/mysql/conn_test.go
@@ -67,6 +67,7 @@ func (m *mockConn) Read(b []byte) (n int, err error) {
m.read += n
return
}
+
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
@@ -86,22 +87,28 @@ func (m *mockConn) Write(b []byte) (n int, err error) {
}
return
}
+
func (m *mockConn) Close() error {
m.closed = true
return nil
}
+
func (m *mockConn) LocalAddr() net.Addr {
return m.laddr
}
+
func (m *mockConn) RemoteAddr() net.Addr {
return m.raddr
}
+
func (m *mockConn) SetDeadline(t time.Time) error {
return nil
}
+
func (m *mockConn) SetReadDeadline(t time.Time) error {
return nil
}
+
func (m *mockConn) SetWriteDeadline(t time.Time) error {
return nil
}
diff --git a/pkg/mysql/execute_handle.go b/pkg/mysql/execute_handle.go
index 5d1974cfd..f150b5d57 100644
--- a/pkg/mysql/execute_handle.go
+++ b/pkg/mysql/execute_handle.go
@@ -29,6 +29,7 @@ import (
"github.com/arana-db/arana/pkg/constants/mysql"
"github.com/arana-db/arana/pkg/mysql/errors"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/hint"
"github.com/arana-db/arana/pkg/security"
"github.com/arana-db/arana/pkg/util/log"
)
@@ -70,28 +71,38 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
- log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
+ log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err)
}
}()
c.recycleReadPacket()
- result, warn, err := l.executor.ExecutorComQuery(ctx)
- if err != nil {
+
+ var (
+ result proto.Result
+ err error
+ warn uint16
+ )
+
+ if result, warn, err = l.executor.ExecutorComQuery(ctx); err != nil {
+ log.Errorf("executor com_query error %v: %v", ctx.ConnectionID, err)
if wErr := c.writeErrorPacketFromError(err); wErr != nil {
- log.Errorf("Error writing query error to client %v: %v", l.connectionID, wErr)
+ log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, wErr)
return wErr
}
return nil
}
- if cr, ok := result.(*proto.CloseableResult); ok {
- result = cr.Result
- defer func() {
- _ = cr.Close()
- }()
+ var ds proto.Dataset
+ if ds, err = result.Dataset(); err != nil {
+ log.Errorf("get dataset error %v: %v", ctx.ConnectionID, err)
+ if wErr := c.writeErrorPacketFromError(err); wErr != nil {
+ log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, wErr)
+ return wErr
+ }
+ return nil
}
- if len(result.GetFields()) == 0 {
+ if ds == nil {
// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
@@ -104,10 +115,15 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error {
)
return c.writeOKPacket(affected, insertId, c.StatusFlags, warn)
}
- if err = c.writeFields(l.capabilities, result); err != nil {
+
+ fields, _ := ds.Fields()
+
+ if err = c.writeFields(l.capabilities, fields); err != nil {
+ log.Errorf("write fields error %v: %v", ctx.ConnectionID, err)
return err
}
- if err = c.writeRowChan(result); err != nil {
+ if err = c.writeDataset(ds); err != nil {
+ log.Errorf("write dataset error %v: %v", ctx.ConnectionID, err)
return err
}
if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil {
@@ -128,17 +144,37 @@ func (l *Listener) handleFieldList(c *Conn, ctx *proto.Context) error {
return wErr
}
}
- return c.writeFields(l.capabilities, &Result{Fields: fields})
+
+ // Combine the fields into a package to send
+ var des []byte
+ for _, field := range fields {
+ fld := field.(*Field)
+ des = append(des, c.DefColumnDefinition(fld)...)
+ }
+
+ des = append(des, c.buildEOFPacket(0, 2)...)
+
+ if err = c.writePacketForFieldList(des); err != nil {
+ return err
+ }
+
+ return nil
}
func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
- log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
+ log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err)
}
}()
- stmtID, _, err := c.parseComStmtExecute(&l.stmts, ctx.Data)
+
+ var (
+ stmtID uint32
+ err error
+ )
+
+ stmtID, _, err = c.parseComStmtExecute(&l.stmts, ctx.Data)
c.recycleReadPacket()
if stmtID != uint32(0) {
@@ -154,7 +190,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error {
if err != nil {
if wErr := c.writeErrorPacketFromError(err); wErr != nil {
// If we can't even write the error, we're done.
- log.Error("Error writing query error to client %v: %v", l.connectionID, wErr)
+ log.Error("Error writing query error to client %v: %v", ctx.ConnectionID, wErr)
return wErr
}
return nil
@@ -163,23 +199,29 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error {
prepareStmt, _ := l.stmts.Load(stmtID)
ctx.Stmt = prepareStmt.(*proto.Stmt)
- result, warn, err := l.executor.ExecutorComStmtExecute(ctx)
- if err != nil {
+ var (
+ result proto.Result
+ warn uint16
+ )
+
+ if result, warn, err = l.executor.ExecutorComStmtExecute(ctx); err != nil {
if wErr := c.writeErrorPacketFromError(err); wErr != nil {
- log.Errorf("Error writing query error to client %v: %v, executor error: %v", l.connectionID, wErr, err)
+ log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err)
return wErr
}
return nil
}
- if cr, ok := result.(*proto.CloseableResult); ok {
- result = cr.Result
- defer func() {
- _ = cr.Close()
- }()
+ var ds proto.Dataset
+ if ds, err = result.Dataset(); err != nil {
+ if wErr := c.writeErrorPacketFromError(err); wErr != nil {
+ log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err)
+ return wErr
+ }
+ return nil
}
- if len(result.GetFields()) == 0 {
+ if ds == nil {
// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
@@ -190,10 +232,17 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error {
lastInsertId, _ := result.LastInsertId()
return c.writeOKPacket(affected, lastInsertId, c.StatusFlags, warn)
}
- if err = c.writeFields(l.capabilities, result); err != nil {
+
+ defer func() {
+ _ = ds.Close()
+ }()
+
+ fields, _ := ds.Fields()
+
+ if err = c.writeFields(l.capabilities, fields); err != nil {
return err
}
- if err = c.writeBinaryRowChan(result); err != nil {
+ if err = c.writeDatasetBinary(ds); err != nil {
return err
}
if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil {
@@ -207,7 +256,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error {
query := string(ctx.Data[1:])
c.recycleReadPacket()
- // Popoulate PrepareData
+ // Populate PrepareData
statementID := l.statementID.Inc()
stmt := &proto.Stmt{
@@ -215,7 +264,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error {
PrepareStmt: query,
}
p := parser.New()
- act, err := p.ParseOneStmt(stmt.PrepareStmt, "", "")
+ act, hts, err := p.ParseOneStmtHints(stmt.PrepareStmt, "", "")
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if wErr := c.writeErrorPacketFromError(err); wErr != nil {
@@ -223,6 +272,18 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error {
return wErr
}
}
+
+ for _, it := range hts {
+ var h *hint.Hint
+ if h, err = hint.Parse(it); err != nil {
+ if wErr := c.writeErrorPacketFromError(err); wErr != nil {
+ log.Errorf("Conn %v: Error writing prepared statement error: %v", c, wErr)
+ return wErr
+ }
+ }
+ stmt.Hints = append(stmt.Hints, h)
+ }
+
stmt.StmtNode = act
paramsCount := uint16(strings.Count(query, "?"))
@@ -260,7 +321,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error {
case 1:
l.capabilities &^= mysql.CapabilityClientMultiStatements
default:
- log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", l.connectionID, ctx.Data)
+ log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data)
if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
@@ -271,7 +332,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error {
return err
}
}
- log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", l.connectionID, ctx.Data)
+ log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data)
if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
diff --git a/pkg/mysql/fields.go b/pkg/mysql/fields.go
index 0eab340e9..987d3111c 100644
--- a/pkg/mysql/fields.go
+++ b/pkg/mysql/fields.go
@@ -19,6 +19,7 @@ package mysql
import (
"database/sql"
+ "math"
"reflect"
)
@@ -61,19 +62,62 @@ type Field struct {
defaultValue []byte
}
+func (mf *Field) FieldType() mysql.FieldType {
+ return mf.fieldType
+}
+
+func (mf *Field) DecimalSize() (int64, int64, bool) {
+ decimals := int64(mf.decimals)
+ switch mf.fieldType {
+ case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal:
+ if decimals > 0 {
+ return int64(mf.length) - 2, decimals, true
+ }
+ return int64(mf.length) - 1, decimals, true
+ case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, mysql.FieldTypeTime:
+ return decimals, decimals, true
+ case mysql.FieldTypeFloat, mysql.FieldTypeDouble:
+ if decimals == 0x1f {
+ return math.MaxInt64, math.MaxInt64, true
+ }
+ return math.MaxInt64, decimals, true
+ }
+
+ return 0, 0, false
+}
+
+func (mf *Field) Length() (length int64, ok bool) {
+ length = int64(mf.length)
+ ok = true
+ return
+}
+
+func (mf *Field) Nullable() (nullable, ok bool) {
+ nullable, ok = mf.flags&mysql.NotNullFlag == 0, true
+ return
+}
+
func NewField(name string, filedType mysql.FieldType) *Field {
return &Field{name: name, fieldType: filedType}
}
+func (mf *Field) Name() string {
+ return mf.name
+}
+
+func (mf *Field) OriginName() string {
+ return mf.orgName
+}
+
func (mf *Field) TableName() string {
return mf.table
}
-func (mf *Field) DataBaseName() string {
+func (mf *Field) DatabaseName() string {
return mf.database
}
-func (mf *Field) TypeDatabaseName() string {
+func (mf *Field) DatabaseTypeName() string {
switch mf.fieldType {
case mysql.FieldTypeBit:
return "BIT"
@@ -157,7 +201,7 @@ func (mf *Field) TypeDatabaseName() string {
}
}
-func (mf *Field) scanType() reflect.Type {
+func (mf *Field) ScanType() reflect.Type {
switch mf.fieldType {
case mysql.FieldTypeTiny:
if mf.flags&mysql.NotNullFlag != 0 {
diff --git a/pkg/mysql/fields_test.go b/pkg/mysql/fields_test.go
index 1d2651034..fecdeec37 100644
--- a/pkg/mysql/fields_test.go
+++ b/pkg/mysql/fields_test.go
@@ -36,7 +36,7 @@ func TestTableName(t *testing.T) {
func TestDataBaseName(t *testing.T) {
field := createDefaultField()
- assert.Equal(t, "db_arana", field.DataBaseName())
+ assert.Equal(t, "db_arana", field.DatabaseName())
}
func TestTypeDatabaseName(t *testing.T) {
@@ -83,7 +83,7 @@ func TestTypeDatabaseName(t *testing.T) {
}
for _, unit := range unitTests {
field := createField(unit.field, mysql.Collations[unit.collation])
- assert.Equal(t, unit.expected, field.TypeDatabaseName())
+ assert.Equal(t, unit.expected, field.DatabaseTypeName())
}
}
diff --git a/pkg/mysql/iterator.go b/pkg/mysql/iterator.go
deleted file mode 100644
index 27b50c750..000000000
--- a/pkg/mysql/iterator.go
+++ /dev/null
@@ -1,412 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package mysql
-
-import (
- "encoding/binary"
- "fmt"
- "math"
- "time"
-)
-
-import (
- "github.com/arana-db/arana/pkg/constants/mysql"
- "github.com/arana-db/arana/pkg/mysql/errors"
- "github.com/arana-db/arana/pkg/proto"
-)
-
-// Iter is used to iterate output results
-type Iter interface {
- Next() (bool, error)
-}
-
-// IterRow implementation of Iter
-type IterRow struct {
- *BackendConnection
- *Row
- hasNext bool
-}
-
-// TextIterRow is iterator for text protocol result set
-type TextIterRow struct {
- *IterRow
-}
-
-// BinaryIterRow is iterator for binary protocol result set
-type BinaryIterRow struct {
- *IterRow
-}
-
-func (iterRow *IterRow) Next() (bool, error) {
- // read one row
- data, err := iterRow.c.ReadPacket()
- if err != nil {
- iterRow.hasNext = false
- return iterRow.hasNext, err
- }
-
- if isEOFPacket(data) {
- iterRow.hasNext = false
- // The deprecated EOF packets change means that this is either an
- // EOF packet or an OK packet with the EOF type code.
- if iterRow.capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
- _, _, err = parseEOFPacket(data)
- if err != nil {
- return iterRow.hasNext, err
- }
- } else {
- _, _, _, _, err = parseOKPacket(data)
- if err != nil {
- return iterRow.hasNext, err
- }
- }
- return iterRow.hasNext, nil
-
- } else if isErrorPacket(data) {
- // Error packet.
- iterRow.hasNext = false
- return iterRow.hasNext, ParseErrorPacket(data)
- }
-
- iterRow.Content = data
- return iterRow.hasNext, nil
-}
-
-func (iterRow *IterRow) Decode() ([]*proto.Value, error) {
- return nil, nil
-}
-
-func (rows *TextIterRow) Decode() ([]*proto.Value, error) {
- dest := make([]*proto.Value, len(rows.ResultSet.Columns))
-
- // RowSet Packet
- var val []byte
- var isNull bool
- var n int
- var err error
- pos := 0
-
- for i := 0; i < len(rows.ResultSet.Columns); i++ {
- field := rows.ResultSet.Columns[i].(*Field)
-
- // Read bytes and convert to string
- val, isNull, n, err = readLengthEncodedString(rows.Content[pos:])
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: val,
- }
- pos += n
- if err == nil {
- if !isNull {
- switch field.fieldType {
- case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime,
- mysql.FieldTypeDate, mysql.FieldTypeNewDate:
- dest[i].Val, err = parseDateTime(
- val,
- time.Local,
- )
- if err == nil {
- continue
- }
- default:
- continue
- }
- } else {
- dest[i].Val = nil
- continue
- }
- }
- return nil, err // err != nil
- }
-
- return dest, nil
-}
-
-func (rows *BinaryIterRow) Decode() ([]*proto.Value, error) {
- dest := make([]*proto.Value, len(rows.ResultSet.Columns))
-
- if rows.Content[0] != mysql.OKPacket {
- return nil, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", rows)
- }
-
- // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
- pos := 1 + (len(dest)+7+2)>>3
- nullMask := rows.Content[1:pos]
-
- for i := 0; i < len(rows.ResultSet.Columns); i++ {
- // Field is NULL
- // (byte >> bit-pos) % 2 == 1
- if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
- dest[i] = nil
- continue
- }
-
- field := rows.ResultSet.Columns[i].(*Field)
- // Convert to byte-coded string
- // TODO Optimize storage space based on the length of data types
- mysqlType, _ := mysql.TypeToMySQL(field.fieldType)
- switch mysql.FieldType(mysqlType) {
- case mysql.FieldTypeNULL:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: nil,
- Raw: nil,
- }
- continue
-
- // Numeric Types
- case mysql.FieldTypeTiny:
- if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: int64(rows.Content[pos]),
- Raw: rows.Content[pos : pos+1],
- }
- } else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: int64(int8(rows.Content[pos])),
- Raw: rows.Content[pos : pos+1],
- }
- }
- pos++
- continue
-
- case mysql.FieldTypeShort, mysql.FieldTypeYear:
- if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 2,
- Val: int64(binary.LittleEndian.Uint16(rows.Content[pos : pos+2])),
- Raw: rows.Content[pos : pos+1],
- }
- } else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 2,
- Val: int64(int16(binary.LittleEndian.Uint16(rows.Content[pos : pos+2]))),
- Raw: rows.Content[pos : pos+1],
- }
- }
- pos += 2
- continue
-
- case mysql.FieldTypeInt24, mysql.FieldTypeLong:
- if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: int64(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])),
- Raw: rows.Content[pos : pos+4],
- }
- } else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: int64(int32(binary.LittleEndian.Uint32(rows.Content[pos : pos+4]))),
- Raw: rows.Content[pos : pos+4],
- }
- }
- pos += 4
- continue
-
- case mysql.FieldTypeLongLong:
- if field.flags&mysql.UnsignedFlag != 0 {
- val := binary.LittleEndian.Uint64(rows.Content[pos : pos+8])
- if val > math.MaxInt64 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: uint64ToString(val),
- Raw: rows.Content[pos : pos+8],
- }
- } else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: int64(val),
- Raw: rows.Content[pos : pos+8],
- }
- }
- } else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: int64(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])),
- Raw: rows.Content[pos : pos+8],
- }
- }
- pos += 8
- continue
-
- case mysql.FieldTypeFloat:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: math.Float32frombits(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])),
- Raw: rows.Content[pos : pos+4],
- }
- pos += 4
- continue
-
- case mysql.FieldTypeDouble:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: math.Float64frombits(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])),
- Raw: rows.Content[pos : pos+8],
- }
- pos += 8
- continue
-
- // Length coded Binary Strings
- case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar,
- mysql.FieldTypeBit, mysql.FieldTypeEnum, mysql.FieldTypeSet, mysql.FieldTypeTinyBLOB,
- mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB,
- mysql.FieldTypeVarString, mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON:
- var val interface{}
- var isNull bool
- var n int
- var err error
- val, isNull, n, err = readLengthEncodedString(rows.Content[pos:])
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
- }
- pos += n
- if err == nil {
- if !isNull {
- continue
- } else {
- dest[i].Val = nil
- continue
- }
- }
- return nil, err
-
- case
- mysql.FieldTypeDate, mysql.FieldTypeNewDate, // Date YYYY-MM-DD
- mysql.FieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
- mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
-
- num, isNull, n := readLengthEncodedInteger(rows.Content[pos:])
- pos += n
-
- var val interface{}
- var err error
- switch {
- case isNull:
- dest[i] = nil
- continue
- case field.fieldType == mysql.FieldTypeTime:
- // database/sql does not support an equivalent to TIME, return a string
- var dstlen uint8
- switch decimals := field.decimals; decimals {
- case 0x00, 0x1f:
- dstlen = 8
- case 1, 2, 3, 4, 5, 6:
- dstlen = 8 + 1 + decimals
- default:
- return nil, fmt.Errorf(
- "protocol error, illegal decimals architecture.Value %d",
- field.decimals,
- )
- }
- val, err = formatBinaryTime(rows.Content[pos:pos+int(num)], dstlen)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
- }
- default:
- val, err = parseBinaryDateTime(num, rows.Content[pos:], time.Local)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: int(num),
- Val: val,
- Raw: rows.Content[pos : pos+int(num)],
- }
- if err == nil {
- break
- }
-
- var dstlen uint8
- if field.fieldType == mysql.FieldTypeDate {
- dstlen = 10
- } else {
- switch decimals := field.decimals; decimals {
- case 0x00, 0x1f:
- dstlen = 19
- case 1, 2, 3, 4, 5, 6:
- dstlen = 19 + 1 + decimals
- default:
- return nil, fmt.Errorf(
- "protocol error, illegal decimals architecture.Value %d",
- field.decimals,
- )
- }
- }
- val, err = formatBinaryDateTime(rows.Content[pos:pos+int(num)], dstlen)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: int(num),
- Val: val,
- Raw: rows.Content[pos : pos+n],
- }
- }
-
- if err == nil {
- pos += int(num)
- continue
- } else {
- return nil, err
- }
-
- // Please report if this happens!
- default:
- return nil, fmt.Errorf("unknown field type %d", field.fieldType)
- }
- }
-
- return dest, nil
-}
diff --git a/pkg/mysql/result.go b/pkg/mysql/result.go
index 99d86a9c9..359e96dbb 100644
--- a/pkg/mysql/result.go
+++ b/pkg/mysql/result.go
@@ -18,33 +18,324 @@
package mysql
import (
+ "io"
+ "sync"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/constants/mysql"
+ merrors "github.com/arana-db/arana/pkg/mysql/errors"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+const (
+ _flagTextMode uint8 = 1 << iota
+ _flagWantFields
)
-type Result struct {
- Fields []proto.Field // Columns information
- AffectedRows uint64
- InsertId uint64
- Rows []proto.Row
- DataChan chan proto.Row
+var (
+ _ proto.Result = (*RawResult)(nil)
+ _ proto.Dataset = (*RawResult)(nil)
+)
+
+type RawResult struct {
+ flag uint8
+
+ c *BackendConnection
+
+ lastInsertID, affectedRows uint64
+ colNumber int
+ more bool
+ warnings uint16
+ fields []proto.Field
+
+ preflightOnce sync.Once
+ preflightFailure error
+ flightOnce sync.Once
+ flightFailure error
+ postFlightOnce sync.Once
+ postFlightFailure error
+
+ closeFunc func() error
+ closeOnce sync.Once
+ closeFailure error
+
+ eof bool
+}
+
+func (rr *RawResult) Discard() (err error) {
+ defer func() {
+ _ = rr.Close()
+ }()
+
+ if err = rr.postFlight(); err != nil {
+ return errors.Wrapf(err, "failed to discard mysql result")
+ }
+
+ if rr.colNumber < 1 {
+ return
+ }
+
+ for {
+ _, err = rr.nextRowData()
+ if err != nil {
+ return
+ }
+ }
+}
+
+func (rr *RawResult) Warn() (uint16, error) {
+ if err := rr.preflight(); err != nil {
+ return 0, err
+ }
+ return rr.warnings, nil
+}
+
+func (rr *RawResult) Close() error {
+ rr.closeOnce.Do(func() {
+ if rr.closeFunc == nil {
+ return
+ }
+ rr.closeFailure = rr.closeFunc()
+ if rr.closeFailure != nil {
+ log.Errorf("failed to close mysql result: %v", rr.closeFailure)
+ }
+ })
+ return rr.closeFailure
+}
+
+func (rr *RawResult) Fields() ([]proto.Field, error) {
+ if err := rr.flight(); err != nil {
+ return nil, err
+ }
+ return rr.fields, nil
+}
+
+func (rr *RawResult) Next() (row proto.Row, err error) {
+ defer func() {
+ if err != nil {
+ _ = rr.Close()
+ }
+ }()
+
+ var data []byte
+ if data, err = rr.nextRowData(); err != nil {
+ return
+ }
+
+ if rr.flag&_flagTextMode != 0 {
+ row = TextRow{
+ fields: rr.fields,
+ raw: data,
+ }
+ } else {
+ row = BinaryRow{
+ fields: rr.fields,
+ raw: data,
+ }
+ }
+ return
+}
+
+func (rr *RawResult) Dataset() (proto.Dataset, error) {
+ if err := rr.postFlight(); err != nil {
+ return nil, err
+ }
+
+ if len(rr.fields) < 1 {
+ return nil, nil
+ }
+
+ return rr, nil
+}
+
+func (rr *RawResult) preflight() (err error) {
+ defer func() {
+ if err != nil || rr.colNumber < 1 {
+ _ = rr.Close()
+ }
+ }()
+
+ rr.preflightOnce.Do(func() {
+ rr.affectedRows, rr.lastInsertID, rr.colNumber, rr.more, rr.warnings, rr.preflightFailure = rr.c.readResultSetHeaderPacket()
+ })
+ err = rr.preflightFailure
+ return
+}
+
+func (rr *RawResult) flight() (err error) {
+ if err = rr.preflight(); err != nil {
+ return err
+ }
+
+ defer func() {
+ if err != nil {
+ _ = rr.Close()
+ }
+ }()
+
+ rr.flightOnce.Do(func() {
+ if rr.colNumber < 1 {
+ return
+ }
+ columns := make([]proto.Field, rr.colNumber)
+ for i := 0; i < rr.colNumber; i++ {
+ var field Field
+
+ if rr.flag&_flagWantFields != 0 {
+ rr.flightFailure = rr.c.ReadColumnDefinition(&field, i)
+ } else {
+ rr.flightFailure = rr.c.ReadColumnDefinitionType(&field, i)
+ }
+
+ if rr.flightFailure != nil {
+ return
+ }
+ columns[i] = &field
+ }
+
+ rr.fields = columns
+ })
+
+ err = rr.flightFailure
+
+ return
+}
+
+func (rr *RawResult) postFlight() (err error) {
+ if err = rr.flight(); err != nil {
+ return err
+ }
+
+ defer func() {
+ if err != nil {
+ _ = rr.Close()
+ }
+ }()
+
+ rr.postFlightOnce.Do(func() {
+ if rr.colNumber < 1 {
+ return
+ }
+
+ if rr.c.capabilities&mysql.CapabilityClientDeprecateEOF != 0 {
+ return
+ }
+
+ var data []byte
+
+ // EOF is only present here if it's not deprecated.
+ if data, rr.postFlightFailure = rr.c.c.readEphemeralPacket(); rr.postFlightFailure != nil {
+ rr.postFlightFailure = merrors.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", rr.postFlightFailure)
+ return
+ }
+
+ if isEOFPacket(data) {
+ // This is what we expect.
+ // Warnings and status flags are ignored.
+ rr.c.c.recycleReadPacket()
+ // goto: read row loop
+ return
+ }
+
+ defer rr.c.c.recycleReadPacket()
+
+ if isErrorPacket(data) {
+ rr.postFlightFailure = ParseErrorPacket(data)
+ } else {
+ rr.postFlightFailure = errors.Errorf("unexpected packet after fields: %v", data)
+ }
+ })
+
+ err = rr.postFlightFailure
+
+ return
+}
+
+func (rr *RawResult) LastInsertId() (uint64, error) {
+ if err := rr.preflight(); err != nil {
+ return 0, err
+ }
+ return rr.lastInsertID, nil
+}
+
+func (rr *RawResult) RowsAffected() (uint64, error) {
+ if err := rr.preflight(); err != nil {
+ return 0, err
+ }
+ return rr.affectedRows, nil
+}
+
+func (rr *RawResult) nextRowData() (data []byte, err error) {
+ if rr.eof {
+ err = io.EOF
+ return
+ }
+
+ if data, err = rr.c.c.readPacket(); err != nil {
+ return
+ }
+
+ switch {
+ case isEOFPacket(data):
+ rr.eof = true
+ // The deprecated EOF packets change means that this is either an
+ // EOF packet or an OK packet with the EOF type code.
+ if rr.c.capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
+ if _, rr.more, err = parseEOFPacket(data); err != nil {
+ return
+ }
+ } else {
+ var statusFlags uint16
+ if _, _, statusFlags, _, err = parseOKPacket(data); err != nil {
+ return
+ }
+ rr.more = statusFlags&mysql.ServerMoreResultsExists != 0
+ }
+ data = nil
+ err = io.EOF
+ case isErrorPacket(data):
+ rr.eof = true
+ data = nil
+ err = ParseErrorPacket(data)
+ }
+
+ // TODO: Check we're not over the limit before we add more.
+ //if len(result.Rows) == maxrows {
+ // if err := conn.DrainResults(); err != nil {
+ // return nil, false, 0, err
+ // }
+ // return nil, false, 0, err2.NewSQLError(mysql.ERVitessMaxRowsExceeded, mysql.SSUnknownSQLState, "Row count exceeded %d")
+ //}
+
+ return
}
-func (res *Result) GetFields() []proto.Field {
- return res.Fields
+func (rr *RawResult) setTextProtocol() {
+ rr.flag |= _flagTextMode
}
-func (res *Result) GetRows() []proto.Row {
- return res.Rows
+func (rr *RawResult) setBinaryProtocol() {
+ rr.flag &= ^_flagTextMode
}
-func (res *Result) LastInsertId() (uint64, error) {
- return res.InsertId, nil
+func (rr *RawResult) setWantFields(b bool) {
+ if b {
+ rr.flag |= _flagWantFields
+ } else {
+ rr.flag &= ^_flagWantFields
+ }
}
-func (res *Result) RowsAffected() (uint64, error) {
- return res.AffectedRows, nil
+func (rr *RawResult) SetCloser(closer func() error) {
+ rr.closeFunc = closer
}
-func (res *Result) GetDataChan() chan proto.Row {
- return res.DataChan
+func newResult(c *BackendConnection) *RawResult {
+ return &RawResult{c: c}
}
diff --git a/pkg/mysql/rows.go b/pkg/mysql/rows.go
index ecf3b3875..edd02cc5c 100644
--- a/pkg/mysql/rows.go
+++ b/pkg/mysql/rows.go
@@ -18,243 +18,92 @@
package mysql
import (
- "bytes"
"encoding/binary"
- "fmt"
+ "io"
"math"
+ "strconv"
"time"
)
+import (
+ "github.com/pkg/errors"
+)
+
import (
"github.com/arana-db/arana/pkg/constants/mysql"
- "github.com/arana-db/arana/pkg/mysql/errors"
+ mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors"
"github.com/arana-db/arana/pkg/proto"
- "github.com/arana-db/arana/pkg/util/log"
)
-type ResultSet struct {
- Columns []proto.Field
- ColumnNames []string
-}
-
-type Row struct {
- Content []byte
- ResultSet *ResultSet
-}
+var (
+ _ proto.KeyedRow = (*BinaryRow)(nil)
+ _ proto.KeyedRow = (*TextRow)(nil)
+)
type BinaryRow struct {
- Row
+ fields []proto.Field
+ raw []byte
}
-type TextRow struct {
- Row
-}
-
-func (row *Row) Columns() []string {
- if row.ResultSet.ColumnNames != nil {
- return row.ResultSet.ColumnNames
- }
-
- columns := make([]string, len(row.ResultSet.Columns))
- if row.Content != nil {
- for i := range columns {
- field := row.ResultSet.Columns[i].(*Field)
- if tableName := field.table; len(tableName) > 0 {
- columns[i] = tableName + "." + field.name
- } else {
- columns[i] = field.name
- }
- }
- } else {
- for i := range columns {
- field := row.ResultSet.Columns[i].(*Field)
- columns[i] = field.name
+func (bi BinaryRow) Get(name string) (proto.Value, error) {
+ idx := -1
+ for i, it := range bi.fields {
+ if it.Name() == name {
+ idx = i
+ break
}
}
-
- row.ResultSet.ColumnNames = columns
- return columns
-}
-
-func (row *Row) Fields() []proto.Field {
- return row.ResultSet.Columns
-}
-
-func (row *Row) Data() []byte {
- return row.Content
-}
-
-func (row *Row) Encode(values []*proto.Value, columns []proto.Field, columnNames []string) proto.Row {
- var bf bytes.Buffer
- row.ResultSet = &ResultSet{
- Columns: columns,
- ColumnNames: columnNames,
+ if idx == -1 {
+ return nil, errors.Errorf("no such field '%s' found", name)
}
- for _, val := range values {
- bf.Write(val.Raw)
+ dest := make([]proto.Value, len(bi.fields))
+ if err := bi.Scan(dest); err != nil {
+ return nil, errors.WithStack(err)
}
- row.Content = bf.Bytes()
- return row
-}
-func (row *Row) Decode() ([]*proto.Value, error) {
- return nil, nil
+ return dest[idx], nil
}
-func (row *Row) GetColumnValue(column string) (interface{}, error) {
- values, err := row.Decode()
- if err != nil {
- return nil, err
- }
- for _, value := range values {
- if string(value.Raw) == column {
- return value.Val, nil
- }
- }
- return nil, nil
+func (bi BinaryRow) Fields() []proto.Field {
+ return bi.fields
}
-func (rows *TextRow) Encode(row []*proto.Value, fields []proto.Field, columnNames []string) proto.Row {
- var val []byte
-
- for i := 0; i < len(fields); i++ {
- field := fields[i].(*Field)
- switch field.fieldType {
- case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime,
- mysql.FieldTypeDate, mysql.FieldTypeNewDate:
- data, err := appendDateTime(row[i].Raw, row[i].Val.(time.Time))
- if err != nil {
- log.Errorf("appendDateTime fail, val=%+v, err=%v", &row[i], err)
- return nil
- }
- val = append(val, data...)
- default:
- val = PutLengthEncodedString(row[i].Raw)
- }
- }
-
- rows.ResultSet = &ResultSet{
- Columns: fields,
- ColumnNames: columnNames,
+func NewBinaryRow(fields []proto.Field, raw []byte) BinaryRow {
+ return BinaryRow{
+ fields: fields,
+ raw: raw,
}
-
- rows.Content = val
- return rows
}
-func (rows *TextRow) Decode() ([]*proto.Value, error) {
- dest := make([]*proto.Value, len(rows.ResultSet.Columns))
-
- // RowSet Packet
- var val []byte
- var isNull bool
- var n int
- var err error
- pos := 0
-
- for i := 0; i < len(rows.ResultSet.Columns); i++ {
- field := rows.ResultSet.Columns[i].(*Field)
-
- // Read bytes and convert to string
- val, isNull, n, err = readLengthEncodedString(rows.Content[pos:])
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: val,
- }
- pos += n
- if err == nil {
- if !isNull {
- switch field.fieldType {
- case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime,
- mysql.FieldTypeDate, mysql.FieldTypeNewDate:
- dest[i].Val, err = parseDateTime(
- val,
- time.Local,
- )
- if err == nil {
- continue
- }
- default:
- continue
- }
- } else {
- dest[i].Val = nil
- continue
- }
- }
- return nil, err // err != nil
- }
-
- return dest, nil
+func (bi BinaryRow) IsBinary() bool {
+ return true
}
-func (rows *BinaryRow) Encode(row []*proto.Value, fields []proto.Field, columnNames []string) proto.Row {
- length := 0
- nullBitMapLen := (len(fields) + 7 + 2) / 8
- for _, val := range row {
- if val != nil && val.Val != nil {
- l, err := val2MySQLLen(val)
- if err != nil {
- return nil
- }
- length += l
- }
- }
-
- length += nullBitMapLen + 1
-
- Data := *bufPool.Get(length)
- pos := 0
-
- pos = writeByte(Data, pos, 0x00)
-
- for i := 0; i < nullBitMapLen; i++ {
- pos = writeByte(Data, pos, 0x00)
- }
-
- for i, val := range row {
- if val == nil || val.Val == nil {
- bytePos := (i+2)/8 + 1
- bitPos := (i + 2) % 8
- Data[bytePos] |= 1 << uint(bitPos)
- } else {
- v, err := val2MySQL(val)
- if err != nil {
- return nil
- }
- pos += copy(Data[pos:], v)
- }
- }
-
- if pos != length {
- return nil
- }
+func (bi BinaryRow) Length() int {
+ return len(bi.raw)
+}
- rows.ResultSet = &ResultSet{
- Columns: fields,
- ColumnNames: columnNames,
+func (bi BinaryRow) WriteTo(w io.Writer) (n int64, err error) {
+ var wrote int
+ wrote, err = w.Write(bi.raw)
+ if err != nil {
+ return
}
-
- rows.Content = Data
- return rows
+ n += int64(wrote)
+ return
}
-func (rows *BinaryRow) Decode() ([]*proto.Value, error) {
- dest := make([]*proto.Value, len(rows.ResultSet.Columns))
-
- if rows.Content[0] != mysql.OKPacket {
- return nil, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", rows)
+func (bi BinaryRow) Scan(dest []proto.Value) error {
+ if bi.raw[0] != mysql.OKPacket {
+ return mysqlErrors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", bi)
}
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
pos := 1 + (len(dest)+7+2)>>3
- nullMask := rows.Content[1:pos]
+ nullMask := bi.raw[1:pos]
- for i := 0; i < len(rows.ResultSet.Columns); i++ {
+ for i := 0; i < len(bi.fields); i++ {
// Field is NULL
// (byte >> bit-pos) % 2 == 1
if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
@@ -262,134 +111,64 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) {
continue
}
- field := rows.ResultSet.Columns[i].(*Field)
+ field := bi.fields[i].(*Field)
// Convert to byte-coded string
- switch field.fieldType {
+ // TODO Optimize storage space based on the length of data types
+ mysqlType, _ := mysql.TypeToMySQL(field.fieldType)
+ switch mysql.FieldType(mysqlType) {
case mysql.FieldTypeNULL:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: nil,
- Raw: nil,
- }
+ dest[i] = nil
continue
// Numeric Types
case mysql.FieldTypeTiny:
if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: int64(rows.Content[pos]),
- Raw: rows.Content[pos : pos+1],
- }
+ dest[i] = int64(bi.raw[pos])
} else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 1,
- Val: int64(int8(rows.Content[pos])),
- Raw: rows.Content[pos : pos+1],
- }
+ dest[i] = int64(int8(bi.raw[pos]))
}
pos++
continue
case mysql.FieldTypeShort, mysql.FieldTypeYear:
if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 2,
- Val: int64(binary.LittleEndian.Uint16(rows.Content[pos : pos+2])),
- Raw: rows.Content[pos : pos+1],
- }
+ dest[i] = int64(binary.LittleEndian.Uint16(bi.raw[pos : pos+2]))
} else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 2,
- Val: int64(int16(binary.LittleEndian.Uint16(rows.Content[pos : pos+2]))),
- Raw: rows.Content[pos : pos+1],
- }
+ dest[i] = int64(int16(binary.LittleEndian.Uint16(bi.raw[pos : pos+2])))
}
pos += 2
continue
case mysql.FieldTypeInt24, mysql.FieldTypeLong:
if field.flags&mysql.UnsignedFlag != 0 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: int64(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])),
- Raw: rows.Content[pos : pos+4],
- }
+ dest[i] = int64(binary.LittleEndian.Uint32(bi.raw[pos : pos+4]))
} else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: int64(int32(binary.LittleEndian.Uint32(rows.Content[pos : pos+4]))),
- Raw: rows.Content[pos : pos+4],
- }
+ dest[i] = int64(int32(binary.LittleEndian.Uint32(bi.raw[pos : pos+4])))
}
pos += 4
continue
case mysql.FieldTypeLongLong:
if field.flags&mysql.UnsignedFlag != 0 {
- val := binary.LittleEndian.Uint64(rows.Content[pos : pos+8])
+ val := binary.LittleEndian.Uint64(bi.raw[pos : pos+8])
if val > math.MaxInt64 {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: uint64ToString(val),
- Raw: rows.Content[pos : pos+8],
- }
+ dest[i] = uint64ToString(val)
} else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: int64(val),
- Raw: rows.Content[pos : pos+8],
- }
+ dest[i] = int64(val)
}
} else {
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: int64(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])),
- Raw: rows.Content[pos : pos+8],
- }
+ dest[i] = int64(binary.LittleEndian.Uint64(bi.raw[pos : pos+8]))
}
pos += 8
continue
case mysql.FieldTypeFloat:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 4,
- Val: math.Float32frombits(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])),
- Raw: rows.Content[pos : pos+4],
- }
+ dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(bi.raw[pos : pos+4]))
pos += 4
continue
case mysql.FieldTypeDouble:
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: 8,
- Val: math.Float64frombits(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])),
- Raw: rows.Content[pos : pos+8],
- }
+ dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(bi.raw[pos : pos+8]))
pos += 8
continue
@@ -402,31 +181,25 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) {
var isNull bool
var n int
var err error
- val, isNull, n, err = readLengthEncodedString(rows.Content[pos:])
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
- }
+ val, isNull, n, err = readLengthEncodedString(bi.raw[pos:])
+ dest[i] = val
pos += n
if err == nil {
if !isNull {
continue
} else {
- dest[i].Val = nil
+ dest[i] = nil
continue
}
}
- return nil, err
+ return err
case
mysql.FieldTypeDate, mysql.FieldTypeNewDate, // Date YYYY-MM-DD
mysql.FieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
- num, isNull, n := readLengthEncodedInteger(rows.Content[pos:])
+ num, isNull, n := readLengthEncodedInteger(bi.raw[pos:])
pos += n
var val interface{}
@@ -444,29 +217,14 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) {
case 1, 2, 3, 4, 5, 6:
dstlen = 8 + 1 + decimals
default:
- return nil, fmt.Errorf(
- "protocol error, illegal decimals architecture.Value %d",
- field.decimals,
- )
- }
- val, err = formatBinaryTime(rows.Content[pos:pos+int(num)], dstlen)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
+ return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals)
}
+ val, err = formatBinaryTime(bi.raw[pos:pos+int(num)], dstlen)
+ dest[i] = val
default:
- val, err = parseBinaryDateTime(num, rows.Content[pos:], time.Local)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
- }
+ val, err = parseBinaryDateTime(num, bi.raw[pos:], time.Local)
if err == nil {
+ dest[i] = val
break
}
@@ -480,34 +238,141 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) {
case 1, 2, 3, 4, 5, 6:
dstlen = 19 + 1 + decimals
default:
- return nil, fmt.Errorf(
- "protocol error, illegal decimals architecture.Value %d",
- field.decimals,
- )
+ return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals)
}
}
- val, err = formatBinaryDateTime(rows.Content[pos:pos+int(num)], dstlen)
- dest[i] = &proto.Value{
- Typ: field.fieldType,
- Flags: field.flags,
- Len: n,
- Val: val,
- Raw: rows.Content[pos : pos+n],
+ val, err = formatBinaryDateTime(bi.raw[pos:pos+int(num)], dstlen)
+ if err != nil {
+ return errors.WithStack(err)
}
+ dest[i] = val
}
if err == nil {
pos += int(num)
continue
- } else {
- return nil, err
}
+ return err
+
// Please report if this happens!
default:
- return nil, fmt.Errorf("unknown field type %d", field.fieldType)
+ return errors.Errorf("unknown field type %d", field.fieldType)
+ }
+ }
+
+ return nil
+}
+
+type TextRow struct {
+ fields []proto.Field
+ raw []byte
+}
+
+func (te TextRow) Get(name string) (proto.Value, error) {
+ idx := -1
+ for i, it := range te.fields {
+ if it.Name() == name {
+ idx = i
+ break
+ }
+ }
+ if idx == -1 {
+ return nil, errors.Errorf("no such field '%s' found", name)
+ }
+
+ dest := make([]proto.Value, len(te.fields))
+ if err := te.Scan(dest); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ return dest[idx], nil
+}
+
+func (te TextRow) Fields() []proto.Field {
+ return te.fields
+}
+
+func NewTextRow(fields []proto.Field, raw []byte) TextRow {
+ return TextRow{
+ fields: fields,
+ raw: raw,
+ }
+}
+
+func (te TextRow) IsBinary() bool {
+ return false
+}
+
+func (te TextRow) Length() int {
+ return len(te.raw)
+}
+
+func (te TextRow) WriteTo(w io.Writer) (n int64, err error) {
+ var wrote int
+ wrote, err = w.Write(te.raw)
+ if err != nil {
+ return
+ }
+ n += int64(wrote)
+ return
+}
+
+func (te TextRow) Scan(dest []proto.Value) error {
+ // RowSet Packet
+ // var val []byte
+ // var isNull bool
+ var (
+ n int
+ err error
+ pos int
+ isNull bool
+ )
+
+ // TODO: support parseTime
+
+ var (
+ b []byte
+ loc = time.Local
+ )
+
+ for i := 0; i < len(te.fields); i++ {
+ b, isNull, n, err = readLengthEncodedString(te.raw[pos:])
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ pos += n
+
+ if isNull {
+ dest[i] = nil
+ continue
+ }
+
+ switch te.fields[i].(*Field).fieldType {
+ case mysql.FieldTypeString, mysql.FieldTypeVarString, mysql.FieldTypeVarChar:
+ dest[i] = string(b)
+ case mysql.FieldTypeTiny, mysql.FieldTypeShort, mysql.FieldTypeLong,
+ mysql.FieldTypeInt24, mysql.FieldTypeLongLong, mysql.FieldTypeYear:
+ if te.fields[i].(*Field).flags&mysql.UnsignedFlag > 0 {
+ dest[i], err = strconv.ParseUint(string(b), 10, 64)
+ } else {
+ dest[i], err = strconv.ParseInt(string(b), 10, 64)
+ }
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ case mysql.FieldTypeFloat, mysql.FieldTypeDouble, mysql.FieldTypeNewDecimal, mysql.FieldTypeDecimal:
+ if dest[i], err = strconv.ParseFloat(string(b), 64); err != nil {
+ return errors.WithStack(err)
+ }
+ case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, mysql.FieldTypeDate, mysql.FieldTypeNewDate:
+ if dest[i], err = parseDateTime(b, loc); err != nil {
+ return errors.WithStack(err)
+ }
+ default:
+ dest[i] = b
}
}
- return dest, nil
+ return nil
}
diff --git a/pkg/mysql/rows/codec.go b/pkg/mysql/rows/codec.go
new file mode 100644
index 000000000..2dfdb8456
--- /dev/null
+++ b/pkg/mysql/rows/codec.go
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rows
+
+import (
+ "bytes"
+ "math"
+ "time"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/util/bytesconv"
+)
+
+const _day = time.Hour * 24
+
+type ValueWriter interface {
+ WriteString(s string) (int64, error)
+ WriteUint64(v uint64) (int64, error)
+ WriteUint32(v uint32) (int64, error)
+ WriteUint16(v uint16) (int64, error)
+ WriteUint8(v uint8) (int64, error)
+ WriteFloat64(f float64) (int64, error)
+ WriteFloat32(f float32) (int64, error)
+ WriteDate(t time.Time) (int64, error)
+ WriteDateTime(t time.Time) (int64, error)
+ WriteDuration(d time.Duration) (int64, error)
+}
+
+var _ ValueWriter = (*BinaryValueWriter)(nil)
+
+type BinaryValueWriter bytes.Buffer
+
+func (bw *BinaryValueWriter) Bytes() []byte {
+ return (*bytes.Buffer)(bw).Bytes()
+}
+
+func (bw *BinaryValueWriter) Reset() {
+ (*bytes.Buffer)(bw).Reset()
+}
+
+func (bw *BinaryValueWriter) WriteDuration(d time.Duration) (n int64, err error) {
+ b := bw.buffer()
+ if d == 0 {
+ if err = b.WriteByte(0); err != nil {
+ return
+ }
+ n++
+ return
+ }
+
+ var (
+ length uint8
+ neg byte
+ )
+
+ if d < 0 {
+ neg = 1
+ d *= -1
+ }
+
+ var (
+ days = uint32(d / _day)
+ hours = uint8(d % _day / time.Hour)
+ minutes = uint8(d % time.Hour / time.Minute)
+ secs = uint8(d % time.Minute / time.Second)
+ msec = uint32(d % time.Millisecond / time.Microsecond)
+ )
+
+ if msec == 0 {
+ length = 8
+ } else {
+ length = 12
+ }
+
+ if err = b.WriteByte(length); err != nil {
+ return
+ }
+ if err = b.WriteByte(neg); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint32(days); err != nil {
+ return
+ }
+ if err = b.WriteByte(hours); err != nil {
+ return
+ }
+ if err = b.WriteByte(minutes); err != nil {
+ return
+ }
+ if err = b.WriteByte(secs); err != nil {
+ return
+ }
+
+ if length == 12 {
+ if _, err = bw.WriteUint32(msec); err != nil {
+ return
+ }
+ }
+
+ n = int64(length) + 1
+
+ return
+}
+
+func (bw *BinaryValueWriter) WriteDate(t time.Time) (int64, error) {
+ if t.IsZero() {
+ return bw.writeTimeN(t, 0)
+ }
+ return bw.writeTimeN(t, 4)
+}
+
+func (bw *BinaryValueWriter) WriteDateTime(t time.Time) (int64, error) {
+ if t.IsZero() {
+ return bw.writeTimeN(t, 0)
+ }
+
+ if t.Nanosecond()/int(time.Microsecond) == 0 {
+ return bw.writeTimeN(t, 7)
+ }
+
+ return bw.writeTimeN(t, 11)
+}
+
+func (bw *BinaryValueWriter) writeTimeN(t time.Time, l int) (n int64, err error) {
+ var (
+ year = uint16(t.Year())
+ month = uint8(t.Month())
+ day = uint8(t.Day())
+ hour = uint8(t.Hour())
+ minute = uint8(t.Minute())
+ sec = uint8(t.Second())
+ )
+
+ if _, err = bw.WriteUint8(byte(l)); err != nil {
+ return
+ }
+
+ switch l {
+ case 0:
+ case 4:
+ if _, err = bw.WriteUint16(year); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(month); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(day); err != nil {
+ return
+ }
+ case 7:
+ if _, err = bw.WriteUint16(year); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(month); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(day); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(hour); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(minute); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(sec); err != nil {
+ return
+ }
+ case 11:
+ if _, err = bw.WriteUint16(year); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(month); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(day); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(hour); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(minute); err != nil {
+ return
+ }
+ if _, err = bw.WriteUint8(sec); err != nil {
+ return
+ }
+ msec := uint32(int64(t.Nanosecond()) / int64(time.Microsecond))
+ if _, err = bw.WriteUint32(msec); err != nil {
+ return
+ }
+ default:
+ err = errors.Errorf("illegal time length %d", l)
+ return
+ }
+
+ n = int64(l) + 1
+ return
+}
+
+func (bw *BinaryValueWriter) WriteFloat32(f float32) (int64, error) {
+ n := math.Float32bits(f)
+ return bw.WriteUint32(n)
+}
+
+func (bw *BinaryValueWriter) WriteFloat64(f float64) (int64, error) {
+ n := math.Float64bits(f)
+ return bw.WriteUint64(n)
+}
+
+func (bw *BinaryValueWriter) WriteUint8(v uint8) (int64, error) {
+ if err := bw.buffer().WriteByte(v); err != nil {
+ return 0, err
+ }
+ return 1, nil
+}
+
+func (bw *BinaryValueWriter) WriteString(s string) (int64, error) {
+ return bw.writeLenEncString(s)
+}
+
+func (bw *BinaryValueWriter) WriteUint16(v uint16) (n int64, err error) {
+ if err = bw.buffer().WriteByte(byte(v)); err != nil {
+ return
+ }
+ n++
+ if err = bw.buffer().WriteByte(byte(v >> 8)); err != nil {
+ return
+ }
+ n++
+ return
+}
+
+func (bw *BinaryValueWriter) WriteUint32(v uint32) (n int64, err error) {
+ b := bw.buffer()
+ if err = b.WriteByte(byte(v)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 8)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 16)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 24)); err != nil {
+ return
+ }
+ n += 4
+ return
+}
+
+func (bw *BinaryValueWriter) WriteUint64(v uint64) (n int64, err error) {
+ b := bw.buffer()
+ if err = b.WriteByte(byte(v)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 8)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 16)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 24)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 32)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 40)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 48)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(v >> 56)); err != nil {
+ return
+ }
+ n += 8
+ return
+}
+
+func (bw *BinaryValueWriter) writeLenEncString(s string) (n int64, err error) {
+ var wrote int64
+
+ if wrote, err = bw.writeLenEncInt(uint64(len(s))); err != nil {
+ return
+ }
+ n += wrote
+
+ if wrote, err = bw.writeEOFString(s); err != nil {
+ return
+ }
+ n += wrote
+
+ return
+}
+
+func (bw *BinaryValueWriter) writeEOFString(s string) (n int64, err error) {
+ var wrote int
+ if wrote, err = bw.buffer().Write(bytesconv.StringToBytes(s)); err != nil {
+ return
+ }
+ n += int64(wrote)
+ return
+}
+
+func (bw *BinaryValueWriter) writeLenEncInt(i uint64) (n int64, err error) {
+ b := bw.buffer()
+ switch {
+ case i < 251:
+ if err = b.WriteByte(byte(i)); err == nil {
+ n++
+ }
+ case i < 1<<16:
+ if err = b.WriteByte(0xfc); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 8)); err != nil {
+ return
+ }
+ n += 3
+ case i < 1<<24:
+ if err = b.WriteByte(0xfd); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 8)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 16)); err != nil {
+ return
+ }
+ n += 4
+ default:
+ if err = b.WriteByte(0xfe); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 8)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 16)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 24)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 32)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 40)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 48)); err != nil {
+ return
+ }
+ if err = b.WriteByte(byte(i >> 56)); err != nil {
+ return
+ }
+ n += 9
+ }
+
+ return
+}
+
+func (bw *BinaryValueWriter) buffer() *bytes.Buffer {
+ return (*bytes.Buffer)(bw)
+}
diff --git a/pkg/mysql/rows/codec_test.go b/pkg/mysql/rows/codec_test.go
new file mode 100644
index 000000000..ed3eab371
--- /dev/null
+++ b/pkg/mysql/rows/codec_test.go
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rows
+
+import (
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBinaryValueWriter(t *testing.T) {
+ var w BinaryValueWriter
+
+ n, err := w.WriteString("foo")
+ assert.NoError(t, err)
+ assert.Equal(t, int64(4), n)
+ assert.Equal(t, []byte{0x03, 0x66, 0x6f, 0x6f}, w.Bytes())
+
+ w.Reset()
+ n, err = w.WriteUint64(1)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(8), n)
+ assert.Equal(t, []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, w.Bytes())
+
+ w.Reset()
+ n, err = w.WriteUint32(1)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(4), n)
+ assert.Equal(t, []byte{0x01, 0x00, 0x00, 0x00}, w.Bytes())
+
+ w.Reset()
+ n, err = w.WriteUint16(1)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(2), n)
+ assert.Equal(t, []byte{0x01, 0x00}, w.Bytes())
+
+ w.Reset()
+ n, err = w.WriteFloat64(10.2)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(8), n)
+ assert.Equal(t, []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x24, 0x40}, w.Bytes())
+
+ w.Reset()
+ n, err = w.WriteFloat32(10.2)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(4), n)
+ assert.Equal(t, []byte{0x33, 0x33, 0x23, 0x41}, w.Bytes())
+
+ // -120d 19:27:30.000 001
+ w.Reset()
+ n, err = w.WriteDuration(-(120*24*time.Hour + 19*time.Hour + 27*time.Minute + 30*time.Second + 1*time.Microsecond))
+ assert.NoError(t, err)
+ assert.Equal(t, int64(13), n)
+ assert.Equal(t, []byte{0x0c, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, w.Bytes())
+
+ // -120d 19:27:30
+ w.Reset()
+ n, err = w.WriteDuration(-(120*24*time.Hour + 19*time.Hour + 27*time.Minute + 30*time.Second))
+ assert.NoError(t, err)
+ assert.Equal(t, int64(9), n)
+ assert.Equal(t, []byte{0x08, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e}, w.Bytes())
+
+ // 0d 00:00:00
+ w.Reset()
+ n, err = w.WriteDuration(0)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), n)
+ assert.Equal(t, []byte{0x00}, w.Bytes())
+}
diff --git a/pkg/mysql/rows/virtual_row.go b/pkg/mysql/rows/virtual_row.go
new file mode 100644
index 000000000..67198fc9c
--- /dev/null
+++ b/pkg/mysql/rows/virtual_row.go
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rows
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "sync"
+ "time"
+)
+
+import (
+ gxbig "github.com/dubbogo/gost/math/big"
+
+ perrors "github.com/pkg/errors"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/util/bufferpool"
+ "github.com/arana-db/arana/pkg/util/bytesconv"
+)
+
+var (
+ _ VirtualRow = (*binaryVirtualRow)(nil)
+ _ VirtualRow = (*textVirtualRow)(nil)
+)
+
+var errScanDiffSize = errors.New("cannot scan to dest with different length")
+
+// IsScanDiffSizeErr returns true if target error is caused by scanning values with different size.
+func IsScanDiffSizeErr(err error) bool {
+ return perrors.Is(err, errScanDiffSize)
+}
+
+// VirtualRow represents virtual row which is created manually.
+type VirtualRow interface {
+ proto.KeyedRow
+ // Values returns all values of current row.
+ Values() []proto.Value
+}
+
+type baseVirtualRow struct {
+ lengthOnce sync.Once
+ length int64
+
+ fields []proto.Field
+ cells []proto.Value
+}
+
+func (b *baseVirtualRow) Get(name string) (proto.Value, error) {
+ idx := -1
+ for i, it := range b.fields {
+ if it.Name() == name {
+ idx = i
+ break
+ }
+ }
+
+ if idx == -1 {
+ return nil, perrors.Errorf("no such field '%s' found", name)
+ }
+
+ return b.cells[idx], nil
+}
+
+func (b *baseVirtualRow) Scan(dest []proto.Value) error {
+ if len(dest) != len(b.cells) {
+ return perrors.WithStack(errScanDiffSize)
+ }
+ copy(dest, b.cells)
+ return nil
+}
+
+func (b *baseVirtualRow) Values() []proto.Value {
+ return b.cells
+}
+
+type binaryVirtualRow baseVirtualRow
+
+func (vi *binaryVirtualRow) Get(name string) (proto.Value, error) {
+ return (*baseVirtualRow)(vi).Get(name)
+}
+
+func (vi *binaryVirtualRow) Fields() []proto.Field {
+ return vi.fields
+}
+
+func (vi *binaryVirtualRow) Values() []proto.Value {
+ return (*baseVirtualRow)(vi).Values()
+}
+
+func (vi *binaryVirtualRow) IsBinary() bool {
+ return true
+}
+
+func (vi *binaryVirtualRow) WriteTo(w io.Writer) (n int64, err error) {
+ // https://dev.mysql.com/doc/internals/en/null-bitmap.html
+
+ // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
+ const offset = 2
+ nullBitmapLen := (len(vi.cells) + 7 + offset) >> 3
+
+ b := bufferpool.Get()
+ defer bufferpool.Put(b)
+
+ b.Grow(1 + nullBitmapLen)
+
+ b.WriteByte(0x00)
+ for i := 0; i < nullBitmapLen; i++ {
+ b.WriteByte(0x00)
+ }
+
+ bw := (*BinaryValueWriter)(b)
+
+ var next proto.Value
+ for i := 0; i < len(vi.cells); i++ {
+ next = vi.cells[i]
+ if next == nil {
+ var (
+ bytePos = (i+2)/8 + 1
+ bitPos = (i + 2) % 8
+ )
+ b.Bytes()[bytePos] |= 1 << bitPos
+ }
+
+ switch val := next.(type) {
+ case uint64:
+ _, err = bw.WriteUint64(val)
+ case int64:
+ _, err = bw.WriteUint64(uint64(val))
+ case uint32:
+ _, err = bw.WriteUint32(val)
+ case int32:
+ _, err = bw.WriteUint32(uint32(val))
+ case uint16:
+ _, err = bw.WriteUint16(val)
+ case int16:
+ _, err = bw.WriteUint16(uint16(val))
+ case string:
+ _, err = bw.WriteString(val)
+ case float32:
+ _, err = bw.WriteFloat32(val)
+ case float64:
+ _, err = bw.WriteFloat64(val)
+ case time.Time:
+ _, err = bw.WriteDateTime(val)
+ case time.Duration:
+ _, err = bw.WriteDuration(val)
+ case *gxbig.Decimal:
+ _, err = bw.WriteString(val.String())
+ default:
+ err = perrors.Errorf("unknown value type %T", val)
+ }
+
+ if err != nil {
+ return
+ }
+ }
+
+ return b.WriteTo(w)
+}
+
+func (vi *binaryVirtualRow) Length() int {
+ vi.lengthOnce.Do(func() {
+ n, _ := vi.WriteTo(io.Discard)
+ vi.length = n
+ })
+ return int(vi.length)
+}
+
+func (vi *binaryVirtualRow) Scan(dest []proto.Value) error {
+ return (*baseVirtualRow)(vi).Scan(dest)
+}
+
+type textVirtualRow baseVirtualRow
+
+func (t *textVirtualRow) Get(name string) (proto.Value, error) {
+ return (*baseVirtualRow)(t).Get(name)
+}
+
+func (t *textVirtualRow) Fields() []proto.Field {
+ return t.fields
+}
+
+func (t *textVirtualRow) WriteTo(w io.Writer) (n int64, err error) {
+ bf := bufferpool.Get()
+ defer bufferpool.Put(bf)
+
+ for i := 0; i < len(t.fields); i++ {
+ var (
+ field = t.fields[i].(*mysql.Field)
+ cell = t.cells[i]
+ )
+
+ if cell == nil {
+ if err = bf.WriteByte(0xfb); err != nil {
+ err = perrors.WithStack(err)
+ return
+ }
+ continue
+ }
+
+ switch field.FieldType() {
+ case consts.FieldTypeTimestamp, consts.FieldTypeDateTime, consts.FieldTypeDate, consts.FieldTypeNewDate:
+ var b []byte
+ if b, err = mysql.AppendDateTime(b, t.cells[i].(time.Time)); err != nil {
+ err = perrors.WithStack(err)
+ return
+ }
+ if _, err = (*BinaryValueWriter)(bf).WriteString(bytesconv.BytesToString(b)); err != nil {
+ err = perrors.WithStack(err)
+ }
+ default:
+ var s string
+ switch val := cell.(type) {
+ case fmt.Stringer:
+ s = val.String()
+ default:
+ s = fmt.Sprint(cell)
+ }
+ if _, err = (*BinaryValueWriter)(bf).WriteString(s); err != nil {
+ err = perrors.WithStack(err)
+ return
+ }
+ }
+ }
+
+ n, err = bf.WriteTo(w)
+
+ return
+}
+
+func (t *textVirtualRow) IsBinary() bool {
+ return false
+}
+
+func (t *textVirtualRow) Length() int {
+ t.lengthOnce.Do(func() {
+ t.length, _ = t.WriteTo(io.Discard)
+ })
+ return int(t.length)
+}
+
+func (t *textVirtualRow) Scan(dest []proto.Value) error {
+ return (*baseVirtualRow)(t).Scan(dest)
+}
+
+func (t *textVirtualRow) Values() []proto.Value {
+ return (*baseVirtualRow)(t).Values()
+}
+
+// NewBinaryVirtualRow creates a virtual row with binary-protocol.
+func NewBinaryVirtualRow(fields []proto.Field, cells []proto.Value) VirtualRow {
+ return (*binaryVirtualRow)(newBaseVirtualRow(fields, cells))
+}
+
+// NewTextVirtualRow creates a virtual row with text-protocol.
+func NewTextVirtualRow(fields []proto.Field, cells []proto.Value) VirtualRow {
+ return (*textVirtualRow)(newBaseVirtualRow(fields, cells))
+}
+
+func newBaseVirtualRow(fields []proto.Field, cells []proto.Value) *baseVirtualRow {
+ if len(fields) != len(cells) {
+ panic(fmt.Sprintf("the lengths of fields and cells are doesn't match!"))
+ }
+ return &baseVirtualRow{
+ fields: fields,
+ cells: cells,
+ }
+}
diff --git a/pkg/mysql/rows/virtual_row_test.go b/pkg/mysql/rows/virtual_row_test.go
new file mode 100644
index 000000000..602250c82
--- /dev/null
+++ b/pkg/mysql/rows/virtual_row_test.go
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rows
+
+import (
+ "bytes"
+ "database/sql"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+func TestNew(t *testing.T) {
+ fields := []proto.Field{
+ mysql.NewField("name", consts.FieldTypeString),
+ mysql.NewField("uid", consts.FieldTypeLongLong),
+ mysql.NewField("created_at", consts.FieldTypeDateTime),
+ }
+
+ var (
+ row VirtualRow
+ b bytes.Buffer
+ now = time.Unix(time.Now().Unix(), 0)
+ values = []proto.Value{"foobar", int64(1), now}
+ )
+
+ t.Run("Binary", func(t *testing.T) {
+ b.Reset()
+
+ row = NewBinaryVirtualRow(fields, values)
+ _, err := row.WriteTo(&b)
+ assert.NoError(t, err)
+
+ br := mysql.NewBinaryRow(fields, b.Bytes())
+ cells := make([]proto.Value, len(fields))
+ err = br.Scan(cells)
+ assert.NoError(t, err)
+
+ var (
+ name sql.NullString
+ uid sql.NullInt64
+ createdAt sql.NullTime
+ )
+
+ _ = name.Scan(cells[0])
+ _ = uid.Scan(cells[1])
+ _ = createdAt.Scan(cells[2])
+
+ t.Log("name:", name.String)
+ t.Log("uid:", uid.Int64)
+ t.Log("created_at:", createdAt.Time)
+
+ assert.Equal(t, "foobar", name.String)
+ assert.Equal(t, int64(1), uid.Int64)
+ assert.Equal(t, now, createdAt.Time)
+ })
+
+ t.Run("Text", func(t *testing.T) {
+ b.Reset()
+
+ row = NewTextVirtualRow(fields, values)
+ _, err := row.WriteTo(&b)
+ assert.NoError(t, err)
+
+ cells := make([]proto.Value, len(fields))
+
+ tr := mysql.NewTextRow(fields, b.Bytes())
+ err = tr.Scan(cells)
+ assert.NoError(t, err)
+
+ var (
+ name sql.NullString
+ uid sql.NullInt64
+ createdAt sql.NullTime
+ )
+
+ _ = name.Scan(cells[0])
+ _ = uid.Scan(cells[1])
+ _ = createdAt.Scan(cells[2])
+
+ t.Log("name:", name.String)
+ t.Log("uid:", uid.Int64)
+ t.Log("created_at:", createdAt.Time)
+
+ assert.Equal(t, "foobar", name.String)
+ assert.Equal(t, int64(1), uid.Int64)
+ assert.Equal(t, now, createdAt.Time)
+ })
+}
diff --git a/pkg/mysql/rows_test.go b/pkg/mysql/rows_test.go
deleted file mode 100644
index 5e7f5a57b..000000000
--- a/pkg/mysql/rows_test.go
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package mysql
-
-import (
- "testing"
-)
-
-import (
- "github.com/stretchr/testify/assert"
-)
-
-import (
- "github.com/arana-db/arana/pkg/constants/mysql"
- "github.com/arana-db/arana/pkg/proto"
-)
-
-func TestFields(t *testing.T) {
- row := &Row{
- Content: createContent(),
- ResultSet: createResultSet(),
- }
- fields := row.Fields()
- assert.Equal(t, 3, len(fields))
- assert.Equal(t, "db_arana", fields[0].DataBaseName())
- assert.Equal(t, "t_order", fields[1].TableName())
- assert.Equal(t, "DECIMAL", fields[2].TypeDatabaseName())
-}
-
-func TestData(t *testing.T) {
- row := &Row{
- Content: createContent(),
- ResultSet: createResultSet(),
- }
- content := row.Data()
- assert.Equal(t, 3, len(content))
- assert.Equal(t, byte('1'), content[0])
- assert.Equal(t, byte('2'), content[1])
- assert.Equal(t, byte('3'), content[2])
-}
-
-func TestColumnsForColumnNames(t *testing.T) {
- row := &Row{
- Content: createContent(),
- ResultSet: createResultSet(),
- }
- columns := row.Columns()
- assert.Equal(t, "t_order.id", columns[0])
- assert.Equal(t, "t_order.order_id", columns[1])
- assert.Equal(t, "t_order.order_amount", columns[2])
-}
-
-func TestColumnsForColumns(t *testing.T) {
- row := &Row{
- Content: createContent(),
- ResultSet: &ResultSet{
- Columns: createColumns(),
- ColumnNames: nil,
- },
- }
- columns := row.Columns()
- assert.Equal(t, "t_order.id", columns[0])
- assert.Equal(t, "t_order.order_id", columns[1])
- assert.Equal(t, "t_order.order_amount", columns[2])
-}
-
-func TestDecodeForRow(t *testing.T) {
- row := &Row{
- Content: createContent(),
- ResultSet: createResultSet(),
- }
- val, err := row.Decode()
- //TODO row.Decode() is empty.
- assert.Nil(t, val)
- assert.Nil(t, err)
-}
-
-func TestEncodeForRow(t *testing.T) {
- values := make([]*proto.Value, 0, 3)
- names := []string{
- "id", "order_id", "order_amount",
- }
- fields := createColumns()
- for i := 0; i < 3; i++ {
- values = append(values, &proto.Value{Raw: []byte{byte(i)}, Len: 1})
- }
-
- row := &Row{}
- r := row.Encode(values, fields, names)
- assert.Equal(t, 3, len(r.Columns()))
- assert.Equal(t, 3, len(r.Fields()))
- assert.Equal(t, []byte{byte(0), byte(1), byte(2)}, r.Data())
-
-}
-
-func createContent() []byte {
- result := []byte{
- '1', '2', '3',
- }
- return result
-}
-
-func createResultSet() *ResultSet {
-
- result := &ResultSet{
- Columns: createColumns(),
- ColumnNames: createColumnNames(),
- }
-
- return result
-}
-
-func createColumns() []proto.Field {
- result := []proto.Field{
- &Field{
- database: "db_arana",
- table: "t_order",
- name: "id",
- fieldType: mysql.FieldTypeLong,
- }, &Field{
- database: "db_arana",
- table: "t_order",
- name: "order_id",
- fieldType: mysql.FieldTypeLong,
- }, &Field{
- database: "db_arana",
- table: "t_order",
- name: "order_amount",
- fieldType: mysql.FieldTypeDecimal,
- },
- }
- return result
-}
-
-func createColumnNames() []string {
- result := []string{
- "t_order.id", "t_order.order_id", "t_order.order_amount",
- }
- return result
-}
diff --git a/pkg/mysql/server.go b/pkg/mysql/server.go
index b41af0614..a79c6dedf 100644
--- a/pkg/mysql/server.go
+++ b/pkg/mysql/server.go
@@ -33,7 +33,7 @@ import (
import (
_ "github.com/arana-db/parser/test_driver"
- err2 "github.com/pkg/errors"
+ perrors "github.com/pkg/errors"
"go.uber.org/atomic"
)
@@ -157,11 +157,10 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) {
if x := recover(); x != nil {
log.Errorf("mysql_server caught panic:\n%v", x)
}
-
conn.Close()
l.executor.ConnectionClose(&proto.Context{
Context: context.Background(),
- ConnectionID: l.connectionID,
+ ConnectionID: c.ConnectionID,
})
}()
@@ -182,8 +181,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) {
for {
c.sequence = 0
- data, err := c.readEphemeralPacket()
- if err != nil {
+ var data []byte
+ if data, err = c.readEphemeralPacket(); err != nil {
// Don't log EOF errors. They cause too much spam.
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
log.Errorf("Error reading packet from %s: %v", c, err)
@@ -197,10 +196,16 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) {
Context: context.Background(),
Schema: c.Schema,
Tenant: c.Tenant,
- ConnectionID: l.connectionID,
+ ConnectionID: c.ConnectionID,
Data: content,
}
+
if err = l.ExecuteCommand(c, ctx); err != nil {
+ if err == io.EOF {
+ log.Debugf("the connection#%d of remote client %s requests quit", c.ConnectionID, c.conn.(*net.TCPConn).RemoteAddr())
+ } else {
+ log.Errorf("failed to execute command: %v", err)
+ }
return
}
}
@@ -273,20 +278,19 @@ func (l *Listener) writeHandshakeV10(c *Conn, enableTLS bool, salt []byte) error
capabilities |= mysql.CapabilityClientSSL
}
- length :=
- 1 + // protocol version
- lenNullString(l.conf.ServerVersion) +
- 4 + // connection ID
- 8 + // first part of salt Content
- 1 + // filler byte
- 2 + // capability flags (lower 2 bytes)
- 1 + // character set
- 2 + // status flag
- 2 + // capability flags (upper 2 bytes)
- 1 + // length of auth plugin Content
- 10 + // reserved (0)
- 13 + // auth-plugin-Content
- lenNullString(mysql.MysqlNativePassword) // auth-plugin-name
+ length := 1 + // protocol version
+ lenNullString(l.conf.ServerVersion) +
+ 4 + // connection ID
+ 8 + // first part of salt Content
+ 1 + // filler byte
+ 2 + // capability flags (lower 2 bytes)
+ 1 + // character set
+ 2 + // status flag
+ 2 + // capability flags (upper 2 bytes)
+ 1 + // length of auth plugin Content
+ 10 + // reserved (0)
+ 13 + // auth-plugin-Content
+ lenNullString(mysql.MysqlNativePassword) // auth-plugin-name
data := c.startEphemeralPacket(length)
pos := 0
@@ -334,7 +338,7 @@ func (l *Listener) writeHandshakeV10(c *Conn, enableTLS bool, salt []byte) error
// Sanity check.
if pos != len(data) {
- return err2.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data))
+ return perrors.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data))
}
if err := c.writeEphemeralPacket(); err != nil {
@@ -359,10 +363,10 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
// Client flags, 4 bytes.
clientFlags, pos, ok := readUint32(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read client flags")
+ return nil, perrors.New("parseClientHandshakePacket: can't read client flags")
}
if clientFlags&mysql.CapabilityClientProtocol41 == 0 {
- return nil, err2.New("parseClientHandshakePacket: only support protocol 4.1")
+ return nil, perrors.New("parseClientHandshakePacket: only support protocol 4.1")
}
// Remember a subset of the capabilities, so we can use them
@@ -381,13 +385,13 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
// See doc.go for more information.
_, pos, ok = readUint32(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read maxPacketSize")
+ return nil, perrors.New("parseClientHandshakePacket: can't read maxPacketSize")
}
// Character set. Need to handle it.
characterSet, pos, ok := readByte(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read characterSet")
+ return nil, perrors.New("parseClientHandshakePacket: can't read characterSet")
}
l.characterSet = characterSet
@@ -407,7 +411,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
// username
username, pos, ok := readNullString(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read username")
+ return nil, perrors.New("parseClientHandshakePacket: can't read username")
}
// auth-response can have three forms.
@@ -416,29 +420,29 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
var l uint64
l, pos, ok = readLenEncInt(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read auth-response variable length")
+ return nil, perrors.New("parseClientHandshakePacket: can't read auth-response variable length")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read auth-response")
+ return nil, perrors.New("parseClientHandshakePacket: can't read auth-response")
}
} else if clientFlags&mysql.CapabilityClientSecureConnection != 0 {
var l byte
l, pos, ok = readByte(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read auth-response length")
+ return nil, perrors.New("parseClientHandshakePacket: can't read auth-response length")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read auth-response")
+ return nil, perrors.New("parseClientHandshakePacket: can't read auth-response")
}
} else {
a := ""
a, pos, ok = readNullString(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read auth-response")
+ return nil, perrors.New("parseClientHandshakePacket: can't read auth-response")
}
authResponse = []byte(a)
}
@@ -449,7 +453,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
dbname := ""
dbname, pos, ok = readNullString(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read dbname")
+ return nil, perrors.New("parseClientHandshakePacket: can't read dbname")
}
schemaName = dbname
}
@@ -459,7 +463,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han
if clientFlags&mysql.CapabilityClientPluginAuth != 0 {
authMethod, pos, ok = readNullString(data, pos)
if !ok {
- return nil, err2.New("parseClientHandshakePacket: can't read authMethod")
+ return nil, perrors.New("parseClientHandshakePacket: can't read authMethod")
}
}
@@ -542,7 +546,7 @@ func (l *Listener) ExecuteCommand(c *Conn, ctx *proto.Context) error {
case mysql.ComQuit:
// https://dev.mysql.com/doc/internals/en/com-quit.html
c.recycleReadPacket()
- return err2.New("ComQuit")
+ return io.EOF
case mysql.ComInitDB:
return l.handleInitDB(c, ctx)
case mysql.ComQuery:
@@ -582,7 +586,7 @@ func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
attrLen, pos, ok := readLenEncInt(data, pos)
if !ok {
- return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attributes variable length")
+ return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attributes variable length")
}
var attrLenRead uint64
@@ -593,27 +597,27 @@ func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
var keyLen byte
keyLen, pos, ok = readByte(data, pos)
if !ok {
- return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute key length")
+ return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute key length")
}
attrLenRead += uint64(keyLen) + 1
var connAttrKey []byte
connAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen))
if !ok {
- return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute key")
+ return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute key")
}
var valLen byte
valLen, pos, ok = readByte(data, pos)
if !ok {
- return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute value length")
+ return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute value length")
}
attrLenRead += uint64(valLen) + 1
var connAttrVal []byte
connAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen))
if !ok {
- return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute value")
+ return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute value")
}
attrs[string(connAttrKey[:])] = string(connAttrVal[:])
@@ -657,7 +661,7 @@ func (c *Conn) parseComStmtExecute(stmts *sync.Map, data []byte) (uint32, byte,
if !ok {
return 0, 0, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "reading statement ID failed")
}
- //prepare, ok := stmts[stmtID]
+ // prepare, ok := stmts[stmtID]
prepare, ok := stmts.Load(stmtID)
if !ok {
return 0, 0, errors.NewSQLError(mysql.CRCommandsOutOfSync, mysql.SSUnknownSQLState, "statement ID is not found from record")
@@ -740,27 +744,15 @@ func (c *Conn) parseStmtArgs(data []byte, typ mysql.FieldType, pos int) (interfa
case mysql.FieldTypeTiny:
val, pos, ok := readByte(data, pos)
return int64(int8(val)), pos, ok
- case mysql.FieldTypeUint8:
- val, pos, ok := readByte(data, pos)
- return int64(int8(val)), pos, ok
- case mysql.FieldTypeUint16:
- val, pos, ok := readUint16(data, pos)
- return int64(int16(val)), pos, ok
case mysql.FieldTypeShort, mysql.FieldTypeYear:
val, pos, ok := readUint16(data, pos)
return int64(int16(val)), pos, ok
- case mysql.FieldTypeUint24, mysql.FieldTypeUint32:
- val, pos, ok := readUint32(data, pos)
- return int64(val), pos, ok
case mysql.FieldTypeInt24, mysql.FieldTypeLong:
val, pos, ok := readUint32(data, pos)
return int64(int32(val)), pos, ok
case mysql.FieldTypeFloat:
val, pos, ok := readUint32(data, pos)
- return math.Float32frombits(uint32(val)), pos, ok
- case mysql.FieldTypeUint64:
- val, pos, ok := readUint64(data, pos)
- return val, pos, ok
+ return math.Float32frombits(val), pos, ok
case mysql.FieldTypeLongLong:
val, pos, ok := readUint64(data, pos)
return int64(val), pos, ok
@@ -962,6 +954,57 @@ func (c *Conn) parseStmtArgs(data []byte, typ mysql.FieldType, pos int) (interfa
}
}
+func (c *Conn) DefColumnDefinition(field *Field) []byte {
+ length := 4 +
+ lenEncStringSize("def") +
+ lenEncStringSize(field.database) +
+ lenEncStringSize(field.table) +
+ lenEncStringSize(field.orgTable) +
+ lenEncStringSize(field.name) +
+ lenEncStringSize(field.orgName) +
+ 1 + // length of fixed length fields
+ 2 + // character set
+ 4 + // column length
+ 1 + // type
+ 2 + // flags
+ 1 + // decimals
+ 2 + // filler
+ lenEncStringSize(string(field.defaultValue)) // default value
+
+ // Get the type and the flags back. If the Field contains
+ // non-zero flags, we use them. Otherwise, use the flags we
+ // derive from the type.
+ typ, flags := mysql.TypeToMySQL(field.fieldType)
+ if field.flags != 0 {
+ flags = int64(field.flags)
+ }
+
+ data := make([]byte, length)
+ writeLenEncInt(data, 0, uint64(length-4))
+ writeLenEncInt(data, 3, uint64(c.sequence))
+ c.sequence++
+ pos := 4
+
+ pos = writeLenEncString(data, pos, "def") // Always same.
+ pos = writeLenEncString(data, pos, field.database)
+ pos = writeLenEncString(data, pos, field.table)
+ pos = writeLenEncString(data, pos, field.orgTable)
+ pos = writeLenEncString(data, pos, field.name)
+ pos = writeLenEncString(data, pos, field.orgName)
+ pos = writeByte(data, pos, 0x0c)
+ pos = writeUint16(data, pos, field.charSet)
+ pos = writeUint32(data, pos, field.columnLength)
+ pos = writeByte(data, pos, byte(typ))
+ pos = writeUint16(data, pos, uint16(flags))
+ pos = writeByte(data, pos, byte(field.decimals))
+ pos = writeUint16(data, pos, uint16(0x0000))
+ if len(field.defaultValue) > 0 {
+ writeLenEncString(data, pos, string(field.defaultValue))
+ }
+
+ return data
+}
+
func (c *Conn) writeColumnDefinition(field *Field) error {
length := 4 + // lenEncStringSize("def")
lenEncStringSize(field.database) +
@@ -1011,11 +1054,7 @@ func (c *Conn) writeColumnDefinition(field *Field) error {
// writeFields writes the fields of a Result. It should be called only
// if there are valid Columns in the result.
-func (c *Conn) writeFields(capabilities uint32, result proto.Result) error {
- var (
- fields = result.GetFields()
- )
-
+func (c *Conn) writeFields(capabilities uint32, fields []proto.Field) error {
// Send the number of fields first.
if err := c.sendColumnCount(uint64(len(fields))); err != nil {
return err
@@ -1039,86 +1078,35 @@ func (c *Conn) writeFields(capabilities uint32, result proto.Result) error {
return nil
}
-func (c *Conn) writeRow(row []*proto.Value) error {
- length := 0
- for _, val := range row {
- if val == nil || val.Val == nil {
- length++
- } else {
- l := len(val.Raw)
- length += lenEncIntSize(uint64(l)) + l
- }
- }
-
- data := c.startEphemeralPacket(length)
- pos := 0
- for _, val := range row {
- if val == nil || val.Val == nil {
- pos = writeByte(data, pos, mysql.NullValue)
- } else {
- l := len(val.Raw)
- pos = writeLenEncInt(data, pos, uint64(l))
- pos += copy(data[pos:], val.Raw)
- }
- }
-
- if pos != length {
- return err2.Errorf("packet row: got %v bytes but expected %v", pos, length)
+func (c *Conn) writeRow(row proto.Row) error {
+ var bf bytes.Buffer
+ n, err := row.WriteTo(&bf)
+ if err != nil {
+ return perrors.WithStack(err)
}
+ data := c.startEphemeralPacket(int(n))
+ copy(data, bf.Bytes())
return c.writeEphemeralPacket()
}
-// writeRowIter sends the rows of a Result.
-func (c *Conn) writeRowIter(result proto.Result) error {
- row := result.GetRows()[0]
- rowIter := row.(*TextIterRow)
+func (c *Conn) writeDataset(ds proto.Dataset) error {
var (
- has bool
- err error
- values []*proto.Value
+ row proto.Row
+ err error
)
- for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- if values, err = rowIter.Decode(); err != nil {
- return err
- }
- if err = c.writeRow(values); err != nil {
- return err
- }
- }
- return err
-}
-
-// writeRowChan sends the rows of a Result.
-func (c *Conn) writeRowChan(result proto.Result) error {
- res := result.(*Result)
- for row := range res.GetDataChan() {
- textRow := row.(*TextIterRow)
- values, err := textRow.Decode()
- if err != nil {
- return err
- }
- if err = c.writeRow(values); err != nil {
- return err
+ for {
+ row, err = ds.Next()
+ if perrors.Is(err, io.EOF) {
+ return nil
}
- }
- return nil
-}
-
-// writeRows sends the rows of a Result.
-func (c *Conn) writeRows(result proto.Result) error {
- for _, row := range result.GetRows() {
- r := row.(*Row)
- textRow := TextRow{*r}
- values, err := textRow.Decode()
if err != nil {
return err
}
- if err := c.writeRow(values); err != nil {
+ if err = c.writeRow(row); err != nil {
return err
}
}
- return nil
}
// writeEndResult concludes the sending of a Result.
@@ -1144,7 +1132,7 @@ func (c *Conn) writeEndResult(capabilities uint32, more bool, affectedRows, last
return nil
}
-// writePrepare writes a prepare query response to the wire.
+// writePrepare writes a prepared query response to the wire.
func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error {
paramsCount := prepare.ParamsCount
@@ -1186,473 +1174,22 @@ func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error {
return nil
}
-func (c *Conn) writeBinaryRow(fields []proto.Field, row []*proto.Value) error {
- length := 0
- nullBitMapLen := (len(fields) + 7 + 2) / 8
- for _, val := range row {
- if val != nil && val.Val != nil {
- l, err := val2MySQLLen(val)
- if err != nil {
- return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err)
- }
- length += l
- }
- }
-
- length += nullBitMapLen + 1
-
- Data := c.startEphemeralPacket(length)
- pos := 0
-
- pos = writeByte(Data, pos, 0x00)
-
- for i := 0; i < nullBitMapLen; i++ {
- pos = writeByte(Data, pos, 0x00)
- }
-
- for i, val := range row {
- if val == nil || val.Val == nil {
- bytePos := (i+2)/8 + 1
- bitPos := (i + 2) % 8
- Data[bytePos] |= 1 << uint(bitPos)
- } else {
- v, err := val2MySQL(val)
- if err != nil {
- c.recycleWritePacket()
- return fmt.Errorf("internal value %v to MySQL value error: %v", val, err)
- }
- pos += copy(Data[pos:], v)
- }
- }
-
- if pos != length {
- return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length)
- }
-
- return c.writeEphemeralPacket()
-}
-
-func (c *Conn) writeBinaryRowIter(result proto.Result) error {
- row := result.GetRows()[0]
- rowIter := row.(*BinaryIterRow)
+func (c *Conn) writeDatasetBinary(result proto.Dataset) error {
var (
- has bool
- err error
- values []*proto.Value
+ row proto.Row
+ err error
)
- for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- if values, err = rowIter.Decode(); err != nil {
- return err
- }
- if err = c.writeBinaryRow(rowIter.Fields(), values); err != nil {
- return err
- }
- }
- return err
-}
-// writeRowChan sends the rows of a Result.
-func (c *Conn) writeBinaryRowChan(result proto.Result) error {
- res := result.(*Result)
- for row := range res.GetDataChan() {
- r := row.(*BinaryIterRow)
- if err := c.writePacket(r.Data()); err != nil {
- return err
+ for {
+ row, err = result.Next()
+ if perrors.Is(err, io.EOF) {
+ return nil
}
- }
- return nil
-}
-
-// writeTextToBinaryRows sends the rows of a Result with binary form.
-func (c *Conn) writeTextToBinaryRows(result proto.Result) error {
- for _, row := range result.GetRows() {
- r := row.(*Row)
- textRow := TextRow{*r}
- values, err := textRow.Decode()
if err != nil {
return err
}
- if err := c.writeBinaryRow(result.GetFields(), values); err != nil {
+ if err = c.writeRow(row); err != nil {
return err
}
}
- return nil
-}
-
-func val2MySQL(v *proto.Value) ([]byte, error) {
- var out []byte
- pos := 0
- if v == nil {
- return out, nil
- }
-
- switch v.Typ {
- case mysql.FieldTypeNULL:
- // no-op
- case mysql.FieldTypeTiny:
- val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 1)
- writeByte(out, pos, uint8(val))
- case mysql.FieldTypeUint8:
- val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 1)
- writeByte(out, pos, uint8(val))
- case mysql.FieldTypeUint16:
- val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 16)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 2)
- writeUint16(out, pos, uint16(val))
- case mysql.FieldTypeShort, mysql.FieldTypeYear:
- val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 16)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 2)
- writeUint16(out, pos, uint16(val))
- case mysql.FieldTypeUint24, mysql.FieldTypeUint32:
- val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 32)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 4)
- writeUint32(out, pos, uint32(val))
- case mysql.FieldTypeInt24, mysql.FieldTypeLong:
- val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 32)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 4)
- writeUint32(out, pos, uint32(val))
- case mysql.FieldTypeFloat:
- val, err := strconv.ParseFloat(fmt.Sprint(v.Val), 32)
- if err != nil {
- return []byte{}, err
- }
- bits := math.Float32bits(float32(val))
- out = make([]byte, 4)
- writeUint32(out, pos, bits)
- case mysql.FieldTypeUint64:
- val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 64)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 8)
- writeUint64(out, pos, uint64(val))
- case mysql.FieldTypeLongLong:
- val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 64)
- if err != nil {
- return []byte{}, err
- }
- out = make([]byte, 8)
- writeUint64(out, pos, uint64(val))
- case mysql.FieldTypeDouble:
- val, err := strconv.ParseFloat(fmt.Sprint(v.Val), 64)
- if err != nil {
- return []byte{}, err
- }
- bits := math.Float64bits(val)
- out = make([]byte, 8)
- writeUint64(out, pos, bits)
- case mysql.FieldTypeTimestamp, mysql.FieldTypeDate, mysql.FieldTypeDateTime:
- if len(v.Raw) > 19 {
- out = make([]byte, 1+11)
- out[pos] = 0x0b
- pos++
- year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16)
- if err != nil {
- return []byte{}, err
- }
- month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- day, err := strconv.ParseUint(string(v.Raw[8:10]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- hour, err := strconv.ParseUint(string(v.Raw[11:13]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- minute, err := strconv.ParseUint(string(v.Raw[14:16]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- second, err := strconv.ParseUint(string(v.Raw[17:19]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- val := make([]byte, 6)
- count := copy(val, v.Raw[20:])
- for i := 0; i < (6 - count); i++ {
- val[count+i] = 0x30
- }
- microSecond, err := strconv.ParseUint(string(val), 10, 32)
- if err != nil {
- return []byte{}, err
- }
- pos = writeUint16(out, pos, uint16(year))
- pos = writeByte(out, pos, byte(month))
- pos = writeByte(out, pos, byte(day))
- pos = writeByte(out, pos, byte(hour))
- pos = writeByte(out, pos, byte(minute))
- pos = writeByte(out, pos, byte(second))
- writeUint32(out, pos, uint32(microSecond))
- } else if len(v.Raw) > 10 {
- out = make([]byte, 1+7)
- out[pos] = 0x07
- pos++
- year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16)
- if err != nil {
- return []byte{}, err
- }
- month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- day, err := strconv.ParseUint(string(v.Raw[8:10]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- hour, err := strconv.ParseUint(string(v.Raw[11:13]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- minute, err := strconv.ParseUint(string(v.Raw[14:16]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- second, err := strconv.ParseUint(string(v.Raw[17:]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- pos = writeUint16(out, pos, uint16(year))
- pos = writeByte(out, pos, byte(month))
- pos = writeByte(out, pos, byte(day))
- pos = writeByte(out, pos, byte(hour))
- pos = writeByte(out, pos, byte(minute))
- writeByte(out, pos, byte(second))
- } else if len(v.Raw) > 0 {
- out = make([]byte, 1+4)
- out[pos] = 0x04
- pos++
- year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16)
- if err != nil {
- return []byte{}, err
- }
- month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- day, err := strconv.ParseUint(string(v.Raw[8:]), 10, 8)
- if err != nil {
- return []byte{}, err
- }
- pos = writeUint16(out, pos, uint16(year))
- pos = writeByte(out, pos, byte(month))
- writeByte(out, pos, byte(day))
- } else {
- out = make([]byte, 1)
- out[pos] = 0x00
- }
- case mysql.FieldTypeTime:
- if string(v.Raw) == "00:00:00" {
- out = make([]byte, 1)
- out[pos] = 0x00
- } else if strings.Contains(string(v.Raw), ".") {
- out = make([]byte, 1+12)
- out[pos] = 0x0c
- pos++
-
- sub1 := strings.Split(string(v.Raw), ":")
- if len(sub1) != 3 {
- err := fmt.Errorf("incorrect time value, ':' is not found")
- return []byte{}, err
- }
- sub2 := strings.Split(sub1[2], ".")
- if len(sub2) != 2 {
- err := fmt.Errorf("incorrect time value, '.' is not found")
- return []byte{}, err
- }
-
- var total []byte
- if strings.HasPrefix(sub1[0], "-") {
- out[pos] = 0x01
- total = []byte(sub1[0])
- total = total[1:]
- } else {
- out[pos] = 0x00
- total = []byte(sub1[0])
- }
- pos++
-
- h, err := strconv.ParseUint(string(total), 10, 32)
- if err != nil {
- return []byte{}, err
- }
-
- days := uint32(h) / 24
- hours := uint32(h) % 24
- minute := sub1[1]
- second := sub2[0]
- microSecond := sub2[1]
-
- minutes, err := strconv.ParseUint(minute, 10, 8)
- if err != nil {
- return []byte{}, err
- }
-
- seconds, err := strconv.ParseUint(second, 10, 8)
- if err != nil {
- return []byte{}, err
- }
- pos = writeUint32(out, pos, uint32(days))
- pos = writeByte(out, pos, byte(hours))
- pos = writeByte(out, pos, byte(minutes))
- pos = writeByte(out, pos, byte(seconds))
-
- val := make([]byte, 6)
- count := copy(val, microSecond)
- for i := 0; i < (6 - count); i++ {
- val[count+i] = 0x30
- }
- microSeconds, err := strconv.ParseUint(string(val), 10, 32)
- if err != nil {
- return []byte{}, err
- }
- writeUint32(out, pos, uint32(microSeconds))
- } else if len(v.Raw) > 0 {
- out = make([]byte, 1+8)
- out[pos] = 0x08
- pos++
-
- sub1 := strings.Split(string(v.Raw), ":")
- if len(sub1) != 3 {
- err := fmt.Errorf("incorrect time value, ':' is not found")
- return []byte{}, err
- }
-
- var total []byte
- if strings.HasPrefix(sub1[0], "-") {
- out[pos] = 0x01
- total = []byte(sub1[0])
- total = total[1:]
- } else {
- out[pos] = 0x00
- total = []byte(sub1[0])
- }
- pos++
-
- h, err := strconv.ParseUint(string(total), 10, 32)
- if err != nil {
- return []byte{}, err
- }
-
- days := uint32(h) / 24
- hours := uint32(h) % 24
- minute := sub1[1]
- second := sub1[2]
-
- minutes, err := strconv.ParseUint(minute, 10, 8)
- if err != nil {
- return []byte{}, err
- }
-
- seconds, err := strconv.ParseUint(second, 10, 8)
- if err != nil {
- return []byte{}, err
- }
- pos = writeUint32(out, pos, uint32(days))
- pos = writeByte(out, pos, byte(hours))
- pos = writeByte(out, pos, byte(minutes))
- writeByte(out, pos, byte(seconds))
- } else {
- err := fmt.Errorf("incorrect time value")
- return []byte{}, err
- }
- case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar, mysql.FieldTypeTinyBLOB,
- mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB, mysql.FieldTypeVarString,
- mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON, mysql.FieldTypeBit,
- mysql.FieldTypeEnum, mysql.FieldTypeSet:
- l := len(v.Raw)
- length := lenEncIntSize(uint64(l)) + l
- out = make([]byte, length)
- pos = writeLenEncInt(out, pos, uint64(l))
- copy(out[pos:], v.Raw)
- default:
- out = make([]byte, len(v.Raw))
- copy(out, v.Raw)
- }
- return out, nil
-}
-
-func val2MySQLLen(v *proto.Value) (int, error) {
- var length int
- var err error
- if v == nil {
- return 0, nil
- }
-
- switch v.Typ {
- case mysql.FieldTypeNULL:
- length = 0
- case mysql.FieldTypeTiny, mysql.FieldTypeUint8:
- length = 1
- case mysql.FieldTypeUint16, mysql.FieldTypeShort, mysql.FieldTypeYear:
- length = 2
- case mysql.FieldTypeUint24, mysql.FieldTypeUint32, mysql.FieldTypeInt24, mysql.FieldTypeLong, mysql.FieldTypeFloat:
- length = 4
- case mysql.FieldTypeUint64, mysql.FieldTypeLongLong, mysql.FieldTypeDouble:
- length = 8
- case mysql.FieldTypeTimestamp, mysql.FieldTypeDate, mysql.FieldTypeDateTime:
- if len(v.Raw) > 19 {
- length = 12
- } else if len(v.Raw) > 10 {
- length = 8
- } else if len(v.Raw) > 0 {
- length = 5
- } else {
- length = 1
- }
- case mysql.FieldTypeTime:
- if string(v.Raw) == "00:00:00" {
- length = 1
- } else if strings.Contains(string(v.Raw), ".") {
- length = 13
- } else if len(v.Raw) > 0 {
- length = 9
- } else {
- err = fmt.Errorf("incorrect time value")
- }
- case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar, mysql.FieldTypeTinyBLOB,
- mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB, mysql.FieldTypeVarString,
- mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON, mysql.FieldTypeBit,
- mysql.FieldTypeEnum, mysql.FieldTypeSet:
- l := len(v.Raw)
- length = lenEncIntSize(uint64(l)) + l
- default:
- length = len(v.Raw)
- }
- if err != nil {
- return 0, err
- }
- return length, nil
-}
-
-func (c *Conn) writeBinaryRows(result proto.Result) error {
- for _, row := range result.GetRows() {
- r := row.(*Row)
- if err := c.writePacket(r.Data()); err != nil {
- return err
- }
- }
- return nil
}
diff --git a/pkg/mysql/statement.go b/pkg/mysql/statement.go
index fe2a3ef5a..398523c35 100644
--- a/pkg/mysql/statement.go
+++ b/pkg/mysql/statement.go
@@ -323,12 +323,12 @@ func (stmt *BackendStatement) writeExecutePacket(args []interface{}) error {
paramTypes[i+i+1] = 0x00
var a [64]byte
- var b = a[:0]
+ b := a[:0]
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
- b, err = appendDateTime(b, v.In(bc.conf.Loc))
+ b, err = AppendDateTime(b, v.In(bc.conf.Loc))
if err != nil {
return err
}
@@ -357,65 +357,40 @@ func (stmt *BackendStatement) writeExecutePacket(args []interface{}) error {
return bc.c.writePacket(data[4:])
}
-func (stmt *BackendStatement) execArgs(args []interface{}) (*Result, uint16, error) {
+func (stmt *BackendStatement) execArgs(args []interface{}) (proto.Result, error) {
err := stmt.writeExecutePacket(args)
if err != nil {
- return nil, 0, err
+ return nil, err
}
- affectedRows, lastInsertID, colNumber, _, warnings, err := stmt.conn.ReadComQueryResponse()
- if err != nil {
- return nil, 0, err
- }
-
- if colNumber > 0 {
- // columns
- if err = stmt.conn.DrainResults(); err != nil {
- return nil, 0, err
- }
- // rows
- if err = stmt.conn.DrainResults(); err != nil {
- return nil, 0, err
- }
- }
-
- return &Result{
- AffectedRows: affectedRows,
- InsertId: lastInsertID,
- }, warnings, nil
+ //if colNumber > 0 {
+ // // columns
+ // if err = stmt.conn.DrainResults(); err != nil {
+ // return nil, 0, err
+ // }
+ // // rows
+ // if err = stmt.conn.DrainResults(); err != nil {
+ // return nil, 0, err
+ // }
+ //}
+
+ return stmt.conn.ReadQueryRow(), nil
}
-// queryArgsIterRow is iterator for binary protocol result set
-func (stmt *BackendStatement) queryArgsIterRow(args []interface{}) (*Result, uint16, error) {
+func (stmt *BackendStatement) queryArgs(args []interface{}) (proto.Result, error) {
err := stmt.writeExecutePacket(args)
if err != nil {
- return nil, 0, err
+ return nil, err
}
- result, affectedRows, lastInsertID, _, warnings, err := stmt.conn.ReadQueryRow()
-
- iterRow := &BinaryIterRow{result}
+ res := stmt.conn.ReadQueryRow()
+ res.setWantFields(true)
+ res.setBinaryProtocol()
- return &Result{
- AffectedRows: affectedRows,
- InsertId: lastInsertID,
- Fields: iterRow.Fields(),
- Rows: []proto.Row{iterRow},
- DataChan: make(chan proto.Row, 1),
- }, warnings, err
+ return res, nil
}
-func (stmt *BackendStatement) queryArgs(args []interface{}) (*Result, uint16, error) {
- err := stmt.writeExecutePacket(args)
- if err != nil {
- return nil, 0, err
- }
-
- result, _, warnings, err := stmt.conn.ReadQueryResult(true)
- return result, warnings, err
-}
-
-func (stmt *BackendStatement) exec(args []byte) (*Result, uint16, error) {
+func (stmt *BackendStatement) exec(args []byte) (proto.Result, error) {
args[1] = byte(stmt.id)
args[2] = byte(stmt.id >> 8)
args[3] = byte(stmt.id >> 16)
@@ -426,33 +401,29 @@ func (stmt *BackendStatement) exec(args []byte) (*Result, uint16, error) {
err := stmt.conn.c.writePacket(args)
if err != nil {
- return nil, 0, err
+ return nil, err
}
stmt.conn.c.recycleWritePacket()
- affectedRows, lastInsertID, colNumber, _, warnings, err := stmt.conn.ReadComQueryResponse()
- if err != nil {
- return nil, 0, err
- }
-
- if colNumber > 0 {
- // columns
- if err = stmt.conn.DrainResults(); err != nil {
- return nil, 0, err
- }
- // rows
- if err = stmt.conn.DrainResults(); err != nil {
- return nil, 0, err
- }
- }
-
- return &Result{
- AffectedRows: affectedRows,
- InsertId: lastInsertID,
- }, warnings, nil
+ res := stmt.conn.ReadQueryRow()
+
+ //if colNumber > 0 {
+ // // columns
+ // if err = stmt.conn.DrainResults(); err != nil {
+ // return nil, 0, err
+ // }
+ // // rows
+ // if err = stmt.conn.DrainResults(); err != nil {
+ // return nil, 0, err
+ // }
+ //}
+
+ res.setBinaryProtocol()
+ res.setWantFields(true)
+ return res, nil
}
-func (stmt *BackendStatement) query(args []byte) (*Result, uint16, error) {
+func (stmt *BackendStatement) query(args []byte) (proto.Result, error) {
args[1] = byte(stmt.id)
args[2] = byte(stmt.id >> 8)
args[3] = byte(stmt.id >> 16)
@@ -463,9 +434,8 @@ func (stmt *BackendStatement) query(args []byte) (*Result, uint16, error) {
err := stmt.conn.c.writePacket(args)
if err != nil {
- return nil, 0, err
+ return nil, err
}
- result, _, warnings, err := stmt.conn.ReadQueryResult(true)
- return result, warnings, err
+ return stmt.conn.ReadQueryResult(true), nil
}
diff --git a/pkg/mysql/thead/thead.go b/pkg/mysql/thead/thead.go
new file mode 100644
index 000000000..824a2066b
--- /dev/null
+++ b/pkg/mysql/thead/thead.go
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package thead
+
+import (
+ consts "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+var (
+ Topology = Thead{
+ Col{Name: "id", FieldType: consts.FieldTypeLongLong},
+ Col{Name: "group_name", FieldType: consts.FieldTypeVarString},
+ Col{Name: "table_name", FieldType: consts.FieldTypeVarString},
+ }
+ Database = Thead{
+ Col{Name: "Database", FieldType: consts.FieldTypeVarString},
+ }
+)
+
+type Col struct {
+ Name string
+ FieldType consts.FieldType
+}
+
+type Thead []Col
+
+func (t Thead) ToFields() []proto.Field {
+ columns := make([]proto.Field, len(t))
+ for i := 0; i < len(t); i++ {
+ columns[i] = mysql.NewField(t[i].Name, t[i].FieldType)
+ }
+ return columns
+}
diff --git a/pkg/mysql/utils.go b/pkg/mysql/utils.go
index 63c0bb76c..2f6ec435e 100644
--- a/pkg/mysql/utils.go
+++ b/pkg/mysql/utils.go
@@ -295,7 +295,7 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
-func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
+func AppendDateTime(buf []byte, t time.Time) ([]byte, error) {
year, month, day := t.Date()
hour, min, sec := t.Clock()
nsec := t.Nanosecond()
@@ -335,14 +335,11 @@ func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
localBuf[19] = '.'
// milli second
- localBuf[20], localBuf[21], localBuf[22] =
- digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
+ localBuf[20], localBuf[21], localBuf[22] = digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
// micro second
- localBuf[23], localBuf[24], localBuf[25] =
- digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
+ localBuf[23], localBuf[24], localBuf[25] = digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
// nano second
- localBuf[26], localBuf[27], localBuf[28] =
- digits01[nsec100], digits10[nsec1], digits01[nsec1]
+ localBuf[26], localBuf[27], localBuf[28] = digits01[nsec100], digits10[nsec1], digits01[nsec1]
// trim trailing zeros
n := len(localBuf)
@@ -600,6 +597,11 @@ func skipLengthEncodedString(b []byte) (int, error) {
return n, io.EOF
}
+func readComFieldListDefaultValueLength(data []byte, pos int) (uint64, int) {
+ l, _, n := readLengthEncodedInteger(data[pos:])
+ return l, pos + n
+}
+
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
@@ -632,7 +634,7 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
return uint64(b[0]), false, 1
}
-// encodes a uint64 value and appends it to the given bytes slice
+// appendLengthEncodedInteger encodes an uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
@@ -1079,7 +1081,7 @@ func convertAssignRows(dest, src interface{}) error {
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
- return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
@@ -1091,7 +1093,7 @@ func convertAssignRows(dest, src interface{}) error {
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
- return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
@@ -1103,7 +1105,7 @@ func convertAssignRows(dest, src interface{}) error {
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
- return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
@@ -1121,7 +1123,7 @@ func convertAssignRows(dest, src interface{}) error {
}
}
- return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
+ return fmt.Errorf("unsupported Scan, storing driver.V type %T into type %T", src, dest)
}
func cloneBytes(b []byte) []byte {
@@ -1361,8 +1363,10 @@ func PutLengthEncodedInt(n uint64) []byte {
return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
case n <= 0xffffffffffffffff:
- return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24),
- byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}
+ return []byte{
+ 0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24),
+ byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56),
+ }
}
return nil
}
diff --git a/pkg/proto/data.go b/pkg/proto/data.go
index 074ec4e8d..f54d3a6da 100644
--- a/pkg/proto/data.go
+++ b/pkg/proto/data.go
@@ -15,82 +15,87 @@
* limitations under the License.
*/
+//go:generate mockgen -destination=../../testdata/mock_data.go -package=testdata . Field,Row,KeyedRow,Dataset,Result
package proto
import (
- "sync"
+ "io"
+ "reflect"
)
-import (
- "github.com/arana-db/parser/ast"
-)
-
-import (
- "github.com/arana-db/arana/pkg/constants/mysql"
-)
-
-// CloseableResult is a temporary solution for chan-based result.
-// Deprecated: TODO, should be removed in the future
-type CloseableResult struct {
- once sync.Once
- Result
- Closer func() error
-}
-
-func (c *CloseableResult) Close() error {
- var err error
- c.once.Do(func() {
- if c.Closer != nil {
- err = c.Closer()
- }
- })
- return err
-}
-
type (
- Value struct {
- Typ mysql.FieldType
- Flags uint
- Len int
- Val interface{}
- Raw []byte
+ // Field contains the name and type of column, it follows sql.ColumnType.
+ Field interface {
+ // Name returns the name or alias of the column.
+ Name() string
+
+ // DecimalSize returns the scale and precision of a decimal type.
+ // If not applicable or if not supported ok is false.
+ DecimalSize() (precision, scale int64, ok bool)
+
+ // ScanType returns a Go type suitable for scanning into using Rows.Scan.
+ // If a driver does not support this property ScanType will return
+ // the type of empty interface.
+ ScanType() reflect.Type
+
+ // Length returns the column type length for variable length column types such
+ // as text and binary field types. If the type length is unbounded the value will
+ // be math.MaxInt64 (any database limits will still apply).
+ // If the column type is not variable length, such as an int, or if not supported
+ // by the driver ok is false.
+ Length() (length int64, ok bool)
+
+ // Nullable reports whether the column may be null.
+ // If a driver does not support this property ok will be false.
+ Nullable() (nullable, ok bool)
+
+ // DatabaseTypeName returns the database system name of the column type. If an empty
+ // string is returned, then the driver type name is not supported.
+ // Consult your driver documentation for a list of driver data types. Length specifiers
+ // are not included.
+ // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
+ // "INT", and "BIGINT".
+ DatabaseTypeName() string
}
- Field interface {
- TableName() string
+ // Value represents the cell value of Row.
+ Value interface{}
- DataBaseName() string
+ // Row represents a row data from a result set.
+ Row interface {
+ io.WriterTo
+ IsBinary() bool
- TypeDatabaseName() string
- }
+ // Length returns the length of Row.
+ Length() int
- // Row is an iterator over an executed query's results.
- Row interface {
- // Columns returns the names of the columns. The number of
- // columns of the result is inferred from the length of the
- // slice. If a particular column name isn't known, an empty
- // string should be returned for that entry.
- Columns() []string
+ // Scan scans the Row to values.
+ Scan(dest []Value) error
+ }
+ // KeyedRow represents row with fields.
+ KeyedRow interface {
+ Row
+ // Fields returns the fields of row.
Fields() []Field
+ // Get returns the value of column name.
+ Get(name string) (Value, error)
+ }
- // Data returns the result in bytes.
- Data() []byte
-
- Decode() ([]*Value, error)
+ Dataset interface {
+ io.Closer
- GetColumnValue(column string) (interface{}, error)
+ // Fields returns the fields of Dataset.
+ Fields() ([]Field, error)
- Encode(values []*Value, columns []Field, columnNames []string) Row
+ // Next returns the next row.
+ Next() (Row, error)
}
// Result is the result of a query execution.
Result interface {
- // GetFields returns the fields.
- GetFields() []Field
-
- // GetRows returns the rows.
- GetRows() []Row
+ // Dataset returns the Dataset.
+ Dataset() (Dataset, error)
// LastInsertId returns the database's auto-generated ID
// after, for example, an INSERT into a table with primary
@@ -101,15 +106,4 @@ type (
// query.
RowsAffected() (uint64, error)
}
-
- // Stmt is a buffer used for store prepare statement meta data
- Stmt struct {
- StatementID uint32
- PrepareStmt string
- ParamsCount uint16
- ParamsType []int32
- ColumnNames []string
- BindVars map[string]interface{}
- StmtNode ast.StmtNode
- }
)
diff --git a/pkg/proto/hint/hint.go b/pkg/proto/hint/hint.go
new file mode 100644
index 000000000..5a1adb877
--- /dev/null
+++ b/pkg/proto/hint/hint.go
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package hint
+
+import (
+ "bufio"
+ "bytes"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/runtime/misc"
+)
+
+const (
+ _ Type = iota
+ TypeMaster // force route to master node
+ TypeSlave // force route to slave node
+ TypeRoute // custom route
+ TypeFullScan // enable full-scan
+ TypeDirect // direct route
+)
+
+var _hintTypes = [...]string{
+ TypeMaster: "MASTER",
+ TypeSlave: "SLAVE",
+ TypeRoute: "ROUTE",
+ TypeFullScan: "FULLSCAN",
+ TypeDirect: "DIRECT",
+}
+
+// KeyValue represents a pair of key and value.
+type KeyValue struct {
+ K string // key (optional)
+ V string // value
+}
+
+// Type represents the type of Hint.
+type Type uint8
+
+// String returns the display string.
+func (tp Type) String() string {
+ return _hintTypes[tp]
+}
+
+// Hint represents a Hint, a valid Hint should include type and input kv pairs.
+//
+// Follow the format below:
+// - without inputs: YOUR_HINT()
+// - with non-keyed inputs: YOUR_HINT(foo,bar,quz)
+// - with keyed inputs: YOUR_HINT(x=foo,y=bar,z=quz)
+//
+type Hint struct {
+ Type Type
+ Inputs []KeyValue
+}
+
+// String returns the display string.
+func (h Hint) String() string {
+ var sb strings.Builder
+ sb.WriteString(h.Type.String())
+
+ if len(h.Inputs) < 1 {
+ sb.WriteString("()")
+ return sb.String()
+ }
+
+ sb.WriteByte('(')
+
+ writeKv := func(p KeyValue) {
+ if key := p.K; len(key) > 0 {
+ sb.WriteString(key)
+ sb.WriteByte('=')
+ }
+ sb.WriteString(p.V)
+ }
+
+ writeKv(h.Inputs[0])
+ for i := 1; i < len(h.Inputs); i++ {
+ sb.WriteByte(',')
+ writeKv(h.Inputs[i])
+ }
+
+ sb.WriteByte(')')
+ return sb.String()
+}
+
+// Parse parses Hint from an input string.
+func Parse(s string) (*Hint, error) {
+ var (
+ tpStr string
+ tp Type
+ )
+
+ offset := strings.Index(s, "(")
+ if offset == -1 {
+ tpStr = s
+ } else {
+ tpStr = s[:offset]
+ }
+
+ for i, v := range _hintTypes {
+ if strings.EqualFold(tpStr, v) {
+ tp = Type(i)
+ break
+ }
+ }
+
+ if tp == 0 {
+ return nil, errors.Errorf("hint: invalid input '%s'", s)
+ }
+
+ if offset == -1 {
+ return &Hint{Type: tp}, nil
+ }
+
+ end := strings.LastIndex(s, ")")
+ if end == -1 {
+ return nil, errors.Errorf("hint: invalid input '%s'", s)
+ }
+
+ s = s[offset+1 : end]
+
+ scanner := bufio.NewScanner(strings.NewReader(s))
+ scanner.Split(scanComma)
+
+ var kvs []KeyValue
+
+ for scanner.Scan() {
+ text := scanner.Text()
+
+ // split kv by '='
+ i := strings.Index(text, "=")
+ if i == -1 {
+ // omit blank text
+ if misc.IsBlank(text) {
+ continue
+ }
+ kvs = append(kvs, KeyValue{V: strings.TrimSpace(text)})
+ } else {
+ var (
+ k = strings.TrimSpace(text[:i])
+ v = strings.TrimSpace(text[i+1:])
+ )
+ // omit blank key/value
+ if misc.IsBlank(k) || misc.IsBlank(v) {
+ continue
+ }
+ kvs = append(kvs, KeyValue{K: k, V: v})
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, errors.Wrapf(err, "hint: invalid input '%s'", s)
+ }
+
+ return &Hint{Type: tp, Inputs: kvs}, nil
+}
+
+func scanComma(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
+ }
+ if i := bytes.IndexByte(data, ','); i >= 0 {
+ return i + 1, data[0:i], nil
+ }
+ if atEOF {
+ return len(data), data, nil
+ }
+ return 0, nil, nil
+}
diff --git a/pkg/proto/hint/hint_test.go b/pkg/proto/hint/hint_test.go
new file mode 100644
index 000000000..aa812cc68
--- /dev/null
+++ b/pkg/proto/hint/hint_test.go
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package hint
+
+import (
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+func TestParse(t *testing.T) {
+ type tt struct {
+ input string
+ output string
+ pass bool
+ }
+
+ for _, next := range []tt{
+ {"route( foo , bar , qux )", "ROUTE(foo,bar,qux)", true},
+ {"master", "MASTER()", true},
+ {"slave", "SLAVE()", true},
+ {"not_exist_hint(1,2,3)", "", false},
+ {"route(,,,)", "ROUTE()", true},
+ {"fullscan()", "FULLSCAN()", true},
+ {"route(foo=111,bar=222,qux=333,)", "ROUTE(foo=111,bar=222,qux=333)", true},
+ } {
+ t.Run(next.input, func(t *testing.T) {
+ res, err := Parse(next.input)
+ if next.pass {
+ assert.NoError(t, err)
+ assert.Equal(t, next.output, res.String())
+ } else {
+ assert.Error(t, err)
+ }
+ })
+ }
+}
diff --git a/pkg/proto/interface.go b/pkg/proto/interface.go
index 3c9d70673..677abc05a 100644
--- a/pkg/proto/interface.go
+++ b/pkg/proto/interface.go
@@ -20,6 +20,7 @@ package proto
import (
"context"
"encoding/json"
+ "sort"
)
import (
@@ -45,9 +46,7 @@ type (
Listener interface {
SetExecutor(executor Executor)
-
Listen()
-
Close()
}
@@ -74,35 +73,22 @@ type (
// Executor
Executor interface {
AddPreFilter(filter PreFilter)
-
AddPostFilter(filter PostFilter)
-
GetPreFilters() []PreFilter
-
GetPostFilters() []PostFilter
-
ProcessDistributedTransaction() bool
-
InLocalTransaction(ctx *Context) bool
-
InGlobalTransaction(ctx *Context) bool
-
ExecuteUseDB(ctx *Context) error
-
ExecuteFieldList(ctx *Context) ([]Field, error)
-
ExecutorComQuery(ctx *Context) (Result, uint16, error)
-
ExecutorComStmtExecute(ctx *Context) (Result, uint16, error)
-
ConnectionClose(ctx *Context)
}
ResourceManager interface {
GetMasterResourcePool(name string) *pools.ResourcePool
-
GetSlaveResourcePool(name string) *pools.ResourcePool
-
GetMetaResourcePool(name string) *pools.ResourcePool
}
)
@@ -118,3 +104,23 @@ func (c Context) GetQuery() string {
}
return bytesconv.BytesToString(c.Data[1:])
}
+
+func (c Context) GetArgs() []interface{} {
+ if c.Stmt == nil || len(c.Stmt.BindVars) < 1 {
+ return nil
+ }
+
+ var (
+ keys = make([]string, 0, len(c.Stmt.BindVars))
+ args = make([]interface{}, 0, len(c.Stmt.BindVars))
+ )
+
+ for k := range c.Stmt.BindVars {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for _, k := range keys {
+ args = append(args, c.Stmt.BindVars[k])
+ }
+ return args
+}
diff --git a/pkg/proto/rule/range.go b/pkg/proto/rule/range.go
index 2a1ce8c29..b884918a2 100644
--- a/pkg/proto/rule/range.go
+++ b/pkg/proto/rule/range.go
@@ -42,9 +42,7 @@ const (
Ustr
)
-var (
- DefaultNumberStepper = Stepper{U: Unum, N: 1}
-)
+var DefaultNumberStepper = Stepper{U: Unum, N: 1}
// Range represents a value range.
type Range interface {
diff --git a/pkg/proto/rule/range_test.go b/pkg/proto/rule/range_test.go
index ab57a8b93..0489e12b7 100644
--- a/pkg/proto/rule/range_test.go
+++ b/pkg/proto/rule/range_test.go
@@ -37,7 +37,7 @@ func TestStepper_Date_After(t *testing.T) {
U: Uday,
}
t.Log(daySt.String())
- testTime := time.Date(2021, 1, 17, 17, 45, 04, 0, time.UTC)
+ testTime := time.Date(2021, 1, 17, 17, 45, 0o4, 0, time.UTC)
hour, err := hourSt.After(testTime)
assert.NoError(t, err)
assert.Equal(t, 19, hour.(time.Time).Hour())
diff --git a/pkg/proto/rule/topology.go b/pkg/proto/rule/topology.go
index df15bc5a6..e3f77f436 100644
--- a/pkg/proto/rule/topology.go
+++ b/pkg/proto/rule/topology.go
@@ -18,62 +18,154 @@
package rule
import (
+ "math"
"sort"
+ "sync"
)
// Topology represents the topology of databases and tables.
type Topology struct {
+ mu sync.RWMutex
dbRender, tbRender func(int) string
- idx map[int][]int
+ idx sync.Map // map[int][]int
}
// Len returns the length of database and table.
func (to *Topology) Len() (dbLen int, tblLen int) {
- dbLen = len(to.idx)
- for _, v := range to.idx {
- tblLen += len(v)
- }
+ to.idx.Range(func(_, value any) bool {
+ dbLen++
+ tblLen += len(value.([]int))
+ return true
+ })
return
}
// SetTopology sets the topology.
func (to *Topology) SetTopology(db int, tables ...int) {
- if to.idx == nil {
- to.idx = make(map[int][]int)
- }
-
if len(tables) < 1 {
- delete(to.idx, db)
+ to.idx.Delete(db)
return
}
clone := make([]int, len(tables))
copy(clone, tables)
sort.Ints(clone)
- to.idx[db] = clone
+ to.idx.Store(db, clone)
}
// SetRender sets the database/table name render.
func (to *Topology) SetRender(dbRender, tbRender func(int) string) {
+ to.mu.Lock()
to.dbRender, to.tbRender = dbRender, tbRender
+ to.mu.Unlock()
}
// Render renders the name of database and table from indexes.
func (to *Topology) Render(dbIdx, tblIdx int) (string, string, bool) {
+ to.mu.RLock()
+ defer to.mu.RUnlock()
+
if to.tbRender == nil || to.dbRender == nil {
return "", "", false
}
return to.dbRender(dbIdx), to.tbRender(tblIdx), true
}
+func (to *Topology) EnumerateDatabases() []string {
+ to.mu.RLock()
+ render := to.dbRender
+ to.mu.RUnlock()
+
+ var keys []string
+
+ to.idx.Range(func(key, _ any) bool {
+ keys = append(keys, render(key.(int)))
+ return true
+ })
+
+ sort.Strings(keys)
+
+ return keys
+}
+
+func (to *Topology) Enumerate() DatabaseTables {
+ to.mu.RLock()
+ dbRender, tbRender := to.dbRender, to.tbRender
+ to.mu.RUnlock()
+
+ dt := make(DatabaseTables)
+ to.Each(func(dbIdx, tbIdx int) bool {
+ d := dbRender(dbIdx)
+ t := tbRender(tbIdx)
+ dt[d] = append(dt[d], t)
+ return true
+ })
+
+ return dt
+}
+
// Each enumerates items in current Topology.
func (to *Topology) Each(onEach func(dbIdx, tbIdx int) (ok bool)) bool {
- for d, v := range to.idx {
+ done := true
+ to.idx.Range(func(key, value any) bool {
+ var (
+ d = key.(int)
+ v = value.([]int)
+ )
for _, t := range v {
if !onEach(d, t) {
+ done = false
return false
}
}
+ return true
+ })
+
+ return done
+}
+
+func (to *Topology) Smallest() (db, tb string, ok bool) {
+ to.mu.RLock()
+ dbRender, tbRender := to.dbRender, to.tbRender
+ to.mu.RUnlock()
+
+ smallest := [2]int{math.MaxInt64, math.MaxInt64}
+ to.idx.Range(func(key, value any) bool {
+ if d := key.(int); d < smallest[0] {
+ smallest[0] = d
+ if t := value.([]int); len(t) > 0 {
+ smallest[1] = t[0]
+ }
+ }
+ return true
+ })
+
+ if smallest[0] != math.MaxInt64 || smallest[1] != math.MaxInt64 {
+ db, tb, ok = dbRender(smallest[0]), tbRender(smallest[1]), true
+ }
+
+ return
+}
+
+func (to *Topology) Largest() (db, tb string, ok bool) {
+ to.mu.RLock()
+ dbRender, tbRender := to.dbRender, to.tbRender
+ to.mu.RUnlock()
+
+ largest := [2]int{math.MinInt64, math.MinInt64}
+ to.idx.Range(func(key, value any) bool {
+ if d := key.(int); d > largest[0] {
+ largest[0] = d
+ if t := value.([]int); len(t) > 0 {
+ largest[1] = t[len(t)-1]
+ }
+ }
+ return true
+ })
+
+ if largest[0] != math.MinInt64 || largest[1] != math.MinInt64 {
+ db, tb, ok = dbRender(largest[0]), tbRender(largest[1]), true
}
- return true
+
+ return
}
diff --git a/pkg/proto/rule/topology_test.go b/pkg/proto/rule/topology_test.go
index 637f438f6..f62dcfe14 100644
--- a/pkg/proto/rule/topology_test.go
+++ b/pkg/proto/rule/topology_test.go
@@ -33,24 +33,18 @@ func TestLen(t *testing.T) {
assert.Equal(t, 6, tblLen)
}
-func TestSetTopologyForIdxNil(t *testing.T) {
- topology := &Topology{
- idx: nil,
- }
+func TestSetTopology(t *testing.T) {
+ var topology Topology
+
topology.SetTopology(2, 2, 3, 4)
- for each := range topology.idx {
- assert.Equal(t, 2, each)
- assert.Equal(t, 3, len(topology.idx[each]))
- }
dbLen, tblLen := topology.Len()
assert.Equal(t, 1, dbLen)
assert.Equal(t, 3, tblLen)
}
-func TestSetTopologyForIdxNotNil(t *testing.T) {
- topology := &Topology{
- idx: map[int][]int{0: []int{1, 2, 3}},
- }
+func TestSetTopologyNoConflict(t *testing.T) {
+ var topology Topology
+ topology.SetTopology(0, 1, 2, 3)
topology.SetTopology(1, 4, 5, 6)
dbLen, tblLen := topology.Len()
assert.Equal(t, 2, dbLen)
@@ -58,14 +52,12 @@ func TestSetTopologyForIdxNotNil(t *testing.T) {
}
func TestSetTopologyForTablesLessThanOne(t *testing.T) {
- topology := &Topology{
- idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}},
- }
+ var topology Topology
+
+ topology.SetTopology(0, 1, 2, 3)
+ topology.SetTopology(1, 4, 5, 6)
topology.SetTopology(1)
- for each := range topology.idx {
- assert.Equal(t, 0, each)
- assert.Equal(t, 3, len(topology.idx[each]))
- }
+
dbLen, tblLen := topology.Len()
assert.Equal(t, 1, dbLen)
assert.Equal(t, 3, tblLen)
@@ -104,6 +96,36 @@ func TestTopology_Each(t *testing.T) {
t.Logf("on each: %d,%d\n", dbIdx, tbIdx)
return true
})
+
+ assert.False(t, topology.Each(func(dbIdx, tbIdx int) bool {
+ return false
+ }))
+}
+
+func TestTopology_Enumerate(t *testing.T) {
+ topology := createTopology()
+ shards := topology.Enumerate()
+ assert.Greater(t, shards.Len(), 0)
+}
+
+func TestTopology_EnumerateDatabases(t *testing.T) {
+ topology := createTopology()
+ dbs := topology.EnumerateDatabases()
+ assert.Greater(t, len(dbs), 0)
+}
+
+func TestTopology_Largest(t *testing.T) {
+ topology := createTopology()
+ db, tb, ok := topology.Largest()
+ assert.True(t, ok)
+ t.Logf("largest: %s.%s\n", db, tb)
+}
+
+func TestTopology_Smallest(t *testing.T) {
+ topology := createTopology()
+ db, tb, ok := topology.Smallest()
+ assert.True(t, ok)
+ t.Logf("smallest: %s.%s\n", db, tb)
}
func createTopology() *Topology {
@@ -114,7 +136,10 @@ func createTopology() *Topology {
tbRender: func(i int) string {
return fmt.Sprintf("%s:%d", "tbRender", i)
},
- idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}},
}
+
+ result.SetTopology(0, 1, 2, 3)
+ result.SetTopology(1, 4, 5, 6)
+
return result
}
diff --git a/pkg/proto/runtime.go b/pkg/proto/runtime.go
index db87fb475..58d387c43 100644
--- a/pkg/proto/runtime.go
+++ b/pkg/proto/runtime.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-//go:generate mockgen -destination=../../testdata/mock_runtime.go -package=testdata . VConn,Plan,Optimizer,DB,SchemaLoader
+//go:generate mockgen -destination=../../testdata/mock_runtime.go -package=testdata . VConn,Plan,Optimizer,DB
package proto
import (
@@ -24,10 +24,6 @@ import (
"time"
)
-import (
- "github.com/arana-db/parser/ast"
-)
-
const (
PlanTypeQuery PlanType = iota // QUERY
PlanTypeExec // EXEC
@@ -56,7 +52,7 @@ type (
// Optimizer represents a sql statement optimizer which can be used to create QueryPlan or ExecPlan.
Optimizer interface {
// Optimize optimizes the sql with arguments then returns a Plan.
- Optimize(ctx context.Context, conn VConn, stmt ast.StmtNode, args ...interface{}) (Plan, error)
+ Optimize(ctx context.Context) (Plan, error)
}
// Weight represents the read/write weight info.
@@ -106,6 +102,7 @@ type (
// Tx represents transaction.
Tx interface {
Executable
+ VConn
// ID returns the unique transaction id.
ID() int64
// Commit commits current transaction.
@@ -113,8 +110,4 @@ type (
// Rollback rollbacks current transaction.
Rollback(ctx context.Context) (Result, uint16, error)
}
-
- SchemaLoader interface {
- Load(ctx context.Context, conn VConn, schema string, tables []string) map[string]*TableMetadata
- }
)
diff --git a/pkg/proto/metadata.go b/pkg/proto/schema.go
similarity index 74%
rename from pkg/proto/metadata.go
rename to pkg/proto/schema.go
index 3da651c42..6ab374661 100644
--- a/pkg/proto/metadata.go
+++ b/pkg/proto/schema.go
@@ -15,9 +15,11 @@
* limitations under the License.
*/
+//go:generate mockgen -destination=../../testdata/mock_schema.go -package=testdata . SchemaLoader
package proto
import (
+ "context"
"strings"
)
@@ -66,3 +68,29 @@ type ColumnMetadata struct {
type IndexMetadata struct {
Name string
}
+
+var _defaultSchemaLoader SchemaLoader
+
+func RegisterSchemaLoader(l SchemaLoader) {
+ _defaultSchemaLoader = l
+}
+
+func LoadSchemaLoader() SchemaLoader {
+ cur := _defaultSchemaLoader
+ if cur == nil {
+ return noopSchemaLoader{}
+ }
+ return cur
+}
+
+// SchemaLoader represents a schema discovery.
+type SchemaLoader interface {
+ // Load loads the schema.
+ Load(ctx context.Context, schema string, table []string) (map[string]*TableMetadata, error)
+}
+
+type noopSchemaLoader struct{}
+
+func (n noopSchemaLoader) Load(_ context.Context, _ string, _ []string) (map[string]*TableMetadata, error) {
+ return nil, nil
+}
diff --git a/pkg/proto/schema_manager/loader.go b/pkg/proto/schema_manager/loader.go
deleted file mode 100644
index 6cef03e50..000000000
--- a/pkg/proto/schema_manager/loader.go
+++ /dev/null
@@ -1,238 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package schema_manager
-
-import (
- "context"
- "fmt"
- "io"
- "strings"
-)
-
-import (
- "github.com/arana-db/arana/pkg/mysql"
- "github.com/arana-db/arana/pkg/proto"
- rcontext "github.com/arana-db/arana/pkg/runtime/context"
- "github.com/arana-db/arana/pkg/util/log"
-)
-
-const (
- orderByOrdinalPosition = " ORDER BY ORDINAL_POSITION"
- tableMetadataNoOrder = "SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE, COLUMN_KEY, EXTRA, COLLATION_NAME, ORDINAL_POSITION FROM information_schema.columns WHERE TABLE_SCHEMA=database()"
- tableMetadataSQL = tableMetadataNoOrder + orderByOrdinalPosition
- tableMetadataSQLInTables = tableMetadataNoOrder + " AND TABLE_NAME IN (%s)" + orderByOrdinalPosition
- indexMetadataSQL = "SELECT TABLE_NAME, INDEX_NAME FROM information_schema.statistics WHERE TABLE_SCHEMA=database() AND TABLE_NAME IN (%s)"
-)
-
-type SimpleSchemaLoader struct{}
-
-func (l *SimpleSchemaLoader) Load(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string]*proto.TableMetadata {
- ctx = rcontext.WithRead(rcontext.WithDirect(ctx))
- var (
- tableMetadataMap = make(map[string]*proto.TableMetadata, len(tables))
- indexMetadataMap map[string][]*proto.IndexMetadata
- columnMetadataMap map[string][]*proto.ColumnMetadata
- )
- columnMetadataMap = l.LoadColumnMetadataMap(ctx, conn, schema, tables)
- if columnMetadataMap != nil {
- indexMetadataMap = l.LoadIndexMetadata(ctx, conn, schema, tables)
- }
-
- for tableName, columns := range columnMetadataMap {
- tableMetadataMap[tableName] = proto.NewTableMetadata(tableName, columns, indexMetadataMap[tableName])
- }
-
- return tableMetadataMap
-}
-
-func (l *SimpleSchemaLoader) LoadColumnMetadataMap(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string][]*proto.ColumnMetadata {
- resultSet, err := conn.Query(ctx, schema, getColumnMetadataSQL(tables))
- if err != nil {
- return nil
- }
- if closer, ok := resultSet.(io.Closer); ok {
- defer func() {
- _ = closer.Close()
- }()
- }
-
- result := make(map[string][]*proto.ColumnMetadata, 0)
- if err != nil {
- log.Errorf("Load ColumnMetadata error when call db: %v", err)
- return nil
- }
- if resultSet == nil {
- log.Error("Load ColumnMetadata error because the result is nil")
- return nil
- }
-
- row := resultSet.GetRows()[0]
- var rowIter mysql.Iter
- switch r := row.(type) {
- case *mysql.BinaryIterRow:
- rowIter = r
- case *mysql.TextIterRow:
- rowIter = r
- }
-
- var (
- has bool
- rowValues []*proto.Value
- )
- for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- if rowValues, err = row.Decode(); err != nil {
- return nil
- }
- tableName := convertInterfaceToStrNullable(rowValues[0].Val)
- columnName := convertInterfaceToStrNullable(rowValues[1].Val)
- dataType := convertInterfaceToStrNullable(rowValues[2].Val)
- columnKey := convertInterfaceToStrNullable(rowValues[3].Val)
- extra := convertInterfaceToStrNullable(rowValues[4].Val)
- collationName := convertInterfaceToStrNullable(rowValues[5].Val)
- ordinalPosition := convertInterfaceToStrNullable(rowValues[6].Val)
- result[tableName] = append(result[tableName], &proto.ColumnMetadata{
- Name: columnName,
- DataType: dataType,
- Ordinal: ordinalPosition,
- PrimaryKey: strings.EqualFold("PRI", columnKey),
- Generated: strings.EqualFold("auto_increment", extra),
- CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"),
- })
- }
-
- //for _, row := range resultSet.GetRows() {
- // var innerRow mysql.Row
- // switch r := row.(type) {
- // case *mysql.BinaryRow:
- // innerRow = r.Row
- // case *mysql.Row:
- // innerRow = *r
- // case *mysql.TextRow:
- // innerRow = r.Row
- // }
- // textRow := mysql.TextRow{Row: innerRow}
- // rowValues, err := textRow.Decode()
- // if err != nil {
- // //logger.Errorf("Load ColumnMetadata error when decode text row: %v", err)
- // return nil
- // }
- // tableName := convertInterfaceToStrNullable(rowValues[0].Val)
- // columnName := convertInterfaceToStrNullable(rowValues[1].Val)
- // dataType := convertInterfaceToStrNullable(rowValues[2].Val)
- // columnKey := convertInterfaceToStrNullable(rowValues[3].Val)
- // extra := convertInterfaceToStrNullable(rowValues[4].Val)
- // collationName := convertInterfaceToStrNullable(rowValues[5].Val)
- // ordinalPosition := convertInterfaceToStrNullable(rowValues[6].Val)
- // result[tableName] = append(result[tableName], &proto.ColumnMetadata{
- // Name: columnName,
- // DataType: dataType,
- // Ordinal: ordinalPosition,
- // PrimaryKey: strings.EqualFold("PRI", columnKey),
- // Generated: strings.EqualFold("auto_increment", extra),
- // CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"),
- // })
- //}
- return result
-}
-
-func convertInterfaceToStrNullable(value interface{}) string {
- if value != nil {
- return string(value.([]byte))
- }
- return ""
-}
-
-func (l *SimpleSchemaLoader) LoadIndexMetadata(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string][]*proto.IndexMetadata {
- resultSet, err := conn.Query(ctx, schema, getIndexMetadataSQL(tables))
- if err != nil {
- return nil
- }
-
- if closer, ok := resultSet.(io.Closer); ok {
- defer func() {
- _ = closer.Close()
- }()
- }
-
- result := make(map[string][]*proto.IndexMetadata, 0)
-
- row := resultSet.GetRows()[0]
- var rowIter mysql.Iter
- switch r := row.(type) {
- case *mysql.BinaryIterRow:
- rowIter = r
- case *mysql.TextIterRow:
- rowIter = r
- }
-
- var (
- has bool
- rowValues []*proto.Value
- )
- for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- if rowValues, err = row.Decode(); err != nil {
- return nil
- }
- tableName := convertInterfaceToStrNullable(rowValues[0].Val)
- indexName := convertInterfaceToStrNullable(rowValues[1].Val)
- result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName})
- }
-
- //for _, row := range resultSet.GetRows() {
- // var innerRow mysql.Row
- // switch r := row.(type) {
- // case *mysql.BinaryRow:
- // innerRow = r.Row
- // case *mysql.Row:
- // innerRow = *r
- // case *mysql.TextRow:
- // innerRow = r.Row
- // }
- // textRow := mysql.TextRow{Row: innerRow}
- // rowValues, err := textRow.Decode()
- // if err != nil {
- // log.Errorf("Load ColumnMetadata error when decode text row: %v", err)
- // return nil
- // }
- // tableName := convertInterfaceToStrNullable(rowValues[0].Val)
- // indexName := convertInterfaceToStrNullable(rowValues[1].Val)
- // result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName})
- //}
-
- return result
-}
-
-func getIndexMetadataSQL(tables []string) string {
- tableParamList := make([]string, 0, len(tables))
- for _, table := range tables {
- tableParamList = append(tableParamList, "'"+table+"'")
- }
- return fmt.Sprintf(indexMetadataSQL, strings.Join(tableParamList, ","))
-}
-
-func getColumnMetadataSQL(tables []string) string {
- if len(tables) == 0 {
- return tableMetadataSQL
- }
- tableParamList := make([]string, len(tables))
- for i, table := range tables {
- tableParamList[i] = "'" + table + "'"
- }
- // TODO use strings.Builder in the future
- return fmt.Sprintf(tableMetadataSQLInTables, strings.Join(tableParamList, ","))
-}
diff --git a/pkg/proto/stmt.go b/pkg/proto/stmt.go
new file mode 100644
index 000000000..d37c14e19
--- /dev/null
+++ b/pkg/proto/stmt.go
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package proto
+
+import (
+ "github.com/arana-db/parser/ast"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto/hint"
+)
+
+// Stmt is a buffer used for store prepare statement metadata.
+type Stmt struct {
+ StatementID uint32
+ PrepareStmt string
+ ParamsCount uint16
+ ParamsType []int32
+ ColumnNames []string
+ BindVars map[string]interface{}
+ Hints []*hint.Hint
+ StmtNode ast.StmtNode
+}
diff --git a/pkg/resultx/resultx.go b/pkg/resultx/resultx.go
new file mode 100644
index 000000000..d6b6266d6
--- /dev/null
+++ b/pkg/resultx/resultx.go
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package resultx
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+)
+
+var (
+ _ proto.Result = (*emptyResult)(nil) // contains nothing
+ _ proto.Result = (*slimResult)(nil) // only contains rows-affected and last-insert-id, design for exec
+ _ proto.Result = (*dsResult)(nil) // only contains dataset, design for query
+ _ proto.Result = (*fullResult)(nil) // contains all
+)
+
+type option struct {
+ ds proto.Dataset
+ id, affected uint64
+}
+
+// Option represents the option to create a result.
+type Option func(*option)
+
+// WithLastInsertID specify the last-insert-id for the result to be created.
+func WithLastInsertID(id uint64) Option {
+ return func(o *option) {
+ o.id = id
+ }
+}
+
+// WithRowsAffected specify the rows-affected for the result to be created.
+func WithRowsAffected(n uint64) Option {
+ return func(o *option) {
+ o.affected = n
+ }
+}
+
+// WithDataset specify the dataset for the result to be created.
+func WithDataset(d proto.Dataset) Option {
+ return func(o *option) {
+ o.ds = d
+ }
+}
+
+// New creates a result from some options.
+func New(options ...Option) proto.Result {
+ var o option
+ for _, it := range options {
+ it(&o)
+ }
+
+ // When execute EXEC, no need to specify dataset.
+ if o.ds == nil {
+ if o.id == 0 && o.affected == 0 {
+ return emptyResult{}
+ }
+ return slimResult{o.id, o.affected}
+ }
+
+ // When execute QUERY, only dataset is required.
+ if o.id == 0 && o.affected == 0 {
+ return dsResult{ds: o.ds}
+ }
+
+ // should never happen
+ return fullResult{
+ ds: o.ds,
+ id: o.id,
+ affected: o.affected,
+ }
+}
+
+type emptyResult struct{}
+
+func (n emptyResult) Dataset() (proto.Dataset, error) {
+ return nil, nil
+}
+
+func (n emptyResult) LastInsertId() (uint64, error) {
+ return 0, nil
+}
+
+func (n emptyResult) RowsAffected() (uint64, error) {
+ return 0, nil
+}
+
+type slimResult [2]uint64 // [lastInsertId,rowsAffected]
+
+func (h slimResult) Dataset() (proto.Dataset, error) {
+ return nil, nil
+}
+
+func (h slimResult) LastInsertId() (uint64, error) {
+ return h[0], nil
+}
+
+func (h slimResult) RowsAffected() (uint64, error) {
+ return h[1], nil
+}
+
+type fullResult struct {
+ ds proto.Dataset
+ id uint64
+ affected uint64
+}
+
+func (f fullResult) Dataset() (proto.Dataset, error) {
+ return f.ds, nil
+}
+
+func (f fullResult) LastInsertId() (uint64, error) {
+ return f.id, nil
+}
+
+func (f fullResult) RowsAffected() (uint64, error) {
+ return f.affected, nil
+}
+
+type dsResult struct {
+ ds proto.Dataset
+}
+
+func (d dsResult) Dataset() (proto.Dataset, error) {
+ return d.ds, nil
+}
+
+func (d dsResult) LastInsertId() (uint64, error) {
+ return 0, nil
+}
+
+func (d dsResult) RowsAffected() (uint64, error) {
+ return 0, nil
+}
+
+func Drain(result proto.Result) {
+ if d, _ := result.Dataset(); d != nil {
+ defer func() {
+ _ = d.Close()
+ }()
+ }
+ return
+}
diff --git a/pkg/runtime/ast/alter_table.go b/pkg/runtime/ast/alter_table.go
index a60f1596c..eab6fd8dd 100644
--- a/pkg/runtime/ast/alter_table.go
+++ b/pkg/runtime/ast/alter_table.go
@@ -179,5 +179,5 @@ func (at *AlterTableStatement) CntParams() int {
}
func (at *AlterTableStatement) Mode() SQLType {
- return SalterTable
+ return SQLTypeAlterTable
}
diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go
index 5c6f7e897..6a02d7da5 100644
--- a/pkg/runtime/ast/ast.go
+++ b/pkg/runtime/ast/ast.go
@@ -34,20 +34,19 @@ import (
)
import (
+ "github.com/arana-db/arana/pkg/proto/hint"
"github.com/arana-db/arana/pkg/runtime/cmp"
"github.com/arana-db/arana/pkg/runtime/logical"
)
-var (
- _opcode2comparison = map[opcode.Op]cmp.Comparison{
- opcode.EQ: cmp.Ceq,
- opcode.NE: cmp.Cne,
- opcode.LT: cmp.Clt,
- opcode.GT: cmp.Cgt,
- opcode.LE: cmp.Clte,
- opcode.GE: cmp.Cgte,
- }
-)
+var _opcode2comparison = map[opcode.Op]cmp.Comparison{
+ opcode.EQ: cmp.Ceq,
+ opcode.NE: cmp.Cne,
+ opcode.LT: cmp.Clt,
+ opcode.GT: cmp.Cgt,
+ opcode.LE: cmp.Clte,
+ opcode.GE: cmp.Cgte,
+}
type (
parseOption struct {
@@ -98,7 +97,7 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) {
}
switch tgt := result.(type) {
case *ShowColumns:
- return &DescribeStatement{Table: tgt.tableName}, nil
+ return &DescribeStatement{Table: tgt.TableName, Column: tgt.Column}, nil
default:
return &ExplainStatement{tgt: tgt}, nil
}
@@ -108,11 +107,52 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) {
return cc.convDropTableStmt(stmt), nil
case *ast.AlterTableStmt:
return cc.convAlterTableStmt(stmt), nil
+ case *ast.DropIndexStmt:
+ return cc.convDropIndexStmt(stmt), nil
+ case *ast.DropTriggerStmt:
+ return cc.convDropTrigger(stmt), nil
+ case *ast.CreateIndexStmt:
+ return cc.convCreateIndexStmt(stmt), nil
default:
return nil, errors.Errorf("unimplement: stmt type %T!", stmt)
}
}
+func (cc *convCtx) convDropIndexStmt(stmt *ast.DropIndexStmt) *DropIndexStatement {
+ var tableName TableName
+ if db := stmt.Table.Schema.O; len(db) > 0 {
+ tableName = append(tableName, db)
+ }
+ tableName = append(tableName, stmt.Table.Name.O)
+ return &DropIndexStatement{
+ IfExists: stmt.IfExists,
+ IndexName: stmt.IndexName,
+ Table: tableName,
+ }
+}
+
+func (cc *convCtx) convCreateIndexStmt(stmt *ast.CreateIndexStmt) *CreateIndexStatement {
+ var tableName TableName
+ if db := stmt.Table.Schema.O; len(db) > 0 {
+ tableName = append(tableName, db)
+ }
+ tableName = append(tableName, stmt.Table.Name.O)
+
+ keys := make([]*IndexPartSpec, len(stmt.IndexPartSpecifications))
+ for i, k := range stmt.IndexPartSpecifications {
+ keys[i] = &IndexPartSpec{
+ Column: cc.convColumn(k.Column),
+ Expr: toExpressionNode(cc.convExpr(k.Expr)),
+ }
+ }
+
+ return &CreateIndexStatement{
+ Table: tableName,
+ IndexName: stmt.IndexName,
+ Keys: keys,
+ }
+}
+
func (cc *convCtx) convAlterTableStmt(stmt *ast.AlterTableStmt) *AlterTableStatement {
var tableName TableName
if db := stmt.Table.Schema.O; len(db) > 0 {
@@ -282,7 +322,7 @@ func (cc *convCtx) convConstraint(c *ast.Constraint) *Constraint {
}
func (cc *convCtx) convDropTableStmt(stmt *ast.DropTableStmt) *DropTableStatement {
- var tables = make([]*TableName, len(stmt.Tables))
+ tables := make([]*TableName, len(stmt.Tables))
for i, table := range stmt.Tables {
tables[i] = &TableName{
table.Name.String(),
@@ -506,6 +546,23 @@ func (cc *convCtx) convInsertStmt(stmt *ast.InsertStmt) Statement {
}
}
+ if stmt.Select != nil {
+ switch v := stmt.Select.(type) {
+ case *ast.SelectStmt:
+ return &InsertSelectStatement{
+ baseInsertStatement: &bi,
+ sel: cc.convSelectStmt(v),
+ duplicatedUpdates: updates,
+ }
+ case *ast.SetOprStmt:
+ return &InsertSelectStatement{
+ baseInsertStatement: &bi,
+ unionSel: cc.convUnionStmt(v),
+ duplicatedUpdates: updates,
+ }
+ }
+ }
+
return &InsertStatement{
baseInsertStatement: &bi,
values: values,
@@ -520,12 +577,30 @@ func (cc *convCtx) convTruncateTableStmt(node *ast.TruncateTableStmt) Statement
}
func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement {
+ toIn := func(node *ast.ShowStmt) (string, bool) {
+ if node.DBName == "" {
+ return "", false
+ }
+ return node.DBName, true
+ }
+ toFrom := func(node *ast.ShowStmt) (FromTable, bool) {
+ if node.Table == nil {
+ return "", false
+ }
+ return FromTable(node.Table.Name.String()), true
+ }
toWhere := func(node *ast.ShowStmt) (ExpressionNode, bool) {
if node.Where == nil {
return nil, false
}
return toExpressionNode(cc.convExpr(node.Where)), true
}
+ toShowLike := func(node *ast.ShowStmt) (PredicateNode, bool) {
+ if node.Pattern == nil {
+ return nil, false
+ }
+ return cc.convPatternLikeExpr(node.Pattern), true
+ }
toLike := func(node *ast.ShowStmt) (string, bool) {
if node.Pattern == nil {
return "", false
@@ -535,15 +610,23 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement {
toBaseShow := func() *baseShow {
var bs baseShow
- if like, ok := toLike(node); ok {
+ if like, ok := toShowLike(node); ok {
bs.filter = like
} else if where, ok := toWhere(node); ok {
bs.filter = where
+ } else if in, ok := toIn(node); ok {
+ bs.filter = in
+ } else if from, ok := toFrom(node); ok {
+ bs.filter = from
}
return &bs
}
switch node.Tp {
+ case ast.ShowTopology:
+ return &ShowTopology{baseShow: toBaseShow()}
+ case ast.ShowOpenTables:
+ return &ShowOpenTables{baseShow: toBaseShow()}
case ast.ShowTables:
return &ShowTables{baseShow: toBaseShow()}
case ast.ShowDatabases:
@@ -555,7 +638,7 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement {
}
case ast.ShowIndex:
ret := &ShowIndex{
- tableName: []string{node.Table.Name.O},
+ TableName: []string{node.Table.Name.O},
}
if where, ok := toWhere(node); ok {
ret.where = where
@@ -563,7 +646,10 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement {
return ret
case ast.ShowColumns:
ret := &ShowColumns{
- tableName: []string{node.Table.Name.O},
+ TableName: []string{node.Table.Name.O},
+ }
+ if node.Column != nil {
+ ret.Column = node.Column.Name.O
}
if node.Extended {
ret.flag |= scFlagExtended
@@ -599,28 +685,46 @@ func convInsertColumns(columnNames []*ast.ColumnName) []string {
}
// Parse parses the SQL string to Statement.
-func Parse(sql string, options ...ParseOption) (Statement, error) {
+func Parse(sql string, options ...ParseOption) ([]*hint.Hint, Statement, error) {
var o parseOption
for _, it := range options {
it(&o)
}
p := parser.New()
- s, err := p.ParseOneStmt(sql, o.charset, o.collation)
+ s, hintStrs, err := p.ParseOneStmtHints(sql, o.charset, o.collation)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ stmt, err := FromStmtNode(s)
if err != nil {
- return nil, err
+ return nil, nil, err
+ }
+
+ if len(hintStrs) < 1 {
+ return nil, stmt, nil
+ }
+
+ hints := make([]*hint.Hint, 0, len(hintStrs))
+ for _, it := range hintStrs {
+ var h *hint.Hint
+ if h, err = hint.Parse(it); err != nil {
+ return nil, nil, errors.WithStack(err)
+ }
+ hints = append(hints, h)
}
- return FromStmtNode(s)
+ return hints, stmt, nil
}
// MustParse parses the SQL string to Statement, panic if failed.
-func MustParse(sql string) Statement {
- stmt, err := Parse(sql)
+func MustParse(sql string) ([]*hint.Hint, Statement) {
+ hints, stmt, err := Parse(sql)
if err != nil {
panic(err.Error())
}
- return stmt
+ return hints, stmt
}
type convCtx struct {
@@ -981,9 +1085,7 @@ func (cc *convCtx) convAggregateFuncExpr(node *ast.AggregateFuncExpr) PredicateN
}
func (cc *convCtx) convFuncCallExpr(expr *ast.FuncCallExpr) PredicateNode {
- var (
- fnName = strings.ToUpper(expr.FnName.O)
- )
+ fnName := strings.ToUpper(expr.FnName.O)
// NOTICE: tidb-parser cannot process CONVERT('foobar' USING utf8).
// It should be a CastFunc, but now will be parsed as a FuncCall.
@@ -1101,13 +1203,13 @@ func (cc *convCtx) toArg(arg ast.ExprNode) *FunctionArg {
func (cc *convCtx) convPatternLikeExpr(expr *ast.PatternLikeExpr) PredicateNode {
var (
- left = cc.convExpr(expr.Expr)
- right = cc.convExpr(expr.Pattern)
+ left, _ = cc.convExpr(expr.Expr).(PredicateNode)
+ right, _ = cc.convExpr(expr.Pattern).(PredicateNode)
)
return &LikePredicateNode{
Not: expr.Not,
- Left: left.(PredicateNode),
- Right: right.(PredicateNode),
+ Left: left,
+ Right: right,
}
}
@@ -1199,7 +1301,6 @@ func (cc *convCtx) convValueExpr(expr ast.ValueExpr) PredicateNode {
default:
if val == nil {
atom = &ConstantExpressionAtom{Inner: Null{}}
-
} else {
atom = &ConstantExpressionAtom{Inner: val}
}
@@ -1283,9 +1384,7 @@ func (cc *convCtx) convBinaryOperationExpr(expr *ast.BinaryOperationExpr) interf
Right: right.(*AtomPredicateNode).A,
}}
case opcode.EQ, opcode.NE, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
- var (
- op = _opcode2comparison[expr.Op]
- )
+ op := _opcode2comparison[expr.Op]
if !isColumnAtom(left.(PredicateNode)) && isColumnAtom(right.(PredicateNode)) {
// do reverse:
@@ -1389,6 +1488,15 @@ func (cc *convCtx) convTableName(val *ast.TableName, tgt *TableSourceNode) {
tgt.partitions = partitions
}
+func (cc *convCtx) convDropTrigger(stmt *ast.DropTriggerStmt) *DropTriggerStatement {
+ var tableName TableName
+ if db := stmt.Trigger.Schema.O; len(db) > 0 {
+ tableName = append(tableName, db)
+ }
+ tableName = append(tableName, stmt.Trigger.Name.O)
+ return &DropTriggerStatement{Table: tableName, IfExists: stmt.IfExists}
+}
+
func toExpressionNode(src interface{}) ExpressionNode {
if src == nil {
return nil
diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go
index 90b4a5651..50a10a51d 100644
--- a/pkg/runtime/ast/ast_test.go
+++ b/pkg/runtime/ast/ast_test.go
@@ -46,39 +46,39 @@ func TestParse(t *testing.T) {
"select * from student where uid = !0",
} {
t.Run(sql, func(t *testing.T) {
- stmt, err = Parse(sql)
+ _, stmt, err = Parse(sql)
assert.NoError(t, err)
t.Log("stmt:", stmt)
})
}
// 1. select statement
- stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age")
+ _, stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age")
assert.NoError(t, err, "parse+conv ast failed")
t.Logf("stmt:%+v", stmt)
// 2. delete statement
- deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)")
+ _, deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)")
assert.NoError(t, err, "parse+conv ast failed")
t.Logf("stmt:%+v", deleteStmt)
// 3. insert statements
- insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'")
+ _, insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'")
assert.NoError(t, err, "parse+conv ast failed")
t.Logf("stmt:%+v", insertStmtWithSetClause)
- insertStmtWithValues, err := Parse("insert into sink values(1, '2')")
+ _, insertStmtWithValues, err := Parse("insert into sink values(1, '2')")
assert.NoError(t, err, "parse+conv ast failed")
t.Logf("stmt:%+v", insertStmtWithValues)
- insertStmtWithOnDuplicateUpdates, err := Parse(
+ _, insertStmtWithOnDuplicateUpdates, err := Parse(
"insert into sink (a, b) values(1, '2') on duplicate key update a=a+1",
)
assert.NoError(t, err, "parse+conv ast failed")
t.Logf("stmt:%+v", insertStmtWithOnDuplicateUpdates)
// 4. update statement
- updateStmt, err := Parse(
+ _, updateStmt, err := Parse(
"update source set a=a+1, b=b+2 where a>1 order by a limit 5",
)
assert.NoError(t, err, "parse+conv ast failed")
@@ -98,7 +98,7 @@ func TestParse_UnionStmt(t *testing.T) {
{"select id,uid,name,nickname from student where uid in (?,?,?) union all select id,uid,name,nickname from tb_user where uid in (?,?,?)", "SELECT `id`,`uid`,`name`,`nickname` FROM `student` WHERE `uid` IN (?,?,?) UNION ALL SELECT `id`,`uid`,`name`,`nickname` FROM `tb_user` WHERE `uid` IN (?,?,?)"},
} {
t.Run(next.input, func(t *testing.T) {
- stmt, err := Parse(next.input)
+ _, stmt, err := Parse(next.input)
assert.NoError(t, err, "should parse ok")
assert.IsType(t, (*UnionSelectStatement)(nil), stmt, "should be union statement")
@@ -106,9 +106,7 @@ func TestParse_UnionStmt(t *testing.T) {
assert.NoError(t, err, "should restore ok")
assert.Equal(t, next.expect, actual)
})
-
}
-
}
func TestParse_SelectStmt(t *testing.T) {
@@ -162,7 +160,7 @@ func TestParse_SelectStmt(t *testing.T) {
{"select null as pkid", "SELECT NULL AS `pkid`"},
} {
t.Run(next.input, func(t *testing.T) {
- stmt, err := Parse(next.input)
+ _, stmt, err := Parse(next.input)
assert.NoError(t, err, "should parse ok")
assert.IsType(t, (*SelectStatement)(nil), stmt, "should be select statement")
@@ -171,7 +169,6 @@ func TestParse_SelectStmt(t *testing.T) {
assert.Equal(t, next.expect, actual)
})
}
-
}
func TestParse_DeleteStmt(t *testing.T) {
@@ -185,7 +182,7 @@ func TestParse_DeleteStmt(t *testing.T) {
{"delete low_priority quick ignore from student where id = 1", "DELETE LOW_PRIORITY QUICK IGNORE FROM `student` WHERE `id` = 1"},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsType(t, (*DeleteStatement)(nil), stmt, "should be delete statement")
@@ -206,7 +203,7 @@ func TestParse_DescribeStatement(t *testing.T) {
{"desc foobar", "DESC `foobar`"},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsType(t, (*DescribeStatement)(nil), stmt, "should be describe statement")
@@ -228,7 +225,12 @@ func TestParse_ShowStatement(t *testing.T) {
{"show databases", (*ShowDatabases)(nil), "SHOW DATABASES"},
{"show databases like '%foo%'", (*ShowDatabases)(nil), "SHOW DATABASES LIKE '%foo%'"},
{"show databases where name = 'foobar'", (*ShowDatabases)(nil), "SHOW DATABASES WHERE `name` = 'foobar'"},
+ {"show open tables", (*ShowOpenTables)(nil), "SHOW OPEN TABLES"},
+ {"show open tables in foobar", (*ShowOpenTables)(nil), "SHOW OPEN TABLES IN `foobar`"},
+ {"show open tables like '%foo%'", (*ShowOpenTables)(nil), "SHOW OPEN TABLES LIKE '%foo%'"},
+ {"show open tables where name = 'foo'", (*ShowOpenTables)(nil), "SHOW OPEN TABLES WHERE `name` = 'foo'"},
{"show tables", (*ShowTables)(nil), "SHOW TABLES"},
+ {"show tables in foobar", (*ShowTables)(nil), "SHOW TABLES IN `foobar`"},
{"show tables like '%foo%'", (*ShowTables)(nil), "SHOW TABLES LIKE '%foo%'"},
{"show tables where name = 'foo'", (*ShowTables)(nil), "SHOW TABLES WHERE `name` = 'foo'"},
{"sHow indexes from foo", (*ShowIndex)(nil), "SHOW INDEXES FROM `foo`"},
@@ -238,7 +240,7 @@ func TestParse_ShowStatement(t *testing.T) {
{"show create table `foo`", (*ShowCreate)(nil), "SHOW CREATE TABLE `foo`"},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsTypef(t, it.expectTyp, stmt, "should be %T", it.expectTyp)
@@ -247,11 +249,10 @@ func TestParse_ShowStatement(t *testing.T) {
assert.Equal(t, it.expect, actual)
})
}
-
}
func TestParse_ExplainStmt(t *testing.T) {
- stmt, err := Parse("explain select * from student where uid = 1")
+ _, stmt, err := Parse("explain select * from student where uid = 1")
assert.NoError(t, err)
assert.IsType(t, (*ExplainStatement)(nil), stmt)
s := MustRestoreToString(RestoreDefault, stmt)
@@ -286,7 +287,7 @@ func TestParseMore(t *testing.T) {
for _, sql := range tbls {
t.Run(sql, func(t *testing.T) {
- _, err := Parse(sql)
+ _, _, err := Parse(sql)
assert.NoError(t, err)
})
}
@@ -303,7 +304,7 @@ func TestParse_UpdateStmt(t *testing.T) {
{"update low_priority student set nickname = ? where id = 1 limit 1", "UPDATE LOW_PRIORITY `student` SET `nickname` = ? WHERE `id` = 1 LIMIT 1"},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsTypef(t, (*UpdateStatement)(nil), stmt, "should be update statement")
@@ -311,7 +312,6 @@ func TestParse_UpdateStmt(t *testing.T) {
assert.NoError(t, err, "should restore ok")
assert.Equal(t, it.expect, actual)
})
-
}
}
@@ -337,7 +337,7 @@ func TestParse_InsertStmt(t *testing.T) {
},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsTypef(t, (*InsertStatement)(nil), stmt, "should be insert statement")
@@ -347,10 +347,38 @@ func TestParse_InsertStmt(t *testing.T) {
})
}
+ for _, it := range []tt{
+ {
+ "insert into student select * from student_tmp",
+ "INSERT INTO `student` SELECT * FROM `student_tmp`",
+ },
+ {
+ "insert into student(id,name) select emp_no, name from employees limit 10,2",
+ "INSERT INTO `student`(`id`, `name`) SELECT `emp_no`,`name` FROM `employees` LIMIT 10,2",
+ },
+ {
+ "insert into student(id,name) select emp_no, name from employees on duplicate key update version=version+1,modified_at=NOW()",
+ "INSERT INTO `student`(`id`, `name`) SELECT `emp_no`,`name` FROM `employees` ON DUPLICATE KEY UPDATE `version` = `version`+1, `modified_at` = NOW()",
+ },
+ {
+ "insert student select id, score from student_tmp union select id * 10, score * 10 from student_tmp",
+ "INSERT INTO `student` SELECT `id`,`score` FROM `student_tmp` UNION SELECT `id`*10,`score`*10 FROM `student_tmp`",
+ },
+ } {
+ t.Run(it.input, func(t *testing.T) {
+ _, stmt, err := Parse(it.input)
+ assert.NoError(t, err)
+ assert.IsTypef(t, (*InsertSelectStatement)(nil), stmt, "should be insert-select statement")
+
+ actual, err := RestoreToString(RestoreDefault, stmt.(Restorer))
+ assert.NoError(t, err, "should restore ok")
+ assert.Equal(t, it.expect, actual)
+ })
+ }
}
func TestRestoreCount(t *testing.T) {
- stmt := MustParse("select count(1)")
+ _, stmt := MustParse("select count(1)")
sel := stmt.(*SelectStatement)
var sb strings.Builder
_ = sel.Restore(RestoreDefault, &sb, nil)
@@ -358,7 +386,7 @@ func TestRestoreCount(t *testing.T) {
}
func TestQuote(t *testing.T) {
- stmt := MustParse("select `a``bc`")
+ _, stmt := MustParse("select `a``bc`")
sel := stmt.(*SelectStatement)
var sb strings.Builder
_ = sel.Restore(RestoreDefault, &sb, nil)
@@ -403,7 +431,7 @@ func TestParse_AlterTableStmt(t *testing.T) {
},
} {
t.Run(it.input, func(t *testing.T) {
- stmt, err := Parse(it.input)
+ _, stmt, err := Parse(it.input)
assert.NoError(t, err)
assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement")
@@ -412,5 +440,15 @@ func TestParse_AlterTableStmt(t *testing.T) {
assert.Equal(t, it.expect, actual)
})
}
+}
+func TestParse_DescStmt(t *testing.T) {
+ _, stmt := MustParse("desc student id")
+ // In MySQL, the case of "desc student 'id'" will be parsed successfully,
+ // but in arana, it will get an error by tidb parser.
+ desc := stmt.(*DescribeStatement)
+ var sb strings.Builder
+ _ = desc.Restore(RestoreDefault, &sb, nil)
+ t.Logf(sb.String())
+ assert.Equal(t, "DESC `student` `id`", sb.String())
}
diff --git a/pkg/runtime/ast/create_index.go b/pkg/runtime/ast/create_index.go
new file mode 100644
index 000000000..050713bc0
--- /dev/null
+++ b/pkg/runtime/ast/create_index.go
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ast
+
+import (
+ "strings"
+)
+
+var (
+ _ Statement = (*CreateIndexStatement)(nil)
+ _ Restorer = (*CreateIndexStatement)(nil)
+)
+
+type CreateIndexStatement struct {
+ IndexName string
+ Table TableName
+ Keys []*IndexPartSpec
+}
+
+func (c *CreateIndexStatement) CntParams() int {
+ return 0
+}
+
+func (c *CreateIndexStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
+ sb.WriteString("CREATE INDEX ")
+ sb.WriteString(c.IndexName)
+ if len(c.Table) == 0 {
+ return nil
+ }
+ sb.WriteString(" ON ")
+ if err := c.Table.Restore(flag, sb, args); err != nil {
+ return err
+ }
+
+ sb.WriteString(" (")
+ for i, k := range c.Keys {
+ if i != 0 {
+ sb.WriteString(", ")
+ }
+ if err := k.Restore(flag, sb, args); err != nil {
+ return err
+ }
+ }
+ sb.WriteString(")")
+ return nil
+}
+
+func (c *CreateIndexStatement) Validate() error {
+ return nil
+}
+
+func (c *CreateIndexStatement) Mode() SQLType {
+ return SQLTypeCreateIndex
+}
diff --git a/pkg/runtime/ast/delete.go b/pkg/runtime/ast/delete.go
index 90987637a..b8b38f05c 100644
--- a/pkg/runtime/ast/delete.go
+++ b/pkg/runtime/ast/delete.go
@@ -109,7 +109,7 @@ func (ds *DeleteStatement) CntParams() int {
}
func (ds *DeleteStatement) Mode() SQLType {
- return Sdelete
+ return SQLTypeDelete
}
func (ds *DeleteStatement) IsLowPriority() bool {
diff --git a/pkg/runtime/ast/describe.go b/pkg/runtime/ast/describe.go
index 3107e2831..3523ff482 100644
--- a/pkg/runtime/ast/describe.go
+++ b/pkg/runtime/ast/describe.go
@@ -33,7 +33,7 @@ var (
// DescribeStatement represents mysql describe statement. see https://dev.mysql.com/doc/refman/8.0/en/describe.html
type DescribeStatement struct {
Table TableName
- column string
+ Column string
}
// Restore implements Restorer.
@@ -42,9 +42,9 @@ func (d *DescribeStatement) Restore(flag RestoreFlag, sb *strings.Builder, args
if err := d.Table.Restore(flag, sb, args); err != nil {
return errors.WithStack(err)
}
- if len(d.column) > 0 {
+ if len(d.Column) > 0 {
sb.WriteByte(' ')
- WriteID(sb, d.column)
+ WriteID(sb, d.Column)
}
return nil
@@ -59,14 +59,7 @@ func (d *DescribeStatement) CntParams() int {
}
func (d *DescribeStatement) Mode() SQLType {
- return Squery
-}
-
-func (d *DescribeStatement) Column() (string, bool) {
- if len(d.column) > 0 {
- return d.column, true
- }
- return "", false
+ return SQLTypeDescribe
}
// ExplainStatement represents mysql explain statement. see https://dev.mysql.com/doc/refman/8.0/en/explain.html
@@ -95,5 +88,5 @@ func (e *ExplainStatement) CntParams() int {
}
func (e *ExplainStatement) Mode() SQLType {
- return Squery
+ return SQLTypeSelect
}
diff --git a/pkg/runtime/ast/drop_index.go b/pkg/runtime/ast/drop_index.go
new file mode 100644
index 000000000..58918a21b
--- /dev/null
+++ b/pkg/runtime/ast/drop_index.go
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ast
+
+import (
+ "strings"
+)
+
+var (
+ _ Statement = (*DropIndexStatement)(nil)
+ _ Restorer = (*DropIndexStatement)(nil)
+)
+
+type DropIndexStatement struct {
+ IfExists bool
+ IndexName string
+ Table TableName
+}
+
+func (d *DropIndexStatement) CntParams() int {
+ return 0
+}
+
+func (d *DropIndexStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
+ sb.WriteString("DROP INDEX ")
+ if d.IfExists {
+ sb.WriteString("IF EXISTS")
+ }
+ sb.WriteString(d.IndexName)
+ if len(d.Table) == 0 {
+ return nil
+ }
+ sb.WriteString(" ON ")
+ return d.Table.Restore(flag, sb, args)
+}
+
+func (d *DropIndexStatement) Validate() error {
+ return nil
+}
+
+func (d *DropIndexStatement) Mode() SQLType {
+ return SQLTypeDropIndex
+}
diff --git a/pkg/runtime/ast/drop_table.go b/pkg/runtime/ast/drop_table.go
index 1bf3b63cb..1b2778711 100644
--- a/pkg/runtime/ast/drop_table.go
+++ b/pkg/runtime/ast/drop_table.go
@@ -55,5 +55,5 @@ func (d DropTableStatement) Validate() error {
}
func (d DropTableStatement) Mode() SQLType {
- return SdropTable
+ return SQLTypeDropTable
}
diff --git a/pkg/runtime/ast/function.go b/pkg/runtime/ast/function.go
index 6c7402983..c297d8fd6 100644
--- a/pkg/runtime/ast/function.go
+++ b/pkg/runtime/ast/function.go
@@ -48,8 +48,10 @@ const (
Fpasswd
)
-type FunctionArgType uint8
-type FunctionType uint8
+type (
+ FunctionArgType uint8
+ FunctionType uint8
+)
func (f FunctionType) String() string {
switch f {
@@ -535,9 +537,7 @@ func (cd *ConvertDataType) Parse(s string) error {
return errors.Errorf("invalid cast string '%s'", s)
}
- var (
- name, first, second, suffix string
- )
+ var name, first, second, suffix string
for i := 1; i < len(keys); i++ {
sub := subs[i]
switch keys[i] {
diff --git a/pkg/runtime/ast/insert.go b/pkg/runtime/ast/insert.go
index e173a12f8..683a14af1 100644
--- a/pkg/runtime/ast/insert.go
+++ b/pkg/runtime/ast/insert.go
@@ -146,7 +146,7 @@ type ReplaceStatement struct {
}
func (r *ReplaceStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
- //TODO implement me
+ // TODO implement me
panic("implement me")
}
@@ -159,7 +159,7 @@ func (r *ReplaceStatement) Values() [][]ExpressionNode {
}
func (r *ReplaceStatement) Mode() SQLType {
- return Sreplace
+ return SQLTypeReplace
}
func (r *ReplaceStatement) CntParams() int {
@@ -343,7 +343,7 @@ func (is *InsertStatement) CntParams() int {
}
func (is *InsertStatement) Mode() SQLType {
- return Sinsert
+ return SQLTypeInsert
}
type ReplaceSelectStatement struct {
@@ -352,7 +352,7 @@ type ReplaceSelectStatement struct {
}
func (r *ReplaceSelectStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
- //TODO implement me
+ // TODO implement me
panic("implement me")
}
@@ -369,21 +369,84 @@ func (r *ReplaceSelectStatement) CntParams() int {
}
func (r *ReplaceSelectStatement) Mode() SQLType {
- return Sreplace
+ return SQLTypeReplace
}
type InsertSelectStatement struct {
*baseInsertStatement
- sel *SelectStatement
+ duplicatedUpdates []*UpdateElement
+ sel *SelectStatement
+ unionSel *UnionSelectStatement
}
func (is *InsertSelectStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
- //TODO implement me
- panic("implement me")
+ sb.WriteString("INSERT ")
+
+ // write priority
+ if is.IsLowPriority() {
+ sb.WriteString("LOW_PRIORITY ")
+ } else if is.IsHighPriority() {
+ sb.WriteString("HIGH_PRIORITY ")
+ } else if is.IsDelayed() {
+ sb.WriteString("DELAYED ")
+ }
+
+ if is.IsIgnore() {
+ sb.WriteString("IGNORE ")
+ }
+
+ sb.WriteString("INTO ")
+
+ if err := is.Table().Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+
+ if len(is.columns) > 0 {
+ sb.WriteByte('(')
+ WriteID(sb, is.columns[0])
+ for i := 1; i < len(is.columns); i++ {
+ sb.WriteString(", ")
+ WriteID(sb, is.columns[i])
+ }
+ sb.WriteString(") ")
+ } else {
+ sb.WriteByte(' ')
+ }
+
+ if is.sel != nil {
+ if err := is.sel.Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ }
+
+ if is.unionSel != nil {
+ if err := is.unionSel.Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ }
+
+ if len(is.duplicatedUpdates) > 0 {
+ sb.WriteString(" ON DUPLICATE KEY UPDATE ")
+
+ if err := is.duplicatedUpdates[0].Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ for i := 1; i < len(is.duplicatedUpdates); i++ {
+ sb.WriteString(", ")
+ if err := is.duplicatedUpdates[i].Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ }
+ }
+
+ return nil
}
func (is *InsertSelectStatement) Validate() error {
- return nil
+ if is.unionSel != nil {
+ return is.unionSel.Validate()
+ }
+ return is.sel.Validate()
}
func (is *InsertSelectStatement) Select() *SelectStatement {
@@ -391,9 +454,12 @@ func (is *InsertSelectStatement) Select() *SelectStatement {
}
func (is *InsertSelectStatement) CntParams() int {
+ if is.unionSel != nil {
+ return is.unionSel.CntParams()
+ }
return is.sel.CntParams()
}
func (is *InsertSelectStatement) Mode() SQLType {
- return Sinsert
+ return SQLTypeInsertSelect
}
diff --git a/pkg/runtime/ast/model.go b/pkg/runtime/ast/model.go
index f6bfcb5b9..cac08dcb0 100644
--- a/pkg/runtime/ast/model.go
+++ b/pkg/runtime/ast/model.go
@@ -310,10 +310,22 @@ func (ln *LimitNode) SetLimitVar() {
ln.flag |= flagLimitLimitVar
}
+func (ln *LimitNode) UnsetOffsetVar() {
+ ln.flag &= ^flagLimitOffsetVar
+}
+
+func (ln *LimitNode) UnsetLimitVar() {
+ ln.flag &= ^flagLimitLimitVar
+}
+
func (ln *LimitNode) SetHasOffset() {
ln.flag |= flagLimitHasOffset
}
+func (ln *LimitNode) UnsetHasOffset() {
+ ln.flag &= ^flagLimitHasOffset
+}
+
func (ln *LimitNode) HasOffset() bool {
return ln.flag&flagLimitHasOffset != 0
}
diff --git a/pkg/runtime/ast/predicate.go b/pkg/runtime/ast/predicate.go
index 64dd379a4..5a26ef993 100644
--- a/pkg/runtime/ast/predicate.go
+++ b/pkg/runtime/ast/predicate.go
@@ -74,8 +74,10 @@ func (l *LikePredicateNode) InTables(tables map[string]struct{}) error {
}
func (l *LikePredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
- if err := l.Left.Restore(flag, sb, args); err != nil {
- return errors.WithStack(err)
+ if l.Left != nil {
+ if err := l.Left.Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
}
if l.Not {
@@ -83,9 +85,10 @@ func (l *LikePredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args
} else {
sb.WriteString(" LIKE ")
}
-
- if err := l.Right.Restore(flag, sb, args); err != nil {
- return errors.WithStack(err)
+ if l.Right != nil {
+ if err := l.Right.Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
}
return nil
diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go
index c5109dcdb..4923a714b 100644
--- a/pkg/runtime/ast/proto.go
+++ b/pkg/runtime/ast/proto.go
@@ -22,15 +22,29 @@ import (
)
const (
- _ SQLType = iota
- Squery // QUERY
- Sdelete // DELETE
- Supdate // UPDATE
- Sinsert // INSERT
- Sreplace // REPLACE
- Struncate // TRUNCATE
- SdropTable // DROP TABLE
- SalterTable // ALTER TABLE
+ _ SQLType = iota
+ SQLTypeSelect // SELECT
+ SQLTypeDelete // DELETE
+ SQLTypeUpdate // UPDATE
+ SQLTypeInsert // INSERT
+ SQLTypeInsertSelect // INSERT SELECT
+ SQLTypeReplace // REPLACE
+ SQLTypeTruncate // TRUNCATE
+ SQLTypeDropTable // DROP TABLE
+ SQLTypeAlterTable // ALTER TABLE
+ SQLTypeDropIndex // DROP INDEX
+ SQLTypeShowDatabases // SHOW DATABASES
+ SQLTypeShowTables // SHOW TABLES
+ SQLTypeShowOpenTables // SHOW OPEN TABLES
+ SQLTypeShowIndex // SHOW INDEX
+ SQLTypeShowColumns // SHOW COLUMNS
+ SQLTypeShowCreate // SHOW CREATE
+ SQLTypeShowVariables // SHOW VARIABLES
+ SQLTypeShowTopology // SHOW TOPOLOGY
+ SQLTypeDescribe // DESCRIBE
+ SQLTypeUnion // UNION
+ SQLTypeDropTrigger // DROP TRIGGER
+ SQLTypeCreateIndex // CREATE INDEX
)
type RestoreFlag uint32
@@ -45,14 +59,27 @@ type Restorer interface {
}
var _sqlTypeNames = [...]string{
- Squery: "QUERY",
- Sdelete: "DELETE",
- Supdate: "UPDATE",
- Sinsert: "INSERT",
- Sreplace: "REPLACE",
- Struncate: "TRUNCATE",
- SdropTable: "DROP TABLE",
- SalterTable: "ALTER TABLE",
+ SQLTypeSelect: "SELECT",
+ SQLTypeDelete: "DELETE",
+ SQLTypeUpdate: "UPDATE",
+ SQLTypeInsert: "INSERT",
+ SQLTypeInsertSelect: "INSERT SELECT",
+ SQLTypeReplace: "REPLACE",
+ SQLTypeTruncate: "TRUNCATE",
+ SQLTypeDropTable: "DROP TABLE",
+ SQLTypeAlterTable: "ALTER TABLE",
+ SQLTypeDropIndex: "DROP INDEX",
+ SQLTypeShowDatabases: "SHOW DATABASES",
+ SQLTypeShowTables: "SHOW TABLES",
+ SQLTypeShowOpenTables: "SHOW OPEN TABLES",
+ SQLTypeShowIndex: "SHOW INDEX",
+ SQLTypeShowColumns: "SHOW COLUMNS",
+ SQLTypeShowCreate: "SHOW CREATE",
+ SQLTypeShowVariables: "SHOW VARIABLES",
+ SQLTypeDescribe: "DESCRIBE",
+ SQLTypeUnion: "UNION",
+ SQLTypeDropTrigger: "DROP TRIGGER",
+ SQLTypeCreateIndex: "CREATE INDEX",
}
// SQLType represents the type of SQL.
diff --git a/pkg/runtime/ast/select.go b/pkg/runtime/ast/select.go
index af0ebff1a..4366bffff 100644
--- a/pkg/runtime/ast/select.go
+++ b/pkg/runtime/ast/select.go
@@ -252,7 +252,7 @@ func (ss *SelectStatement) Validate() error {
}
func (ss *SelectStatement) Mode() SQLType {
- return Squery
+ return SQLTypeSelect
}
func (ss *SelectStatement) CntParams() int {
diff --git a/pkg/runtime/ast/show.go b/pkg/runtime/ast/show.go
index 85b3f33e1..fdde805a0 100644
--- a/pkg/runtime/ast/show.go
+++ b/pkg/runtime/ast/show.go
@@ -28,12 +28,20 @@ import (
var (
_ Statement = (*ShowTables)(nil)
+ _ Statement = (*ShowOpenTables)(nil)
_ Statement = (*ShowCreate)(nil)
_ Statement = (*ShowDatabases)(nil)
_ Statement = (*ShowColumns)(nil)
_ Statement = (*ShowIndex)(nil)
+ _ Statement = (*ShowTopology)(nil)
)
+type FromTable string
+
+func (f FromTable) String() string {
+ return string(f)
+}
+
type baseShow struct {
filter interface{} // ExpressionNode or string
}
@@ -41,11 +49,16 @@ type baseShow struct {
func (bs *baseShow) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
switch val := bs.filter.(type) {
case string:
- sb.WriteString(" LIKE ")
- sb.WriteByte('\'')
+ sb.WriteString(" IN ")
+ sb.WriteByte('`')
sb.WriteString(val)
- sb.WriteByte('\'')
+ sb.WriteByte('`')
return nil
+ case FromTable:
+ sb.WriteString(val.String())
+ return nil
+ case PredicateNode:
+ return val.Restore(flag, sb, nil)
case ExpressionNode:
sb.WriteString(" WHERE ")
return val.Restore(flag, sb, args)
@@ -58,6 +71,7 @@ func (bs *baseShow) Like() (string, bool) {
v, ok := bs.filter.(string)
return v, ok
}
+
func (bs *baseShow) Where() (ExpressionNode, bool) {
v, ok := bs.filter.(ExpressionNode)
return v, ok
@@ -67,14 +81,14 @@ func (bs *baseShow) CntParams() int {
return 0
}
-func (bs *baseShow) Mode() SQLType {
- return Squery
-}
-
type ShowDatabases struct {
*baseShow
}
+func (s ShowDatabases) Mode() SQLType {
+ return SQLTypeShowDatabases
+}
+
func (s ShowDatabases) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("SHOW DATABASES")
if err := s.baseShow.Restore(flag, sb, args); err != nil {
@@ -91,6 +105,10 @@ type ShowTables struct {
*baseShow
}
+func (s ShowTables) Mode() SQLType {
+ return SQLTypeShowTables
+}
+
func (s ShowTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("SHOW TABLES")
if err := s.baseShow.Restore(flag, sb, args); err != nil {
@@ -103,6 +121,42 @@ func (s ShowTables) Validate() error {
return nil
}
+type ShowTopology struct {
+ *baseShow
+}
+
+func (s ShowTopology) Mode() SQLType {
+ return SQLTypeShowTopology
+}
+
+func (s ShowTopology) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
+ return s.baseShow.Restore(flag, sb, args)
+}
+
+func (s ShowTopology) Validate() error {
+ return nil
+}
+
+type ShowOpenTables struct {
+ *baseShow
+}
+
+func (s ShowOpenTables) Mode() SQLType {
+ return SQLTypeShowOpenTables
+}
+
+func (s ShowOpenTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
+ sb.WriteString("SHOW OPEN TABLES")
+ if err := s.baseShow.Restore(flag, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ return nil
+}
+
+func (s ShowOpenTables) Validate() error {
+ return nil
+}
+
const (
_ ShowCreateType = iota
ShowCreateTypeTable
@@ -139,6 +193,13 @@ type ShowCreate struct {
tgt string
}
+func (s *ShowCreate) ResetTable(table string) *ShowCreate {
+ ret := new(ShowCreate)
+ *ret = *s
+ ret.tgt = table
+ return ret
+}
+
func (s *ShowCreate) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("SHOW CREATE ")
sb.WriteString(s.typ.String())
@@ -164,11 +225,11 @@ func (s *ShowCreate) CntParams() int {
}
func (s *ShowCreate) Mode() SQLType {
- return Squery
+ return SQLTypeShowCreate
}
type ShowIndex struct {
- tableName TableName
+ TableName TableName
where ExpressionNode
}
@@ -179,7 +240,7 @@ func (s *ShowIndex) Validate() error {
func (s *ShowIndex) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("SHOW INDEXES FROM ")
- _ = s.tableName.Restore(flag, sb, args)
+ _ = s.TableName.Restore(flag, sb, args)
if where, ok := s.Where(); ok {
sb.WriteString(" WHERE ")
@@ -191,10 +252,6 @@ func (s *ShowIndex) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int)
return nil
}
-func (s *ShowIndex) TableName() TableName {
- return s.tableName
-}
-
func (s *ShowIndex) Where() (ExpressionNode, bool) {
if s.where != nil {
return s.where, true
@@ -210,7 +267,7 @@ func (s *ShowIndex) CntParams() int {
}
func (s *ShowIndex) Mode() SQLType {
- return Squery
+ return SQLTypeShowIndex
}
type showColumnsFlag uint8
@@ -224,8 +281,9 @@ const (
type ShowColumns struct {
flag showColumnsFlag
- tableName TableName
+ TableName TableName
like sql.NullString
+ Column string
}
func (sh *ShowColumns) IsFull() bool {
@@ -247,7 +305,7 @@ func (sh *ShowColumns) Restore(flag RestoreFlag, sb *strings.Builder, args *[]in
}
sb.WriteString("COLUMNS FROM ")
- if err := sh.tableName.Restore(flag, sb, args); err != nil {
+ if err := sh.TableName.Restore(flag, sb, args); err != nil {
return errors.WithStack(err)
}
@@ -274,7 +332,7 @@ func (sh *ShowColumns) Validate() error {
}
func (sh *ShowColumns) Table() TableName {
- return sh.tableName
+ return sh.TableName
}
func (sh *ShowColumns) CntParams() int {
@@ -282,7 +340,7 @@ func (sh *ShowColumns) CntParams() int {
}
func (sh *ShowColumns) Mode() SQLType {
- return Squery
+ return SQLTypeShowColumns
}
func (sh *ShowColumns) Full() bool {
@@ -342,5 +400,5 @@ func (s *ShowVariables) CntParams() int {
}
func (s *ShowVariables) Mode() SQLType {
- return Squery
+ return SQLTypeShowVariables
}
diff --git a/pkg/runtime/ast/table_source.go b/pkg/runtime/ast/table_source.go
index 48ed4c121..e8d9351dd 100644
--- a/pkg/runtime/ast/table_source.go
+++ b/pkg/runtime/ast/table_source.go
@@ -37,9 +37,7 @@ type TableSourceNode struct {
func (t *TableSourceNode) ResetTableName(newTableName string) bool {
switch source := t.source.(type) {
case TableName:
- var (
- newSource = make(TableName, len(source))
- )
+ newSource := make(TableName, len(source))
copy(newSource, source)
newSource[len(newSource)-1] = newTableName
t.source = newSource
diff --git a/pkg/runtime/ast/trigger.go b/pkg/runtime/ast/trigger.go
new file mode 100644
index 000000000..31c5b40b1
--- /dev/null
+++ b/pkg/runtime/ast/trigger.go
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ast
+
+import (
+ "strings"
+)
+
+var (
+ _ Statement = (*DropTriggerStatement)(nil)
+)
+
+type DropTriggerStatement struct {
+ IfExists bool
+ Table TableName
+}
+
+func (d DropTriggerStatement) CntParams() int {
+ return 0
+}
+
+func (d DropTriggerStatement) Mode() SQLType {
+ return SQLTypeDropTrigger
+}
+
+func (d DropTriggerStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
+ sb.WriteString("DROP TRIGGER ")
+ if d.IfExists {
+ sb.WriteString("IF EXISTS ")
+ }
+ return d.Table.Restore(flag, sb, args)
+}
+
+func (d DropTriggerStatement) Validate() error { return nil }
diff --git a/pkg/runtime/ast/truncate.go b/pkg/runtime/ast/truncate.go
index aa5ea87c9..86db2114c 100644
--- a/pkg/runtime/ast/truncate.go
+++ b/pkg/runtime/ast/truncate.go
@@ -25,9 +25,7 @@ import (
"github.com/pkg/errors"
)
-var (
- _ Statement = (*TruncateStatement)(nil)
-)
+var _ Statement = (*TruncateStatement)(nil)
// TruncateStatement represents mysql describe statement. see https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html
type TruncateStatement struct {
@@ -64,5 +62,5 @@ func (stmt *TruncateStatement) CntParams() int {
}
func (stmt *TruncateStatement) Mode() SQLType {
- return Struncate
+ return SQLTypeTruncate
}
diff --git a/pkg/runtime/ast/union.go b/pkg/runtime/ast/union.go
index 57ce0cbc0..61a7b07f3 100644
--- a/pkg/runtime/ast/union.go
+++ b/pkg/runtime/ast/union.go
@@ -112,7 +112,7 @@ func (u *UnionSelectStatement) OrderBy() OrderByNode {
}
func (u *UnionSelectStatement) Mode() SQLType {
- return Squery
+ return SQLTypeUnion
}
func (u *UnionSelectStatement) First() *SelectStatement {
diff --git a/pkg/runtime/ast/update.go b/pkg/runtime/ast/update.go
index 93c855f3e..26a2038b7 100644
--- a/pkg/runtime/ast/update.go
+++ b/pkg/runtime/ast/update.go
@@ -143,5 +143,5 @@ func (u *UpdateStatement) CntParams() int {
}
func (u *UpdateStatement) Mode() SQLType {
- return Supdate
+ return SQLTypeUpdate
}
diff --git a/pkg/runtime/context/context.go b/pkg/runtime/context/context.go
index bd7b11284..67894172d 100644
--- a/pkg/runtime/context/context.go
+++ b/pkg/runtime/context/context.go
@@ -23,7 +23,7 @@ import (
import (
"github.com/arana-db/arana/pkg/proto"
- "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/proto/hint"
)
const (
@@ -34,12 +34,13 @@ const (
type (
keyFlag struct{}
- keyRule struct{}
keySequence struct{}
keySql struct{}
keyNodeLabel struct{}
keySchema struct{}
keyDefaultDBGroup struct{}
+ keyTenant struct{}
+ keyHints struct{}
)
type cFlag uint8
@@ -60,14 +61,9 @@ func WithSQL(ctx context.Context, sql string) context.Context {
return context.WithValue(ctx, keySql{}, sql)
}
-// WithDBGroup binds the default db.
-func WithDBGroup(ctx context.Context, group string) context.Context {
- return context.WithValue(ctx, keyDefaultDBGroup{}, group)
-}
-
-// WithRule binds a rule.
-func WithRule(ctx context.Context, ru *rule.Rule) context.Context {
- return context.WithValue(ctx, keyRule{}, ru)
+// WithTenant binds the tenant.
+func WithTenant(ctx context.Context, tenant string) context.Context {
+ return context.WithValue(ctx, keyTenant{}, tenant)
}
func WithSchema(ctx context.Context, data string) context.Context {
@@ -89,6 +85,11 @@ func WithRead(ctx context.Context) context.Context {
return context.WithValue(ctx, keyFlag{}, _flagRead|getFlag(ctx))
}
+// WithHints binds the hints.
+func WithHints(ctx context.Context, hints []*hint.Hint) context.Context {
+ return context.WithValue(ctx, keyHints{}, hints)
+}
+
// Sequencer extracts the sequencer.
func Sequencer(ctx context.Context) proto.Sequencer {
s, ok := ctx.Value(keySequence{}).(proto.Sequencer)
@@ -98,24 +99,15 @@ func Sequencer(ctx context.Context) proto.Sequencer {
return s
}
-// DBGroup extracts the db.
-func DBGroup(ctx context.Context) string {
- db, ok := ctx.Value(keyDefaultDBGroup{}).(string)
+// Tenant extracts the tenant.
+func Tenant(ctx context.Context) string {
+ db, ok := ctx.Value(keyTenant{}).(string)
if !ok {
return ""
}
return db
}
-// Rule extracts the rule.
-func Rule(ctx context.Context) *rule.Rule {
- ru, ok := ctx.Value(keyRule{}).(*rule.Rule)
- if !ok {
- return nil
- }
- return ru
-}
-
// IsRead returns true if this is a read operation
func IsRead(ctx context.Context) bool {
return hasFlag(ctx, _flagRead)
@@ -154,6 +146,15 @@ func NodeLabel(ctx context.Context) string {
return ""
}
+// Hints extracts the hints.
+func Hints(ctx context.Context) []*hint.Hint {
+ hints, ok := ctx.Value(keyHints{}).([]*hint.Hint)
+ if !ok {
+ return nil
+ }
+ return hints
+}
+
func hasFlag(ctx context.Context, flag cFlag) bool {
return getFlag(ctx)&flag != 0
}
diff --git a/pkg/runtime/function/function_test.go b/pkg/runtime/function/function_test.go
index 64ff195da..b17326a98 100644
--- a/pkg/runtime/function/function_test.go
+++ b/pkg/runtime/function/function_test.go
@@ -104,7 +104,7 @@ func BenchmarkEval(b *testing.B) {
}
func mustGetMathAtom() *ast.MathExpressionAtom {
- stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)")
+ _, stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)")
if err != nil {
panic(err.Error())
}
diff --git a/pkg/runtime/function/vm.go b/pkg/runtime/function/vm.go
index 847a401e4..7fecaf34d 100644
--- a/pkg/runtime/function/vm.go
+++ b/pkg/runtime/function/vm.go
@@ -52,9 +52,7 @@ var scripts embed.FS
var _decimalRegex = regexp.MustCompile(`^(?P[+\-])?(?P[0-9])+(?P\.[0-9]+)$`)
-var (
- freeList = make(chan *VM, 16)
-)
+var freeList = make(chan *VM, 16)
const FuncUnary = "__unary"
diff --git a/pkg/runtime/function/vm_test.go b/pkg/runtime/function/vm_test.go
index 3b41e65e4..ff9388e9f 100644
--- a/pkg/runtime/function/vm_test.go
+++ b/pkg/runtime/function/vm_test.go
@@ -58,7 +58,6 @@ func TestScripts_Time(t *testing.T) {
v, err = testFunc("DATE_FORMAT", now, "%Y-%m-%d")
assert.NoError(t, err)
t.Log("DATE_FORMAT:", v)
-
}
func TestTime_Month(t *testing.T) {
diff --git a/pkg/runtime/namespace/namespace.go b/pkg/runtime/namespace/namespace.go
index fe0b08aaa..ceca48015 100644
--- a/pkg/runtime/namespace/namespace.go
+++ b/pkg/runtime/namespace/namespace.go
@@ -76,8 +76,7 @@ type (
name string // the name of Namespace
- rule atomic.Value // *rule.Rule
- optimizer proto.Optimizer
+ rule atomic.Value // *rule.Rule
// datasource map, eg: employee_0001 -> [mysql-a,mysql-b,mysql-c], ... employee_0007 -> [mysql-x,mysql-y,mysql-z]
dss atomic.Value // map[string][]proto.DB
@@ -91,12 +90,11 @@ type (
)
// New creates a Namespace.
-func New(name string, optimizer proto.Optimizer, commands ...Command) *Namespace {
+func New(name string, commands ...Command) *Namespace {
ns := &Namespace{
- name: name,
- optimizer: optimizer,
- cmds: make(chan Command, 1),
- done: make(chan struct{}),
+ name: name,
+ cmds: make(chan Command, 1),
+ done: make(chan struct{}),
}
ns.dss.Store(make(map[string][]proto.DB)) // init empty map
ns.rule.Store(&rule.Rule{}) // init empty rule
@@ -153,7 +151,6 @@ func (ns *Namespace) DB(ctx context.Context, group string) proto.DB {
for _, db := range exist {
wrList = append(wrList, int(db.Weight().R))
}
-
} else if rcontext.IsWrite(ctx) {
for _, db := range exist {
wrList = append(wrList, int(db.Weight().W))
@@ -166,9 +163,54 @@ func (ns *Namespace) DB(ctx context.Context, group string) proto.DB {
return exist[target]
}
-// Optimizer returns the optimizer.
-func (ns *Namespace) Optimizer() proto.Optimizer {
- return ns.optimizer
+// DBMaster returns a master DB, returns nil if nothing selected.
+func (ns *Namespace) DBMaster(_ context.Context, group string) proto.DB {
+ // use weight manager to select datasource
+ dss := ns.dss.Load().(map[string][]proto.DB)
+ exist, ok := dss[group]
+ if !ok {
+ return nil
+ }
+ // master weight w>0 && r>0
+ for _, db := range exist {
+ if db.Weight().W > 0 && db.Weight().R > 0 {
+ return db
+ }
+ }
+ return nil
+}
+
+// DBSlave returns a slave DB, returns nil if nothing selected.
+func (ns *Namespace) DBSlave(_ context.Context, group string) proto.DB {
+ // use weight manager to select datasource
+ dss := ns.dss.Load().(map[string][]proto.DB)
+ exist, ok := dss[group]
+ if !ok {
+ return nil
+ }
+ var (
+ target = 0
+ wrList = make([]int, 0, len(exist))
+ readDBList = make([]proto.DB, 0, len(exist))
+ )
+ // slave weight w==0 && r>=0
+ for _, db := range exist {
+ if db.Weight().W != 0 {
+ continue
+ }
+ // r==0 has high priority
+ if db.Weight().R == 0 {
+ return db
+ }
+ if db.Weight().R > 0 {
+ wrList = append(wrList, int(db.Weight().R))
+ readDBList = append(readDBList, db)
+ }
+ }
+ if len(wrList) != 0 {
+ target = selector.NewWeightRandomSelector(wrList).GetDataSourceNo()
+ }
+ return readDBList[target]
}
// Rule returns the sharding rule.
diff --git a/pkg/runtime/namespace/namespace_test.go b/pkg/runtime/namespace/namespace_test.go
index d579c5ecd..5fb72f347 100644
--- a/pkg/runtime/namespace/namespace_test.go
+++ b/pkg/runtime/namespace/namespace_test.go
@@ -44,11 +44,7 @@ func TestRegister(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- opt := testdata.NewMockOptimizer(ctrl)
-
- const (
- name = "employees"
- )
+ const name = "employees"
getDB := func(i int) proto.DB {
db := testdata.NewMockDB(ctrl)
@@ -58,7 +54,7 @@ func TestRegister(t *testing.T) {
return db
}
- err := Register(New(name, opt, UpsertDB(getGroup(0), getDB(1))))
+ err := Register(New(name, UpsertDB(getGroup(0), getDB(1))))
assert.NoError(t, err, "should register namespace ok")
defer func() {
@@ -89,8 +85,6 @@ func TestGetDBByWeight(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- opt := testdata.NewMockOptimizer(ctrl)
-
const (
name = "account"
)
@@ -104,7 +98,7 @@ func TestGetDBByWeight(t *testing.T) {
}
// when doing read operation, db 3 is the max
// when doing write operation, db 2 is the max
- err := Register(New(name, opt,
+ err := Register(New(name,
UpsertDB(getGroup(0), getDB(1, 9, 1)),
UpsertDB(getGroup(0), getDB(2, 10, 5)),
UpsertDB(getGroup(0), getDB(3, 3, 10)),
diff --git a/pkg/runtime/optimize/dal/show_columns.go b/pkg/runtime/optimize/dal/show_columns.go
new file mode 100644
index 000000000..ec39d30f2
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_columns.go
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowColumns, optimizeShowColumns)
+}
+
+func optimizeShowColumns(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.ShowColumns)
+
+ vts := o.Rule.VTables()
+ vtName := []string(stmt.TableName)[0]
+ ret := &dal.ShowColumnsPlan{Stmt: stmt}
+ ret.BindArgs(o.Args)
+
+ if vTable, ok := vts[vtName]; ok {
+ _, tblName, _ := vTable.Topology().Smallest()
+ ret.Table = tblName
+ }
+
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_create.go b/pkg/runtime/optimize/dal/show_create.go
new file mode 100644
index 000000000..c926187a7
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_create.go
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowCreate, optimizeShowCreate)
+}
+
+func optimizeShowCreate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.ShowCreate)
+
+ if stmt.Type() != ast.ShowCreateTypeTable {
+ return nil, errors.Errorf("not support SHOW CREATE %s", stmt.Type().String())
+ }
+
+ var (
+ ret = dal.NewShowCreatePlan(stmt)
+ table = stmt.Target()
+ )
+ ret.BindArgs(o.Args)
+
+ if vt, ok := o.Rule.VTable(table); ok {
+ // sharding
+ topology := vt.Topology()
+ if d, t, ok := topology.Render(0, 0); ok {
+ ret.Database = d
+ ret.Table = t
+ } else {
+ return nil, errors.Errorf("failed to render table:%s ", table)
+ }
+ } else {
+ ret.Table = table
+ }
+
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_databases.go b/pkg/runtime/optimize/dal/show_databases.go
new file mode 100644
index 000000000..cc59628da
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_databases.go
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowDatabases, optimizeShowDatabases)
+}
+
+func optimizeShowDatabases(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ ret := &dal.ShowDatabasesPlan{Stmt: o.Stmt.(*ast.ShowDatabases)}
+ ret.BindArgs(o.Args)
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_index.go b/pkg/runtime/optimize/dal/show_index.go
new file mode 100644
index 000000000..e75d93d1a
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_index.go
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowIndex, optimizeShowIndex)
+}
+
+func optimizeShowIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.ShowIndex)
+
+ ret := &dal.ShowIndexPlan{Stmt: stmt}
+ ret.BindArgs(o.Args)
+
+ vt, ok := o.Rule.VTable(stmt.TableName.Suffix())
+ if !ok {
+ return ret, nil
+ }
+
+ shards := rule.DatabaseTables{}
+
+ topology := vt.Topology()
+ if d, t, ok := topology.Render(0, 0); ok {
+ shards[d] = append(shards[d], t)
+ }
+ ret.Shards = shards
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_open_tables.go b/pkg/runtime/optimize/dal/show_open_tables.go
new file mode 100644
index 000000000..e60e1c226
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_open_tables.go
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/namespace"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+ "github.com/arana-db/arana/pkg/security"
+ "github.com/arana-db/arana/pkg/transformer"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowOpenTables, optimizeShowOpenTables)
+}
+
+func optimizeShowOpenTables(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ var invertedIndex map[string]string
+ for logicalTable, v := range o.Rule.VTables() {
+ t := v.Topology()
+ t.Each(func(x, y int) bool {
+ if _, phyTable, ok := t.Render(x, y); ok {
+ if invertedIndex == nil {
+ invertedIndex = make(map[string]string)
+ }
+ invertedIndex[phyTable] = logicalTable
+ }
+ return true
+ })
+ }
+
+ stmt := o.Stmt.(*ast.ShowOpenTables)
+
+ clusters := security.DefaultTenantManager().GetClusters(rcontext.Tenant(ctx))
+ plans := make([]proto.Plan, 0, len(clusters))
+ for _, cluster := range clusters {
+ ns := namespace.Load(cluster)
+ // 配置里原子库 都需要执行一次
+ groups := ns.DBGroups()
+ for i := 0; i < len(groups); i++ {
+ ret := dal.NewShowOpenTablesPlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetInvertedShards(invertedIndex)
+ ret.SetDatabase(groups[i])
+ plans = append(plans, ret)
+ }
+ }
+
+ unionPlan := &dml.UnionPlan{
+ Plans: plans,
+ }
+
+ aggregate := &dml.AggregatePlan{
+ Plan: unionPlan,
+ Combiner: transformer.NewCombinerManager(),
+ AggrLoader: transformer.LoadAggrs(nil),
+ }
+
+ return aggregate, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_tables.go b/pkg/runtime/optimize/dal/show_tables.go
new file mode 100644
index 000000000..778a7f784
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_tables.go
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowTables, optimizeShowTables)
+}
+
+func optimizeShowTables(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.ShowTables)
+ var invertedIndex map[string]string
+ for logicalTable, v := range o.Rule.VTables() {
+ t := v.Topology()
+ t.Each(func(x, y int) bool {
+ if _, phyTable, ok := t.Render(x, y); ok {
+ if invertedIndex == nil {
+ invertedIndex = make(map[string]string)
+ }
+ invertedIndex[phyTable] = logicalTable
+ }
+ return true
+ })
+ }
+
+ ret := dal.NewShowTablesPlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetInvertedShards(invertedIndex)
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_topology.go b/pkg/runtime/optimize/dal/show_topology.go
new file mode 100644
index 000000000..a03d0d5d6
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_topology.go
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowTopology, optimizeShowTopology)
+}
+
+func optimizeShowTopology(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ rule := o.Rule
+ stmt := o.Stmt.(*ast.ShowTopology)
+ ret := dal.NewShowTopologyPlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetRule(rule)
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dal/show_variables.go b/pkg/runtime/optimize/dal/show_variables.go
new file mode 100644
index 000000000..5cd6aedbf
--- /dev/null
+++ b/pkg/runtime/optimize/dal/show_variables.go
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dal"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeShowVariables, optimizeShowVariables)
+}
+
+func optimizeShowVariables(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ ret := dal.NewShowVariablesPlan(o.Stmt.(*ast.ShowVariables))
+ ret.BindArgs(o.Args)
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/ddl/alter_table.go b/pkg/runtime/optimize/ddl/alter_table.go
new file mode 100644
index 000000000..9891e9fe7
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/alter_table.go
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeAlterTable, optimizeAlterTable)
+}
+
+func optimizeAlterTable(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ var (
+ stmt = o.Stmt.(*ast.AlterTableStatement)
+ ret = ddl.NewAlterTablePlan(stmt)
+ table = stmt.Table
+ vt *rule.VTable
+ ok bool
+ )
+ ret.BindArgs(o.Args)
+
+ // non-sharding update
+ if vt, ok = o.Rule.VTable(table.Suffix()); !ok {
+ return ret, nil
+ }
+
+ //TODO alter table table or column to new name , should update sharding info
+
+ // exit if full-scan is disabled
+ if !vt.AllowFullScan() {
+ return nil, optimize.ErrDenyFullScan
+ }
+
+ // sharding
+ ret.Shards = vt.Topology().Enumerate()
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/ddl/create_index.go b/pkg/runtime/optimize/ddl/create_index.go
new file mode 100644
index 000000000..f29a17db8
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/create_index.go
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeCreateIndex, optimizeCreateIndex)
+}
+
+func optimizeCreateIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.CreateIndexStatement)
+ ret := ddl.NewCreateIndexPlan(stmt)
+ vt, ok := o.Rule.VTable(stmt.Table.Suffix())
+
+ // table shard
+ if !ok {
+ return ret, nil
+ }
+
+ ret.SetShard(vt.Topology().Enumerate())
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/ddl/drop_index.go b/pkg/runtime/optimize/ddl/drop_index.go
new file mode 100644
index 000000000..e5dd7d0ce
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/drop_index.go
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeDropIndex, optimizeDropIndex)
+}
+
+func optimizeDropIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.DropIndexStatement)
+ //table shard
+
+ shard, err := o.ComputeShards(stmt.Table, nil, o.Args)
+ if err != nil {
+ return nil, err
+ }
+ if len(shard) == 0 {
+ return plan.Transparent(stmt, o.Args), nil
+ }
+
+ shardPlan := ddl.NewDropIndexPlan(stmt)
+ shardPlan.SetShard(shard)
+ shardPlan.BindArgs(o.Args)
+ return shardPlan, nil
+}
diff --git a/pkg/runtime/optimize/ddl/drop_table.go b/pkg/runtime/optimize/ddl/drop_table.go
new file mode 100644
index 000000000..1d179b8d4
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/drop_table.go
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeDropTable, optimizeDropTable)
+}
+
+func optimizeDropTable(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.DropTableStatement)
+ //table shard
+ var shards []rule.DatabaseTables
+ //tables not shard
+ noShardStmt := ast.NewDropTableStatement()
+ for _, table := range stmt.Tables {
+ shard, err := o.ComputeShards(*table, nil, o.Args)
+ if err != nil {
+ return nil, err
+ }
+ if shard == nil {
+ noShardStmt.Tables = append(noShardStmt.Tables, table)
+ continue
+ }
+ shards = append(shards, shard)
+ }
+
+ shardPlan := ddl.NewDropTablePlan(stmt)
+ shardPlan.BindArgs(o.Args)
+ shardPlan.SetShards(shards)
+
+ if len(noShardStmt.Tables) == 0 {
+ return shardPlan, nil
+ }
+
+ noShardPlan := plan.Transparent(noShardStmt, o.Args)
+
+ return &dml.UnionPlan{
+ Plans: []proto.Plan{
+ noShardPlan, shardPlan,
+ },
+ }, nil
+}
diff --git a/pkg/runtime/optimize/ddl/drop_trigger.go b/pkg/runtime/optimize/ddl/drop_trigger.go
new file mode 100644
index 000000000..74e1d1de4
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/drop_trigger.go
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeDropTrigger, optimizeTrigger)
+}
+
+func optimizeTrigger(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ shards := rule.DatabaseTables{}
+ for _, table := range o.Rule.VTables() {
+ shards = table.Topology().Enumerate()
+ break
+ }
+
+ ret := &ddl.DropTriggerPlan{Stmt: o.Stmt.(*ast.DropTriggerStatement), Shards: shards}
+ ret.BindArgs(o.Args)
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/ddl/truncate.go b/pkg/runtime/optimize/ddl/truncate.go
new file mode 100644
index 000000000..6e3cbbc77
--- /dev/null
+++ b/pkg/runtime/optimize/ddl/truncate.go
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/runtime/plan/ddl"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeTruncate, optimizeTruncate)
+}
+
+func optimizeTruncate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.TruncateStatement)
+ shards, err := o.ComputeShards(stmt.Table, nil, o.Args)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement")
+ }
+
+ if shards == nil {
+ return plan.Transparent(stmt, o.Args), nil
+ }
+
+ ret := ddl.NewTruncatePlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetShards(shards)
+
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dml/delete.go b/pkg/runtime/optimize/dml/delete.go
new file mode 100644
index 000000000..c9ee9c2ce
--- /dev/null
+++ b/pkg/runtime/optimize/dml/delete.go
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeDelete, optimizeDelete)
+}
+
+func optimizeDelete(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ var (
+ stmt = o.Stmt.(*ast.DeleteStatement)
+ )
+
+ shards, err := o.ComputeShards(stmt.Table, stmt.Where, o.Args)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to optimize DELETE statement")
+ }
+
+ // TODO: delete from a child sharding-table directly
+
+ if shards == nil {
+ return plan.Transparent(stmt, o.Args), nil
+ }
+
+ ret := dml.NewSimpleDeletePlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetShards(shards)
+
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/dml/insert.go b/pkg/runtime/optimize/dml/insert.go
new file mode 100644
index 000000000..3684dcd67
--- /dev/null
+++ b/pkg/runtime/optimize/dml/insert.go
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/cmp"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeInsert, optimizeInsert)
+ optimize.Register(ast.SQLTypeInsertSelect, optimizeInsertSelect)
+}
+
+func optimizeInsert(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ ret := dml.NewSimpleInsertPlan()
+ ret.BindArgs(o.Args)
+
+ var (
+ stmt = o.Stmt.(*ast.InsertStatement)
+ vt *rule.VTable
+ ok bool
+ )
+
+ if vt, ok = o.Rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table
+ ret.Put("", stmt)
+ return ret, nil
+ }
+
+ // TODO: handle multiple shard keys.
+
+ bingo := -1
+ // check existing shard columns
+ for i, col := range stmt.Columns() {
+ if _, _, ok = vt.GetShardMetadata(col); ok {
+ bingo = i
+ break
+ }
+ }
+
+ if bingo < 0 {
+ return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert")
+ }
+
+ //check on duplicated key update
+ for _, upd := range stmt.DuplicatedUpdates() {
+ if upd.Column.Suffix() == stmt.Columns()[bingo] {
+ return nil, errors.New("do not support update sharding key")
+ }
+ }
+
+ var (
+ sharder = (*optimize.Sharder)(o.Rule)
+ left = ast.ColumnNameExpressionAtom(make([]string, 1))
+ filter = &ast.PredicateExpressionNode{
+ P: &ast.BinaryComparisonPredicateNode{
+ Left: &ast.AtomPredicateNode{
+ A: left,
+ },
+ Op: cmp.Ceq,
+ },
+ }
+ slots = make(map[string]map[string][]int) // (db,table,valuesIndex)
+ )
+
+ // reset filter
+ resetFilter := func(column string, value ast.ExpressionNode) {
+ left[0] = column
+ filter.P.(*ast.BinaryComparisonPredicateNode).Right = value.(*ast.PredicateExpressionNode).P
+ }
+
+ for i, values := range stmt.Values() {
+ value := values[bingo]
+ resetFilter(stmt.Columns()[bingo], value)
+
+ shards, _, err := sharder.Shard(stmt.Table(), filter, o.Args...)
+
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ if shards.Len() != 1 {
+ return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert")
+ }
+
+ var (
+ db string
+ table string
+ )
+
+ for k, v := range shards {
+ db = k
+ table = v[0]
+ break
+ }
+
+ if _, ok = slots[db]; !ok {
+ slots[db] = make(map[string][]int)
+ }
+ slots[db][table] = append(slots[db][table], i)
+ }
+
+ _, tb0, _ := vt.Topology().Smallest()
+
+ for db, slot := range slots {
+ for table, indexes := range slot {
+ // clone insert stmt without values
+ newborn := ast.NewInsertStatement(ast.TableName{table}, stmt.Columns())
+ newborn.SetFlag(stmt.Flag())
+ newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates())
+
+ // collect values with same table
+ values := make([][]ast.ExpressionNode, 0, len(indexes))
+ for _, i := range indexes {
+ values = append(values, stmt.Values()[i])
+ }
+ newborn.SetValues(values)
+
+ rewriteInsertStatement(ctx, newborn, tb0)
+ ret.Put(db, newborn)
+ }
+ }
+
+ return ret, nil
+}
+
+func optimizeInsertSelect(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.InsertSelectStatement)
+
+ ret := dml.NewInsertSelectPlan()
+
+ ret.BindArgs(o.Args)
+
+ if _, ok := o.Rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table
+ ret.Batch[""] = stmt
+ return ret, nil
+ }
+
+ // TODO: handle shard keys.
+
+ return nil, errors.New("not support insert-select into sharding table")
+}
+
+func rewriteInsertStatement(ctx context.Context, stmt *ast.InsertStatement, tb string) error {
+ metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb})
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ metadata := metadatas[tb]
+ if metadata == nil || len(metadata.ColumnNames) == 0 {
+ return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb)
+ }
+
+ if len(metadata.ColumnNames) == len(stmt.Columns()) {
+ // User had explicitly specified every value
+ return nil
+ }
+ columnsMetadata := metadata.Columns
+
+ for _, colName := range stmt.Columns() {
+ if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated {
+ // User had explicitly specified auto-generated primary key column
+ return nil
+ }
+ }
+
+ pkColName := ""
+ for name, column := range columnsMetadata {
+ if column.PrimaryKey && column.Generated {
+ pkColName = name
+ break
+ }
+ }
+ if len(pkColName) < 1 {
+ // There's no auto-generated primary key column
+ return nil
+ }
+
+ // TODO rewrite columns and add distributed primary key
+ //stmt.SetColumns(append(stmt.Columns(), pkColName))
+ // append value of distributed primary key
+ //newValues := stmt.Values()
+ //for _, newValue := range newValues {
+ // newValue = append(newValue, )
+ //}
+ return nil
+}
diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go
new file mode 100644
index 000000000..d9cf57eb9
--- /dev/null
+++ b/pkg/runtime/optimize/dml/select.go
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/merge/aggregator"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+ "github.com/arana-db/arana/pkg/transformer"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+const (
+ _bypass uint32 = 1 << iota
+ _supported
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeSelect, optimizeSelect)
+}
+
+func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.SelectStatement)
+
+ // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be
+ // `select * from student offset 0 limit 100+5`
+ originOffset, newLimit := overwriteLimit(stmt, &o.Args)
+ if stmt.HasJoin() {
+ return optimizeJoin(o, stmt)
+ }
+ flag := getSelectFlag(o.Rule, stmt)
+ if flag&_supported == 0 {
+ return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx))
+ }
+
+ if flag&_bypass != 0 {
+ if len(stmt.From) > 0 {
+ err := rewriteSelectStatement(ctx, stmt, stmt.From[0].TableName().Suffix())
+ if err != nil {
+ return nil, err
+ }
+ }
+ ret := &dml.SimpleQueryPlan{Stmt: stmt}
+ ret.BindArgs(o.Args)
+ return ret, nil
+ }
+
+ var (
+ shards rule.DatabaseTables
+ fullScan bool
+ err error
+ vt = o.Rule.MustVTable(stmt.From[0].TableName().Suffix())
+ )
+
+ if shards, fullScan, err = (*optimize.Sharder)(o.Rule).Shard(stmt.From[0].TableName(), stmt.Where, o.Args...); err != nil {
+ return nil, errors.Wrap(err, "calculate shards failed")
+ }
+
+ log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan)
+
+ // return error if full-scan is disabled
+ if fullScan && !vt.AllowFullScan() {
+ return nil, errors.WithStack(optimize.ErrDenyFullScan)
+ }
+
+ toSingle := func(db, tbl string) (proto.Plan, error) {
+ _, tb0, _ := vt.Topology().Smallest()
+ if err := rewriteSelectStatement(ctx, stmt, tb0); err != nil {
+ return nil, err
+ }
+ ret := &dml.SimpleQueryPlan{
+ Stmt: stmt,
+ Database: db,
+ Tables: []string{tbl},
+ }
+ ret.BindArgs(o.Args)
+
+ return ret, nil
+ }
+
+ // Go through first table if no shards matched.
+ // For example:
+ // SELECT ... FROM xxx WHERE a > 8 and a < 4
+ if shards.IsEmpty() {
+ var (
+ db0, tbl0 string
+ ok bool
+ )
+ if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok {
+ return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix())
+ }
+
+ return toSingle(db0, tbl0)
+ }
+
+ // Handle single shard
+ if shards.Len() == 1 {
+ var db, tbl string
+ for k, v := range shards {
+ db = k
+ tbl = v[0]
+ }
+ return toSingle(db, tbl)
+ }
+
+ // Handle multiple shards
+
+ if shards.IsFullScan() { // expand all shards if all shards matched
+ shards = vt.Topology().Enumerate()
+ }
+
+ plans := make([]proto.Plan, 0, len(shards))
+ for k, v := range shards {
+ next := &dml.SimpleQueryPlan{
+ Database: k,
+ Tables: v,
+ Stmt: stmt,
+ }
+ next.BindArgs(o.Args)
+ plans = append(plans, next)
+ }
+
+ if len(plans) > 0 {
+ _, tb, _ := vt.Topology().Smallest()
+ if err = rewriteSelectStatement(ctx, stmt, tb); err != nil {
+ return nil, errors.WithStack(err)
+ }
+ }
+
+ var tmpPlan proto.Plan
+ tmpPlan = &dml.UnionPlan{
+ Plans: plans,
+ }
+
+ if stmt.Limit != nil {
+ tmpPlan = &dml.LimitPlan{
+ ParentPlan: tmpPlan,
+ OriginOffset: originOffset,
+ OverwriteLimit: newLimit,
+ }
+ }
+
+ orderByItems := optimizeOrderBy(stmt)
+
+ if stmt.OrderBy != nil {
+ tmpPlan = &dml.OrderPlan{
+ ParentPlan: tmpPlan,
+ OrderByItems: orderByItems,
+ }
+ }
+
+ convertOrderByItems := func(origins []*ast.OrderByItem) []dataset.OrderByItem {
+ var result = make([]dataset.OrderByItem, 0, len(origins))
+ for _, origin := range origins {
+ var columnName string
+ if cn, ok := origin.Expr.(ast.ColumnNameExpressionAtom); ok {
+ columnName = cn.Suffix()
+ }
+ result = append(result, dataset.OrderByItem{
+ Column: columnName,
+ Desc: origin.Desc,
+ })
+ }
+ return result
+ }
+ if stmt.GroupBy != nil {
+ return &dml.GroupPlan{
+ Plan: tmpPlan,
+ AggItems: aggregator.LoadAggs(stmt.Select),
+ OrderByItems: convertOrderByItems(stmt.OrderBy),
+ }, nil
+ } else {
+ // TODO: refactor groupby/orderby/aggregate plan to a unified plan
+ return &dml.AggregatePlan{
+ Plan: tmpPlan,
+ Combiner: transformer.NewCombinerManager(),
+ AggrLoader: transformer.LoadAggrs(stmt.Select),
+ }, nil
+ }
+}
+
+//optimizeJoin ony support a join b in one db
+func optimizeJoin(o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) {
+ join := stmt.From[0].Source().(*ast.JoinNode)
+
+ compute := func(tableSource *ast.TableSourceNode) (database, alias string, shardList []string, err error) {
+ table := tableSource.TableName()
+ if table == nil {
+ err = errors.New("must table, not statement or join node")
+ return
+ }
+ alias = tableSource.Alias()
+ database = table.Prefix()
+
+ shards, err := o.ComputeShards(table, nil, o.Args)
+ if err != nil {
+ return
+ }
+ //table no shard
+ if shards == nil {
+ shardList = append(shardList, table.Suffix())
+ return
+ }
+ //table shard more than one db
+ if len(shards) > 1 {
+ err = errors.New("not support more than one db")
+ return
+ }
+
+ for k, v := range shards {
+ database = k
+ shardList = v
+ }
+
+ if alias == "" {
+ alias = table.Suffix()
+ }
+
+ return
+ }
+
+ dbLeft, aliasLeft, shardLeft, err := compute(join.Left)
+ if err != nil {
+ return nil, err
+ }
+ dbRight, aliasRight, shardRight, err := compute(join.Right)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if dbLeft != "" && dbRight != "" && dbLeft != dbRight {
+ return nil, errors.New("not support more than one db")
+ }
+
+ joinPan := &dml.SimpleJoinPlan{
+ Left: &dml.JoinTable{
+ Tables: shardLeft,
+ Alias: aliasLeft,
+ },
+ Join: join,
+ Right: &dml.JoinTable{
+ Tables: shardRight,
+ Alias: aliasRight,
+ },
+ Stmt: o.Stmt.(*ast.SelectStatement),
+ }
+ joinPan.BindArgs(o.Args)
+
+ return joinPan, nil
+}
+
+func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) {
+ switch len(stmt.From) {
+ case 1:
+ from := stmt.From[0]
+ tn := from.TableName()
+
+ if tn == nil { // only FROM table supported now
+ return
+ }
+
+ flag |= _supported
+
+ if len(tn) > 1 {
+ switch strings.ToLower(tn.Prefix()) {
+ case "mysql", "information_schema":
+ flag |= _bypass
+ return
+ }
+ }
+ if !ru.Has(tn.Suffix()) {
+ flag |= _bypass
+ }
+ case 0:
+ flag |= _bypass
+ flag |= _supported
+ }
+ return
+}
+
+func optimizeOrderBy(stmt *ast.SelectStatement) []dataset.OrderByItem {
+ if stmt == nil || stmt.OrderBy == nil {
+ return nil
+ }
+ result := make([]dataset.OrderByItem, 0, len(stmt.OrderBy))
+ for _, node := range stmt.OrderBy {
+ column, _ := node.Expr.(ast.ColumnNameExpressionAtom)
+ item := dataset.OrderByItem{
+ Column: column[0],
+ Desc: node.Desc,
+ }
+ result = append(result, item)
+ }
+ return result
+}
+
+func overwriteLimit(stmt *ast.SelectStatement, args *[]interface{}) (originOffset, overwriteLimit int64) {
+ if stmt == nil || stmt.Limit == nil {
+ return 0, 0
+ }
+
+ offset := stmt.Limit.Offset()
+ limit := stmt.Limit.Limit()
+
+ // SELECT * FROM student where uid = ? limit ? offset ?
+ var offsetIndex int64
+ var limitIndex int64
+
+ if stmt.Limit.IsOffsetVar() {
+ offsetIndex = offset
+ offset = (*args)[offsetIndex].(int64)
+
+ if !stmt.Limit.IsLimitVar() {
+ limit = stmt.Limit.Limit()
+ *args = append(*args, limit)
+ limitIndex = int64(len(*args) - 1)
+ }
+ }
+ originOffset = offset
+
+ if stmt.Limit.IsLimitVar() {
+ limitIndex = limit
+ limit = (*args)[limitIndex].(int64)
+
+ if !stmt.Limit.IsOffsetVar() {
+ *args = append(*args, int64(0))
+ offsetIndex = int64(len(*args) - 1)
+ }
+ }
+
+ if stmt.Limit.IsLimitVar() || stmt.Limit.IsOffsetVar() {
+ if !stmt.Limit.IsLimitVar() {
+ stmt.Limit.SetLimitVar()
+ stmt.Limit.SetLimit(limitIndex)
+ }
+ if !stmt.Limit.IsOffsetVar() {
+ stmt.Limit.SetOffsetVar()
+ stmt.Limit.SetOffset(offsetIndex)
+ }
+
+ newLimitVar := limit + offset
+ overwriteLimit = newLimitVar
+ (*args)[limitIndex] = newLimitVar
+ (*args)[offsetIndex] = int64(0)
+ return
+ }
+
+ stmt.Limit.SetOffset(0)
+ stmt.Limit.SetLimit(offset + limit)
+ overwriteLimit = offset + limit
+ return
+}
+
+func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb string) error {
+ // todo db 计算逻辑&tb shard 的计算逻辑
+ var starExpand = false
+ if len(stmt.Select) == 1 {
+ if _, ok := stmt.Select[0].(*ast.SelectElementAll); ok {
+ starExpand = true
+ }
+ }
+
+ if !starExpand {
+ return nil
+ }
+
+ if len(tb) < 1 {
+ tb = stmt.From[0].TableName().Suffix()
+ }
+ metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb})
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ metadata := metadatas[tb]
+ if metadata == nil || len(metadata.ColumnNames) == 0 {
+ return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb)
+ }
+ selectElements := make([]ast.SelectElement, len(metadata.Columns))
+ for i, column := range metadata.ColumnNames {
+ selectElements[i] = ast.NewSelectElementColumn([]string{column}, "")
+ }
+ stmt.Select = selectElements
+
+ return nil
+}
diff --git a/pkg/runtime/optimize/dml/update.go b/pkg/runtime/optimize/dml/update.go
new file mode 100644
index 000000000..53bdf4632
--- /dev/null
+++ b/pkg/runtime/optimize/dml/update.go
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/runtime/plan/dml"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeUpdate, optimizeUpdate)
+}
+
+func optimizeUpdate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ var (
+ stmt = o.Stmt.(*ast.UpdateStatement)
+ table = stmt.Table
+ vt *rule.VTable
+ ok bool
+ )
+
+ // non-sharding update
+ if vt, ok = o.Rule.VTable(table.Suffix()); !ok {
+ ret := dml.NewUpdatePlan(stmt)
+ ret.BindArgs(o.Args)
+ return ret, nil
+ }
+
+ //check update sharding key
+ for _, element := range stmt.Updated {
+ if _, _, ok := vt.GetShardMetadata(element.Column.Suffix()); ok {
+ return nil, errors.New("do not support update sharding key")
+ }
+ }
+
+ var (
+ shards rule.DatabaseTables
+ fullScan = true
+ err error
+ )
+
+ // compute shards
+ if where := stmt.Where; where != nil {
+ sharder := (*optimize.Sharder)(o.Rule)
+ if shards, fullScan, err = sharder.Shard(table, where, o.Args...); err != nil {
+ return nil, errors.Wrap(err, "failed to update")
+ }
+ }
+
+ // exit if full-scan is disabled
+ if fullScan && !vt.AllowFullScan() {
+ return nil, optimize.ErrDenyFullScan
+ }
+
+ // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1)
+ if shards.IsEmpty() {
+ return plan.AlwaysEmptyExecPlan{}, nil
+ }
+
+ // compute all sharding tables
+ if shards.IsFullScan() {
+ // compute all tables
+ shards = vt.Topology().Enumerate()
+ }
+
+ ret := dml.NewUpdatePlan(stmt)
+ ret.BindArgs(o.Args)
+ ret.SetShards(shards)
+
+ return ret, nil
+}
diff --git a/pkg/runtime/optimize/optimizer.go b/pkg/runtime/optimize/optimizer.go
index 45c341368..a23ed7f58 100644
--- a/pkg/runtime/optimize/optimizer.go
+++ b/pkg/runtime/optimize/optimizer.go
@@ -19,700 +19,106 @@ package optimize
import (
"context"
- stdErrors "errors"
- "strings"
+ "errors"
)
import (
"github.com/arana-db/parser/ast"
- "github.com/pkg/errors"
+ perrors "github.com/pkg/errors"
+
+ "go.opentelemetry.io/otel"
)
import (
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/hint"
"github.com/arana-db/arana/pkg/proto/rule"
- "github.com/arana-db/arana/pkg/proto/schema_manager"
rast "github.com/arana-db/arana/pkg/runtime/ast"
- "github.com/arana-db/arana/pkg/runtime/cmp"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
- "github.com/arana-db/arana/pkg/runtime/plan"
- "github.com/arana-db/arana/pkg/transformer"
"github.com/arana-db/arana/pkg/util/log"
)
-var _ proto.Optimizer = (*optimizer)(nil)
+var _ proto.Optimizer = (*Optimizer)(nil)
+
+var Tracer = otel.Tracer("optimize")
// errors group
var (
- errNoRuleFound = stdErrors.New("no rule found")
- errDenyFullScan = stdErrors.New("the full-scan query is not allowed")
- errNoShardKeyFound = stdErrors.New("no shard key found")
+ ErrNoRuleFound = errors.New("optimize: no rule found")
+ ErrDenyFullScan = errors.New("optimize: the full-scan query is not allowed")
+ ErrNoShardKeyFound = errors.New("optimize: no shard key found")
)
// IsNoShardKeyFoundErr returns true if target error is caused by NO-SHARD-KEY-FOUND
func IsNoShardKeyFoundErr(err error) bool {
- return errors.Is(err, errNoShardKeyFound)
+ return perrors.Is(err, ErrNoShardKeyFound)
}
// IsNoRuleFoundErr returns true if target error is caused by NO-RULE-FOUND.
func IsNoRuleFoundErr(err error) bool {
- return errors.Is(err, errNoRuleFound)
+ return perrors.Is(err, ErrNoRuleFound)
}
// IsDenyFullScanErr returns true if target error is caused by DENY-FULL-SCAN.
func IsDenyFullScanErr(err error) bool {
- return errors.Is(err, errDenyFullScan)
-}
-
-func GetOptimizer() proto.Optimizer {
- return optimizer{
- schemaLoader: &schema_manager.SimpleSchemaLoader{},
- }
-}
-
-type optimizer struct {
- schemaLoader proto.SchemaLoader
-}
-
-func (o *optimizer) SetSchemaLoader(schemaLoader proto.SchemaLoader) {
- o.schemaLoader = schemaLoader
-}
-
-func (o *optimizer) SchemaLoader() proto.SchemaLoader {
- return o.schemaLoader
+ return perrors.Is(err, ErrDenyFullScan)
}
-func (o optimizer) Optimize(ctx context.Context, conn proto.VConn, stmt ast.StmtNode, args ...interface{}) (plan proto.Plan, err error) {
- defer func() {
- if rec := recover(); rec != nil {
- err = errors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx))
- log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec)
- }
- }()
-
- var rstmt rast.Statement
- if rstmt, err = rast.FromStmtNode(stmt); err != nil {
- return nil, errors.Wrap(err, "optimize failed")
- }
- return o.doOptimize(ctx, conn, rstmt, args...)
-}
-
-func (o optimizer) doOptimize(ctx context.Context, conn proto.VConn, stmt rast.Statement, args ...interface{}) (proto.Plan, error) {
- switch t := stmt.(type) {
- case *rast.ShowDatabases:
- return o.optimizeShowDatabases(ctx, t, args)
- case *rast.SelectStatement:
- return o.optimizeSelect(ctx, conn, t, args)
- case *rast.InsertStatement:
- return o.optimizeInsert(ctx, conn, t, args)
- case *rast.DeleteStatement:
- return o.optimizeDelete(ctx, t, args)
- case *rast.UpdateStatement:
- return o.optimizeUpdate(ctx, conn, t, args)
- case *rast.ShowTables:
- return o.optimizeShowTables(ctx, t, args)
- case *rast.TruncateStatement:
- return o.optimizeTruncate(ctx, t, args)
- case *rast.DropTableStatement:
- return o.optimizeDropTable(ctx, t, args)
- case *rast.ShowVariables:
- return o.optimizeShowVariables(ctx, t, args)
- case *rast.DescribeStatement:
- return o.optimizeDescribeStatement(ctx, t, args)
- case *rast.AlterTableStatement:
- return o.optimizeAlterTable(ctx, t, args)
- }
-
- //TODO implement all statements
- panic("implement me")
-}
-
-const (
- _bypass uint32 = 1 << iota
- _supported
+var (
+ _handlers = make(map[rast.SQLType]Processor)
)
-func (o optimizer) optimizeAlterTable(ctx context.Context, stmt *rast.AlterTableStatement, args []interface{}) (proto.Plan, error) {
- var (
- ret = plan.NewAlterTablePlan(stmt)
- ru = rcontext.Rule(ctx)
- table = stmt.Table
- vt *rule.VTable
- ok bool
- )
- ret.BindArgs(args)
-
- // non-sharding update
- if vt, ok = ru.VTable(table.Suffix()); !ok {
- return ret, nil
- }
-
- //TODO alter table table or column to new name , should update sharding info
-
- // exit if full-scan is disabled
- if !vt.AllowFullScan() {
- return nil, errDenyFullScan
- }
-
- // sharding
- shards := rule.DatabaseTables{}
- topology := vt.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
- ret.Shards = shards
- return ret, nil
-}
-
-func (o optimizer) optimizeDropTable(ctx context.Context, stmt *rast.DropTableStatement, args []interface{}) (proto.Plan, error) {
- ru := rcontext.Rule(ctx)
- //table shard
- var shards []rule.DatabaseTables
- //tables not shard
- noShardStmt := rast.NewDropTableStatement()
- for _, table := range stmt.Tables {
- shard, err := o.computeShards(ru, *table, nil, args)
- if err != nil {
- return nil, err
- }
- if shard == nil {
- noShardStmt.Tables = append(noShardStmt.Tables, table)
- continue
- }
- shards = append(shards, shard)
- }
-
- shardPlan := plan.NewDropTablePlan(stmt)
- shardPlan.BindArgs(args)
- shardPlan.SetShards(shards)
-
- if len(noShardStmt.Tables) == 0 {
- return shardPlan, nil
- }
-
- noShardPlan := plan.Transparent(noShardStmt, args)
-
- return &plan.UnionPlan{
- Plans: []proto.Plan{
- noShardPlan, shardPlan,
- },
- }, nil
+func Register(t rast.SQLType, h Processor) {
+ _handlers[t] = h
}
-func (o optimizer) getSelectFlag(ctx context.Context, stmt *rast.SelectStatement) (flag uint32) {
- switch len(stmt.From) {
- case 1:
- from := stmt.From[0]
- tn := from.TableName()
+type Processor = func(ctx context.Context, o *Optimizer) (proto.Plan, error)
- if tn == nil { // only FROM table supported now
- return
- }
-
- flag |= _supported
-
- if len(tn) > 1 {
- switch strings.ToLower(tn.Prefix()) {
- case "mysql", "information_schema":
- flag |= _bypass
- return
- }
- }
- if !rcontext.Rule(ctx).Has(tn.Suffix()) {
- flag |= _bypass
- }
- case 0:
- flag |= _bypass
- flag |= _supported
- }
- return
-}
-
-func (o optimizer) optimizeShowDatabases(ctx context.Context, stmt *rast.ShowDatabases, args []interface{}) (proto.Plan, error) {
- ret := &plan.ShowDatabasesPlan{Stmt: stmt}
- ret.BindArgs(args)
- return ret, nil
+type Optimizer struct {
+ Rule *rule.Rule
+ Hints []*hint.Hint
+ Stmt rast.Statement
+ Args []interface{}
}
-func (o optimizer) optimizeSelect(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) {
- var ru *rule.Rule
- if ru = rcontext.Rule(ctx); ru == nil {
- return nil, errors.WithStack(errNoRuleFound)
- }
- if stmt.HasJoin() {
- return o.optimizeJoin(ctx, conn, stmt, args)
- }
- flag := o.getSelectFlag(ctx, stmt)
- if flag&_supported == 0 {
- return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx))
- }
-
- if flag&_bypass != 0 {
- if len(stmt.From) > 0 {
- err := o.rewriteSelectStatement(ctx, conn, stmt, rcontext.DBGroup(ctx), stmt.From[0].TableName().Suffix())
- if err != nil {
- return nil, err
- }
- }
- ret := &plan.SimpleQueryPlan{Stmt: stmt}
- ret.BindArgs(args)
- return ret, nil
- }
-
+func NewOptimizer(rule *rule.Rule, hints []*hint.Hint, stmt ast.StmtNode, args []interface{}) (proto.Optimizer, error) {
var (
- shards rule.DatabaseTables
- fullScan bool
- err error
- vt = ru.MustVTable(stmt.From[0].TableName().Suffix())
+ rstmt rast.Statement
+ err error
)
-
- if shards, fullScan, err = (*Sharder)(ru).Shard(stmt.From[0].TableName(), stmt.Where, args...); err != nil {
- return nil, errors.Wrap(err, "calculate shards failed")
- }
-
- log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan)
-
- // return error if full-scan is disabled
- if fullScan && !vt.AllowFullScan() {
- return nil, errors.WithStack(errDenyFullScan)
- }
-
- toSingle := func(db, tbl string) (proto.Plan, error) {
- if err := o.rewriteSelectStatement(ctx, conn, stmt, db, tbl); err != nil {
- return nil, err
- }
- ret := &plan.SimpleQueryPlan{
- Stmt: stmt,
- Database: db,
- Tables: []string{tbl},
- }
- ret.BindArgs(args)
-
- return ret, nil
- }
-
- // Go through first table if no shards matched.
- // For example:
- // SELECT ... FROM xxx WHERE a > 8 and a < 4
- if shards.IsEmpty() {
- var (
- db0, tbl0 string
- ok bool
- )
- if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok {
- return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix())
- }
-
- return toSingle(db0, tbl0)
- }
-
- // Handle single shard
- if shards.Len() == 1 {
- var db, tbl string
- for k, v := range shards {
- db = k
- tbl = v[0]
- }
- return toSingle(db, tbl)
- }
-
- // Handle multiple shards
-
- if shards.IsFullScan() { // expand all shards if all shards matched
- // init shards
- shards = rule.DatabaseTables{}
- // compute all tables
- topology := vt.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
- }
-
- plans := make([]proto.Plan, 0, len(shards))
- for k, v := range shards {
- next := &plan.SimpleQueryPlan{
- Database: k,
- Tables: v,
- Stmt: stmt,
- }
- next.BindArgs(args)
- plans = append(plans, next)
- }
-
- if len(plans) > 0 {
- tempPlan := plans[0].(*plan.SimpleQueryPlan)
- if err = o.rewriteSelectStatement(ctx, conn, stmt, tempPlan.Database, tempPlan.Tables[0]); err != nil {
- return nil, err
- }
- }
-
- unionPlan := &plan.UnionPlan{
- Plans: plans,
- }
-
- // TODO: order/groupBy/aggregate
- aggregate := &plan.AggregatePlan{
- UnionPlan: unionPlan,
- Combiner: transformer.NewCombinerManager(),
- AggrLoader: transformer.LoadAggrs(stmt.Select),
- }
-
- return aggregate, nil
-}
-
-//optimizeJoin ony support a join b in one db
-func (o optimizer) optimizeJoin(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) {
-
- var ru *rule.Rule
- if ru = rcontext.Rule(ctx); ru == nil {
- return nil, errors.WithStack(errNoRuleFound)
- }
-
- join := stmt.From[0].Source().(*rast.JoinNode)
-
- compute := func(tableSource *rast.TableSourceNode) (database, alias string, shardList []string, err error) {
- table := tableSource.TableName()
- if table == nil {
- err = errors.New("must table, not statement or join node")
- return
- }
- alias = tableSource.Alias()
- database = table.Prefix()
-
- shards, err := o.computeShards(ru, table, nil, args)
- if err != nil {
- return
- }
- //table no shard
- if shards == nil {
- shardList = append(shardList, table.Suffix())
- return
- }
- //table shard more than one db
- if len(shards) > 1 {
- err = errors.New("not support more than one db")
- return
- }
-
- for k, v := range shards {
- database = k
- shardList = v
- }
-
- if alias == "" {
- alias = table.Suffix()
- }
-
- return
- }
-
- dbLeft, aliasLeft, shardLeft, err := compute(join.Left)
- if err != nil {
- return nil, err
- }
- dbRight, aliasRight, shardRight, err := compute(join.Right)
-
- if err != nil {
- return nil, err
- }
-
- if dbLeft != "" && dbRight != "" && dbLeft != dbRight {
- return nil, errors.New("not support more than one db")
- }
-
- joinPan := &plan.SimpleJoinPlan{
- Left: &plan.JoinTable{
- Tables: shardLeft,
- Alias: aliasLeft,
- },
- Join: join,
- Right: &plan.JoinTable{
- Tables: shardRight,
- Alias: aliasRight,
- },
- Stmt: stmt,
- }
- joinPan.BindArgs(args)
-
- return joinPan, nil
-}
-
-func (o optimizer) optimizeUpdate(ctx context.Context, conn proto.VConn, stmt *rast.UpdateStatement, args []interface{}) (proto.Plan, error) {
- var (
- ru = rcontext.Rule(ctx)
- table = stmt.Table
- vt *rule.VTable
- ok bool
- )
-
- // non-sharding update
- if vt, ok = ru.VTable(table.Suffix()); !ok {
- ret := plan.NewUpdatePlan(stmt)
- ret.BindArgs(args)
- return ret, nil
- }
-
- var (
- shards rule.DatabaseTables
- fullScan = true
- err error
- )
-
- // compute shards
- if where := stmt.Where; where != nil {
- sharder := (*Sharder)(ru)
- if shards, fullScan, err = sharder.Shard(table, where, args...); err != nil {
- return nil, errors.Wrap(err, "failed to update")
- }
- }
-
- // exit if full-scan is disabled
- if fullScan && !vt.AllowFullScan() {
- return nil, errDenyFullScan
- }
-
- // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1)
- if shards.IsEmpty() {
- return plan.AlwaysEmptyExecPlan{}, nil
- }
-
- // compute all sharding tables
- if shards.IsFullScan() {
- // init shards
- shards = rule.DatabaseTables{}
- // compute all tables
- topology := vt.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
- }
-
- ret := plan.NewUpdatePlan(stmt)
- ret.BindArgs(args)
- ret.SetShards(shards)
-
- return ret, nil
-}
-
-func (o optimizer) optimizeInsert(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement, args []interface{}) (proto.Plan, error) {
- var (
- ru = rcontext.Rule(ctx)
- ret = plan.NewSimpleInsertPlan()
- )
-
- ret.BindArgs(args)
-
- var (
- vt *rule.VTable
- ok bool
- )
-
- if vt, ok = ru.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table
- ret.Put("", stmt)
- return ret, nil
- }
-
- // TODO: handle multiple shard keys.
-
- bingo := -1
- // check existing shard columns
- for i, col := range stmt.Columns() {
- if _, _, ok = vt.GetShardMetadata(col); ok {
- bingo = i
- break
- }
- }
-
- if bingo < 0 {
- return nil, errors.Wrap(errNoShardKeyFound, "failed to insert")
- }
-
- var (
- sharder = (*Sharder)(ru)
- left = rast.ColumnNameExpressionAtom(make([]string, 1))
- filter = &rast.PredicateExpressionNode{
- P: &rast.BinaryComparisonPredicateNode{
- Left: &rast.AtomPredicateNode{
- A: left,
- },
- Op: cmp.Ceq,
- },
- }
- slots = make(map[string]map[string][]int) // (db,table,valuesIndex)
- )
-
- // reset filter
- resetFilter := func(column string, value rast.ExpressionNode) {
- left[0] = column
- filter.P.(*rast.BinaryComparisonPredicateNode).Right = value.(*rast.PredicateExpressionNode).P
- }
-
- for i, values := range stmt.Values() {
- value := values[bingo]
- resetFilter(stmt.Columns()[bingo], value)
-
- shards, _, err := sharder.Shard(stmt.Table(), filter, args...)
-
- if err != nil {
- return nil, errors.WithStack(err)
- }
-
- if shards.Len() != 1 {
- return nil, errors.Wrap(errNoShardKeyFound, "failed to insert")
- }
-
- var (
- db string
- table string
- )
-
- for k, v := range shards {
- db = k
- table = v[0]
- break
- }
-
- if _, ok = slots[db]; !ok {
- slots[db] = make(map[string][]int)
- }
- slots[db][table] = append(slots[db][table], i)
- }
-
- for db, slot := range slots {
- for table, indexes := range slot {
- // clone insert stmt without values
- newborn := rast.NewInsertStatement(rast.TableName{table}, stmt.Columns())
- newborn.SetFlag(stmt.Flag())
- newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates())
-
- // collect values with same table
- values := make([][]rast.ExpressionNode, 0, len(indexes))
- for _, i := range indexes {
- values = append(values, stmt.Values()[i])
- }
- newborn.SetValues(values)
-
- o.rewriteInsertStatement(ctx, conn, newborn, db, table)
- ret.Put(db, newborn)
- }
- }
-
- return ret, nil
-}
-
-func (o optimizer) optimizeDelete(ctx context.Context, stmt *rast.DeleteStatement, args []interface{}) (proto.Plan, error) {
- ru := rcontext.Rule(ctx)
- shards, err := o.computeShards(ru, stmt.Table, stmt.Where, args)
- if err != nil {
- return nil, errors.Wrap(err, "failed to optimize DELETE statement")
- }
-
- // TODO: delete from a child sharding-table directly
-
- if shards == nil {
- return plan.Transparent(stmt, args), nil
+ if rstmt, err = rast.FromStmtNode(stmt); err != nil {
+ return nil, perrors.Wrap(err, "optimize failed")
}
- ret := plan.NewSimpleDeletePlan(stmt)
- ret.BindArgs(args)
- ret.SetShards(shards)
-
- return ret, nil
+ return &Optimizer{
+ Rule: rule,
+ Hints: hints,
+ Stmt: rstmt,
+ Args: args,
+ }, nil
}
-func (o optimizer) optimizeShowTables(ctx context.Context, stmt *rast.ShowTables, args []interface{}) (proto.Plan, error) {
- vts := rcontext.Rule(ctx).VTables()
- databaseTablesMap := make(map[string]rule.DatabaseTables, len(vts))
- for tableName, vt := range vts {
- shards := rule.DatabaseTables{}
- // compute all tables
- topology := vt.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
- databaseTablesMap[tableName] = shards
- }
-
- tmpPlanData := make(map[string]plan.DatabaseTable)
- for showTableName, databaseTables := range databaseTablesMap {
- for databaseName, shardingTables := range databaseTables {
- for _, shardingTable := range shardingTables {
- tmpPlanData[shardingTable] = plan.DatabaseTable{
- Database: databaseName,
- TableName: showTableName,
- }
- }
+func (o *Optimizer) Optimize(ctx context.Context) (plan proto.Plan, err error) {
+ ctx, span := Tracer.Start(ctx, "Optimize")
+ defer func() {
+ span.End()
+ if rec := recover(); rec != nil {
+ err = perrors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx))
+ log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec)
}
- }
-
- ret := plan.NewShowTablesPlan(stmt)
- ret.BindArgs(args)
- ret.SetAllShards(tmpPlanData)
- return ret, nil
-}
-
-func (o optimizer) optimizeTruncate(ctx context.Context, stmt *rast.TruncateStatement, args []interface{}) (proto.Plan, error) {
- ru := rcontext.Rule(ctx)
- shards, err := o.computeShards(ru, stmt.Table, nil, args)
- if err != nil {
- return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement")
- }
-
- if shards == nil {
- return plan.Transparent(stmt, args), nil
- }
-
- ret := plan.NewTruncatePlan(stmt)
- ret.BindArgs(args)
- ret.SetShards(shards)
-
- return ret, nil
-}
-
-func (o optimizer) optimizeShowVariables(ctx context.Context, stmt *rast.ShowVariables, args []interface{}) (proto.Plan, error) {
- ret := plan.NewShowVariablesPlan(stmt)
- ret.BindArgs(args)
- return ret, nil
-}
-
-func (o optimizer) optimizeDescribeStatement(ctx context.Context, stmt *rast.DescribeStatement, args []interface{}) (proto.Plan, error) {
- vts := rcontext.Rule(ctx).VTables()
- vtName := []string(stmt.Table)[0]
- ret := plan.NewDescribePlan(stmt)
- ret.BindArgs(args)
+ }()
- if vTable, ok := vts[vtName]; ok {
- shards := rule.DatabaseTables{}
- // compute all tables
- topology := vTable.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
- dbName, tblName := shards.Smallest()
- ret.Database = dbName
- ret.Table = tblName
+ h, ok := _handlers[o.Stmt.Mode()]
+ if !ok {
+ return nil, perrors.Errorf("optimize: no handler found for '%s'", o.Stmt.Mode())
}
- return ret, nil
+ return h(ctx, o)
}
-func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) {
+func (o *Optimizer) ComputeShards(table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) {
+ ru := o.Rule
vt, ok := ru.VTable(table.Suffix())
if !ok {
return nil, nil
@@ -720,14 +126,14 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast
shards, fullScan, err := (*Sharder)(ru).Shard(table, where, args...)
if err != nil {
- return nil, errors.Wrap(err, "calculate shards failed")
+ return nil, perrors.Wrapf(err, "optimize: cannot calculate shards of table '%s'", table.Suffix())
}
- log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan)
+ //log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan)
// return error if full-scan is disabled
if fullScan && !vt.AllowFullScan() {
- return nil, errors.WithStack(errDenyFullScan)
+ return nil, perrors.WithStack(ErrDenyFullScan)
}
if shards.IsEmpty() {
@@ -735,86 +141,9 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast
}
if len(shards) == 0 {
- // init shards
- shards = rule.DatabaseTables{}
// compute all tables
- topology := vt.Topology()
- topology.Each(func(dbIdx, tbIdx int) bool {
- if d, t, ok := topology.Render(dbIdx, tbIdx); ok {
- shards[d] = append(shards[d], t)
- }
- return true
- })
+ shards = vt.Topology().Enumerate()
}
return shards, nil
}
-
-func (o optimizer) rewriteSelectStatement(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement,
- db, tb string) error {
- // todo db 计算逻辑&tb shard 的计算逻辑
- var starExpand = false
- if len(stmt.Select) == 1 {
- if _, ok := stmt.Select[0].(*rast.SelectElementAll); ok {
- starExpand = true
- }
- }
- if starExpand {
- if len(tb) < 1 {
- tb = stmt.From[0].TableName().Suffix()
- }
- metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb]
- if metaData == nil || len(metaData.ColumnNames) == 0 {
- return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb)
- }
- selectElements := make([]rast.SelectElement, len(metaData.Columns))
- for i, column := range metaData.ColumnNames {
- selectElements[i] = rast.NewSelectElementColumn([]string{column}, "")
- }
- stmt.Select = selectElements
- }
-
- return nil
-}
-
-func (o optimizer) rewriteInsertStatement(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement,
- db, tb string) error {
- metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb]
- if metaData == nil || len(metaData.ColumnNames) == 0 {
- return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb)
- }
-
- if len(metaData.ColumnNames) == len(stmt.Columns()) {
- // User had explicitly specified every value
- return nil
- }
- columnsMetadata := metaData.Columns
-
- for _, colName := range stmt.Columns() {
- if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated {
- // User had explicitly specified auto-generated primary key column
- return nil
- }
- }
-
- pkColName := ""
- for name, column := range columnsMetadata {
- if column.PrimaryKey && column.Generated {
- pkColName = name
- break
- }
- }
- if len(pkColName) < 1 {
- // There's no auto-generated primary key column
- return nil
- }
-
- // TODO rewrite columns and add distributed primary key
- //stmt.SetColumns(append(stmt.Columns(), pkColName))
- // append value of distributed primary key
- //newValues := stmt.Values()
- //for _, newValue := range newValues {
- // newValue = append(newValue, )
- //}
- return nil
-}
diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go
index 80569ce62..00034dbd8 100644
--- a/pkg/runtime/optimize/optimizer_test.go
+++ b/pkg/runtime/optimize/optimizer_test.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package optimize
+package optimize_test
import (
"context"
@@ -33,10 +33,14 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
- rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/resultx"
+ . "github.com/arana-db/arana/pkg/runtime/optimize"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/dal"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/ddl"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/dml"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/utility"
"github.com/arana-db/arana/testdata"
)
@@ -49,21 +53,25 @@ func TestOptimizer_OptimizeSelect(t *testing.T) {
conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) {
t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args)
- return &mysql.Result{}, nil
+
+ ds := testdata.NewMockDataset(ctrl)
+ ds.EXPECT().Fields().Return([]proto.Field{}, nil).AnyTimes()
+
+ return resultx.New(resultx.WithDataset(ds)), nil
}).
AnyTimes()
var (
- sql = "select id, uid from student where uid in (?,?,?)"
- ctx = context.Background()
- rule = makeFakeRule(ctrl, 8)
- opt optimizer
+ sql = "select id, uid from student where uid in (?,?,?)"
+ ctx = context.Background()
+ ru = makeFakeRule(ctrl, 8)
)
p := parser.New()
stmt, _ := p.ParseOneStmt(sql, "", "")
-
- plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1, 2, 3)
+ opt, err := NewOptimizer(ru, nil, stmt, []interface{}{1, 2, 3})
+ assert.NoError(t, err)
+ plan, err := opt.Optimize(ctx)
assert.NoError(t, err)
_, _ = plan.ExecIn(ctx, conn)
@@ -117,27 +125,34 @@ func TestOptimizer_OptimizeInsert(t *testing.T) {
t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args)
fakeId++
- return &mysql.Result{
- AffectedRows: uint64(strings.Count(sql, "?")),
- InsertId: fakeId,
- }, nil
+ return resultx.New(
+ resultx.WithRowsAffected(uint64(strings.Count(sql, "?"))),
+ resultx.WithLastInsertID(fakeId),
+ ), nil
}).
AnyTimes()
- loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeStudentMetadata).Times(2)
+ loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeStudentMetadata, nil).Times(2)
+
+ oldLoader := proto.LoadSchemaLoader()
+ proto.RegisterSchemaLoader(loader)
+ defer proto.RegisterSchemaLoader(oldLoader)
var (
- ctx = context.Background()
- rule = makeFakeRule(ctrl, 8)
- opt = optimizer{schemaLoader: loader}
+ ctx = context.Background()
+ ru = makeFakeRule(ctrl, 8)
)
t.Run("sharding", func(t *testing.T) {
+
sql := "insert into student(name,uid,age) values('foo',?,18),('bar',?,19),('qux',?,17)"
p := parser.New()
stmt, _ := p.ParseOneStmt(sql, "", "")
- plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 8, 9, 16) // 8,16 -> fake_db_0000, 9 -> fake_db_0001
+ opt, err := NewOptimizer(ru, nil, stmt, []interface{}{8, 9, 16})
+ assert.NoError(t, err)
+
+ plan, err := opt.Optimize(ctx) // 8,16 -> fake_db_0000, 9 -> fake_db_0001
assert.NoError(t, err)
res, err := plan.ExecIn(ctx, conn)
@@ -155,7 +170,10 @@ func TestOptimizer_OptimizeInsert(t *testing.T) {
p := parser.New()
stmt, _ := p.ParseOneStmt(sql, "", "")
- plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1)
+ opt, err := NewOptimizer(ru, nil, stmt, []interface{}{1})
+ assert.NoError(t, err)
+
+ plan, err := opt.Optimize(ctx)
assert.NoError(t, err)
res, err := plan.ExecIn(ctx, conn)
@@ -166,7 +184,6 @@ func TestOptimizer_OptimizeInsert(t *testing.T) {
lastInsertId, _ := res.LastInsertId()
assert.Equal(t, fakeId, lastInsertId)
})
-
}
func TestOptimizer_OptimizeAlterTable(t *testing.T) {
@@ -174,23 +191,21 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) {
defer ctrl.Finish()
conn := testdata.NewMockVConn(ctrl)
- loader := testdata.NewMockSchemaLoader(ctrl)
conn.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) {
t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args)
- return &mysql.Result{}, nil
+ return resultx.New(), nil
}).AnyTimes()
var (
- ctx = context.Background()
- opt = optimizer{schemaLoader: loader}
- ru rule.Rule
- tab rule.VTable
- topo rule.Topology
+ ctx = context.Background()
+ ru rule.Rule
+ tab rule.VTable
+ topology rule.Topology
)
- topo.SetRender(func(_ int) string {
+ topology.SetRender(func(_ int) string {
return "fake_db"
}, func(i int) string {
return fmt.Sprintf("student_%04d", i)
@@ -199,8 +214,8 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) {
for i := 0; i < 8; i++ {
tables = append(tables, i)
}
- topo.SetTopology(0, tables...)
- tab.SetTopology(&topo)
+ topology.SetTopology(0, tables...)
+ tab.SetTopology(&topology)
tab.SetAllowFullScan(true)
ru.SetVTable("student", &tab)
@@ -210,12 +225,14 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) {
p := parser.New()
stmt, _ := p.ParseOneStmt(sql, "", "")
- plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt)
+ opt, err := NewOptimizer(&ru, nil, stmt, nil)
assert.NoError(t, err)
- _, err = plan.ExecIn(ctx, conn)
+ plan, err := opt.Optimize(ctx)
assert.NoError(t, err)
+ _, err = plan.ExecIn(ctx, conn)
+ assert.NoError(t, err)
})
t.Run("non-sharding", func(t *testing.T) {
@@ -224,10 +241,61 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) {
p := parser.New()
stmt, _ := p.ParseOneStmt(sql, "", "")
- plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt)
+ opt, err := NewOptimizer(&ru, nil, stmt, nil)
+ assert.NoError(t, err)
+
+ plan, err := opt.Optimize(ctx)
assert.NoError(t, err)
_, err = plan.ExecIn(ctx, conn)
assert.NoError(t, err)
})
}
+
+func TestOptimizer_OptimizeInsertSelect(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ conn := testdata.NewMockVConn(ctrl)
+
+ var fakeId uint64
+ conn.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
+ DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) {
+ t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args)
+ fakeId++
+
+ return resultx.New(
+ resultx.WithRowsAffected(uint64(strings.Count(sql, "?"))),
+ resultx.WithLastInsertID(fakeId),
+ ), nil
+ }).
+ AnyTimes()
+
+ var (
+ ctx = context.Background()
+ ru rule.Rule
+ )
+
+ ru.SetVTable("student", nil)
+
+ t.Run("non-sharding", func(t *testing.T) {
+ sql := "insert into employees(name, age) select name,age from employees_tmp limit 10,2"
+
+ p := parser.New()
+ stmt, _ := p.ParseOneStmt(sql, "", "")
+
+ opt, err := NewOptimizer(&ru, nil, stmt, []interface{}{1})
+ assert.NoError(t, err)
+
+ plan, err := opt.Optimize(ctx)
+ assert.NoError(t, err)
+
+ res, err := plan.ExecIn(ctx, conn)
+ assert.NoError(t, err)
+
+ affected, _ := res.RowsAffected()
+ assert.Equal(t, uint64(0), affected)
+ lastInsertId, _ := res.LastInsertId()
+ assert.Equal(t, fakeId, lastInsertId)
+ })
+}
diff --git a/pkg/runtime/optimize/sharder.go b/pkg/runtime/optimize/sharder.go
index 8dea8fcec..45e56c872 100644
--- a/pkg/runtime/optimize/sharder.go
+++ b/pkg/runtime/optimize/sharder.go
@@ -37,9 +37,7 @@ import (
rrule "github.com/arana-db/arana/pkg/runtime/rule"
)
-var (
- errArgumentOutOfRange = stdErrors.New("argument is out of bounds")
-)
+var errArgumentOutOfRange = stdErrors.New("argument is out of bounds")
// IsErrArgumentOutOfRange returns true if target error is caused by argument out of range.
func IsErrArgumentOutOfRange(err error) bool {
diff --git a/pkg/runtime/optimize/sharder_test.go b/pkg/runtime/optimize/sharder_test.go
index 4a6b6ac8b..229038366 100644
--- a/pkg/runtime/optimize/sharder_test.go
+++ b/pkg/runtime/optimize/sharder_test.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package optimize
+package optimize_test
import (
"fmt"
@@ -34,6 +34,7 @@ import (
import (
"github.com/arana-db/arana/pkg/proto/rule"
"github.com/arana-db/arana/pkg/runtime/ast"
+ . "github.com/arana-db/arana/pkg/runtime/optimize"
"github.com/arana-db/arana/testdata"
)
@@ -56,7 +57,8 @@ func TestShard(t *testing.T) {
{"select * from student where uid = if(PI()<3, 1, ?)", []interface{}{0}, []int{0}},
} {
t.Run(it.sql, func(t *testing.T) {
- stmt := ast.MustParse(it.sql).(*ast.SelectStatement)
+ _, rawStmt := ast.MustParse(it.sql)
+ stmt := rawStmt.(*ast.SelectStatement)
result, _, err := (*Sharder)(fakeRule).Shard(stmt.From[0].TableName(), stmt.Where, it.args...)
assert.NoError(t, err, "shard failed")
diff --git a/pkg/runtime/optimize/utility/describe.go b/pkg/runtime/optimize/utility/describe.go
new file mode 100644
index 000000000..3c757dc2d
--- /dev/null
+++ b/pkg/runtime/optimize/utility/describe.go
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package utility
+
+import (
+ "context"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ "github.com/arana-db/arana/pkg/runtime/plan/utility"
+)
+
+func init() {
+ optimize.Register(ast.SQLTypeDescribe, optimizeDescribeStatement)
+}
+
+func optimizeDescribeStatement(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) {
+ stmt := o.Stmt.(*ast.DescribeStatement)
+ vts := o.Rule.VTables()
+ vtName := []string(stmt.Table)[0]
+ ret := utility.NewDescribePlan(stmt)
+ ret.BindArgs(o.Args)
+
+ if vTable, ok := vts[vtName]; ok {
+ dbName, tblName, _ := vTable.Topology().Smallest()
+ ret.Database = dbName
+ ret.Table = tblName
+ ret.Column = stmt.Column
+ }
+
+ return ret, nil
+}
diff --git a/pkg/runtime/plan/always.go b/pkg/runtime/plan/always.go
index a402efc91..90444b96f 100644
--- a/pkg/runtime/plan/always.go
+++ b/pkg/runtime/plan/always.go
@@ -22,22 +22,19 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
)
var _ proto.Plan = (*AlwaysEmptyExecPlan)(nil)
-var _emptyResult mysql.Result
-
// AlwaysEmptyExecPlan represents an exec plan which affects nothing.
-type AlwaysEmptyExecPlan struct {
-}
+type AlwaysEmptyExecPlan struct{}
func (a AlwaysEmptyExecPlan) Type() proto.PlanType {
return proto.PlanTypeExec
}
-func (a AlwaysEmptyExecPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
- return &_emptyResult, nil
+func (a AlwaysEmptyExecPlan) ExecIn(_ context.Context, _ proto.VConn) (proto.Result, error) {
+ return resultx.New(), nil
}
diff --git a/pkg/runtime/plan/dal/show_columns.go b/pkg/runtime/plan/dal/show_columns.go
new file mode 100644
index 000000000..79b9672ca
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_columns.go
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*ShowColumnsPlan)(nil)
+
+type ShowColumnsPlan struct {
+ plan.BasePlan
+ Stmt *ast.ShowColumns
+ Table string
+}
+
+func (s *ShowColumnsPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (s *ShowColumnsPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ )
+
+ if err := s.generate(&sb, &indexes); err != nil {
+ return nil, errors.Wrap(err, "failed to generate show columns sql")
+ }
+
+ return conn.Query(ctx, "", sb.String(), s.ToArgs(indexes)...)
+}
+
+func (s *ShowColumnsPlan) generate(sb *strings.Builder, args *[]int) error {
+ var (
+ stmt = *s.Stmt
+ err error
+ )
+
+ if s.Table == "" {
+ if err = s.Stmt.Restore(ast.RestoreDefault, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ return nil
+ }
+
+ s.resetTable(&stmt, s.Table)
+ if err = stmt.Restore(ast.RestoreDefault, sb, args); err != nil {
+ return errors.WithStack(err)
+ }
+ return nil
+}
+
+func (s *ShowColumnsPlan) resetTable(dstmt *ast.ShowColumns, table string) {
+ dstmt.TableName = dstmt.TableName.ResetSuffix(table)
+}
diff --git a/pkg/runtime/plan/dal/show_create.go b/pkg/runtime/plan/dal/show_create.go
new file mode 100644
index 000000000..f9903a4a3
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_create.go
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*ShowCreatePlan)(nil)
+
+type ShowCreatePlan struct {
+ plan.BasePlan
+ Stmt *ast.ShowCreate
+ Database string
+ Table string
+}
+
+// NewShowCreatePlan create ShowCreate Plan
+func NewShowCreatePlan(stmt *ast.ShowCreate) *ShowCreatePlan {
+ return &ShowCreatePlan{
+ Stmt: stmt,
+ }
+}
+
+func (st *ShowCreatePlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (st *ShowCreatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ res proto.Result
+ err error
+ )
+
+ if err = st.Stmt.ResetTable(st.Table).Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ query = sb.String()
+ args = st.ToArgs(indexes)
+ target = st.Stmt.Target()
+ )
+
+ if res, err = conn.Query(ctx, st.Database, query, args...); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ // sharding table should be changed of target name
+ if st.Table != target {
+ fields, _ := ds.Fields()
+ ds = dataset.Pipe(ds,
+ dataset.Map(nil, func(next proto.Row) (proto.Row, error) {
+ dest := make([]proto.Value, len(fields))
+ if next.Scan(dest) != nil {
+ return next, nil
+ }
+ dest[0] = target
+ dest[1] = strings.Replace(dest[1].(string), st.Table, target, 1)
+
+ if next.IsBinary() {
+ return rows.NewBinaryVirtualRow(fields, dest), nil
+ }
+ return rows.NewTextVirtualRow(fields, dest), nil
+ }),
+ )
+ }
+
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
diff --git a/pkg/runtime/plan/show_databases.go b/pkg/runtime/plan/dal/show_databases.go
similarity index 63%
rename from pkg/runtime/plan/show_databases.go
rename to pkg/runtime/plan/dal/show_databases.go
index 9aedcc883..ea22c6e85 100644
--- a/pkg/runtime/plan/show_databases.go
+++ b/pkg/runtime/plan/dal/show_databases.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dal
import (
"context"
@@ -26,20 +26,21 @@ import (
)
import (
- fieldType "github.com/arana-db/arana/pkg/constants/mysql"
- "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/mysql/thead"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/plan"
"github.com/arana-db/arana/pkg/security"
)
-var tenantErr = errors.New("current db tenant not fund")
-
var _ proto.Plan = (*ShowDatabasesPlan)(nil)
type ShowDatabasesPlan struct {
- basePlan
+ plan.BasePlan
Stmt *ast.ShowDatabases
}
@@ -48,31 +49,21 @@ func (s *ShowDatabasesPlan) Type() proto.PlanType {
}
func (s *ShowDatabasesPlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "ShowDatabasesPlan.ExecIn")
+ defer span.End()
tenant, ok := security.DefaultTenantManager().GetTenantOfCluster(rcontext.Schema(ctx))
if !ok {
- return nil, tenantErr
+ return nil, errors.New("no tenant found in current db")
}
- clusters := security.DefaultTenantManager().GetClusters(tenant)
- var rows = make([]proto.Row, 0, len(clusters))
+ columns := thead.Database.ToFields()
+ ds := &dataset.VirtualDataset{
+ Columns: columns,
+ }
- for _, cluster := range clusters {
- encoded := mysql.PutLengthEncodedString([]byte(cluster))
- rows = append(rows, (&mysql.TextRow{}).Encode([]*proto.Value{
- {
- Typ: fieldType.FieldTypeVarString,
- Flags: fieldType.NotNullFlag,
- Raw: encoded,
- Val: cluster,
- Len: len(encoded),
- },
- },
- []proto.Field{&mysql.Field{}}, nil))
+ for _, cluster := range security.DefaultTenantManager().GetClusters(tenant) {
+ ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(columns, []proto.Value{cluster}))
}
- return &mysql.Result{
- Fields: []proto.Field{mysql.NewField("Database", fieldType.FieldTypeVarString)},
- Rows: rows,
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(resultx.WithDataset(ds)), nil
}
diff --git a/pkg/runtime/plan/dal/show_index.go b/pkg/runtime/plan/dal/show_index.go
new file mode 100644
index 000000000..088639746
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_index.go
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*ShowIndexPlan)(nil)
+
+type ShowIndexPlan struct {
+ plan.BasePlan
+ Stmt *ast.ShowIndex
+ Shards rule.DatabaseTables
+}
+
+func (s *ShowIndexPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (s *ShowIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ err error
+ )
+
+ if s.Shards == nil {
+ if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+ return conn.Query(ctx, "", sb.String(), s.ToArgs(indexes)...)
+ }
+
+ toTable := s.Stmt.TableName.Suffix()
+
+ db, table := s.Shards.Smallest()
+ s.Stmt.TableName = ast.TableName{table}
+
+ if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ query, err := conn.Query(ctx, db, sb.String(), s.ToArgs(indexes)...)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := query.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ fields, err := ds.Fields()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds = dataset.Pipe(ds,
+ dataset.Map(nil, func(next proto.Row) (proto.Row, error) {
+ dest := make([]proto.Value, len(fields))
+ if next.Scan(dest) != nil {
+ return next, nil
+ }
+ dest[0] = toTable
+
+ if next.IsBinary() {
+ return rows.NewBinaryVirtualRow(fields, dest), nil
+ }
+ return rows.NewTextVirtualRow(fields, dest), nil
+ }),
+ )
+
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
diff --git a/pkg/runtime/plan/dal/show_open_tables.go b/pkg/runtime/plan/dal/show_open_tables.go
new file mode 100644
index 000000000..2a4d2cfa2
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_open_tables.go
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*ShowOpenTablesPlan)(nil)
+
+type ShowOpenTablesPlan struct {
+ plan.BasePlan
+ Database string
+ Conn proto.DB
+ Stmt *ast.ShowOpenTables
+ invertedShards map[string]string // phy table name -> logical table name
+}
+
+// NewShowOpenTablesPlan create ShowTables Plan
+func NewShowOpenTablesPlan(stmt *ast.ShowOpenTables) *ShowOpenTablesPlan {
+ return &ShowOpenTablesPlan{
+ Stmt: stmt,
+ }
+}
+
+func (st *ShowOpenTablesPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (st *ShowOpenTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ res proto.Result
+ err error
+ )
+
+ if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ query = sb.String()
+ args = st.ToArgs(indexes)
+ )
+
+ if res, err = conn.Query(ctx, st.Database, query, args...); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ fields, _ := ds.Fields()
+
+ // filter duplicates
+ duplicates := make(map[string]struct{})
+
+ // 1. convert to logical table name
+ // 2. filter duplicated table name
+ ds = dataset.Pipe(ds,
+ dataset.Map(nil, func(next proto.Row) (proto.Row, error) {
+ dest := make([]proto.Value, len(fields))
+ if next.Scan(dest) != nil {
+ return next, nil
+ }
+
+ if logicalTableName, ok := st.invertedShards[dest[1].(string)]; ok {
+ dest[1] = logicalTableName
+ }
+
+ if next.IsBinary() {
+ return rows.NewBinaryVirtualRow(fields, dest), nil
+ }
+ return rows.NewTextVirtualRow(fields, dest), nil
+ }),
+ dataset.Filter(func(next proto.Row) bool {
+ var vr rows.VirtualRow
+ switch val := next.(type) {
+ case mysql.TextRow, mysql.BinaryRow:
+ return true
+ case rows.VirtualRow:
+ vr = val
+ default:
+ return true
+ }
+
+ tableName := vr.Values()[1].(string)
+ if _, ok := duplicates[tableName]; ok {
+ return false
+ }
+ duplicates[tableName] = struct{}{}
+ return true
+ }),
+ )
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
+
+func (st *ShowOpenTablesPlan) SetDatabase(database string) {
+ st.Database = database
+}
+
+func (st *ShowOpenTablesPlan) SetInvertedShards(m map[string]string) {
+ st.invertedShards = m
+}
diff --git a/pkg/runtime/plan/dal/show_tables.go b/pkg/runtime/plan/dal/show_tables.go
new file mode 100644
index 000000000..6df54b164
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_tables.go
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ constant "github.com/arana-db/arana/pkg/constants/mysql"
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var (
+ _ proto.Plan = (*ShowTablesPlan)(nil)
+
+ headerPrefix = "Tables_in_"
+)
+
+type ShowTablesPlan struct {
+ plan.BasePlan
+ Database string
+ Stmt *ast.ShowTables
+ invertedShards map[string]string // phy table name -> logical table name
+}
+
+// NewShowTablesPlan create ShowTables Plan
+func NewShowTablesPlan(stmt *ast.ShowTables) *ShowTablesPlan {
+ return &ShowTablesPlan{
+ Stmt: stmt,
+ }
+}
+
+func (st *ShowTablesPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (st *ShowTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ res proto.Result
+ err error
+ )
+ ctx, span := plan.Tracer.Start(ctx, "ShowTablesPlan.ExecIn")
+ defer span.End()
+
+ if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ query = sb.String()
+ args = st.ToArgs(indexes)
+ )
+
+ if res, err = conn.Query(ctx, st.Database, query, args...); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ fields, _ := ds.Fields()
+
+ fields[0] = mysql.NewField(headerPrefix+rcontext.Schema(ctx), constant.FieldTypeVarString)
+
+ // filter duplicates
+ duplicates := make(map[string]struct{})
+
+ // 1. convert to logical table name
+ // 2. filter duplicated table name
+ ds = dataset.Pipe(ds,
+ dataset.Map(nil, func(next proto.Row) (proto.Row, error) {
+ dest := make([]proto.Value, len(fields))
+ if next.Scan(dest) != nil {
+ return next, nil
+ }
+ var tableName sql.NullString
+ _ = tableName.Scan(dest[0])
+ dest[0] = tableName.String
+
+ if logicalTableName, ok := st.invertedShards[tableName.String]; ok {
+ dest[0] = logicalTableName
+ }
+
+ if next.IsBinary() {
+ return rows.NewBinaryVirtualRow(fields, dest), nil
+ }
+ return rows.NewTextVirtualRow(fields, dest), nil
+ }),
+ dataset.Filter(func(next proto.Row) bool {
+ var vr rows.VirtualRow
+ switch val := next.(type) {
+ case mysql.TextRow, mysql.BinaryRow:
+ return true
+ case rows.VirtualRow:
+ vr = val
+ default:
+ return true
+ }
+
+ tableName := vr.Values()[0].(string)
+ if _, ok := duplicates[tableName]; ok {
+ return false
+ }
+ duplicates[tableName] = struct{}{}
+ return true
+ }),
+ )
+
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
+
+func (st *ShowTablesPlan) SetDatabase(db string) {
+ st.Database = db
+}
+
+func (st *ShowTablesPlan) SetInvertedShards(m map[string]string) {
+ st.invertedShards = m
+}
diff --git a/pkg/runtime/plan/dal/show_topology.go b/pkg/runtime/plan/dal/show_topology.go
new file mode 100644
index 000000000..f1683f6cd
--- /dev/null
+++ b/pkg/runtime/plan/dal/show_topology.go
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dal
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql/rows"
+ "github.com/arana-db/arana/pkg/mysql/thead"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var (
+ _ proto.Plan = (*ShowTopology)(nil)
+)
+
+type ShowTopology struct {
+ plan.BasePlan
+ Stmt *ast.ShowTopology
+ rule *rule.Rule
+}
+
+// NewShowTopologyPlan create ShowTopology Plan
+func NewShowTopologyPlan(stmt *ast.ShowTopology) *ShowTopology {
+ return &ShowTopology{
+ Stmt: stmt,
+ }
+}
+
+func (st *ShowTopology) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (st *ShowTopology) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ indexes []int
+ err error
+ table string
+ )
+ ctx, span := plan.Tracer.Start(ctx, "ShowTopology.ExecIn")
+ defer span.End()
+
+ if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ table = sb.String()
+
+ fields := thead.Topology.ToFields()
+
+ ds := &dataset.VirtualDataset{
+ Columns: fields,
+ }
+ vtable, ok := st.rule.VTable(table)
+ if !ok {
+ return nil, errors.New(fmt.Sprintf("%s do not have %s's topology", rcontext.Schema(ctx), table))
+ }
+ t := vtable.Topology()
+ t.Each(func(x, y int) bool {
+ if dbGroup, phyTable, ok := t.Render(x, y); ok {
+ ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
+ 0, dbGroup, phyTable,
+ }))
+ }
+ return true
+ })
+ sort.Slice(ds.Rows, func(i, j int) bool {
+ if ds.Rows[i].(rows.VirtualRow).Values()[1].(string) < ds.Rows[j].(rows.VirtualRow).Values()[1].(string) {
+ return true
+ }
+ return ds.Rows[i].(rows.VirtualRow).Values()[1].(string) == ds.Rows[j].(rows.VirtualRow).Values()[1].(string) &&
+ ds.Rows[i].(rows.VirtualRow).Values()[2].(string) < ds.Rows[j].(rows.VirtualRow).Values()[2].(string)
+ })
+
+ for id := 0; id < len(ds.Rows); id++ {
+ ds.Rows[id].(rows.VirtualRow).Values()[0] = id
+ }
+
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
+
+func (st *ShowTopology) SetRule(rule *rule.Rule) {
+ st.rule = rule
+}
diff --git a/pkg/runtime/plan/variables.go b/pkg/runtime/plan/dal/show_variables.go
similarity index 88%
rename from pkg/runtime/plan/variables.go
rename to pkg/runtime/plan/dal/show_variables.go
index b603d9f92..d4bfe87b8 100644
--- a/pkg/runtime/plan/variables.go
+++ b/pkg/runtime/plan/dal/show_variables.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dal
import (
"context"
@@ -29,12 +29,13 @@ import (
import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
var _ proto.Plan = (*ShowVariablesPlan)(nil)
type ShowVariablesPlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.ShowVariables
}
@@ -47,17 +48,18 @@ func (s *ShowVariablesPlan) Type() proto.PlanType {
}
func (s *ShowVariablesPlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Result, error) {
-
var (
sb strings.Builder
args []int
)
+ ctx, span := plan.Tracer.Start(ctx, "ShowVariablesPlan.ExecIn")
+ defer span.End()
if err := s.stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
return nil, errors.Wrap(err, "failed to execute DELETE statement")
}
- ret, err := vConn.Query(ctx, "", sb.String(), s.toArgs(args)...)
+ ret, err := vConn.Query(ctx, "", sb.String(), s.ToArgs(args)...)
if err != nil {
return nil, err
}
diff --git a/pkg/runtime/plan/alter_table.go b/pkg/runtime/plan/ddl/alter_table.go
similarity index 87%
rename from pkg/runtime/plan/alter_table.go
rename to pkg/runtime/plan/ddl/alter_table.go
index c04cd5602..593e2278e 100644
--- a/pkg/runtime/plan/alter_table.go
+++ b/pkg/runtime/plan/ddl/alter_table.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package ddl
import (
"context"
@@ -31,17 +31,18 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
"github.com/arana-db/arana/pkg/util/log"
)
var _ proto.Plan = (*AlterTablePlan)(nil)
type AlterTablePlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.AlterTableStatement
Shards rule.DatabaseTables
}
@@ -50,18 +51,19 @@ func NewAlterTablePlan(stmt *ast.AlterTableStatement) *AlterTablePlan {
return &AlterTablePlan{stmt: stmt}
}
-func (d *AlterTablePlan) Type() proto.PlanType {
+func (at *AlterTablePlan) Type() proto.PlanType {
return proto.PlanTypeExec
}
func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ // TODO: ADD trace in all plan ExecIn
if at.Shards == nil {
// non-sharding alter table
var sb strings.Builder
if err := at.stmt.Restore(ast.RestoreDefault, &sb, nil); err != nil {
return nil, err
}
- return conn.Exec(ctx, "", sb.String(), at.args...)
+ return conn.Exec(ctx, "", sb.String(), at.Args...)
}
var (
affects = uatomic.NewUint64(0)
@@ -93,7 +95,7 @@ func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R
return errors.WithStack(err)
}
- if res, err = conn.Exec(ctx, db, sb.String(), at.toArgs(args)...); err != nil {
+ if res, err = conn.Exec(ctx, db, sb.String(), at.ToArgs(args)...); err != nil {
return errors.WithStack(err)
}
@@ -118,8 +120,5 @@ func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R
log.Debugf("sharding alter table success: batch=%d, affects=%d", cnt.Load(), affects.Load())
- return &mysql.Result{
- AffectedRows: affects.Load(),
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(resultx.WithRowsAffected(affects.Load())), nil
}
diff --git a/pkg/runtime/plan/ddl/create_index.go b/pkg/runtime/plan/ddl/create_index.go
new file mode 100644
index 000000000..c580c0bf6
--- /dev/null
+++ b/pkg/runtime/plan/ddl/create_index.go
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+type CreateIndexPlan struct {
+ plan.BasePlan
+ stmt *ast.CreateIndexStatement
+ Shards rule.DatabaseTables
+}
+
+func NewCreateIndexPlan(stmt *ast.CreateIndexStatement) *CreateIndexPlan {
+ return &CreateIndexPlan{
+ stmt: stmt,
+ }
+}
+
+func (c *CreateIndexPlan) Type() proto.PlanType {
+ return proto.PlanTypeExec
+}
+
+func (c *CreateIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ args []int
+ )
+
+ for db, tables := range c.Shards {
+ for i := range tables {
+ table := tables[i]
+
+ stmt := new(ast.CreateIndexStatement)
+ stmt.Table = ast.TableName{table}
+ stmt.IndexName = c.stmt.IndexName
+ stmt.Keys = c.stmt.Keys
+
+ if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
+ return nil, err
+ }
+
+ if err := c.execOne(ctx, conn, db, sb.String(), c.ToArgs(args)); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ sb.Reset()
+ }
+ }
+ return resultx.New(), nil
+}
+
+func (c *CreateIndexPlan) SetShard(shard rule.DatabaseTables) {
+ c.Shards = shard
+}
+
+func (c *CreateIndexPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return err
+ }
+ _, _ = res.Dataset()
+ return nil
+}
diff --git a/pkg/runtime/plan/ddl/drop_index.go b/pkg/runtime/plan/ddl/drop_index.go
new file mode 100644
index 000000000..c1cad0385
--- /dev/null
+++ b/pkg/runtime/plan/ddl/drop_index.go
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+type DropIndexPlan struct {
+ plan.BasePlan
+ stmt *ast.DropIndexStatement
+ shard rule.DatabaseTables
+}
+
+func NewDropIndexPlan(stmt *ast.DropIndexStatement) *DropIndexPlan {
+ return &DropIndexPlan{
+ stmt: stmt,
+ }
+}
+
+func (d *DropIndexPlan) Type() proto.PlanType {
+ return proto.PlanTypeExec
+}
+
+func (d *DropIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ args []int
+ )
+
+ for db, tables := range d.shard {
+ for i := range tables {
+ table := tables[i]
+
+ stmt := new(ast.DropIndexStatement)
+ stmt.Table = ast.TableName{table}
+ stmt.IndexName = d.stmt.IndexName
+
+ if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
+ return nil, err
+ }
+
+ if err := d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ sb.Reset()
+ }
+ }
+
+ return resultx.New(), nil
+}
+
+func (d *DropIndexPlan) SetShard(shard rule.DatabaseTables) {
+ d.shard = shard
+}
+
+func (d *DropIndexPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return err
+ }
+ _, _ = res.Dataset()
+ return nil
+}
diff --git a/pkg/runtime/plan/drop_table.go b/pkg/runtime/plan/ddl/drop_table.go
similarity index 62%
rename from pkg/runtime/plan/drop_table.go
rename to pkg/runtime/plan/ddl/drop_table.go
index c724e9f85..256f5df82 100644
--- a/pkg/runtime/plan/drop_table.go
+++ b/pkg/runtime/plan/ddl/drop_table.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package ddl
import (
"context"
@@ -23,14 +23,19 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
+ "github.com/pkg/errors"
+)
+
+import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
type DropTablePlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.DropTableStatement
shardsMap []rule.DatabaseTables
}
@@ -41,39 +46,51 @@ func NewDropTablePlan(stmt *ast.DropTableStatement) *DropTablePlan {
}
}
-func (d DropTablePlan) Type() proto.PlanType {
+func (d *DropTablePlan) Type() proto.PlanType {
return proto.PlanTypeExec
}
-func (d DropTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
- var sb strings.Builder
- var args []int
+func (d *DropTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "DropTablePlan.ExecIn")
+ defer span.End()
+ var (
+ sb strings.Builder
+ args []int
+ )
for _, shards := range d.shardsMap {
- var stmt = new(ast.DropTableStatement)
+ stmt := new(ast.DropTableStatement)
for db, tables := range shards {
for _, table := range tables {
-
stmt.Tables = append(stmt.Tables, &ast.TableName{
table,
})
}
err := stmt.Restore(ast.RestoreDefault, &sb, &args)
-
if err != nil {
return nil, err
}
- _, err = conn.Exec(ctx, db, sb.String(), d.toArgs(args)...)
- if err != nil {
- return nil, err
+
+ if err = d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil {
+ return nil, errors.WithStack(err)
}
+
sb.Reset()
}
}
- return &mysql.Result{DataChan: make(chan proto.Row, 1)}, nil
+ return resultx.New(), nil
}
-func (s *DropTablePlan) SetShards(shardsMap []rule.DatabaseTables) {
- s.shardsMap = shardsMap
+func (d *DropTablePlan) SetShards(shardsMap []rule.DatabaseTables) {
+ d.shardsMap = shardsMap
+}
+
+func (d *DropTablePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return err
+ }
+ _, _ = res.Dataset()
+ return nil
}
diff --git a/pkg/runtime/plan/ddl/drop_trigger.go b/pkg/runtime/plan/ddl/drop_trigger.go
new file mode 100644
index 000000000..39e09d8f8
--- /dev/null
+++ b/pkg/runtime/plan/ddl/drop_trigger.go
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ddl
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*DropTriggerPlan)(nil)
+
+type DropTriggerPlan struct {
+ plan.BasePlan
+ Stmt *ast.DropTriggerStatement
+ Shards rule.DatabaseTables
+}
+
+func (d *DropTriggerPlan) Type() proto.PlanType {
+ return proto.PlanTypeExec
+}
+
+func (d *DropTriggerPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ sb strings.Builder
+ args []int
+ err error
+ )
+
+ for db := range d.Shards {
+ if err = d.Stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
+ return nil, err
+ }
+
+ if err = d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ sb.Reset()
+ }
+
+ return resultx.New(), nil
+}
+
+func (d *DropTriggerPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return err
+ }
+ _, _ = res.Dataset()
+ return nil
+}
diff --git a/pkg/runtime/plan/truncate.go b/pkg/runtime/plan/ddl/truncate.go
similarity index 77%
rename from pkg/runtime/plan/truncate.go
rename to pkg/runtime/plan/ddl/truncate.go
index 6624b17b3..2464e51c5 100644
--- a/pkg/runtime/plan/truncate.go
+++ b/pkg/runtime/plan/ddl/truncate.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package ddl
import (
"context"
@@ -27,16 +27,17 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
var _ proto.Plan = (*TruncatePlan)(nil)
type TruncatePlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.TruncateStatement
shards rule.DatabaseTables
}
@@ -51,8 +52,10 @@ func (s *TruncatePlan) Type() proto.PlanType {
}
func (s *TruncatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "TruncatePlan.ExecIn")
+ defer span.End()
if s.shards == nil || s.shards.IsEmpty() {
- return &mysql.Result{AffectedRows: 0}, nil
+ return resultx.New(), nil
}
var (
@@ -70,19 +73,28 @@ func (s *TruncatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resu
return nil, errors.Wrap(err, "failed to execute TRUNCATE statement")
}
- _, err := conn.Exec(ctx, db, sb.String(), s.toArgs(args)...)
- if err != nil {
+ if err := s.execOne(ctx, conn, db, sb.String(), s.ToArgs(args)); err != nil {
return nil, errors.WithStack(err)
}
+
sb.Reset()
}
}
- return &mysql.Result{
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(), nil
}
func (s *TruncatePlan) SetShards(shards rule.DatabaseTables) {
s.shards = shards
}
+
+func (s *TruncatePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return errors.WithStack(err)
+ }
+
+ defer resultx.Drain(res)
+
+ return nil
+}
diff --git a/pkg/runtime/plan/aggregate.go b/pkg/runtime/plan/dml/aggregate.go
similarity index 86%
rename from pkg/runtime/plan/aggregate.go
rename to pkg/runtime/plan/dml/aggregate.go
index 9127222b5..7dbde95ea 100644
--- a/pkg/runtime/plan/aggregate.go
+++ b/pkg/runtime/plan/dml/aggregate.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -27,13 +27,14 @@ import (
import (
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime/plan"
"github.com/arana-db/arana/pkg/transformer"
)
type AggregatePlan struct {
transformer.Combiner
AggrLoader *transformer.AggrLoader
- UnionPlan *UnionPlan
+ Plan proto.Plan
}
func (a *AggregatePlan) Type() proto.PlanType {
@@ -41,7 +42,9 @@ func (a *AggregatePlan) Type() proto.PlanType {
}
func (a *AggregatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
- res, err := a.UnionPlan.ExecIn(ctx, conn)
+ ctx, span := plan.Tracer.Start(ctx, "AggregatePlan.ExecIn")
+ defer span.End()
+ res, err := a.Plan.ExecIn(ctx, conn)
if err != nil {
return nil, errors.WithStack(err)
}
diff --git a/pkg/runtime/plan/dml/group_plan.go b/pkg/runtime/plan/dml/group_plan.go
new file mode 100644
index 000000000..fd5f6b2b8
--- /dev/null
+++ b/pkg/runtime/plan/dml/group_plan.go
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/merge"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+)
+
+// GroupPlan TODO now only support stmt which group by items equal with order by items, such as
+// `select uid, max(score) from student group by uid order by uid`
+type GroupPlan struct {
+ Plan proto.Plan
+ AggItems map[int]func() merge.Aggregator
+ OrderByItems []dataset.OrderByItem
+}
+
+func (g *GroupPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (g *GroupPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ res, err := g.Plan.ExecIn(ctx, conn)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ fields, err := ds.Fields()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ return resultx.New(resultx.WithDataset(dataset.Pipe(ds, dataset.GroupReduce(
+ g.OrderByItems,
+ func(fields []proto.Field) []proto.Field {
+ return fields
+ },
+ func() dataset.Reducer {
+ return dataset.NewGroupReducer(g.AggItems, fields)
+ },
+ )))), nil
+}
diff --git a/pkg/runtime/plan/dml/insert_select.go b/pkg/runtime/plan/dml/insert_select.go
new file mode 100644
index 000000000..1c906f300
--- /dev/null
+++ b/pkg/runtime/plan/dml/insert_select.go
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+)
+
+var _ proto.Plan = (*InsertSelectPlan)(nil)
+
+type InsertSelectPlan struct {
+ plan.BasePlan
+ Batch map[string]*ast.InsertSelectStatement
+}
+
+func NewInsertSelectPlan() *InsertSelectPlan {
+ return &InsertSelectPlan{
+ Batch: make(map[string]*ast.InsertSelectStatement),
+ }
+}
+
+func (sp *InsertSelectPlan) Type() proto.PlanType {
+ return proto.PlanTypeExec
+}
+
+func (sp *InsertSelectPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ affects uint64
+ lastInsertId uint64
+ )
+ // TODO: consider wrap a transaction if insert into multiple databases
+ // TODO: insert in parallel
+ for db, insert := range sp.Batch {
+ id, affected, err := sp.doInsert(ctx, conn, db, insert)
+ if err != nil {
+ return nil, err
+ }
+ affects += affected
+ if id > lastInsertId {
+ lastInsertId = id
+ }
+ }
+
+ return resultx.New(resultx.WithLastInsertID(lastInsertId), resultx.WithRowsAffected(affects)), nil
+}
+
+func (sp *InsertSelectPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertSelectStatement) (uint64, uint64, error) {
+ var (
+ sb strings.Builder
+ args []int
+ )
+
+ if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
+ return 0, 0, errors.Wrap(err, "cannot restore insert statement")
+ }
+ res, err := conn.Exec(ctx, db, sb.String(), sp.ToArgs(args)...)
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ defer resultx.Drain(res)
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ return id, affected, nil
+}
diff --git a/pkg/runtime/plan/dml/limit.go b/pkg/runtime/plan/dml/limit.go
new file mode 100644
index 000000000..d20a32882
--- /dev/null
+++ b/pkg/runtime/plan/dml/limit.go
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+)
+
+var _ proto.Plan = (*LimitPlan)(nil)
+
+type LimitPlan struct {
+ ParentPlan proto.Plan
+ OriginOffset int64
+ OverwriteLimit int64
+}
+
+func (limitPlan *LimitPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (limitPlan *LimitPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ if limitPlan.ParentPlan == nil {
+ return nil, errors.New("limitPlan: ParentPlan is nil")
+ }
+
+ res, err := limitPlan.ParentPlan.ExecIn(ctx, conn)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var count int64
+ ds = dataset.Pipe(ds, dataset.Filter(func(next proto.Row) bool {
+ count++
+ if count < limitPlan.OriginOffset {
+ return false
+ }
+ if count > limitPlan.OriginOffset && count <= limitPlan.OverwriteLimit {
+ return true
+ }
+ return false
+ }))
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
diff --git a/pkg/runtime/plan/dml/order.go b/pkg/runtime/plan/dml/order.go
new file mode 100644
index 000000000..ea722c32e
--- /dev/null
+++ b/pkg/runtime/plan/dml/order.go
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+)
+
+var _ proto.Plan = (*OrderPlan)(nil)
+
+type OrderPlan struct {
+ ParentPlan proto.Plan
+ OrderByItems []dataset.OrderByItem
+}
+
+func (orderPlan *OrderPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (orderPlan *OrderPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ if orderPlan.ParentPlan == nil {
+ return nil, errors.New("order plan: ParentPlan is nil")
+ }
+
+ res, err := orderPlan.ParentPlan.ExecIn(ctx, conn)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ fuseable, ok := ds.(*dataset.FuseableDataset)
+
+ if !ok {
+ return nil, errors.New("order plan convert Dataset to FuseableDataset cause error")
+ }
+
+ orderedDataset := dataset.NewOrderedDataset(fuseable.ToParallel(), orderPlan.OrderByItems)
+
+ return resultx.New(resultx.WithDataset(orderedDataset)), nil
+}
diff --git a/pkg/runtime/plan/simple_delete.go b/pkg/runtime/plan/dml/simple_delete.go
similarity index 76%
rename from pkg/runtime/plan/simple_delete.go
rename to pkg/runtime/plan/dml/simple_delete.go
index 39ebfe979..6044b642c 100644
--- a/pkg/runtime/plan/simple_delete.go
+++ b/pkg/runtime/plan/dml/simple_delete.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -27,17 +27,18 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
var _ proto.Plan = (*SimpleDeletePlan)(nil)
// SimpleDeletePlan represents a simple delete plan for sharding table.
type SimpleDeletePlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.DeleteStatement
shards rule.DatabaseTables
}
@@ -52,8 +53,10 @@ func (s *SimpleDeletePlan) Type() proto.PlanType {
}
func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "SimpleDeletePlan.ExecIn")
+ defer span.End()
if s.shards == nil || s.shards.IsEmpty() {
- return &mysql.Result{AffectedRows: 0}, nil
+ return resultx.New(), nil
}
var (
@@ -77,12 +80,11 @@ func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.
return nil, errors.Wrap(err, "failed to execute DELETE statement")
}
- res, err := conn.Exec(ctx, db, sb.String(), s.toArgs(args)...)
+ n, err := s.execOne(ctx, conn, db, sb.String(), s.ToArgs(args))
if err != nil {
return nil, errors.WithStack(err)
}
- n, _ := res.RowsAffected()
affects += n
// cleanup
@@ -93,12 +95,23 @@ func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.
}
}
- return &mysql.Result{
- AffectedRows: affects,
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(resultx.WithRowsAffected(affects)), nil
}
func (s *SimpleDeletePlan) SetShards(shards rule.DatabaseTables) {
s.shards = shards
}
+
+func (s *SimpleDeletePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) (uint64, error) {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+ defer resultx.Drain(res)
+
+ var n uint64
+ if n, err = res.RowsAffected(); err != nil {
+ return 0, errors.WithStack(err)
+ }
+ return n, nil
+}
diff --git a/pkg/runtime/plan/simple_insert.go b/pkg/runtime/plan/dml/simple_insert.go
similarity index 68%
rename from pkg/runtime/plan/simple_insert.go
rename to pkg/runtime/plan/dml/simple_insert.go
index d99698cca..8cbd524cb 100644
--- a/pkg/runtime/plan/simple_insert.go
+++ b/pkg/runtime/plan/dml/simple_insert.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -27,15 +27,16 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
var _ proto.Plan = (*SimpleInsertPlan)(nil)
type SimpleInsertPlan struct {
- basePlan
+ plan.BasePlan
batch map[string][]*ast.InsertStatement // key=db
}
@@ -55,41 +56,54 @@ func (sp *SimpleInsertPlan) Put(db string, stmt *ast.InsertStatement) {
func (sp *SimpleInsertPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
var (
- effected uint64
+ affects uint64
lastInsertId uint64
)
+ ctx, span := plan.Tracer.Start(ctx, "SimpleInsertPlan.ExecIn")
+ defer span.End()
// TODO: consider wrap a transaction if insert into multiple databases
// TODO: insert in parallel
for db, inserts := range sp.batch {
for _, insert := range inserts {
- res, err := sp.doInsert(ctx, conn, db, insert)
+ id, affected, err := sp.doInsert(ctx, conn, db, insert)
if err != nil {
return nil, err
}
- if n, _ := res.RowsAffected(); n > 0 {
- effected += n
- }
- if id, _ := res.LastInsertId(); id > lastInsertId {
+ affects += affected
+ if id > lastInsertId {
lastInsertId = id
}
}
}
- return &mysql.Result{
- AffectedRows: effected,
- InsertId: lastInsertId,
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(resultx.WithLastInsertID(lastInsertId), resultx.WithRowsAffected(affects)), nil
}
-func (sp *SimpleInsertPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertStatement) (proto.Result, error) {
+func (sp *SimpleInsertPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertStatement) (uint64, uint64, error) {
var (
sb strings.Builder
args []int
)
if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil {
- return nil, errors.Wrap(err, "cannot restore insert statement")
+ return 0, 0, errors.Wrap(err, "cannot restore insert statement")
+ }
+ res, err := conn.Exec(ctx, db, sb.String(), sp.ToArgs(args)...)
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ defer resultx.Drain(res)
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
}
- return conn.Exec(ctx, db, sb.String(), sp.toArgs(args)...)
+
+ return id, affected, nil
}
diff --git a/pkg/runtime/plan/simple_join.go b/pkg/runtime/plan/dml/simple_join.go
similarity index 93%
rename from pkg/runtime/plan/simple_join.go
rename to pkg/runtime/plan/dml/simple_join.go
index 8355cdbd1..7ab17a232 100644
--- a/pkg/runtime/plan/simple_join.go
+++ b/pkg/runtime/plan/dml/simple_join.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -29,6 +29,7 @@ import (
import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
type JoinTable struct {
@@ -37,7 +38,7 @@ type JoinTable struct {
}
type SimpleJoinPlan struct {
- basePlan
+ plan.BasePlan
Database string
Left *JoinTable
Join *ast.JoinNode
@@ -57,20 +58,23 @@ func (s *SimpleJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Re
err error
)
+ ctx, span := plan.Tracer.Start(ctx, "SimpleJoinPlan.ExecIn")
+ defer span.End()
+
if err := s.generateSelect(&sb, &indexes); err != nil {
return nil, err
}
sb.WriteString(" FROM ")
- //add left part
+ // add left part
if err := s.generateTable(s.Left.Tables, s.Left.Alias, &sb); err != nil {
return nil, err
}
s.generateJoinType(&sb)
- //add right part
+ // add right part
if err := s.generateTable(s.Right.Tables, s.Right.Alias, &sb); err != nil {
return nil, err
}
@@ -90,7 +94,7 @@ func (s *SimpleJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Re
var (
query = sb.String()
- args = s.toArgs(indexes)
+ args = s.ToArgs(indexes)
)
if res, err = conn.Query(ctx, s.Database, query, args...); err != nil {
@@ -121,7 +125,6 @@ func (s *SimpleJoinPlan) generateSelect(sb *strings.Builder, args *[]int) error
}
func (s *SimpleJoinPlan) generateTable(tables []string, alias string, sb *strings.Builder) error {
-
if len(tables) == 1 {
sb.WriteString(tables[0] + " ")
@@ -154,7 +157,7 @@ func (s *SimpleJoinPlan) generateTable(tables []string, alias string, sb *string
}
func (s *SimpleJoinPlan) generateJoinType(sb *strings.Builder) {
- //add join type
+ // add join type
switch s.Join.Typ {
case ast.LeftJoin:
sb.WriteString("LEFT")
diff --git a/pkg/runtime/plan/simple_select.go b/pkg/runtime/plan/dml/simple_select.go
similarity index 87%
rename from pkg/runtime/plan/simple_select.go
rename to pkg/runtime/plan/dml/simple_select.go
index 2fdf30021..4847d99f4 100644
--- a/pkg/runtime/plan/simple_select.go
+++ b/pkg/runtime/plan/dml/simple_select.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -27,15 +27,18 @@ import (
)
import (
+ "github.com/arana-db/arana/pkg/dataset"
"github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
var _ proto.Plan = (*SimpleQueryPlan)(nil)
type SimpleQueryPlan struct {
- basePlan
+ plan.BasePlan
Database string
Tables []string
Stmt *ast.SelectStatement
@@ -53,11 +56,10 @@ func (s *SimpleQueryPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R
err error
)
- if s.filter() {
- return &mysql.Result{
- DataChan: make(chan proto.Row, 1),
- }, nil
- }
+ ctx, span := plan.Tracer.Start(ctx, "SimpleQueryPlan.ExecIn")
+ defer span.End()
+
+ discard := s.filter()
if err = s.generate(&sb, &indexes); err != nil {
return nil, errors.Wrap(err, "failed to generate sql")
@@ -65,13 +67,34 @@ func (s *SimpleQueryPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R
var (
query = sb.String()
- args = s.toArgs(indexes)
+ args = s.ToArgs(indexes)
)
if res, err = conn.Query(ctx, s.Database, query, args...); err != nil {
return nil, errors.WithStack(err)
}
- return res, nil
+
+ if !discard {
+ return res, nil
+ }
+
+ var (
+ rr = res.(*mysql.RawResult)
+ fields []proto.Field
+ )
+
+ defer func() {
+ _ = rr.Discard()
+ }()
+
+ if fields, err = rr.Fields(); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ emptyDs := &dataset.VirtualDataset{
+ Columns: fields,
+ }
+ return resultx.New(resultx.WithDataset(emptyDs)), nil
}
func (s *SimpleQueryPlan) filter() bool {
@@ -132,9 +155,7 @@ func (s *SimpleQueryPlan) generate(sb *strings.Builder, args *[]int) error {
// UNION ALL
// (SELECT * FROM student_0000 WHERE uid IN (1,2,3)
- var (
- stmt = new(ast.SelectStatement)
- )
+ stmt := new(ast.SelectStatement)
*stmt = *s.Stmt // do copy
restore := func(table string) error {
@@ -171,9 +192,7 @@ func (s *SimpleQueryPlan) generate(sb *strings.Builder, args *[]int) error {
}
func (s *SimpleQueryPlan) resetOrderBy(tgt *ast.SelectStatement, sb *strings.Builder, args *[]int) error {
- var (
- builder strings.Builder
- )
+ var builder strings.Builder
builder.WriteString("SELECT * FROM (")
builder.WriteString(sb.String())
builder.WriteString(") ")
diff --git a/pkg/runtime/plan/dml/union.go b/pkg/runtime/plan/dml/union.go
new file mode 100644
index 000000000..d2e182953
--- /dev/null
+++ b/pkg/runtime/plan/dml/union.go
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package dml
+
+import (
+ "context"
+ "fmt"
+ "io"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/plan"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+// UnionPlan merges multiple query plan.
+type UnionPlan struct {
+ Plans []proto.Plan
+}
+
+func (u UnionPlan) Type() proto.PlanType {
+ return proto.PlanTypeQuery
+}
+
+func (u UnionPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "UnionPlan.ExecIn")
+ defer span.End()
+ switch u.Plans[0].Type() {
+ case proto.PlanTypeQuery:
+ return u.query(ctx, conn)
+ case proto.PlanTypeExec:
+ return u.exec(ctx, conn)
+ default:
+ panic("unreachable")
+ }
+}
+
+func (u UnionPlan) showOpenTables(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var (
+ fields []proto.Field
+ rows []proto.Row
+ filterMap = make(map[string]struct{}, 0) // map[database-table]
+ )
+ for _, it := range u.Plans {
+ it := it
+ res, err := it.ExecIn(ctx, conn)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ ds, err := res.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ fields, err = ds.Fields()
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ var row proto.Row
+
+ for {
+ row, err = ds.Next()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ var values []proto.Value
+ if err = row.Scan(values); err != nil {
+ return nil, errors.WithStack(err)
+ }
+ // Database Table In_use Name_locked
+ key := fmt.Sprintf("%s-%s", values[0].(string), values[1].(string))
+ if _, ok := filterMap[key]; ok {
+ continue
+ }
+ filterMap[key] = struct{}{}
+ rows = append(rows, row)
+ }
+ }
+ ds := &dataset.VirtualDataset{
+ Columns: fields,
+ Rows: rows,
+ }
+
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
+
+func (u UnionPlan) query(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var generators []dataset.GenerateFunc
+ for _, it := range u.Plans {
+ it := it
+ generators = append(generators, func() (proto.Dataset, error) {
+ res, err := it.ExecIn(ctx, conn)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ return res.Dataset()
+ })
+ }
+
+ ds, err := dataset.Fuse(generators[0], generators[1:]...)
+ if err != nil {
+ log.Errorf("UnionPlan Fuse error:%v", err)
+ return nil, err
+ }
+ return resultx.New(resultx.WithDataset(ds)), nil
+}
+
+func (u UnionPlan) exec(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ var id, affects uint64
+ for _, it := range u.Plans {
+ i, n, err := u.execOne(ctx, conn, it)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ affects += n
+ id += i
+ }
+
+ return resultx.New(resultx.WithLastInsertID(id), resultx.WithRowsAffected(affects)), nil
+}
+
+func (u UnionPlan) execOne(ctx context.Context, conn proto.VConn, p proto.Plan) (uint64, uint64, error) {
+ res, err := p.ExecIn(ctx, conn)
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ defer resultx.Drain(res)
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return 0, 0, errors.WithStack(err)
+ }
+
+ return id, affected, nil
+}
diff --git a/pkg/runtime/plan/update.go b/pkg/runtime/plan/dml/update.go
similarity index 78%
rename from pkg/runtime/plan/update.go
rename to pkg/runtime/plan/dml/update.go
index 3d8f65fbd..633c42396 100644
--- a/pkg/runtime/plan/update.go
+++ b/pkg/runtime/plan/dml/update.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package dml
import (
"context"
@@ -31,10 +31,11 @@ import (
)
import (
- "github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
+ "github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
"github.com/arana-db/arana/pkg/util/log"
)
@@ -42,7 +43,7 @@ var _ proto.Plan = (*UpdatePlan)(nil)
// UpdatePlan represents a plan to execute sharding-update.
type UpdatePlan struct {
- basePlan
+ plan.BasePlan
stmt *ast.UpdateStatement
shards rule.DatabaseTables
}
@@ -57,12 +58,14 @@ func (up *UpdatePlan) Type() proto.PlanType {
}
func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
+ ctx, span := plan.Tracer.Start(ctx, "UpdatePlan.ExecIn")
+ defer span.End()
if up.shards == nil {
var sb strings.Builder
if err := up.stmt.Restore(ast.RestoreDefault, &sb, nil); err != nil {
return nil, err
}
- return conn.Exec(ctx, "", sb.String(), up.args...)
+ return conn.Exec(ctx, "", sb.String(), up.Args...)
}
var (
@@ -84,8 +87,8 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul
var (
sb strings.Builder
args []int
- res proto.Result
err error
+ n uint64
)
sb.Grow(256)
@@ -95,11 +98,10 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul
return errors.WithStack(err)
}
- if res, err = conn.Exec(ctx, db, sb.String(), up.toArgs(args)...); err != nil {
+ if n, err = up.execOne(ctx, conn, db, sb.String(), up.ToArgs(args)); err != nil {
return errors.WithStack(err)
}
- n, _ := res.RowsAffected()
affects.Add(n)
cnt.Inc()
@@ -120,12 +122,25 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul
log.Debugf("sharding update success: batch=%d, affects=%d", cnt.Load(), affects.Load())
- return &mysql.Result{
- AffectedRows: affects.Load(),
- DataChan: make(chan proto.Row, 1),
- }, nil
+ return resultx.New(resultx.WithRowsAffected(affects.Load())), nil
}
func (up *UpdatePlan) SetShards(shards rule.DatabaseTables) {
up.shards = shards
}
+
+func (up *UpdatePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) (uint64, error) {
+ res, err := conn.Exec(ctx, db, query, args...)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+
+ defer resultx.Drain(res)
+
+ n, err := res.RowsAffected()
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+
+ return n, nil
+}
diff --git a/pkg/runtime/plan/plan.go b/pkg/runtime/plan/plan.go
index 9bff2732d..aab88102c 100644
--- a/pkg/runtime/plan/plan.go
+++ b/pkg/runtime/plan/plan.go
@@ -17,21 +17,27 @@
package plan
-type basePlan struct {
- args []interface{}
+import (
+ "go.opentelemetry.io/otel"
+)
+
+var Tracer = otel.Tracer("ExecPlan")
+
+type BasePlan struct {
+ Args []interface{}
}
-func (bp *basePlan) BindArgs(args []interface{}) {
- bp.args = args
+func (bp *BasePlan) BindArgs(args []interface{}) {
+ bp.Args = args
}
-func (bp basePlan) toArgs(indexes []int) []interface{} {
- if len(indexes) < 1 || len(bp.args) < 1 {
+func (bp *BasePlan) ToArgs(indexes []int) []interface{} {
+ if len(indexes) < 1 || len(bp.Args) < 1 {
return nil
}
ret := make([]interface{}, 0, len(indexes))
for _, idx := range indexes {
- ret = append(ret, bp.args[idx])
+ ret = append(ret, bp.Args[idx])
}
return ret
}
diff --git a/pkg/runtime/plan/show_tables.go b/pkg/runtime/plan/show_tables.go
deleted file mode 100644
index b335a0da2..000000000
--- a/pkg/runtime/plan/show_tables.go
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package plan
-
-import (
- "context"
- "io"
- "strings"
-)
-
-import (
- "github.com/pkg/errors"
-)
-
-import (
- "github.com/arana-db/arana/pkg/mysql"
- "github.com/arana-db/arana/pkg/proto"
- "github.com/arana-db/arana/pkg/runtime/ast"
-)
-
-var _ proto.Plan = (*ShowTablesPlan)(nil)
-
-type DatabaseTable struct {
- Database string
- TableName string
-}
-
-type ShowTablesPlan struct {
- basePlan
- Database string
- Stmt *ast.ShowTables
- allShards map[string]DatabaseTable
-}
-
-// NewShowTablesPlan create ShowTables Plan
-func NewShowTablesPlan(stmt *ast.ShowTables) *ShowTablesPlan {
- return &ShowTablesPlan{
- Stmt: stmt,
- }
-}
-
-func (s *ShowTablesPlan) Type() proto.PlanType {
- return proto.PlanTypeQuery
-}
-
-func (s *ShowTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
- var (
- sb strings.Builder
- indexes []int
- res proto.Result
- err error
- )
-
- if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil {
- return nil, errors.WithStack(err)
- }
-
- var (
- query = sb.String()
- args = s.toArgs(indexes)
- )
-
- if res, err = conn.Query(ctx, s.Database, query, args...); err != nil {
- return nil, errors.WithStack(err)
- }
-
- if closer, ok := res.(io.Closer); ok {
- defer func() {
- _ = closer.Close()
- }()
- }
-
- rebuildResult := mysql.Result{
- Fields: res.GetFields(),
- DataChan: make(chan proto.Row, 1),
- }
- hasRebuildTables := make(map[string]struct{})
- var affectRows uint64
- if len(res.GetRows()) > 0 {
- row := res.GetRows()[0]
- rowIter := row.(*mysql.TextIterRow)
- for has, err := rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- rowValues, err := row.Decode()
- if err != nil {
- return nil, err
- }
- tableName := s.convertInterfaceToStrNullable(rowValues[0].Val)
- if databaseTable, exist := s.allShards[tableName]; exist {
- if _, ok := hasRebuildTables[databaseTable.TableName]; ok {
- continue
- }
-
- if _, ok := hasRebuildTables[databaseTable.TableName]; ok {
- continue
- }
-
- encodeTableName := mysql.PutLengthEncodedString([]byte(databaseTable.TableName))
- tmpValues := rowValues
- for idx := range tmpValues {
- tmpValues[idx].Val = string(encodeTableName)
- tmpValues[idx].Raw = encodeTableName
- tmpValues[idx].Len = len(encodeTableName)
- }
-
- var tmpNewRow mysql.TextRow
- tmpNewRow.Encode(tmpValues, row.Fields(), row.Columns())
- rebuildResult.Rows = append(rebuildResult.Rows, &tmpNewRow)
- hasRebuildTables[databaseTable.TableName] = struct{}{}
- affectRows++
- continue
- }
- affectRows++
- textRow := &mysql.TextRow{Row: *rowIter.Row}
- rebuildResult.Rows = append(rebuildResult.Rows, textRow)
- }
- }
- rebuildResult.AffectedRows = affectRows
- return &rebuildResult, nil
-}
-
-func (s *ShowTablesPlan) convertInterfaceToStrNullable(value interface{}) string {
- if value != nil {
- return string(value.([]byte))
- }
- return ""
-}
-
-func (s *ShowTablesPlan) SetAllShards(allShards map[string]DatabaseTable) {
- s.allShards = allShards
-}
diff --git a/pkg/runtime/plan/transparent.go b/pkg/runtime/plan/transparent.go
index a4b67a360..ff9339239 100644
--- a/pkg/runtime/plan/transparent.go
+++ b/pkg/runtime/plan/transparent.go
@@ -35,7 +35,7 @@ var _ proto.Plan = (*TransparentPlan)(nil)
// TransparentPlan represents a transparent plan.
type TransparentPlan struct {
- basePlan
+ BasePlan
stmt rast.Statement
db string
typ proto.PlanType
@@ -45,7 +45,8 @@ type TransparentPlan struct {
func Transparent(stmt rast.Statement, args []interface{}) *TransparentPlan {
var typ proto.PlanType
switch stmt.Mode() {
- case rast.Sinsert, rast.Sdelete, rast.Sreplace, rast.Supdate, rast.Struncate, rast.SdropTable, rast.SalterTable:
+ case rast.SQLTypeInsert, rast.SQLTypeDelete, rast.SQLTypeReplace, rast.SQLTypeUpdate, rast.SQLTypeTruncate, rast.SQLTypeDropTable,
+ rast.SQLTypeAlterTable, rast.SQLTypeDropIndex, rast.SQLTypeCreateIndex:
typ = proto.PlanTypeExec
default:
typ = proto.PlanTypeQuery
@@ -76,6 +77,8 @@ func (tp *TransparentPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.
args []int
err error
)
+ ctx, span := Tracer.Start(ctx, "TransparentPlan.ExecIn")
+ defer span.End()
if err = tp.stmt.Restore(rast.RestoreDefault, &sb, &args); err != nil {
return nil, errors.WithStack(err)
@@ -83,9 +86,9 @@ func (tp *TransparentPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.
switch tp.typ {
case proto.PlanTypeQuery:
- return conn.Query(ctx, tp.db, sb.String(), tp.toArgs(args)...)
+ return conn.Query(ctx, tp.db, sb.String(), tp.ToArgs(args)...)
case proto.PlanTypeExec:
- return conn.Exec(ctx, tp.db, sb.String(), tp.toArgs(args)...)
+ return conn.Exec(ctx, tp.db, sb.String(), tp.ToArgs(args)...)
default:
panic("unreachable")
}
diff --git a/pkg/runtime/plan/union.go b/pkg/runtime/plan/union.go
deleted file mode 100644
index ebe649d3b..000000000
--- a/pkg/runtime/plan/union.go
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package plan
-
-import (
- "context"
- "io"
-)
-
-import (
- "github.com/pkg/errors"
-
- "go.uber.org/multierr"
-)
-
-import (
- "github.com/arana-db/arana/pkg/proto"
-)
-
-// UnionPlan merges multiple query plan.
-type UnionPlan struct {
- Plans []proto.Plan
-}
-
-func (u UnionPlan) Type() proto.PlanType {
- return proto.PlanTypeQuery
-}
-
-func (u UnionPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
- var results []proto.Result
- for _, it := range u.Plans {
- res, err := it.ExecIn(ctx, conn)
- if err != nil {
- return nil, errors.WithStack(err)
- }
- results = append(results, res)
- }
-
- return compositeResult(results), nil
-}
-
-type compositeResult []proto.Result
-
-func (c compositeResult) Close() error {
- var errs []error
- for _, it := range c {
- if closer, ok := it.(io.Closer); ok {
- if err := closer.Close(); err != nil {
- errs = append(errs, err)
- }
- }
- }
- return multierr.Combine(errs...)
-}
-
-func (c compositeResult) GetFields() []proto.Field {
- for _, it := range c {
- if ret := it.GetFields(); ret != nil {
- return ret
- }
- }
- return nil
-}
-
-func (c compositeResult) GetRows() []proto.Row {
- var rows []proto.Row
- for _, it := range c {
- rows = append(rows, it.GetRows()...)
- }
- return rows
-}
-
-func (c compositeResult) LastInsertId() (uint64, error) {
- return 0, nil
-}
-
-func (c compositeResult) RowsAffected() (uint64, error) {
- return 0, nil
-}
diff --git a/pkg/runtime/plan/describle.go b/pkg/runtime/plan/utility/describle.go
similarity index 91%
rename from pkg/runtime/plan/describle.go
rename to pkg/runtime/plan/utility/describle.go
index 1c121f7db..1d9cd6985 100644
--- a/pkg/runtime/plan/describle.go
+++ b/pkg/runtime/plan/utility/describle.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package plan
+package utility
import (
"context"
@@ -29,13 +29,15 @@ import (
import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/runtime/plan"
)
type DescribePlan struct {
- basePlan
+ plan.BasePlan
Stmt *ast.DescribeStatement
Database string
Table string
+ Column string
}
func NewDescribePlan(stmt *ast.DescribeStatement) *DescribePlan {
@@ -53,6 +55,8 @@ func (d *DescribePlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Res
res proto.Result
err error
)
+ ctx, span := plan.Tracer.Start(ctx, "DescribePlan.ExecIn")
+ defer span.End()
if err = d.generate(&sb, &indexes); err != nil {
return nil, errors.Wrap(err, "failed to generate desc/describe sql")
@@ -60,7 +64,7 @@ func (d *DescribePlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Res
var (
query = sb.String()
- args = d.toArgs(indexes)
+ args = d.ToArgs(indexes)
)
if res, err = vConn.Query(ctx, d.Database, query, args...); err != nil {
diff --git a/pkg/runtime/rule/shard.go b/pkg/runtime/rule/shard.go
index 7491b9286..91e8837eb 100644
--- a/pkg/runtime/rule/shard.go
+++ b/pkg/runtime/rule/shard.go
@@ -40,6 +40,8 @@ const (
HashMd5Shard ShardType = "hashMd5Shard"
HashCrc32Shard ShardType = "hashCrc32Shard"
HashBKDRShard ShardType = "hashBKDRShard"
+ ScriptExpr ShardType = "scriptExpr"
+ FunctionExpr ShardType = "functionExpr"
)
var shardMap = map[ShardType]ShardComputerFunc{
@@ -59,8 +61,10 @@ func ShardFactory(shardType ShardType, shardNum int) (shardStrategy rule.ShardCo
return nil, errors.New("do not have this shardType")
}
-type ShardType string
-type ShardComputerFunc func(shardNum int) rule.ShardComputer
+type (
+ ShardType string
+ ShardComputerFunc func(shardNum int) rule.ShardComputer
+)
type modShard struct {
shardNum int
diff --git a/pkg/runtime/rule/shard_expr.go b/pkg/runtime/rule/shard_expr.go
new file mode 100644
index 000000000..d33a42571
--- /dev/null
+++ b/pkg/runtime/rule/shard_expr.go
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rule
+
+import (
+ "fmt"
+ "strconv"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto/rule"
+)
+
+var _ rule.ShardComputer = (*exprShardComputer)(nil)
+
+type exprShardComputer struct {
+ expr string
+ column string
+}
+
+func NewExprShardComputer(expr, column string) (rule.ShardComputer, error) {
+ result := &exprShardComputer{
+ expr: expr,
+ column: column,
+ }
+ return result, nil
+}
+
+func (compute *exprShardComputer) Compute(value interface{}) (int, error) {
+ expr, vars, err := Parse(compute.expr)
+ if err != nil {
+ return 0, err
+ }
+ if len(vars) != 1 || vars[0] != Var(compute.column) {
+ return 0, errors.Errorf("Parse shard expr is error, expr is: %s", compute.expr)
+ }
+
+ shardValue := fmt.Sprintf("%v", value)
+ eval, _ := expr.Eval(Env{Var(compute.column): Value(shardValue)})
+
+ result, err := strconv.ParseFloat(eval.String(), 64)
+ if err != nil {
+ return 0, err
+ }
+
+ return int(result), nil
+}
diff --git a/pkg/runtime/rule/shard_expr_parse.go b/pkg/runtime/rule/shard_expr_parse.go
new file mode 100644
index 000000000..13d68ab0f
--- /dev/null
+++ b/pkg/runtime/rule/shard_expr_parse.go
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rule
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "text/scanner"
+)
+
+// ---- lexer ----
+// This lexer is similar to the one described in Chapter 13.
+type lexer struct {
+ scan scanner.Scanner
+ token rune // current lookahead token
+}
+
+func (lex *lexer) move() {
+ lex.token = lex.scan.Scan()
+ // fmt.Printf("after move, current token %q\n", lex.text())
+}
+
+func (lex *lexer) text() string { return lex.scan.TokenText() }
+func (lex *lexer) peek() rune { return lex.scan.Peek() }
+
+// describe returns a string describing the current token, for use in errors.
+func (lex *lexer) describe() string {
+ switch lex.token {
+ case scanner.EOF:
+ return "end of file"
+ case scanner.Ident:
+ return fmt.Sprintf("identifier %s", lex.text())
+ case scanner.Int, scanner.Float:
+ return fmt.Sprintf("number %s", lex.text())
+ }
+ return fmt.Sprintf("%q", rune(lex.token)) // any other rune
+}
+
+func precedence(op rune) int {
+ switch op {
+ case '*', '/', '%':
+ return 2
+ case '+', '-':
+ return 1
+ }
+ return 0
+}
+
+// ---- parser ----
+
+// Parse parses the input string as an arithmetic expression.
+//
+// expr = num a constant number, e.g., 3.14159
+// | id a variable name, e.g., x
+// | id '(' expr ',' ... ')' a function
+// | '-' expr a unary operator (+-)
+// | expr '+' expr a binary operator (+-*/)
+//
+func Parse(input string) (_ Expr, vars []Var, rerr error) {
+ defer func() {
+ switch x := recover().(type) {
+ case nil:
+ // no panic
+ case string:
+ rerr = fmt.Errorf("%s", x)
+ default:
+ rerr = fmt.Errorf("unexpected panic: resume state of panic")
+ }
+ }()
+
+ lex := new(lexer)
+ lex.scan.Init(strings.NewReader(input))
+ lex.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings
+ lex.move() // initial lookahead
+ _, e, err := parseExpr(lex, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ if lex.token != scanner.EOF {
+ return nil, nil, fmt.Errorf("unexpected %s", lex.describe())
+ }
+
+ // parse vars
+ l := len(input)
+ for i := 0; i < l; i++ {
+ if input[i] != '#' {
+ continue
+ }
+ i++ // consume '#'
+
+ var s string
+ for {
+ if input[i] == '#' {
+ i++ // consume '#'
+ break
+ }
+ if l <= i {
+ break
+ }
+ s += string(input[i])
+ i++
+ }
+ vars = append(vars, Var(s))
+ }
+
+ return e, vars, nil
+}
+
+func parseExpr(lex *lexer, s *stack) (*stack, Expr, error) { return parseBinary(lex, 1, s) }
+
+// binary = unary ('+' binary)*
+// parseBinary stops when it encounters an
+// operator of lower precedence than prec1.
+func parseBinary(lex *lexer, prec1 int, s *stack) (*stack, Expr, error) {
+ s1, lhs, err := parseUnary(lex, s)
+ if err != nil {
+ return s, nil, err
+ }
+ s = s1
+
+ for prec := precedence(lex.token); prec >= prec1; prec-- {
+ for precedence(lex.token) == prec {
+ op := lex.token
+ lex.move() // consume operator
+ s1, rhs, err := parseBinary(lex, prec+1, s)
+ if err != nil {
+ return s1, nil, err
+ }
+ s = s1
+ lhs = binaryExpr{op: op, x: lhs, y: rhs}
+ }
+ }
+ return s, lhs, nil
+}
+
+// unary = '+' expr | primary
+func parseUnary(lex *lexer, s *stack) (*stack, Expr, error) {
+ if lex.token == '+' || lex.token == '-' {
+ op := lex.token
+ lex.move() // consume '+' or '-'
+ s1, e, err := parseUnary(lex, s)
+ return s1, unary{op, e}, err
+ }
+ return parsePrimary(lex, s)
+}
+
+// primary = id
+// | id '(' expr ',' ... ',' expr ')'
+// | num
+// | '(' expr ')'
+func parsePrimary(lex *lexer, s *stack) (*stack, Expr, error) {
+ switch lex.token {
+ case scanner.Ident:
+ id := lex.text()
+ lex.move() // consume Ident
+ if lex.token != '(' {
+ return s, Var(id), nil
+ }
+
+ lex.move() // consume '('
+
+ // vars []Var
+ var args []Expr
+
+ if lex.token != ')' {
+ if s == (*stack)(nil) {
+ s = NewStack()
+ }
+ s.push(function{fn: id})
+ for {
+ s1, e, err := parseExpr(lex, s)
+ if err != nil {
+ return s, nil, err
+ }
+ s = s1
+ args = append(args, e)
+ //if v, ok := e.(Var); ok {
+ // vars = append(vars, v)
+ //}
+ if lex.token != ',' {
+ break
+ }
+ lex.move() // consume ','
+ }
+ if lex.token != ')' {
+ return s, nil, fmt.Errorf("got %s, want ')'", lex.describe())
+ }
+ }
+ lex.move() // consume ')'
+
+ if s != (*stack)(nil) {
+ e, notEmpty := s.pop()
+ if notEmpty {
+ // if lex.token == ')' && !tmpstack.isEmpty() {
+ //if lex.token == ',' {
+ // lex.move() // consume ","
+ //}
+ c, ok := e.(function)
+ if ok {
+ c.args = append(c.args, args...)
+ // c.vars = append(c.vars, vars...)
+ return s, c, nil
+ } else {
+ return s, nil, fmt.Errorf("illegal stack element %#v, type %T", e, e) // impossible
+ }
+ }
+ }
+ // return s, function{fn:id, args:args, vars: vars}, nil
+ return s, function{fn: id, args: args}, nil
+
+ case scanner.Int, scanner.Float:
+ f, err := strconv.ParseFloat(lex.text(), 64)
+ if err != nil {
+ return s, nil, err
+ }
+ lex.move() // consume number
+ return s, constant(f), nil
+
+ case scanner.String:
+ token := stringConstant(lex.text())
+ lex.move()
+ return s, token, nil
+
+ case '(':
+ lex.move() // consume '('
+ s1, e, err := parseExpr(lex, s)
+ if err != nil || lex.token != ')' {
+ return s, nil, fmt.Errorf("err %s, got %s, want ')'", err, lex.describe())
+ }
+ s = s1
+ lex.move() // consume ')'
+ return s, e, nil
+
+ case '\'':
+ // 此处之所以先存下原 mode,只分析 strings,是在分析 TestParse 的 case "hash(concat(#uid#, '1'), 100)"
+ // 这个例子时,不能正确分析 concat 的分隔符 '1'
+ mode := lex.scan.Mode
+ defer func() {
+ lex.scan.Mode = mode
+ }()
+ lex.scan.Mode = scanner.ScanStrings
+ lex.move() // consume \'
+
+ var str string
+ for {
+ str += string(lex.token)
+ lex.move()
+ if lex.token == '\'' || lex.token == scanner.EOF {
+ break
+ }
+ }
+ if lex.token != '\'' {
+ return s, nil, fmt.Errorf("parsing string with quote ', got illegal last token %s", lex.describe())
+ }
+ lex.move() // consume \'
+ return s, stringConstant(str), nil
+
+ case '#':
+ mode := lex.scan.Mode
+ defer func() {
+ lex.scan.Mode = mode
+ }()
+ lex.scan.Mode = scanner.ScanStrings
+ lex.move() // consume #'
+
+ var str string
+ for {
+ str += string(lex.token)
+ lex.move()
+ if lex.token == '#' || lex.token == scanner.EOF {
+ break
+ }
+ }
+ if lex.token != '#' {
+ return s, nil, fmt.Errorf("parsing string with quote #, got illegal last token %s", lex.describe())
+ }
+ lex.move() // consume #
+ return s, Var(str), nil
+ }
+
+ return s, nil, fmt.Errorf("unexpected %s", lex.describe())
+}
diff --git a/pkg/runtime/rule/shard_expr_parse_test.go b/pkg/runtime/rule/shard_expr_parse_test.go
new file mode 100644
index 000000000..e3169b3b7
--- /dev/null
+++ b/pkg/runtime/rule/shard_expr_parse_test.go
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rule
+
+import (
+ "fmt"
+ "strconv"
+ "testing"
+)
+
+func TestParse(t *testing.T) {
+ // assert.True(t, preParse(expString) == expression)
+ tests := []struct {
+ expr string
+ env Env
+ want string
+ }{
+ {"hash(toint(substr(#uid#, 1, 2)), 100)", Env{"uid": "87616"}, "87"},
+ {"hash(concat(#uid#, '1'), 100)", Env{"uid": "87616"}, "61"},
+ {"div(substr(#uid#, 2), 10)", Env{"uid": "87616"}, "761.6"},
+ }
+ var prevExpr string
+ for _, test := range tests {
+ // Print expr only when it changes.
+ if test.expr != prevExpr {
+ fmt.Printf("\n%s\n", test.expr)
+ prevExpr = test.expr
+ }
+ expr, vars, err := Parse(test.expr)
+ if err != nil {
+ t.Error(err) // parse error
+ continue
+ }
+ if len(vars) != 1 || vars[0] != "uid" {
+ t.Errorf("illegal vars %#v", vars)
+ }
+
+ // got := fmt.Sprintf("%.6g", )
+ evalRes, _ := expr.Eval(test.env)
+ evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64)
+ got := fmt.Sprintf("%.6g", evalResFloat)
+ fmt.Printf("\t%v => %s\n", test.env, got)
+ if got != test.want {
+ t.Errorf("%s.Eval() in %v = %q, want %q\n", test.expr, test.env, got, test.want)
+ }
+ }
+}
diff --git a/pkg/runtime/rule/shard_expr_type.go b/pkg/runtime/rule/shard_expr_type.go
new file mode 100644
index 000000000..f9c507522
--- /dev/null
+++ b/pkg/runtime/rule/shard_expr_type.go
@@ -0,0 +1,556 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rule
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+)
+
+const (
+ defaultValue = ""
+)
+
+// An Expr is an arithmetic/string expression.
+type Expr interface {
+ // Eval returns the value of this Expr in the environment env.
+ Eval(env Env) (Value, error)
+ // Check reports errors in this Expr and adds its Vars to the set.
+ Check(vars map[Var]bool) error
+ // String output
+ String() string
+}
+
+type (
+ // A Var identifies a variable, e.g., x.
+ Var string
+ Env map[Var]Value
+)
+
+func (v Var) Eval(env Env) (Value, error) {
+ return env[v], nil
+}
+
+func (v Var) Check(vars map[Var]bool) error {
+ vars[v] = true
+ return nil
+}
+
+func (v Var) String() string {
+ var buf bytes.Buffer
+ write(&buf, v)
+ return buf.String()
+}
+
+// Value defines var value
+type Value string
+
+func (v Value) String() string {
+ return string(v)
+}
+
+// A constant is a numeric constant, e.g., 3.141.
+type constant float64
+
+func (c constant) Eval(Env) (Value, error) {
+ return Value(strconv.Itoa(int(c))), nil
+}
+
+func (constant) Check(map[Var]bool) error {
+ return nil
+}
+
+func (c constant) String() string {
+ var buf bytes.Buffer
+ write(&buf, c)
+ return buf.String()
+}
+
+// A stringConstant is a string constant, e.g., 3.141.
+type stringConstant string
+
+func (s stringConstant) Eval(Env) (Value, error) {
+ return Value(s), nil
+}
+
+func (stringConstant) Check(map[Var]bool) error {
+ return nil
+}
+
+func (s stringConstant) String() string {
+ return (string)(s)
+}
+
+// A unary represents a unary operator expression, e.g., -x.
+type unary struct {
+ op rune // one of '+', '-'
+ x Expr
+}
+
+func (u unary) Eval(env Env) (Value, error) {
+ switch u.op {
+ case '+':
+ return u.x.Eval(env)
+
+ case '-':
+ xv, err := u.x.Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ fv, err := strconv.ParseFloat(xv.String(), 64)
+ if err != nil {
+ return defaultValue, err
+ }
+ return Value(fmt.Sprintf("%f", -fv)), nil
+ }
+
+ return defaultValue, fmt.Errorf("unsupported unary operator: %q", u.op)
+}
+
+func (u unary) Check(vars map[Var]bool) error {
+ if !strings.ContainsRune("+-", u.op) {
+ return fmt.Errorf("unexpected unary op %q", u.op)
+ }
+ return u.x.Check(vars)
+}
+
+func (u unary) String() string {
+ var buf bytes.Buffer
+ write(&buf, u)
+ return buf.String()
+}
+
+// A binary represents a binary operator expression, e.g., x+y.
+type binaryExpr struct {
+ op rune // one of '+', '-', '*', '/', '%'
+ x, y Expr
+}
+
+func (b binaryExpr) Eval(env Env) (Value, error) {
+ xv, err := b.x.Eval(env)
+ if err != nil {
+ return xv, err
+ }
+ xf, err := strconv.ParseFloat(xv.String(), 64)
+ if err != nil {
+ return xv, err
+ }
+
+ yv, err := b.y.Eval(env)
+ if err != nil {
+ return yv, err
+ }
+ yf, err := strconv.ParseFloat(yv.String(), 64)
+ if err != nil {
+ return yv, err
+ }
+
+ var f float64
+ switch b.op {
+ case '+':
+ f = xf + yf
+
+ case '-':
+ f = xf - yf
+
+ case '*':
+ f = xf * yf
+
+ case '/':
+ if yf == 0 {
+ yf = 1.0
+ }
+ // f = float64(int(xf / yf))
+ f = xf / yf
+
+ case '%':
+ if yf == 0 {
+ yf = 1.0
+ }
+ f = float64(int(xf) % int(yf))
+
+ default:
+ return defaultValue, fmt.Errorf("unsupported binary operator: %q", b.op)
+ }
+
+ return Value(fmt.Sprintf("%f", f)), nil
+}
+
+func (b binaryExpr) String() string {
+ var buf bytes.Buffer
+ write(&buf, b)
+ return buf.String()
+}
+
+func (b binaryExpr) Check(vars map[Var]bool) error {
+ if !strings.ContainsRune("+-*/%", b.op) {
+ return fmt.Errorf("unexpected binary op %q", b.op)
+ }
+ if err := b.x.Check(vars); err != nil {
+ return err
+ }
+ return b.y.Check(vars)
+}
+
+// A function represents a function function expression, e.g., sin(x).
+type function struct {
+ fn string // one of "pow", "sqrt"
+ args []Expr
+}
+
+func (c function) Eval(env Env) (Value, error) {
+ argsNum := len(c.args)
+ if argsNum == 0 {
+ return defaultValue, fmt.Errorf("args number is 0 of func %s", c.fn)
+ }
+
+ switch c.fn {
+ case "toint":
+ return c.args[0].Eval(env)
+
+ case "hash":
+ b := binaryExpr{
+ op: '%',
+ x: c.args[0],
+ y: constant(2),
+ }
+ if len(c.args) == 2 {
+ b.y = c.args[1]
+ }
+ return b.Eval(env)
+
+ case "add":
+ b := binaryExpr{
+ op: '+',
+ x: c.args[0],
+ y: c.args[1],
+ }
+ return b.Eval(env)
+
+ case "sub":
+ b := binaryExpr{
+ op: '-',
+ x: c.args[0],
+ y: c.args[1],
+ }
+ return b.Eval(env)
+
+ case "mul":
+ b := binaryExpr{
+ op: '*',
+ x: c.args[0],
+ y: c.args[1],
+ }
+ return b.Eval(env)
+
+ case "div":
+ b := binaryExpr{
+ op: '/',
+ x: c.args[0],
+ y: c.args[1],
+ }
+ return b.Eval(env)
+
+ case "substr":
+ v0, err := c.args[0].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ str := v0.String()
+
+ v1, err := c.args[1].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ strlen := len(str)
+ f, err := strconv.ParseFloat(v1.String(), 64)
+ if err != nil {
+ return defaultValue, fmt.Errorf("illegal args[1] %v of func substr", v1.String())
+ }
+ startPos := int(f)
+ if startPos == 0 {
+ return defaultValue, nil
+ }
+ if startPos < 0 {
+ startPos += strlen
+ } else {
+ startPos--
+ }
+ if startPos < 0 || startPos >= strlen {
+ return defaultValue, fmt.Errorf("illegal args[1] %v of func substr", v1.String())
+ }
+
+ endPos := strlen
+ if len(c.args) == 3 {
+ v2, err := c.args[2].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ l, err := strconv.Atoi(v2.String())
+ if err != nil {
+ return defaultValue, err
+ }
+ if l < 1 {
+ return defaultValue, fmt.Errorf("illegal args[2] %v of func substr", v2.String())
+ }
+ if startPos+l < endPos {
+ endPos = startPos + l
+ }
+ }
+
+ return Value(str[startPos:endPos]), nil
+
+ case "concat":
+ var builder strings.Builder
+ for i := 0; i < len(c.args); i++ {
+ v, err := c.args[i].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ builder.WriteString(v.String())
+ }
+ return Value(builder.String()), nil
+
+ case "testload":
+ v0, err := c.args[0].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ s := v0.String()
+
+ var b strings.Builder
+ b.Grow(len(s))
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if 'a' <= c && c <= 'j' {
+ c -= 'a' - '0'
+ } else if 'A' <= c && c <= 'J' {
+ c -= 'A' - '0'
+ }
+
+ b.WriteByte(c)
+ }
+
+ return Value(b.String()), nil
+
+ case "split":
+ v0, err := c.args[0].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ v1, err := c.args[1].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ v2, err := c.args[2].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ pos, err := strconv.Atoi(v2.String())
+ if err != nil {
+ return defaultValue, err
+ }
+ if pos < 1 {
+ return defaultValue, fmt.Errorf("illegal args[2] %v of func split", v2.String())
+ }
+
+ arr := strings.Split(v0.String(), v1.String())
+ return Value(arr[pos-1]), nil
+
+ case "pow":
+ v0, err := c.args[0].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ v0f, err := strconv.ParseFloat(v0.String(), 64)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ v1, err := c.args[1].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ v1f, err := strconv.ParseFloat(v1.String(), 64)
+ if err != nil {
+ return defaultValue, err
+ }
+
+ v := math.Pow(v0f, v1f)
+ return Value(fmt.Sprintf("%f", v)), nil
+
+ case "sqrt":
+ v0, err := c.args[0].Eval(env)
+ if err != nil {
+ return defaultValue, err
+ }
+ v0f, err := strconv.ParseFloat(v0.String(), 64)
+ if err != nil {
+ return defaultValue, err
+ }
+ v := math.Sqrt(v0f)
+ return Value(fmt.Sprintf("%f", v)), nil
+ }
+
+ return defaultValue, fmt.Errorf("unsupported function function: %s", c.fn)
+}
+
+func (c function) String() string {
+ var buf bytes.Buffer
+ write(&buf, c)
+ return buf.String()
+}
+
+func (c function) Check(vars map[Var]bool) error {
+ arity, ok := numParams[c.fn]
+ if !ok {
+ return fmt.Errorf("unknown function %q", c.fn)
+ }
+ if argsNum := len(c.args); argsNum != arity {
+ if argsNum == 0 || arity < argsNum {
+ return fmt.Errorf("illegal args number %d of func %s, want %d", argsNum, c.fn, arity)
+ }
+ switch c.fn {
+ case "substr":
+ if argsNum < 2 {
+ return fmt.Errorf("illegal args number %d of func substr, want %d", argsNum, arity)
+ }
+ }
+ }
+ for _, arg := range c.args {
+ if err := arg.Check(vars); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// expr stack
+type stack []Expr
+
+func NewStack() *stack {
+ return &stack{}
+}
+
+// size: return the element number of @s
+func (s *stack) size() int {
+ return len(*s)
+}
+
+// isEmpty: check if stack is empty
+func (s *stack) isEmpty() bool {
+ return s.size() == 0
+}
+
+// push a new value onto the stack
+func (s *stack) push(e Expr) *stack {
+ *s = append(*s, e)
+ return s
+}
+
+// pop: Remove and return top element of stack. Return false if stack is empty.
+func (s *stack) pop() (_ Expr, notEmpty bool) {
+ if s.isEmpty() {
+ return nil, false
+ }
+
+ i := len(*s) - 1
+ e := (*s)[i]
+ *s = (*s)[:i]
+ return e, true
+}
+
+var numParams = map[string]int{
+ "toint": 1,
+ "hash": 2,
+ "add": 2,
+ "sub": 2,
+ "mul": 2,
+ "div": 2,
+ "substr": 3,
+ "concat": 10,
+ "testload": 1,
+ "split": 3,
+ "pow": 2,
+ "sqrt": 1,
+}
+
+func write(buf *bytes.Buffer, e Expr) {
+ switch e.(type) {
+ case constant:
+ fmt.Fprintf(buf, "%g", e.(constant))
+
+ case stringConstant:
+ fmt.Fprintf(buf, "%s", e.(stringConstant))
+
+ case Var, *Var:
+ v, ok := e.(Var)
+ if !ok {
+ v = *(e.(*Var))
+ }
+
+ fmt.Fprintf(buf, "%s", string(v))
+
+ case unary, *unary:
+ u, ok := e.(unary)
+ if !ok {
+ u = *(e.(*unary))
+ }
+
+ fmt.Fprintf(buf, "(%c", u.op)
+ write(buf, u.x)
+ buf.WriteByte(')')
+
+ case binaryExpr, *binaryExpr:
+ b, ok := e.(binaryExpr)
+ if !ok {
+ b = *(e.(*binaryExpr))
+ }
+
+ buf.WriteByte('(')
+ write(buf, b.x)
+ fmt.Fprintf(buf, " %c ", b.op)
+ write(buf, b.y)
+ buf.WriteByte(')')
+
+ case function, *function:
+ f, ok := e.(function)
+ if !ok {
+ f = *(e.(*function))
+ }
+
+ fmt.Fprintf(buf, "%s(", f.fn)
+ for i, arg := range f.args {
+ if i > 0 {
+ buf.WriteString(", ")
+ }
+ write(buf, arg)
+ }
+ buf.WriteByte(')')
+
+ default:
+ fmt.Fprintf(buf, "unknown Expr: %T", e)
+ }
+}
diff --git a/pkg/runtime/rule/shard_expr_type_test.go b/pkg/runtime/rule/shard_expr_type_test.go
new file mode 100644
index 000000000..f793f91e8
--- /dev/null
+++ b/pkg/runtime/rule/shard_expr_type_test.go
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rule
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "testing"
+ "text/scanner"
+)
+
+//!+Eval
+func TestNumEval(t *testing.T) {
+ piStr := fmt.Sprintf("%f", math.Pi)
+ tests := []struct {
+ expr string
+ env Env
+ want string
+ }{
+ {"add(A, B)", Env{"A": "8", "B": "4"}, "12"},
+ {"sub(A, B)", Env{"A": "8", "B": "4"}, "4"},
+ {"mul(A, B)", Env{"A": "8", "B": "4"}, "32"},
+ {"div(A, B)", Env{"A": "8", "B": "4"}, "2"},
+ {"sqrt(A / pi)", Env{"A": "87616", "pi": Value(piStr)}, "167"},
+ {"testload(x)", Env{"x": "dbc"}, "312"},
+ {"testload(x)", Env{"x": "DBc"}, "312"},
+ {"substr(x, '1')", Env{"x": "1234"}, "1234"},
+ {"substr(x, '1', '2')", Env{"x": "1234"}, "12"},
+ {"pow(x, 3) + pow(y, 3)", Env{"x": "9", "y": "10"}, "1729"},
+ {"+1", Env{}, "1"},
+ {"5 / 9 * (F - 32)", Env{"F": "-40"}, "-40"},
+ {"5 / 9 * (F - 32)", Env{"F": "32"}, "0"},
+ {"5 / 9 * (F - 32)", Env{"F": "212"}, "100"},
+ {"-1 + -x", Env{"x": "1"}, "-2"},
+ {"-1 - x", Env{"x": "1"}, "-2"},
+ }
+ var prevExpr string
+ for _, test := range tests {
+ // Print expr only when it changes.
+ if test.expr != prevExpr {
+ fmt.Printf("\n%s\n", test.expr)
+ prevExpr = test.expr
+ }
+ expr, _, err := Parse(test.expr)
+ if err != nil {
+ t.Error(err) // parse error
+ continue
+ }
+ // got := fmt.Sprintf("%.6g", )
+ evalRes, _ := expr.Eval(test.env)
+ evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64)
+ got := fmt.Sprintf("%.6g", evalResFloat)
+ fmt.Printf("\t%v => %s\n", test.env, got)
+ if got != test.want {
+ t.Errorf("%s.Eval() in %v = %q, want %q\n",
+ test.expr, test.env, got, test.want)
+ }
+ }
+
+ v := Var("hello")
+ env := Env{"hello": "hello"}
+ vv, err := v.Eval(env)
+ if vv != "hello" || err != nil {
+ t.Errorf("v %v Eval(env:%v) = {value:%v, error:%v}", v, env, vv, err)
+ }
+
+ sv := stringConstant("hello")
+ vv, err = sv.Eval(nil)
+ if vv != "hello" || err != nil {
+ t.Errorf("v %v Eval(env:%v) = {value:%v, error:%v}", v, env, vv, err)
+ }
+
+ splitExpr, _, err := Parse("split(x, '|', 2)")
+ if err != nil {
+ t.Errorf("Parse('split(x, y, 2)') = error %v", err)
+ }
+ splitRes, err := splitExpr.Eval(Env{"x": "abc|de|f"})
+ if err != nil || splitRes != "de" {
+ t.Errorf("{'split(x, y, 2)', '|', 2} = {res %v, error %v}", splitRes, err)
+ }
+}
+
+func TestErrors(t *testing.T) {
+ for _, test := range []struct{ expr, wantErr string }{
+ {"x $ 2", "unexpected '$'"},
+ {"math.Pi", "unexpected '.'"},
+ {"!true", "unexpected '!'"},
+ //{`"hello"`, "hello"},
+ {"log(10)", `unknown function "log"`},
+ {"sqrt(1, 2)", "illegal args number 2 of func sqrt, want 1"},
+ } {
+ expr, _, err := Parse(test.expr)
+ if err == nil {
+ vars := make(map[Var]bool)
+ err = expr.Check(vars)
+ if err == nil {
+ t.Errorf("unexpected success: %s", test.expr)
+ continue
+ }
+ }
+ fmt.Printf("%-20s%v\n", test.expr, err) // (for book)
+ if err.Error() != test.wantErr {
+ t.Errorf("got error \"%s\", want %s", err, test.wantErr)
+ }
+ }
+}
+
+func TestCheck(t *testing.T) {
+ sc := stringConstant("hello")
+ err := sc.Check(nil)
+ if err != nil {
+ t.Fatalf("stringConstant check() result %v != nil", err)
+ }
+
+ c := function{
+ fn: "hash",
+ args: nil,
+ }
+ err = c.Check(nil)
+ if err == nil {
+ t.Fatalf("function %#v check() result %v should not be nil", c, err)
+ }
+
+ c = function{
+ fn: "substr",
+ args: nil,
+ }
+ err = c.Check(nil)
+ if err == nil {
+ t.Fatalf("function %#v check() result %v should be nil", c, err)
+ }
+ c.args = append(c.args, sc)
+ // c.args = append(c.args, sc)
+ // c.args = append(c.args, sc)
+ err = c.Check(nil)
+ if err == nil {
+ t.Fatalf("function %#v check() result %v should not be nil", c, err)
+ }
+
+ piStr := fmt.Sprintf("%f", math.Pi)
+ tests := []struct {
+ input string
+ env Env
+ want string // expected error from Parse/Check or result from Eval
+ }{
+ {"hello", nil, ""},
+ {"+2", nil, ""},
+ // {"x % 2", nil, "unexpected '%'"},
+ {"x % 2", nil, ""},
+ {"!true", nil, "unexpected '!'"},
+ {"log(10)", nil, `unknown function "log"`},
+ {"sqrt(1, 2)", nil, "illegal args number 2 of func sqrt, want 1"},
+ {"sqrt(A / pi)", Env{"A": "87616", "pi": Value(piStr)}, "167"},
+ {"pow(x, 3) + pow(y, 3)", Env{"x": "9", "y": "10"}, "1729"},
+ {"5 / 9 * (F - 32)", Env{"F": "-40"}, "-40"},
+ }
+
+ for _, test := range tests {
+ expr, _, err := Parse(test.input)
+ if err == nil {
+ err = expr.Check(map[Var]bool{})
+ }
+ if err != nil {
+ if err.Error() != test.want {
+ t.Errorf("%s: got %q, want %q", test.input, err, test.want)
+ }
+ continue
+ }
+
+ if test.want != "" {
+ // got := fmt.Sprintf("%.6g", expr.Eval(test.env))
+ evalRes, _ := expr.Eval(test.env)
+ evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64)
+ got := fmt.Sprintf("%.6g", evalResFloat)
+ if got != test.want {
+ t.Errorf("%s: %v => %s, want %s",
+ test.input, test.env, got, test.want)
+ }
+ }
+ }
+}
+
+func TestString(t *testing.T) {
+ var e Expr
+
+ e = Var("hello")
+ t.Logf("%v string %s", e, e)
+
+ e = constant(3.14)
+ t.Logf("%v string %s", e, e)
+
+ e = stringConstant("hello")
+ t.Logf("%v string %s", e, e)
+
+ e = unary{
+ op: '-',
+ x: constant(1234),
+ }
+ t.Logf("%v string %s", e, e)
+
+ e = binaryExpr{
+ op: '+',
+ x: stringConstant("1"),
+ y: stringConstant("1"),
+ }
+ t.Logf("%v string %s", e, e)
+
+ f := function{
+ fn: "add",
+ }
+ f.args = append(f.args, stringConstant("1"))
+ f.args = append(f.args, stringConstant("2"))
+ e = f
+ t.Logf("%v string %s", e, e)
+
+ var l lexer
+ l.token = scanner.EOF
+ t.Logf("eof string %s", l.describe())
+ l.token = scanner.Ident
+ t.Logf("ident string %s", l.describe())
+ l.token = scanner.Int
+ t.Logf("int string %s", l.describe())
+
+ l.token = scanner.String
+ s, e, err := parsePrimary(&l, nil)
+ if s != nil || err != nil || e == nil {
+ t.Logf("parsePrimary() = {stack:%v, expression:%v error:%v}", s, e, err)
+ }
+}
diff --git a/pkg/runtime/rule/shard_script.go b/pkg/runtime/rule/shard_script.go
index 73e94a918..058df24a1 100644
--- a/pkg/runtime/rule/shard_script.go
+++ b/pkg/runtime/rule/shard_script.go
@@ -100,9 +100,7 @@ func (j *jsShardComputer) putVM(vm *goja.Runtime) {
}
func wrapScript(script string) string {
- var (
- sb strings.Builder
- )
+ var sb strings.Builder
sb.Grow(32 + len(_jsEntrypoint) + len(_jsValueName) + len(script))
diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go
index 7c187ef67..4be5d78d6 100644
--- a/pkg/runtime/runtime.go
+++ b/pkg/runtime/runtime.go
@@ -20,18 +20,19 @@ package runtime
import (
"context"
"encoding/json"
- stdErrors "errors"
+ "errors"
"fmt"
"io"
- "sort"
"sync"
"time"
)
import (
"github.com/bwmarrin/snowflake"
+ perrors "github.com/pkg/errors"
- "github.com/pkg/errors"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/trace"
"go.uber.org/atomic"
@@ -43,8 +44,15 @@ import (
"github.com/arana-db/arana/pkg/metrics"
"github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/proto/hint"
+ "github.com/arana-db/arana/pkg/resultx"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
"github.com/arana-db/arana/pkg/runtime/namespace"
+ "github.com/arana-db/arana/pkg/runtime/optimize"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/dal"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/ddl"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/dml"
+ _ "github.com/arana-db/arana/pkg/runtime/optimize/utility"
"github.com/arana-db/arana/pkg/util/log"
"github.com/arana-db/arana/pkg/util/rand2"
"github.com/arana-db/arana/third_party/pools"
@@ -56,8 +64,10 @@ var (
_ proto.VConn = (*compositeTx)(nil)
)
+var Tracer = otel.Tracer("Runtime")
+
var (
- errTxClosed = stdErrors.New("transaction is closed")
+ errTxClosed = errors.New("transaction is closed")
)
func NewAtomDB(node *config.Node) *AtomDB {
@@ -74,13 +84,20 @@ func NewAtomDB(node *config.Node) *AtomDB {
}
raw, _ := json.Marshal(map[string]interface{}{
- "dsn": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", node.Username, node.Password, node.Host, node.Port, node.Database),
+ "dsn": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", node.Username, node.Password, node.Host, node.Port, node.Database, node.Parameters.String()),
})
connector, err := mysql.NewConnector(raw)
if err != nil {
panic(err)
}
- db.pool = pools.NewResourcePool(connector.NewBackendConnection, 8, 16, 30*time.Minute, 1, nil)
+
+ var (
+ capacity = config.GetConnPropCapacity(node.ConnProps, 8)
+ maxCapacity = config.GetConnPropMaxCapacity(node.ConnProps, 64)
+ idleTime = config.GetConnPropIdleTime(node.ConnProps, 30*time.Minute)
+ )
+
+ db.pool = pools.NewResourcePool(connector.NewBackendConnection, capacity, maxCapacity, idleTime, 1, nil)
return db
}
@@ -88,21 +105,20 @@ func NewAtomDB(node *config.Node) *AtomDB {
// Runtime executes a sql statement.
type Runtime interface {
proto.Executable
+ proto.VConn
// Namespace returns the namespace.
Namespace() *namespace.Namespace
// Begin begins a new transaction.
- Begin(ctx *proto.Context) (proto.Tx, error)
+ Begin(ctx context.Context) (proto.Tx, error)
}
// Load loads a Runtime, here schema means logical database name.
func Load(schema string) (Runtime, error) {
var ns *namespace.Namespace
if ns = namespace.Load(schema); ns == nil {
- return nil, errors.Errorf("no such logical database %s", schema)
+ return nil, perrors.Errorf("no such logical database %s", schema)
}
- return &defaultRuntime{
- ns: ns,
- }, nil
+ return (*defaultRuntime)(ns), nil
}
var (
@@ -141,7 +157,7 @@ func (tx *compositeTx) call(ctx context.Context, db string, query string, args .
res, _, err := atx.Call(ctx, query, args...)
if err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
return res, nil
}
@@ -153,9 +169,13 @@ func (tx *compositeTx) begin(ctx context.Context, group string) (*atomTx, error)
// force use writeable node
ctx = rcontext.WithWrite(ctx)
+ db := selectDB(ctx, group, tx.rt.Namespace())
+ if db == nil {
+ return nil, perrors.Errorf("cannot get upstream database %s", group)
+ }
// begin atom tx
- newborn, err := tx.rt.Namespace().DB(ctx, group).(*AtomDB).begin(ctx)
+ newborn, err := db.(*AtomDB).begin(ctx)
if err != nil {
return nil, err
}
@@ -168,8 +188,11 @@ func (tx *compositeTx) String() string {
}
func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint16, err error) {
+ var span trace.Span
+ ctx.Context, span = Tracer.Start(ctx.Context, "compositeTx.Execute")
execStart := time.Now()
defer func() {
+ span.End()
metrics.ExecuteDuration.Observe(time.Since(execStart).Seconds())
}()
if tx.closed.Load() {
@@ -177,9 +200,7 @@ func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint1
return
}
- var (
- args = tx.rt.extractArgs(ctx)
- )
+ args := ctx.GetArgs()
if direct := rcontext.IsDirect(ctx.Context); direct {
var (
group = tx.rt.Namespace().DBGroups()[0]
@@ -190,32 +211,38 @@ func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint1
return
}
res, warn, err = atx.Call(cctx, ctx.GetQuery(), args...)
- go writeRow(res)
+ if err != nil {
+ err = perrors.WithStack(err)
+ }
return
}
var (
- ru = tx.rt.ns.Rule()
+ ru = tx.rt.Namespace().Rule()
plan proto.Plan
c = ctx.Context
)
- c = rcontext.WithRule(c, ru)
c = rcontext.WithSQL(c, ctx.GetQuery())
+ c = rcontext.WithHints(c, ctx.Stmt.Hints)
+
+ var opt proto.Optimizer
+ if opt, err = optimize.NewOptimizer(ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil {
+ err = perrors.WithStack(err)
+ return
+ }
- if plan, err = tx.rt.ns.Optimizer().Optimize(c, tx, ctx.Stmt.StmtNode, args...); err != nil {
- err = errors.WithStack(err)
+ if plan, err = opt.Optimize(ctx); err != nil {
+ err = perrors.WithStack(err)
return
}
if res, err = plan.ExecIn(c, tx); err != nil {
// TODO: how to warp error packet
- err = errors.WithStack(err)
+ err = perrors.WithStack(err)
return
}
- go writeRow(res)
-
return
}
@@ -227,16 +254,16 @@ func (tx *compositeTx) Commit(ctx context.Context) (proto.Result, uint16, error)
if !tx.closed.CAS(false, true) {
return nil, 0, errTxClosed
}
-
+ ctx, span := Tracer.Start(ctx, "compositeTx.Commit")
defer func() { // cleanup
tx.rt = nil
tx.txs = nil
+ span.End()
}()
var g errgroup.Group
for k, v := range tx.txs {
- k := k
- v := v
+ k, v := k, v
g.Go(func() error {
_, _, err := v.Commit(ctx)
if err != nil {
@@ -253,10 +280,12 @@ func (tx *compositeTx) Commit(ctx context.Context) (proto.Result, uint16, error)
log.Debugf("commit %s success: total=%d", tx, len(tx.txs))
- return &mysql.Result{}, 0, nil
+ return resultx.New(), 0, nil
}
func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, error) {
+ ctx, span := Tracer.Start(ctx, "compositeTx.Rollback")
+ defer span.End()
if !tx.closed.CAS(false, true) {
return nil, 0, errTxClosed
}
@@ -268,8 +297,7 @@ func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, erro
var g errgroup.Group
for k, v := range tx.txs {
- k := k
- v := v
+ k, v := k, v
g.Go(func() error {
_, _, err := v.Rollback(ctx)
if err != nil {
@@ -286,7 +314,7 @@ func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, erro
log.Debugf("rollback %s success: total=%d", tx, len(tx.txs))
- return &mysql.Result{}, 0, nil
+ return resultx.New(), 0, nil
}
type atomTx struct {
@@ -296,12 +324,26 @@ type atomTx struct {
}
func (tx *atomTx) Commit(ctx context.Context) (res proto.Result, warn uint16, err error) {
+ _ = ctx
if !tx.closed.CAS(false, true) {
err = errTxClosed
return
}
defer tx.dispose()
- res, warn, err = tx.bc.ExecuteWithWarningCount("commit", true)
+ if res, err = tx.bc.ExecuteWithWarningCount("commit", true); err != nil {
+ return
+ }
+
+ var affected, lastInsertId uint64
+
+ if affected, err = res.RowsAffected(); err != nil {
+ return
+ }
+ if lastInsertId, err = res.LastInsertId(); err != nil {
+ return
+ }
+
+ res = resultx.New(resultx.WithRowsAffected(affected), resultx.WithLastInsertID(lastInsertId))
return
}
@@ -311,15 +353,15 @@ func (tx *atomTx) Rollback(ctx context.Context) (res proto.Result, warn uint16,
return
}
defer tx.dispose()
- res, warn, err = tx.bc.ExecuteWithWarningCount("rollback", true)
+ res, err = tx.bc.ExecuteWithWarningCount("rollback", true)
return
}
func (tx *atomTx) Call(ctx context.Context, sql string, args ...interface{}) (res proto.Result, warn uint16, err error) {
if len(args) > 0 {
- res, warn, err = tx.bc.PrepareQueryArgsIterRow(sql, args)
+ res, err = tx.bc.PrepareQueryArgs(sql, args)
} else {
- res, warn, err = tx.bc.ExecuteWithWarningCountIterRow(sql)
+ res, err = tx.bc.ExecuteWithWarningCountIterRow(sql)
}
return
}
@@ -328,7 +370,7 @@ func (tx *atomTx) CallFieldList(ctx context.Context, table, wildcard string) ([]
// TODO: choose table
var err error
if err = tx.bc.WriteComFieldList(table, wildcard); err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
return tx.bc.ReadColumnDefinitions()
}
@@ -359,7 +401,7 @@ type AtomDB struct {
func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) {
if db.closed.Load() {
- return nil, errors.Errorf("the db instance '%s' is closed already", db.id)
+ return nil, perrors.Errorf("the db instance '%s' is closed already", db.id)
}
var (
@@ -368,19 +410,30 @@ func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) {
)
if bc, err = db.borrowConnection(ctx); err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
db.pendingRequests.Inc()
- if _, _, err = bc.ExecuteWithWarningCount("begin", true); err != nil {
+ dispose := func() {
// cleanup if failed to begin tx
cnt := db.pendingRequests.Dec()
db.returnConnection(bc)
if cnt == 0 && db.closed.Load() {
db.pool.Close()
}
- return nil, err
+ }
+
+ var res proto.Result
+ if res, err = bc.ExecuteWithWarningCount("begin", true); err != nil {
+ defer dispose()
+ return nil, perrors.WithStack(err)
+ }
+
+ // NOTICE: must consume the result
+ if _, err = res.RowsAffected(); err != nil {
+ defer dispose()
+ return nil, perrors.WithStack(err)
}
return &atomTx{parent: db, bc: bc}, nil
@@ -388,7 +441,7 @@ func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) {
func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([]proto.Field, error) {
if db.closed.Load() {
- return nil, errors.Errorf("the db instance '%s' is closed already", db.id)
+ return nil, perrors.Errorf("the db instance '%s' is closed already", db.id)
}
var (
@@ -397,14 +450,14 @@ func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([]
)
if bc, err = db.borrowConnection(ctx); err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
defer db.returnConnection(bc)
defer db.pending()()
if err = bc.WriteComFieldList(table, wildcard); err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
return bc.ReadColumnDefinitions()
@@ -412,23 +465,23 @@ func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([]
func (db *AtomDB) Call(ctx context.Context, sql string, args ...interface{}) (res proto.Result, warn uint16, err error) {
if db.closed.Load() {
- err = errors.Errorf("the db instance '%s' is closed already", db.id)
+ err = perrors.Errorf("the db instance '%s' is closed already", db.id)
return
}
var bc *mysql.BackendConnection
if bc, err = db.borrowConnection(ctx); err != nil {
- err = errors.WithStack(err)
+ err = perrors.WithStack(err)
return
}
undoPending := db.pending()
if len(args) > 0 {
- res, warn, err = bc.PrepareQueryArgsIterRow(sql, args)
+ res, err = bc.PrepareQueryArgs(sql, args)
} else {
- res, warn, err = bc.ExecuteWithWarningCountIterRow(sql)
+ res, err = bc.ExecuteWithWarningCountIterRow(sql)
}
if err != nil {
@@ -437,20 +490,11 @@ func (db *AtomDB) Call(ctx context.Context, sql string, args ...interface{}) (re
return
}
- if len(res.GetFields()) < 1 {
+ res.(*mysql.RawResult).SetCloser(func() error {
undoPending()
db.returnConnection(bc)
- return
- }
-
- res = &proto.CloseableResult{
- Result: res,
- Closer: func() error {
- undoPending()
- db.returnConnection(bc)
- return nil
- },
- }
+ return nil
+ })
return
}
@@ -515,36 +559,38 @@ func (db *AtomDB) SetWeight(weight proto.Weight) error {
func (db *AtomDB) borrowConnection(ctx context.Context) (*mysql.BackendConnection, error) {
bcp := (*BackendResourcePool)(db.pool)
- //log.Infof("^^^^^ begin borrow conn: active=%d, available=%d", db.pool.Active(), db.pool.Available())
+ //var (
+ // active0, available0 = db.pool.Active(), db.pool.Available()
+ //)
res, err := bcp.Get(ctx)
- //log.Infof("^^^^^ end borrow conn: active=%d, available=%d", db.pool.Active(), db.pool.Available())
+ // log.Infof("^^^^^ borrow conn: %d/%d => %d/%d", available0, active0, db.pool.Active(), db.pool.Available())
if err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
return res, nil
}
func (db *AtomDB) returnConnection(bc *mysql.BackendConnection) {
db.pool.Put(bc)
- //log.Infof("^^^^^ return conn: active=%d, available=%d", db.pool.Active(), db.pool.Available())
+ // log.Infof("^^^^^ return conn: active=%d, available=%d", db.pool.Active(), db.pool.Available())
}
-type defaultRuntime struct {
- ns *namespace.Namespace
-}
+type defaultRuntime namespace.Namespace
-func (pi *defaultRuntime) Begin(ctx *proto.Context) (proto.Tx, error) {
+func (pi *defaultRuntime) Begin(ctx context.Context) (proto.Tx, error) {
+ _, span := Tracer.Start(ctx, "defaultRuntime.Begin")
+ defer span.End()
tx := &compositeTx{
id: nextTxID(),
rt: pi,
txs: make(map[string]*atomTx),
}
- log.Debugf("begin transaction: %s", tx.String())
+ log.Debugf("begin transaction: %s", tx)
return tx, nil
}
func (pi *defaultRuntime) Namespace() *namespace.Namespace {
- return pi.ns
+ return (*namespace.Namespace)(pi)
}
func (pi *defaultRuntime) Query(ctx context.Context, db string, query string, args ...interface{}) (proto.Result, error) {
@@ -556,7 +602,7 @@ func (pi *defaultRuntime) Exec(ctx context.Context, db string, query string, arg
ctx = rcontext.WithWrite(ctx)
res, err := pi.call(ctx, db, query, args...)
if err != nil {
- return nil, errors.WithStack(err)
+ return nil, perrors.WithStack(err)
}
if closer, ok := res.(io.Closer); ok {
@@ -568,91 +614,107 @@ func (pi *defaultRuntime) Exec(ctx context.Context, db string, query string, arg
}
func (pi *defaultRuntime) Execute(ctx *proto.Context) (res proto.Result, warn uint16, err error) {
+ var span trace.Span
+ ctx.Context, span = Tracer.Start(ctx.Context, "defaultRuntime.Execute")
execStart := time.Now()
defer func() {
+ span.End()
metrics.ExecuteDuration.Observe(time.Since(execStart).Seconds())
}()
- args := pi.extractArgs(ctx)
+ args := ctx.GetArgs()
if direct := rcontext.IsDirect(ctx.Context); direct {
return pi.callDirect(ctx, args)
}
var (
- ru = pi.ns.Rule()
+ ru = pi.Namespace().Rule()
plan proto.Plan
c = ctx.Context
)
- c = rcontext.WithRule(c, ru)
c = rcontext.WithSQL(c, ctx.GetQuery())
c = rcontext.WithSchema(c, ctx.Schema)
- c = rcontext.WithDBGroup(c, pi.ns.DBGroups()[0])
+ c = rcontext.WithTenant(c, ctx.Tenant)
+ c = rcontext.WithHints(c, ctx.Stmt.Hints)
start := time.Now()
- if plan, err = pi.ns.Optimizer().Optimize(c, pi, ctx.Stmt.StmtNode, args...); err != nil {
- err = errors.WithStack(err)
+
+ var opt proto.Optimizer
+ if opt, err = optimize.NewOptimizer(ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil {
+ err = perrors.WithStack(err)
+ return
+ }
+
+ if plan, err = opt.Optimize(c); err != nil {
+ err = perrors.WithStack(err)
return
}
metrics.OptimizeDuration.Observe(time.Since(start).Seconds())
if res, err = plan.ExecIn(c, pi); err != nil {
// TODO: how to warp error packet
- err = errors.WithStack(err)
+ err = perrors.WithStack(err)
return
}
- go writeRow(res)
-
return
}
func (pi *defaultRuntime) callDirect(ctx *proto.Context, args []interface{}) (res proto.Result, warn uint16, err error) {
- res, warn, err = pi.ns.DB0(ctx.Context).Call(rcontext.WithWrite(ctx.Context), ctx.GetQuery(), args...)
+ res, warn, err = pi.Namespace().DB0(ctx.Context).Call(rcontext.WithWrite(ctx.Context), ctx.GetQuery(), args...)
if err != nil {
+ err = perrors.WithStack(err)
return
}
- go writeRow(res)
return
}
-func (pi *defaultRuntime) extractArgs(ctx *proto.Context) []interface{} {
- if ctx.Stmt == nil || len(ctx.Stmt.BindVars) < 1 {
- return nil
+func (pi *defaultRuntime) call(ctx context.Context, group, query string, args ...interface{}) (proto.Result, error) {
+ db := selectDB(ctx, group, pi.Namespace())
+ if db == nil {
+ return nil, perrors.Errorf("cannot get upstream database %s", group)
}
+ log.Debugf("call upstream: db=%s, id=%s, sql=\"%s\", args=%v", group, db.ID(), query, args)
+ // TODO: how to pass warn???
+ res, _, err := db.Call(ctx, query, args...)
- var (
- keys = make([]string, 0, len(ctx.Stmt.BindVars))
- args = make([]interface{}, 0, len(ctx.Stmt.BindVars))
- )
-
- for k := range ctx.Stmt.BindVars {
- keys = append(keys, k)
- }
- sort.Strings(keys)
- for _, k := range keys {
- args = append(args, ctx.Stmt.BindVars[k])
- }
- return args
+ return res, err
}
-func (pi *defaultRuntime) call(ctx context.Context, group, query string, args ...interface{}) (proto.Result, error) {
+// select db by group
+func selectDB(ctx context.Context, group string, ns *namespace.Namespace) proto.DB {
if len(group) < 1 { // empty db, select first
- if groups := pi.ns.DBGroups(); len(groups) > 0 {
+ if groups := ns.DBGroups(); len(groups) > 0 {
group = groups[0]
}
}
- db := pi.ns.DB(ctx, group)
- if db == nil {
- return nil, errors.Errorf("cannot get upstream database %s", group)
+ var (
+ db proto.DB
+ hintType hint.Type
+ )
+ // write request
+ if !rcontext.IsRead(ctx) {
+ return ns.DBMaster(ctx, group)
+ }
+ // extracts hints
+ hints := rcontext.Hints(ctx)
+ for _, v := range hints {
+ if v.Type == hint.TypeMaster || v.Type == hint.TypeSlave {
+ hintType = v.Type
+ break
+ }
}
-
- log.Debugf("call upstream: db=%s, sql=\"%s\", args=%v", group, query, args)
- // TODO: how to pass warn???
- res, _, err := db.Call(ctx, query, args...)
-
- return res, err
+ switch hintType {
+ case hint.TypeMaster:
+ db = ns.DBMaster(ctx, group)
+ case hint.TypeSlave:
+ db = ns.DBSlave(ctx, group)
+ default:
+ db = ns.DB(ctx, group)
+ }
+ return db
}
var (
@@ -666,59 +728,3 @@ func nextTxID() int64 {
})
return _txIds.Generate().Int64()
}
-
-// writeRow write data to the chan
-func writeRow(result proto.Result) {
- var res *mysql.Result
- switch val := result.(type) {
- case *proto.CloseableResult:
- res = val.Result.(*mysql.Result)
- case *mysql.Result:
- res = val
- default:
- panic("unreachable")
- }
-
- defer func() {
- close(res.DataChan)
- }()
-
- if len(res.GetFields()) <= 0 {
- return
- }
- var (
- err error
- has bool
- rowIter mysql.Iter
- row = res.GetRows()[0]
- )
-
- switch row.(type) {
- case *mysql.BinaryRow:
- for i := 0; i < len(res.GetRows()); i++ {
- data := &mysql.BinaryIterRow{IterRow: &mysql.IterRow{
- Row: &res.GetRows()[i].(*mysql.BinaryRow).Row,
- }}
- res.DataChan <- data
- }
- return
- case *mysql.TextRow:
- for i := 0; i < len(res.GetRows()); i++ {
- data := &mysql.TextIterRow{IterRow: &mysql.IterRow{
- Row: &res.GetRows()[i].(*mysql.TextRow).Row,
- }}
- res.DataChan <- data
- }
- return
- }
-
- switch r := row.(type) {
- case *mysql.BinaryIterRow:
- rowIter = r
- case *mysql.TextIterRow:
- rowIter = r
- }
- for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() {
- res.DataChan <- row
- }
-}
diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go
index 804fbcf94..2686b5608 100644
--- a/pkg/runtime/runtime_test.go
+++ b/pkg/runtime/runtime_test.go
@@ -30,20 +30,18 @@ import (
import (
"github.com/arana-db/arana/pkg/runtime/namespace"
- "github.com/arana-db/arana/testdata"
)
func TestLoad(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- opt := testdata.NewMockOptimizer(ctrl)
const schemaName = "FakeSchema"
rt, err := Load(schemaName)
assert.Error(t, err)
assert.Nil(t, rt)
- _ = namespace.Register(namespace.New(schemaName, opt))
+ _ = namespace.Register(namespace.New(schemaName))
defer func() {
_ = namespace.Unregister(schemaName)
}()
diff --git a/pkg/schema/loader.go b/pkg/schema/loader.go
new file mode 100644
index 000000000..d97c6e8c6
--- /dev/null
+++ b/pkg/schema/loader.go
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package schema
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "strings"
+)
+
+import (
+ "github.com/pkg/errors"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/proto"
+ "github.com/arana-db/arana/pkg/runtime"
+ rcontext "github.com/arana-db/arana/pkg/runtime/context"
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+const (
+ orderByOrdinalPosition = " ORDER BY ORDINAL_POSITION"
+ tableMetadataNoOrder = "SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE, COLUMN_KEY, EXTRA, COLLATION_NAME, ORDINAL_POSITION FROM information_schema.columns WHERE TABLE_SCHEMA=database()"
+ tableMetadataSQL = tableMetadataNoOrder + orderByOrdinalPosition
+ tableMetadataSQLInTables = tableMetadataNoOrder + " AND TABLE_NAME IN (%s)" + orderByOrdinalPosition
+ indexMetadataSQL = "SELECT TABLE_NAME, INDEX_NAME FROM information_schema.statistics WHERE TABLE_SCHEMA=database() AND TABLE_NAME IN (%s)"
+)
+
+func init() {
+ proto.RegisterSchemaLoader(NewSimpleSchemaLoader())
+}
+
+type SimpleSchemaLoader struct {
+ // key format is schema.table
+ metadataCache map[string]*proto.TableMetadata
+}
+
+func NewSimpleSchemaLoader() *SimpleSchemaLoader {
+ return &SimpleSchemaLoader{metadataCache: make(map[string]*proto.TableMetadata)}
+}
+
+func (l *SimpleSchemaLoader) Load(ctx context.Context, schema string, tables []string) (map[string]*proto.TableMetadata, error) {
+ var (
+ tableMetadataMap = make(map[string]*proto.TableMetadata, len(tables))
+ indexMetadataMap map[string][]*proto.IndexMetadata
+ columnMetadataMap map[string][]*proto.ColumnMetadata
+ queryTables = make([]string, 0, len(tables))
+ )
+
+ if len(schema) > 0 {
+ for _, table := range tables {
+ qualifiedTblName := schema + "." + table
+ if l.metadataCache[qualifiedTblName] != nil {
+ tableMetadataMap[table] = l.metadataCache[qualifiedTblName]
+ } else {
+ queryTables = append(queryTables, table)
+ }
+ }
+ } else {
+ copy(queryTables, tables)
+ }
+
+ if len(queryTables) == 0 {
+ return tableMetadataMap, nil
+ }
+
+ ctx = rcontext.WithRead(rcontext.WithDirect(ctx))
+
+ var err error
+ if columnMetadataMap, err = l.LoadColumnMetadataMap(ctx, schema, queryTables); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ if columnMetadataMap != nil {
+ if indexMetadataMap, err = l.LoadIndexMetadata(ctx, schema, queryTables); err != nil {
+ return nil, errors.WithStack(err)
+ }
+ }
+
+ for tableName, columns := range columnMetadataMap {
+ tableMetadataMap[tableName] = proto.NewTableMetadata(tableName, columns, indexMetadataMap[tableName])
+ if len(schema) > 0 {
+ l.metadataCache[schema+"."+tableName] = tableMetadataMap[tableName]
+ }
+ }
+
+ return tableMetadataMap, nil
+}
+
+func (l *SimpleSchemaLoader) LoadColumnMetadataMap(ctx context.Context, schema string, tables []string) (map[string][]*proto.ColumnMetadata, error) {
+ conn, err := runtime.Load(schema)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+ var (
+ resultSet proto.Result
+ ds proto.Dataset
+ )
+
+ if resultSet, err = conn.Query(ctx, "", getColumnMetadataSQL(tables)); err != nil {
+ log.Errorf("Load ColumnMetadata error when call db: %v", err)
+ return nil, errors.WithStack(err)
+ }
+
+ if ds, err = resultSet.Dataset(); err != nil {
+ log.Errorf("Load ColumnMetadata error when call db: %v", err)
+ return nil, errors.WithStack(err)
+ }
+
+ result := make(map[string][]*proto.ColumnMetadata, 0)
+ if ds == nil {
+ log.Error("Load ColumnMetadata error because the result is nil")
+ return nil, nil
+ }
+
+ var (
+ fields, _ = ds.Fields()
+ row proto.Row
+ cells = make([]proto.Value, len(fields))
+ )
+
+ for {
+ row, err = ds.Next()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ if err = row.Scan(cells); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ tableName := convertInterfaceToStrNullable(cells[0])
+ columnName := convertInterfaceToStrNullable(cells[1])
+ dataType := convertInterfaceToStrNullable(cells[2])
+ columnKey := convertInterfaceToStrNullable(cells[3])
+ extra := convertInterfaceToStrNullable(cells[4])
+ collationName := convertInterfaceToStrNullable(cells[5])
+ ordinalPosition := convertInterfaceToStrNullable(cells[6])
+ result[tableName] = append(result[tableName], &proto.ColumnMetadata{
+ Name: columnName,
+ DataType: dataType,
+ Ordinal: ordinalPosition,
+ PrimaryKey: strings.EqualFold("PRI", columnKey),
+ Generated: strings.EqualFold("auto_increment", extra),
+ CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"),
+ })
+ }
+
+ return result, nil
+}
+
+func convertInterfaceToStrNullable(value proto.Value) string {
+ if value == nil {
+ return ""
+ }
+
+ switch val := value.(type) {
+ case string:
+ return val
+ default:
+ return fmt.Sprint(val)
+ }
+}
+
+func (l *SimpleSchemaLoader) LoadIndexMetadata(ctx context.Context, schema string, tables []string) (map[string][]*proto.IndexMetadata, error) {
+ conn, err := runtime.Load(schema)
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ resultSet proto.Result
+ ds proto.Dataset
+ )
+
+ if resultSet, err = conn.Query(ctx, "", getIndexMetadataSQL(tables)); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ if ds, err = resultSet.Dataset(); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ var (
+ fields, _ = ds.Fields()
+ row proto.Row
+ values = make([]proto.Value, len(fields))
+ result = make(map[string][]*proto.IndexMetadata, 0)
+ )
+
+ for {
+ row, err = ds.Next()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
+ if err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ if err = row.Scan(values); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ tableName := convertInterfaceToStrNullable(values[0])
+ indexName := convertInterfaceToStrNullable(values[1])
+ result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName})
+ }
+
+ return result, nil
+}
+
+func getIndexMetadataSQL(tables []string) string {
+ tableParamList := make([]string, 0, len(tables))
+ for _, table := range tables {
+ tableParamList = append(tableParamList, "'"+table+"'")
+ }
+ return fmt.Sprintf(indexMetadataSQL, strings.Join(tableParamList, ","))
+}
+
+func getColumnMetadataSQL(tables []string) string {
+ if len(tables) == 0 {
+ return tableMetadataSQL
+ }
+ tableParamList := make([]string, len(tables))
+ for i, table := range tables {
+ tableParamList[i] = "'" + table + "'"
+ }
+ // TODO use strings.Builder in the future
+ return fmt.Sprintf(tableMetadataSQLInTables, strings.Join(tableParamList, ","))
+}
diff --git a/pkg/proto/schema_manager/loader_test.go b/pkg/schema/loader_test.go
similarity index 82%
rename from pkg/proto/schema_manager/loader_test.go
rename to pkg/schema/loader_test.go
index 456b3f891..ae9f0b98c 100644
--- a/pkg/proto/schema_manager/loader_test.go
+++ b/pkg/schema/loader_test.go
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package schema_manager
+package schema_test
import (
"context"
@@ -24,9 +24,9 @@ import (
import (
"github.com/arana-db/arana/pkg/config"
- "github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime"
"github.com/arana-db/arana/pkg/runtime/namespace"
+ "github.com/arana-db/arana/pkg/schema"
)
func TestLoader(t *testing.T) {
@@ -46,15 +46,11 @@ func TestLoader(t *testing.T) {
cmds := make([]namespace.Command, 0)
cmds = append(cmds, namespace.UpsertDB(groupName, runtime.NewAtomDB(node)))
namespaceName := "dongjianhui"
- ns := namespace.New(namespaceName, nil, cmds...)
- namespace.Register(ns)
- rt, err := runtime.Load(namespaceName)
- if err != nil {
- panic(err)
- }
+ ns := namespace.New(namespaceName, cmds...)
+ _ = namespace.Register(ns)
schemeName := "employees"
tableName := "employees"
- s := &SimpleSchemaLoader{}
+ s := schema.NewSimpleSchemaLoader()
- s.Load(context.Background(), rt.(proto.VConn), schemeName, []string{tableName})
+ s.Load(context.Background(), schemeName, []string{tableName})
}
diff --git a/pkg/transformer/aggr_loader.go b/pkg/transformer/aggr_loader.go
index e8f9d445d..12d51b051 100644
--- a/pkg/transformer/aggr_loader.go
+++ b/pkg/transformer/aggr_loader.go
@@ -22,7 +22,7 @@ import (
)
type AggrLoader struct {
- Aggrs [][]string
+ Aggrs []string
Alias []string
Name []string
}
@@ -37,9 +37,8 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader {
if n == nil {
return
}
- fieldAggr := make([]string, 0, 10)
- fieldAggr = append(fieldAggr, n.Name())
- aggrLoader.Aggrs = append(aggrLoader.Aggrs, fieldAggr)
+
+ aggrLoader.Aggrs = append(aggrLoader.Aggrs, n.Name())
for _, arg := range n.Args() {
switch arg.Value().(type) {
case ast2.ColumnNameExpressionAtom:
@@ -52,11 +51,11 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader {
}
for i, field := range fields {
- aggrLoader.Alias[i] = field.Alias()
if field == nil {
continue
}
if f, ok := field.(*ast2.SelectElementFunction); ok {
+ aggrLoader.Alias[i] = field.Alias()
enter(f.Function().(*ast2.AggrFunction))
}
}
diff --git a/pkg/transformer/combiner.go b/pkg/transformer/combiner.go
index a550bd617..56c202fb5 100644
--- a/pkg/transformer/combiner.go
+++ b/pkg/transformer/combiner.go
@@ -18,8 +18,10 @@
package transformer
import (
+ "database/sql"
"io"
"math"
+ "sync"
)
import (
@@ -29,14 +31,14 @@ import (
)
import (
- mysql2 "github.com/arana-db/arana/pkg/constants/mysql"
- "github.com/arana-db/arana/pkg/mysql"
+ "github.com/arana-db/arana/pkg/dataset"
+ "github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/proto"
- ast2 "github.com/arana-db/arana/pkg/runtime/ast"
+ "github.com/arana-db/arana/pkg/resultx"
+ "github.com/arana-db/arana/pkg/runtime/ast"
)
-type combinerManager struct {
-}
+type combinerManager struct{}
type (
Combiner interface {
@@ -45,115 +47,108 @@ type (
)
func (c combinerManager) Merge(result proto.Result, loader *AggrLoader) (proto.Result, error) {
- if closer, ok := result.(io.Closer); ok {
- defer func() {
- _ = closer.Close()
- }()
- }
-
- result = &mysql.Result{
- Fields: result.GetFields(),
- Rows: result.GetRows(),
- DataChan: make(chan proto.Row, 1),
- }
-
if len(loader.Aggrs) < 1 {
return result, nil
}
- rows := result.GetRows()
- rowsLen := len(rows)
- if rowsLen < 1 {
- return result, nil
+ ds, err := result.Dataset()
+ if err != nil {
+ return nil, errors.WithStack(err)
}
- mergeRows := make([]proto.Row, 0, 1)
- mergeVals := make([]*proto.Value, 0, len(loader.Aggrs))
+ defer func() {
+ _ = ds.Close()
+ }()
+
+ mergeVals := make([]proto.Value, 0, len(loader.Aggrs))
for i := 0; i < len(loader.Aggrs); i++ {
- switch loader.Aggrs[i][0] {
- case ast2.AggrAvg:
- mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(0), Len: 8})
- case ast2.AggrMin:
- mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(math.MaxInt64), Len: 8})
- case ast2.AggrMax:
- mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(math.MinInt64), Len: 8})
+ switch loader.Aggrs[i] {
+ case ast.AggrAvg:
+ mergeVals = append(mergeVals, gxbig.NewDecFromInt(0))
+ case ast.AggrMin:
+ mergeVals = append(mergeVals, gxbig.NewDecFromInt(math.MaxInt64))
+ case ast.AggrMax:
+ mergeVals = append(mergeVals, gxbig.NewDecFromInt(math.MinInt64))
default:
- mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeLongLong, Val: gxbig.NewDecFromInt(0), Len: 8})
+ mergeVals = append(mergeVals, gxbig.NewDecFromInt(0))
}
}
- for _, row := range rows {
- tRow := &mysql.TextRow{
- Row: mysql.Row{
- Content: row.Data(),
- ResultSet: &mysql.ResultSet{
- Columns: row.Fields(),
- ColumnNames: row.Columns(),
- },
- },
+ var (
+ row proto.Row
+ fields, _ = ds.Fields()
+ vals = make([]proto.Value, len(fields))
+
+ isBinary bool
+ isBinaryOnce sync.Once
+ )
+
+ for {
+ row, err = ds.Next()
+ if errors.Is(err, io.EOF) {
+ break
}
- vals, err := tRow.Decode()
+
if err != nil {
- return result, errors.WithStack(err)
+ return nil, errors.Wrap(err, "failed to aggregate values")
}
- if vals == nil {
- continue
+ if err = row.Scan(vals); err != nil {
+ return nil, errors.Wrap(err, "failed to aggregate values")
}
+ isBinaryOnce.Do(func() {
+ isBinary = row.IsBinary()
+ })
+
for aggrIdx := range loader.Aggrs {
- dummyVal := mergeVals[aggrIdx].Val.(*gxbig.Decimal)
- switch loader.Aggrs[aggrIdx][0] {
- case ast2.AggrMax:
- if v, ok := vals[aggrIdx].Val.([]uint8); ok {
- floatDecimal, err := gxbig.NewDecFromString(string(v))
- if err != nil {
- return nil, errors.WithStack(err)
- }
- if dummyVal.Compare(floatDecimal) < 0 {
- dummyVal = floatDecimal
- }
- }
- case ast2.AggrMin:
- if v, ok := vals[aggrIdx].Val.([]uint8); ok {
- floatDecimal, err := gxbig.NewDecFromString(string(v))
- if err != nil {
- return nil, errors.WithStack(err)
- }
- if dummyVal.Compare(floatDecimal) > 0 {
- dummyVal = floatDecimal
- }
+ dummyVal := mergeVals[aggrIdx].(*gxbig.Decimal)
+ var (
+ s sql.NullString
+ floatDecimal *gxbig.Decimal
+ )
+ _ = s.Scan(vals[aggrIdx])
+
+ if !s.Valid {
+ continue
+ }
+
+ if floatDecimal, err = gxbig.NewDecFromString(s.String); err != nil {
+ return nil, errors.WithStack(err)
+ }
+
+ switch loader.Aggrs[aggrIdx] {
+ case ast.AggrMax:
+ if dummyVal.Compare(floatDecimal) < 0 {
+ dummyVal = floatDecimal
}
- case ast2.AggrSum, ast2.AggrCount:
- if v, ok := vals[aggrIdx].Val.([]uint8); ok {
- floatDecimal, err := gxbig.NewDecFromString(string(v))
- if err != nil {
- return nil, errors.WithStack(err)
- }
- gxbig.DecimalAdd(dummyVal, floatDecimal, dummyVal)
+ case ast.AggrMin:
+ if dummyVal.Compare(floatDecimal) > 0 {
+ dummyVal = floatDecimal
}
+ case ast.AggrSum, ast.AggrCount:
+ _ = gxbig.DecimalAdd(dummyVal, floatDecimal, dummyVal)
}
- mergeVals[aggrIdx].Val = dummyVal
+ mergeVals[aggrIdx] = dummyVal
}
}
for aggrIdx := range loader.Aggrs {
- val := mergeVals[aggrIdx].Val.(*gxbig.Decimal)
- mergeVals[aggrIdx].Val, _ = val.ToFloat64()
- mergeVals[aggrIdx].Raw = []byte(val.String())
- mergeVals[aggrIdx].Len = len(mergeVals[aggrIdx].Raw)
+ val := mergeVals[aggrIdx].(*gxbig.Decimal)
+ mergeVals[aggrIdx] = val
}
- r := &mysql.TextRow{}
- row := r.Encode(mergeVals, result.GetFields(), loader.Alias).(*mysql.TextRow)
- mergeRows = append(mergeRows, &row.Row)
-
- return &mysql.Result{
- Fields: result.GetFields(),
- Rows: mergeRows,
- AffectedRows: 1,
- InsertId: 0,
- DataChan: make(chan proto.Row, 1),
- }, nil
+
+ ret := &dataset.VirtualDataset{
+ Columns: fields,
+ }
+
+ if isBinary {
+ ret.Rows = append(ret.Rows, rows.NewBinaryVirtualRow(fields, mergeVals))
+ } else {
+ ret.Rows = append(ret.Rows, rows.NewTextVirtualRow(fields, mergeVals))
+ }
+
+ return resultx.New(resultx.WithDataset(ret)), nil
}
func NewCombinerManager() Combiner {
diff --git a/pkg/util/bufferpool/bufferpool.go b/pkg/util/bufferpool/bufferpool.go
new file mode 100644
index 000000000..a279e96c5
--- /dev/null
+++ b/pkg/util/bufferpool/bufferpool.go
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package bufferpool
+
+import (
+ "bytes"
+ "sync"
+)
+
+var _bufferPool sync.Pool
+
+// Get borrows a Buffer from pool.
+func Get() *bytes.Buffer {
+ if exist, ok := _bufferPool.Get().(*bytes.Buffer); ok {
+ return exist
+ }
+ return new(bytes.Buffer)
+}
+
+// Put returns a Buffer to pool.
+func Put(b *bytes.Buffer) {
+ if b == nil {
+ return
+ }
+ const maxCap = 1024 * 1024
+ // drop huge buff directly, if cap is over 1MB
+ if b.Cap() > maxCap {
+ return
+ }
+ b.Reset()
+ _bufferPool.Put(b)
+}
diff --git a/pkg/util/bytefmt/bytefmt.go b/pkg/util/bytefmt/bytefmt.go
new file mode 100644
index 000000000..8be5e202c
--- /dev/null
+++ b/pkg/util/bytefmt/bytefmt.go
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package bytefmt
+
+import (
+ "errors"
+ "strconv"
+ "strings"
+ "unicode"
+)
+
+const (
+ BYTE = 1 << (10 * iota)
+ KILOBYTE
+ MEGABYTE
+ GIGABYTE
+ TERABYTE
+ PETABYTE
+ EXABYTE
+)
+
+var invalidByteQuantityError = errors.New("byte quantity must be a positive integer with a unit of measurement like M, MB, MiB, G, GiB, or GB")
+
+// ByteSize returns a human-readable byte string of the form 10M, 12.5K, and so forth. The following units are available:
+// E: Exabyte
+// P: Petabyte
+// T: Terabyte
+// G: Gigabyte
+// M: Megabyte
+// K: Kilobyte
+// B: Byte
+// The unit that results in the smallest number greater than or equal to 1 is always chosen.
+func ByteSize(bytes uint64) string {
+ unit := ""
+ value := float64(bytes)
+
+ switch {
+ case bytes >= EXABYTE:
+ unit = "E"
+ value = value / EXABYTE
+ case bytes >= PETABYTE:
+ unit = "P"
+ value = value / PETABYTE
+ case bytes >= TERABYTE:
+ unit = "T"
+ value = value / TERABYTE
+ case bytes >= GIGABYTE:
+ unit = "G"
+ value = value / GIGABYTE
+ case bytes >= MEGABYTE:
+ unit = "M"
+ value = value / MEGABYTE
+ case bytes >= KILOBYTE:
+ unit = "K"
+ value = value / KILOBYTE
+ case bytes >= BYTE:
+ unit = "B"
+ case bytes == 0:
+ return "0"
+ }
+
+ result := strconv.FormatFloat(value, 'f', 1, 64)
+ result = strings.TrimSuffix(result, ".0")
+ return result + unit
+}
+
+// ToBytes parses a string formatted by ByteSize as bytes. Note binary-prefixed and SI prefixed units both mean a base-2 units
+// KB = K = KiB = 1024
+// MB = M = MiB = 1024 * K
+// GB = G = GiB = 1024 * M
+// TB = T = TiB = 1024 * G
+// PB = P = PiB = 1024 * T
+// EB = E = EiB = 1024 * P
+func ToBytes(s string) (uint64, error) {
+ s = strings.TrimSpace(s)
+ s = strings.ToUpper(s)
+
+ i := strings.IndexFunc(s, unicode.IsLetter)
+
+ if i == -1 {
+ return 0, invalidByteQuantityError
+ }
+
+ bytesString, multiple := s[:i], s[i:]
+ bytes, err := strconv.ParseFloat(bytesString, 64)
+ if err != nil || bytes <= 0 {
+ return 0, invalidByteQuantityError
+ }
+
+ switch multiple {
+ case "E", "EB", "EIB":
+ return uint64(bytes * EXABYTE), nil
+ case "P", "PB", "PIB":
+ return uint64(bytes * PETABYTE), nil
+ case "T", "TB", "TIB":
+ return uint64(bytes * TERABYTE), nil
+ case "G", "GB", "GIB":
+ return uint64(bytes * GIGABYTE), nil
+ case "M", "MB", "MIB":
+ return uint64(bytes * MEGABYTE), nil
+ case "K", "KB", "KIB":
+ return uint64(bytes * KILOBYTE), nil
+ case "B":
+ return uint64(bytes), nil
+ default:
+ return 0, invalidByteQuantityError
+ }
+}
diff --git a/pkg/util/bytefmt/bytefmt_test.go b/pkg/util/bytefmt/bytefmt_test.go
new file mode 100644
index 000000000..1d504642e
--- /dev/null
+++ b/pkg/util/bytefmt/bytefmt_test.go
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package bytefmt
+
+import (
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+func TestByteSize(t *testing.T) {
+ tables := []struct {
+ in uint64
+ want string
+ }{
+ {in: ^uint64(0), want: "16E"},
+ {in: 10 * EXABYTE, want: "10E"},
+ {in: 10.5 * EXABYTE, want: "10.5E"},
+ {in: 10 * PETABYTE, want: "10P"},
+ {in: 10.5 * PETABYTE, want: "10.5P"},
+ {in: 10 * TERABYTE, want: "10T"},
+ {in: 10.5 * TERABYTE, want: "10.5T"},
+ {in: 10 * GIGABYTE, want: "10G"},
+ {in: 10.5 * GIGABYTE, want: "10.5G"},
+ {in: 10 * MEGABYTE, want: "10M"},
+ {in: 10.5 * MEGABYTE, want: "10.5M"},
+ {in: 10 * KILOBYTE, want: "10K"},
+ {in: 10.5 * KILOBYTE, want: "10.5K"},
+ {in: 268435456, want: "256M"},
+ }
+ for i := 0; i < len(tables); i++ {
+ assert.Equal(t, tables[i].want, ByteSize(tables[i].in))
+ }
+}
+
+func TestToBytes(t *testing.T) {
+ tables := []struct {
+ in string
+ want uint64
+ }{
+ {in: "4.5KB", want: 4608},
+ {in: "13.5KB", want: 13824},
+ {in: "5MB", want: 5 * MEGABYTE},
+ {in: "5mb", want: 5 * MEGABYTE},
+ {in: "256M", want: 268435456},
+ {in: "2GB", want: 2 * GIGABYTE},
+ {in: "3TB", want: 3 * TERABYTE},
+ {in: "3PB", want: 3 * PETABYTE},
+ {in: "3EB", want: 3 * EXABYTE},
+ }
+ t.Log(0x120a)
+ for i := 0; i < len(tables); i++ {
+ byteSize, err := ToBytes(tables[i].in)
+ assert.NoError(t, err)
+ assert.Equal(t, tables[i].want, byteSize)
+ }
+}
diff --git a/pkg/util/bytesconv/bytesconv_test.go b/pkg/util/bytesconv/bytesconv_test.go
index ff4d981df..acd1861fb 100644
--- a/pkg/util/bytesconv/bytesconv_test.go
+++ b/pkg/util/bytesconv/bytesconv_test.go
@@ -29,8 +29,10 @@ import (
"time"
)
-var testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere."
-var testBytes = []byte(testString)
+var (
+ testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere."
+ testBytes = []byte(testString)
+)
func rawBytesToStr(b []byte) string {
return string(b)
diff --git a/pkg/util/env/env.go b/pkg/util/env/env.go
new file mode 100644
index 000000000..5ab95c604
--- /dev/null
+++ b/pkg/util/env/env.go
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Copyright 2020 Gin Core Team. All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package env
+
+import (
+ "os"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/constants"
+)
+
+// IsDevelopEnvironment check is the develop environment
+func IsDevelopEnvironment() bool {
+ dev := os.Getenv(constants.EnvDevelopEnvironment)
+ if dev == "1" {
+ return true
+ }
+
+ return false
+}
diff --git a/pkg/util/log/logging.go b/pkg/util/log/logging.go
index f0f1ee7ac..a7a52b6e3 100644
--- a/pkg/util/log/logging.go
+++ b/pkg/util/log/logging.go
@@ -84,19 +84,14 @@ func (l *LogLevel) unmarshalText(text []byte) bool {
type Logger interface {
Debug(v ...interface{})
Debugf(format string, v ...interface{})
-
Info(v ...interface{})
Infof(format string, v ...interface{})
-
Warn(v ...interface{})
Warnf(format string, v ...interface{})
-
Error(v ...interface{})
Errorf(format string, v ...interface{})
-
Panic(v ...interface{})
Panicf(format string, v ...interface{})
-
Fatal(v ...interface{})
Fatalf(format string, v ...interface{})
}
@@ -147,7 +142,7 @@ func Init(logPath string, level LogLevel) {
log = zapLogger.Sugar()
}
-// SetLogger: customize yourself logger.
+// SetLogger customize yourself logger.
func SetLogger(logger Logger) {
log = logger
}
diff --git a/pkg/util/rand2/rand2.go b/pkg/util/rand2/rand2.go
index 3bda91246..6c69bbaf8 100644
--- a/pkg/util/rand2/rand2.go
+++ b/pkg/util/rand2/rand2.go
@@ -206,8 +206,8 @@ func Sample(population []interface{}, k int) (res []interface{}, err error) {
// Same as 'Sample' except it returns both the 'picked' sample set and the
// 'remaining' elements.
func PickN(population []interface{}, n int) (
- picked []interface{}, remaining []interface{}, err error) {
-
+ picked []interface{}, remaining []interface{}, err error,
+) {
total := len(population)
idxs, err := SampleInts(total, n)
if err != nil {
diff --git a/scripts/init.sql b/scripts/init.sql
index e092291b3..b737a6bd8 100644
--- a/scripts/init.sql
+++ b/scripts/init.sql
@@ -39,9 +39,9 @@
-- Any similarity to existing people is purely coincidental.
--
-DROP DATABASE IF EXISTS employees;
-CREATE DATABASE IF NOT EXISTS employees;
-USE employees;
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+USE employees_0000;
SELECT 'CREATING DATABASE STRUCTURE' as 'INFO';
@@ -108,18 +108,7 @@ CREATE TABLE salaries (
from_date DATE NOT NULL,
to_date DATE NOT NULL,
FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE,
- PRIMARY KEY (emp_no, from_date)
+ PRIMARY KEY (emp_no, from_date),
+ KEY `from_date` (`from_date`)
)
;
-
-CREATE OR REPLACE VIEW dept_emp_latest_date AS
- SELECT emp_no, MAX(from_date) AS from_date, MAX(to_date) AS to_date
- FROM dept_emp
- GROUP BY emp_no;
-
-# shows only the current department for each employee
-CREATE OR REPLACE VIEW current_dept_emp AS
- SELECT l.emp_no, dept_no, l.from_date, l.to_date
- FROM dept_emp d
- INNER JOIN dept_emp_latest_date l
- ON d.emp_no=l.emp_no AND d.from_date=l.from_date AND l.to_date = d.to_date;
diff --git a/scripts/sequence.sql b/scripts/sequence.sql
index 7778a60f0..0df0add1c 100644
--- a/scripts/sequence.sql
+++ b/scripts/sequence.sql
@@ -15,18 +15,17 @@
-- limitations under the License.
--
-CREATE DATABASE IF NOT EXISTS employees;
-USE employees;
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
-CREATE TABLE IF NOT EXISTS `sequence`
+CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence`
(
`id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
`name` VARCHAR(64) NOT NULL,
`value` BIGINT NOT NULL,
`step` INT NOT NULL DEFAULT 10000,
- `created_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- `modified_at` TIMESTAMP NOT NULL,
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `uk_name` (`name`)
) ENGINE = InnoDB
- DEFAULT CHARSET = utf8;
+ DEFAULT CHARSET = utf8mb4;
diff --git a/scripts/sharding.sql b/scripts/sharding.sql
index 5507c44ad..e277dff66 100644
--- a/scripts/sharding.sql
+++ b/scripts/sharding.sql
@@ -15,41 +15,653 @@
-- limitations under the License.
--
-CREATE DATABASE IF NOT EXISTS employees;
-USE employees;
-
-DELIMITER //
-CREATE PROCEDURE sp_create_tab()
-BEGIN
- SET @str = ' (
-`id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
-`uid` BIGINT(20) UNSIGNED NOT NULL,
-`name` VARCHAR(255) NOT NULL,
-`score` DECIMAL(6,2) DEFAULT ''0'',
-`nickname` VARCHAR(255) DEFAULT NULL,
-`gender` TINYINT(4) NULL,
-`birth_year` SMALLINT(5) UNSIGNED DEFAULT ''0'',
-`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-`modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
-PRIMARY KEY (`id`),
-UNIQUE KEY `uk_uid` (`uid`)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
-';
-
- SET @j = 0;
- WHILE @j < 32
- DO
- SET @table = CONCAT('student_', LPAD(@j, 4, '0'));
- SET @ddl = CONCAT('CREATE TABLE IF NOT EXISTS ', @table, @str);
- PREPARE ddl FROM @ddl;
- EXECUTE ddl;
- SET @j = @j + 1;
- END WHILE;
-END
-//
-
-DELIMITER ;
-CALL sp_create_tab;
-DROP PROCEDURE sp_create_tab;
-
-insert into student_0001 values (1, 1, 'scott', 95, 'nc_scott', 0, 16, now(), now());
+CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE DATABASE IF NOT EXISTS employees_0000_r CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0008`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0009`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0010`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0011`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0012`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0013`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0014`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0015`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0016`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0017`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0018`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0019`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0020`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0021`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0022`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0023`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0024`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0025`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0026`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0027`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0028`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0029`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0030`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0031`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0000`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0001`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0002`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0003`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0004`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0005`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0006`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0007`
+(
+ `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
+ `uid` BIGINT(20) UNSIGNED NOT NULL,
+ `name` VARCHAR(255) NOT NULL,
+ `score` DECIMAL(6,2) DEFAULT '0',
+ `nickname` VARCHAR(255) DEFAULT NULL,
+ `gender` TINYINT(4) NULL,
+ `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0',
+ `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`id`),
+ UNIQUE KEY `uk_uid` (`uid`),
+ KEY `nickname` (`nickname`)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+
+INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW());
diff --git a/test/dataset.go b/test/dataset.go
new file mode 100644
index 000000000..d2d754d3d
--- /dev/null
+++ b/test/dataset.go
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test
+
+import (
+ "bufio"
+ "database/sql"
+ "fmt"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+import (
+ "github.com/pkg/errors"
+
+ "go.uber.org/zap"
+
+ "gopkg.in/yaml.v3"
+)
+
+import (
+ "github.com/arana-db/arana/pkg/util/log"
+)
+
+const (
+ RowsAffected = "rowsAffected"
+ LastInsertId = "lastInsertId"
+ ValueInt = "valueInt"
+ ValueString = "valueString"
+ File = "file"
+)
+
+type (
+ // Message expected message or actual message for integration test
+ Message struct {
+ Kind string `yaml:"kind"` // dataset, expected
+ MetaData MetaData `yaml:"metadata"`
+ Data []*Data `yaml:"data"`
+ }
+
+ MetaData struct {
+ Tables []*Table `json:"tables"`
+ }
+
+ Data struct {
+ Name string `json:"name"`
+ Value [][]string `json:"value"`
+ }
+
+ Table struct {
+ Name string `yaml:"name"`
+ Columns []Column `yaml:"columns"`
+ }
+
+ Column struct {
+ Name string `yaml:"name"`
+ Type string `yaml:"type"`
+ }
+)
+
+type (
+ Cases struct {
+ Kind string `yaml:"kind"` // dataset, expected
+ ExecCases []*Case `yaml:"exec_cases"`
+ QueryRowsCases []*Case `yaml:"query_rows_cases"`
+ QueryRowCases []*Case `yaml:"query_row_cases"`
+ }
+
+ Case struct {
+ SQL string `yaml:"sql"`
+ Parameters string `yaml:"parameters,omitempty"`
+ Type string `yaml:"type"`
+ Sense []string `yaml:"sense"`
+ ExpectedResult *Expected `yaml:"expected"`
+ }
+
+ Expected struct {
+ ResultType string `yaml:"resultType"`
+ Value string `yaml:"value"`
+ }
+)
+
+func getDataFrom(rows *sql.Rows) ([][]string, []string, error) {
+ pr := func(t interface{}) (r string) {
+ r = "\\N"
+ switch v := t.(type) {
+ case *sql.NullBool:
+ if v.Valid {
+ r = strconv.FormatBool(v.Bool)
+ }
+ case *sql.NullString:
+ if v.Valid {
+ r = v.String
+ }
+ case *sql.NullInt64:
+ if v.Valid {
+ r = fmt.Sprintf("%6d", v.Int64)
+ }
+ case *sql.NullFloat64:
+ if v.Valid {
+ r = fmt.Sprintf("%.2f", v.Float64)
+ }
+ case *time.Time:
+ if v.Year() > 1900 {
+ r = v.Format("_2 Jan 2006")
+ }
+ default:
+ r = fmt.Sprintf("%#v", t)
+ }
+ return
+ }
+
+ c, _ := rows.Columns()
+ n := len(c)
+ field := make([]interface{}, 0, n)
+ for i := 0; i < n; i++ {
+ field = append(field, new(sql.NullString))
+ }
+
+ var converts [][]string
+
+ for rows.Next() {
+ if err := rows.Scan(field...); err != nil {
+ return nil, nil, err
+ }
+ row := make([]string, 0, n)
+ for i := 0; i < n; i++ {
+ col := pr(field[i])
+ row = append(row, col)
+ }
+ converts = append(converts, row)
+ }
+ return converts, c, nil
+}
+
+func (e *Expected) CompareRows(rows *sql.Rows, actual *Message) error {
+ data, columns, err := getDataFrom(rows)
+ if err != nil {
+ return err
+ }
+
+ actualMap := make(map[string][]map[string]string, 10) // table:key:value
+ for _, v := range actual.Data {
+ tb := actualMap[v.Name]
+ if tb == nil {
+ actualMap[v.Name] = make([]map[string]string, 0, 1000)
+ }
+ for _, rows := range v.Value {
+ val := make(map[string]string, len(rows))
+ for i, row := range rows {
+ val[actual.MetaData.Tables[0].Columns[i].Name] = row
+ actualMap[v.Name] = append(actualMap[v.Name], val)
+ }
+ }
+ }
+ return CompareWithActualSet(columns, data, actualMap)
+}
+
+func CompareWithActualSet(columns []string, driverSet [][]string, exceptSet map[string][]map[string]string) error {
+ for _, rows := range driverSet {
+ foundFlag := false
+ // for loop actual data
+ for _, actualRows := range exceptSet {
+ for _, actualRow := range actualRows {
+ if actualRow[columns[0]] == rows[0] {
+ foundFlag = true
+ for i := 1; i < len(rows); i++ {
+ if actualRow[columns[i]] != rows[i] {
+ foundFlag = false
+ }
+ }
+ }
+ }
+
+ if foundFlag {
+ break
+ }
+ }
+
+ if !foundFlag {
+ return errors.New("record not found")
+ }
+ }
+
+ return nil
+}
+
+func (e *Expected) CompareRow(result interface{}) error {
+ switch e.ResultType {
+ case ValueInt:
+ var cnt int
+ row, ok := result.(*sql.Row)
+ if !ok {
+ return errors.New("type error")
+ }
+ if err := row.Scan(&cnt); err != nil {
+ return err
+ }
+ if e.Value != fmt.Sprint(cnt) {
+ return errors.New("not equal")
+ }
+ case RowsAffected:
+ rowsAfferted, _ := result.(sql.Result).RowsAffected()
+ if fmt.Sprint(rowsAfferted) != e.Value {
+ return errors.New("not equal")
+ }
+ case LastInsertId:
+ lastInsertId, _ := result.(sql.Result).LastInsertId()
+ if fmt.Sprint(lastInsertId) != e.Value {
+ return errors.New("not equal")
+ }
+
+ }
+
+ return nil
+}
+
+// LoadYamlConfig load yaml config from path to val, val should be a pointer
+func LoadYamlConfig(path string, val interface{}) error {
+ if err := validInputIsPtr(val); err != nil {
+ log.Fatal("valid conf failed", zap.Error(err))
+ }
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ if err = yaml.NewDecoder(bufio.NewReader(f)).Decode(val); err != nil {
+ return err
+ }
+ return nil
+}
+
+func validInputIsPtr(conf interface{}) error {
+ tp := reflect.TypeOf(conf)
+ if tp.Kind() != reflect.Ptr {
+ return errors.New("conf should be pointer")
+ }
+ return nil
+}
+
+func GetValueByType(param string) (interface{}, error) {
+ kv := strings.Split(param, ":")
+ switch strings.ToLower(kv[1]) {
+ case "string":
+ return kv[0], nil
+ case "int":
+ return strconv.ParseInt(kv[0], 10, 64)
+ case "float":
+ return strconv.ParseFloat(kv[0], 64)
+ }
+
+ return kv[0], nil
+}
diff --git a/test/integration_test.go b/test/integration_test.go
index baf3e8cce..737eabae0 100644
--- a/test/integration_test.go
+++ b/test/integration_test.go
@@ -24,7 +24,7 @@ import (
)
import (
- _ "github.com/go-sql-driver/mysql" // register mysql
+ _ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
@@ -40,13 +40,17 @@ type IntegrationSuite struct {
}
func TestSuite(t *testing.T) {
- su := NewMySuite(WithMySQLServerAuth("root", "123456"), WithMySQLDatabase("employees"))
+ su := NewMySuite(
+ WithMySQLServerAuth("root", "123456"),
+ WithMySQLDatabase("employees"),
+ WithConfig("../integration_test/config/db_tbl/config.yaml"),
+ WithScriptPath("../scripts"),
+ // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!!
+ )
suite.Run(t, &IntegrationSuite{su})
}
func (s *IntegrationSuite) TestBasicTx() {
- // TODO: skip temporarily, need to implement ref-count-down
- s.T().Skip()
var (
db = s.DB()
t = s.T()
@@ -183,13 +187,31 @@ func (s *IntegrationSuite) TestInsert() {
t = s.T()
)
result, err := db.Exec(`INSERT INTO employees ( emp_no, birth_date, first_name, last_name, gender, hire_date )
- VALUES (?, ?, ?, ?, ?, ?)`, 100001, "1992-01-07", "scott", "lewis", "M", "2014-09-01")
+ VALUES (?, ?, ?, ?, ?, ?) `, 100001, "1992-01-07", "scott", "lewis", "M", "2014-09-01")
assert.NoErrorf(t, err, "insert row error: %v", err)
affected, err := result.RowsAffected()
assert.NoErrorf(t, err, "insert row error: %v", err)
assert.Equal(t, int64(1), affected)
}
+func (s *IntegrationSuite) TestInsertOnDuplicateKey() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ i := 32
+ result, err := db.Exec(`INSERT IGNORE INTO student(id,uid,score,name,nickname,gender,birth_year)
+ values (?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE nickname='dump' `, 1654008174496657000, i, 3.14, fmt.Sprintf("fake_name_%d", i), fmt.Sprintf("fake_nickname_%d", i), 1, 2022)
+ assert.NoErrorf(t, err, "insert row error: %v", err)
+ _, err = result.RowsAffected()
+ assert.NoErrorf(t, err, "insert row error: %v", err)
+
+ _, err = db.Exec(`INSERT IGNORE INTO student(id,uid,score,name,nickname,gender,birth_year)
+ values (?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE uid=32 `, 1654008174496657000, i, 3.14, fmt.Sprintf("fake_name_%d", i), fmt.Sprintf("fake_nickname_%d", i), 1, 2022)
+ assert.Error(t, err, "insert row error: %v", err)
+}
+
func (s *IntegrationSuite) TestSelect() {
var (
db = s.DB()
@@ -197,7 +219,7 @@ func (s *IntegrationSuite) TestSelect() {
)
rows, err := db.Query(`SELECT emp_no, birth_date, first_name, last_name, gender, hire_date FROM employees
- WHERE emp_no = ?`, 100001)
+ WHERE emp_no = ?`, "100001")
assert.NoErrorf(t, err, "select row error: %v", err)
defer rows.Close()
@@ -255,6 +277,9 @@ func (s *IntegrationSuite) TestUpdate() {
assert.NoErrorf(t, err, "update row error: %v", err)
assert.Equal(t, int64(1), affected)
+
+ _, err = db.Exec("update student set score=100.0,uid=11 where uid = ?", 32)
+ assert.Error(t, err)
}
func (s *IntegrationSuite) TestDelete() {
@@ -298,7 +323,7 @@ func (s *IntegrationSuite) TestDropTable() {
t.Skip()
- //drop table physical name != logical name and physical name = logical name
+ // drop table physical name != logical name and physical name = logical name
result, err := db.Exec(`DROP TABLE student,salaries`)
assert.NoErrorf(t, err, "drop table error:%v", err)
@@ -306,7 +331,7 @@ func (s *IntegrationSuite) TestDropTable() {
assert.Equal(t, int64(0), affected)
assert.NoErrorf(t, err, "drop table error: %v", err)
- //drop again, return error
+ // drop again, return error
result, err = db.Exec(`DROP TABLE student,salaries`)
assert.Error(t, err, "drop table error: %v", err)
assert.Nil(t, result)
@@ -318,8 +343,10 @@ func (s *IntegrationSuite) TestJoinTable() {
t = s.T()
)
+ t.Skip()
+
sqls := []string{
- //shard & no shard
+ // shard & no shard
`select * from student join titles on student.id=titles.emp_no`,
// shard & no shard with alias
`select * from student join titles as b on student.id=b.emp_no`,
@@ -333,13 +360,12 @@ func (s *IntegrationSuite) TestJoinTable() {
_, err := db.Query(sql)
assert.NoErrorf(t, err, "join table error:%v", err)
}
- //with where
+ // with where
_, err := db.Query(`select * from student join titles on student.id=titles.emp_no where student.id=? and titles.emp_no=?`, 1, 2)
assert.NoErrorf(t, err, "join table error:%v", err)
}
func (s *IntegrationSuite) TestShardingAgg() {
- s.T().Skip()
var (
db = s.DB()
t = s.T()
@@ -471,3 +497,112 @@ func (s *IntegrationSuite) TestAlterTable() {
assert.Equal(t, int64(0), affected)
}
+
+func (s *IntegrationSuite) TestCreateIndex() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ result, err := db.Exec("create index `name` on student (name)")
+ assert.NoErrorf(t, err, "create index error: %v", err)
+ affected, err := result.RowsAffected()
+ assert.NoErrorf(t, err, "create index error: %v", err)
+ assert.Equal(t, int64(0), affected)
+}
+
+func (s *IntegrationSuite) TestDropIndex() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ result, err := db.Exec("drop index `nickname` on student")
+ assert.NoErrorf(t, err, "drop index error: %v", err)
+ affected, err := result.RowsAffected()
+ assert.NoErrorf(t, err, "drop index error: %v", err)
+
+ assert.Equal(t, int64(0), affected)
+}
+
+func (s *IntegrationSuite) TestShowColumns() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ result, err := db.Query("show columns from student")
+ assert.NoErrorf(t, err, "show columns error: %v", err)
+
+ defer result.Close()
+
+ affected, err := result.ColumnTypes()
+ assert.NoErrorf(t, err, "show columns: %v", err)
+ assert.Equal(t, affected[0].DatabaseTypeName(), "VARCHAR")
+}
+
+func (s *IntegrationSuite) TestShowCreate() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ row := db.QueryRow("show create table student")
+ var table, createStr string
+ assert.NoError(t, row.Scan(&table, &createStr))
+ assert.Equal(t, "student", table)
+}
+
+func (s *IntegrationSuite) TestDropTrigger() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ type tt struct {
+ sql string
+ }
+
+ for _, it := range []tt{
+ {"DROP TRIGGER arana"},
+ {"DROP TRIGGER employees_0000.arana"},
+ {"DROP TRIGGER IF EXISTS arana"},
+ {"DROP TRIGGER IF EXISTS employees_0000.arana"},
+ } {
+ t.Run(it.sql, func(t *testing.T) {
+ _, err := db.Exec(it.sql)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func (s *IntegrationSuite) TestHints() {
+ var (
+ db = s.DB()
+ t = s.T()
+ )
+
+ type tt struct {
+ sql string
+ args []interface{}
+ expectLen int
+ }
+
+ for _, it := range []tt{
+ {"/*A! master */ SELECT * FROM student WHERE uid = 42 AND 1=2", nil, 0},
+ {"/*A! slave */ SELECT * FROM student WHERE uid = ?", []interface{}{1}, 0},
+ {"/*A! master */ SELECT * FROM student WHERE uid = ?", []interface{}{1}, 1},
+ {"/*A! master */ SELECT * FROM student WHERE uid in (?)", []interface{}{1}, 1},
+ {"/*A! master */ SELECT * FROM student where uid between 1 and 10", nil, 1},
+ } {
+ t.Run(it.sql, func(t *testing.T) {
+ // select from logical table
+ rows, err := db.Query(it.sql, it.args...)
+ assert.NoError(t, err, "should query from sharding table successfully")
+ defer rows.Close()
+ data, _ := utils.PrintTable(rows)
+ assert.Equal(t, it.expectLen, len(data))
+ })
+ }
+
+}
diff --git a/test/suite.go b/test/suite.go
index 2391e3bd8..ed201f02b 100644
--- a/test/suite.go
+++ b/test/suite.go
@@ -41,8 +41,20 @@ import (
"github.com/arana-db/arana/testdata"
)
+const (
+ timeout = "1s"
+ readTimeout = "3s"
+ writeTimeout = "5s"
+)
+
type Option func(*MySuite)
+func WithDevMode() Option {
+ return func(mySuite *MySuite) {
+ mySuite.devMode = true
+ }
+}
+
func WithMySQLServerAuth(username, password string) Option {
return func(mySuite *MySuite) {
mySuite.username = username
@@ -56,9 +68,60 @@ func WithMySQLDatabase(database string) Option {
}
}
+func WithConfig(path string) Option {
+ return func(mySuite *MySuite) {
+ mySuite.configPath = path
+ }
+}
+
+func WithBootstrap(path string) Option {
+ return func(mySuite *MySuite) {
+ mySuite.bootstrapConfig = path
+ }
+}
+
+func WithScriptPath(path string) Option {
+ return func(mySuite *MySuite) {
+ mySuite.scriptPath = path
+ }
+}
+
+func (ms *MySuite) LoadActualDataSetPath(path string) error {
+ var msg *Message
+ err := LoadYamlConfig(path, &msg)
+ if err != nil {
+ return err
+ }
+ ms.actualDataset = msg
+ return nil
+}
+
+func (ms *MySuite) LoadExpectedDataSetPath(path string) error {
+ var msg *Message
+ err := LoadYamlConfig(path, &msg)
+ if err != nil {
+ return err
+ }
+ ms.expectedDataset = msg
+ return nil
+}
+
+func WithTestCasePath(path string) Option {
+ return func(mySuite *MySuite) {
+ var ts *Cases
+ err := LoadYamlConfig(path, &ts)
+ if err != nil {
+ return
+ }
+ mySuite.cases = ts
+ }
+}
+
type MySuite struct {
suite.Suite
+ devMode bool
+
username, password, database string
port int
@@ -67,12 +130,16 @@ type MySuite struct {
db *sql.DB
dbSync sync.Once
- tmpFile string
+ tmpFile, bootstrapConfig, configPath, scriptPath string
+
+ cases *Cases
+ actualDataset *Message
+ expectedDataset *Message
}
func NewMySuite(options ...Option) *MySuite {
ms := &MySuite{
- port: 3306,
+ port: 13306,
}
for _, it := range options {
it(ms)
@@ -80,6 +147,18 @@ func NewMySuite(options ...Option) *MySuite {
return ms
}
+func (ms *MySuite) ActualDataset() *Message {
+ return ms.actualDataset
+}
+
+func (ms *MySuite) ExpectedDataset() *Message {
+ return ms.expectedDataset
+}
+
+func (ms *MySuite) TestCases() *Cases {
+ return ms.cases
+}
+
func (ms *MySuite) DB() *sql.DB {
ms.dbSync.Do(func() {
if ms.port < 1 {
@@ -87,14 +166,22 @@ func (ms *MySuite) DB() *sql.DB {
}
var (
- dsn = fmt.Sprintf("arana:123456@tcp(127.0.0.1:%d)/employees?timeout=1s&readTimeout=1s&writeTimeout=1s&parseTime=true&loc=Local&charset=utf8mb4,utf8", ms.port)
+ dsn = fmt.Sprintf(
+ "arana:123456@tcp(127.0.0.1:%d)/employees?"+
+ "timeout=%s&"+
+ "readTimeout=%s&"+
+ "writeTimeout=%s&"+
+ "parseTime=true&"+
+ "loc=Local&"+
+ "charset=utf8mb4,utf8",
+ ms.port, timeout, readTimeout, writeTimeout)
err error
)
ms.T().Logf("====== connecting %s ======\n", dsn)
if ms.db, err = sql.Open("mysql", dsn); err != nil {
- ms.T().Log("connect failed:", err.Error())
+ ms.T().Log("connect arana failed:", err.Error())
}
})
@@ -106,11 +193,17 @@ func (ms *MySuite) DB() *sql.DB {
}
func (ms *MySuite) SetupSuite() {
+ if ms.devMode {
+ return
+ }
+
var (
mt = MySQLContainerTester{
Username: ms.username,
Password: ms.password,
Database: ms.database,
+
+ ScriptPath: ms.scriptPath,
}
err error
)
@@ -123,11 +216,17 @@ func (ms *MySuite) SetupSuite() {
ms.T().Logf("====== mysql is listening on %s ======\n", mysqlAddr)
ms.T().Logf("====== arana will listen on 127.0.0.1:%d ======\n", ms.port)
- cfgPath := testdata.Path("../conf/config.yaml")
+ if ms.configPath == "" {
+ ms.configPath = "../conf/config.yaml"
+ }
+ cfgPath := testdata.Path(ms.configPath)
err = ms.createConfigFile(cfgPath, ms.container.Host, ms.container.Port)
require.NoError(ms.T(), err)
+ if ms.bootstrapConfig == "" {
+ ms.bootstrapConfig = "../conf/bootstrap.yaml"
+ }
go func() {
_ = os.Setenv(constants.EnvConfigPath, ms.tmpFile)
start.Run(testdata.Path("../conf/bootstrap.yaml"))
@@ -138,6 +237,12 @@ func (ms *MySuite) SetupSuite() {
}
func (ms *MySuite) TearDownSuite() {
+ if ms.devMode {
+ if ms.db != nil {
+ _ = ms.db.Close()
+ }
+ return
+ }
if len(ms.tmpFile) > 0 {
ms.T().Logf("====== remove temp arana config file: %s ====== \n", ms.tmpFile)
_ = os.Remove(ms.tmpFile)
diff --git a/test/testcontainer_mysql.go b/test/testcontainer_mysql.go
index db8a6a931..9587ac2b8 100644
--- a/test/testcontainer_mysql.go
+++ b/test/testcontainer_mysql.go
@@ -19,6 +19,7 @@ package test
import (
"context"
+ "path"
)
import (
@@ -41,6 +42,8 @@ type MySQLContainerTester struct {
Username string `validate:"required" yaml:"username" json:"username"`
Password string `validate:"required" yaml:"password" json:"password"`
Database string `validate:"required" yaml:"database" json:"database"`
+
+ ScriptPath string
}
func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*MySQLContainer, error) {
@@ -53,9 +56,9 @@ func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*My
"MYSQL_DATABASE": tester.Database,
},
BindMounts: map[string]string{
- "/docker-entrypoint-initdb.d/0.sql": testdata.Path("../scripts/init.sql"),
- "/docker-entrypoint-initdb.d/1.sql": testdata.Path("../scripts/sharding.sql"),
- "/docker-entrypoint-initdb.d/2.sql": testdata.Path("../scripts/sequence.sql"),
+ "/docker-entrypoint-initdb.d/0.sql": testdata.Path(path.Join(tester.ScriptPath, "init.sql")),
+ "/docker-entrypoint-initdb.d/1.sql": testdata.Path(path.Join(tester.ScriptPath, "sharding.sql")),
+ "/docker-entrypoint-initdb.d/2.sql": testdata.Path(path.Join(tester.ScriptPath, "sequence.sql")),
},
WaitingFor: wait.ForLog("port: 3306 MySQL Community Server - GPL"),
}
@@ -64,7 +67,6 @@ func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*My
ContainerRequest: req,
Started: true,
})
-
if err != nil {
return nil, err
}
diff --git a/testdata/fake_bootstrap.yaml b/testdata/fake_bootstrap.yaml
index ec6ede06d..004e60120 100644
--- a/testdata/fake_bootstrap.yaml
+++ b/testdata/fake_bootstrap.yaml
@@ -71,9 +71,11 @@ config:
allow_full_scan: true
db_rules:
- column: student_id
+ type: modShard
expr: modShard(3)
tbl_rules:
- column: student_id
+ type: modShard
expr: modShard(8)
topology:
db_pattern: employee_0000
diff --git a/testdata/fake_config.yaml b/testdata/fake_config.yaml
index 9e25a95d9..e4ee4cd81 100644
--- a/testdata/fake_config.yaml
+++ b/testdata/fake_config.yaml
@@ -50,9 +50,11 @@ data:
allow_full_scan: true
db_rules:
- column: student_id
+ type: modShard
expr: modShard(3)
tbl_rules:
- column: student_id
+ type: modShard
expr: modShard(8)
topology:
db_pattern: employee_0000
diff --git a/testdata/mock_data.go b/testdata/mock_data.go
index b02838d96..dca624e73 100644
--- a/testdata/mock_data.go
+++ b/testdata/mock_data.go
@@ -1,34 +1,133 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/arana-db/arana/pkg/proto (interfaces: Field,Row,KeyedRow,Dataset,Result)
+// Package testdata is a generated GoMock package.
package testdata
import (
- "reflect"
+ io "io"
+ reflect "reflect"
)
import (
- "github.com/golang/mock/gomock"
+ gomock "github.com/golang/mock/gomock"
)
import (
- "github.com/arana-db/arana/pkg/proto"
+ proto "github.com/arana-db/arana/pkg/proto"
)
+// MockField is a mock of Field interface.
+type MockField struct {
+ ctrl *gomock.Controller
+ recorder *MockFieldMockRecorder
+}
+
+// MockFieldMockRecorder is the mock recorder for MockField.
+type MockFieldMockRecorder struct {
+ mock *MockField
+}
+
+// NewMockField creates a new mock instance.
+func NewMockField(ctrl *gomock.Controller) *MockField {
+ mock := &MockField{ctrl: ctrl}
+ mock.recorder = &MockFieldMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockField) EXPECT() *MockFieldMockRecorder {
+ return m.recorder
+}
+
+// DatabaseTypeName mocks base method.
+func (m *MockField) DatabaseTypeName() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DatabaseTypeName")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// DatabaseTypeName indicates an expected call of DatabaseTypeName.
+func (mr *MockFieldMockRecorder) DatabaseTypeName() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DatabaseTypeName", reflect.TypeOf((*MockField)(nil).DatabaseTypeName))
+}
+
+// DecimalSize mocks base method.
+func (m *MockField) DecimalSize() (int64, int64, bool) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DecimalSize")
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(int64)
+ ret2, _ := ret[2].(bool)
+ return ret0, ret1, ret2
+}
+
+// DecimalSize indicates an expected call of DecimalSize.
+func (mr *MockFieldMockRecorder) DecimalSize() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecimalSize", reflect.TypeOf((*MockField)(nil).DecimalSize))
+}
+
+// Length mocks base method.
+func (m *MockField) Length() (int64, bool) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Length")
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(bool)
+ return ret0, ret1
+}
+
+// Length indicates an expected call of Length.
+func (mr *MockFieldMockRecorder) Length() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockField)(nil).Length))
+}
+
+// Name mocks base method.
+func (m *MockField) Name() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Name")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// Name indicates an expected call of Name.
+func (mr *MockFieldMockRecorder) Name() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockField)(nil).Name))
+}
+
+// Nullable mocks base method.
+func (m *MockField) Nullable() (bool, bool) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Nullable")
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(bool)
+ return ret0, ret1
+}
+
+// Nullable indicates an expected call of Nullable.
+func (mr *MockFieldMockRecorder) Nullable() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nullable", reflect.TypeOf((*MockField)(nil).Nullable))
+}
+
+// ScanType mocks base method.
+func (m *MockField) ScanType() reflect.Type {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ScanType")
+ ret0, _ := ret[0].(reflect.Type)
+ return ret0
+}
+
+// ScanType indicates an expected call of ScanType.
+func (mr *MockFieldMockRecorder) ScanType() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanType", reflect.TypeOf((*MockField)(nil).ScanType))
+}
+
// MockRow is a mock of Row interface.
type MockRow struct {
ctrl *gomock.Controller
@@ -52,65 +151,88 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder {
return m.recorder
}
-// Columns mocks base method.
-func (m *MockRow) Columns() []string {
+// IsBinary mocks base method.
+func (m *MockRow) IsBinary() bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Columns")
- ret0, _ := ret[0].([]string)
+ ret := m.ctrl.Call(m, "IsBinary")
+ ret0, _ := ret[0].(bool)
return ret0
}
-// Columns indicates an expected call of Columns.
-func (mr *MockRowMockRecorder) Columns() *gomock.Call {
+// IsBinary indicates an expected call of IsBinary.
+func (mr *MockRowMockRecorder) IsBinary() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Columns", reflect.TypeOf((*MockRow)(nil).Columns))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBinary", reflect.TypeOf((*MockRow)(nil).IsBinary))
}
-// Data mocks base method.
-func (m *MockRow) Data() []byte {
+// Length mocks base method.
+func (m *MockRow) Length() int {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Data")
- ret0, _ := ret[0].([]byte)
+ ret := m.ctrl.Call(m, "Length")
+ ret0, _ := ret[0].(int)
return ret0
}
-// Data indicates an expected call of Data.
-func (mr *MockRowMockRecorder) Data() *gomock.Call {
+// Length indicates an expected call of Length.
+func (mr *MockRowMockRecorder) Length() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Data", reflect.TypeOf((*MockRow)(nil).Data))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockRow)(nil).Length))
}
-// Encode mocks base method.
-func (m *MockRow) Encode(values []*proto.Value, columns []proto.Field, columnNames []string) proto.Row {
+// Scan mocks base method.
+func (m *MockRow) Scan(arg0 []proto.Value) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Encode")
- ret0, _ := ret[0].(proto.Row)
+ ret := m.ctrl.Call(m, "Scan", arg0)
+ ret0, _ := ret[0].(error)
return ret0
}
-// Encode mocks base method.
-func (mr *MockRowMockRecorder) Encode() *gomock.Call {
+// Scan indicates an expected call of Scan.
+func (mr *MockRowMockRecorder) Scan(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encode", reflect.TypeOf((*MockRow)(nil).Encode))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0)
}
-// Decode mocks base method.
-func (m *MockRow) Decode() ([]*proto.Value, error) {
+// WriteTo mocks base method.
+func (m *MockRow) WriteTo(arg0 io.Writer) (int64, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Decode")
- ret0, _ := ret[0].([]*proto.Value)
+ ret := m.ctrl.Call(m, "WriteTo", arg0)
+ ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-// Decode indicates an expected call of Decode.
-func (mr *MockRowMockRecorder) Decode() *gomock.Call {
+// WriteTo indicates an expected call of WriteTo.
+func (mr *MockRowMockRecorder) WriteTo(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decode", reflect.TypeOf((*MockRow)(nil).Decode))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockRow)(nil).WriteTo), arg0)
+}
+
+// MockKeyedRow is a mock of KeyedRow interface.
+type MockKeyedRow struct {
+ ctrl *gomock.Controller
+ recorder *MockKeyedRowMockRecorder
+}
+
+// MockKeyedRowMockRecorder is the mock recorder for MockKeyedRow.
+type MockKeyedRowMockRecorder struct {
+ mock *MockKeyedRow
+}
+
+// NewMockKeyedRow creates a new mock instance.
+func NewMockKeyedRow(ctrl *gomock.Controller) *MockKeyedRow {
+ mock := &MockKeyedRow{ctrl: ctrl}
+ mock.recorder = &MockKeyedRowMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockKeyedRow) EXPECT() *MockKeyedRowMockRecorder {
+ return m.recorder
}
// Fields mocks base method.
-func (m *MockRow) Fields() []proto.Field {
+func (m *MockKeyedRow) Fields() []proto.Field {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Fields")
ret0, _ := ret[0].([]proto.Field)
@@ -118,22 +240,214 @@ func (m *MockRow) Fields() []proto.Field {
}
// Fields indicates an expected call of Fields.
-func (mr *MockRowMockRecorder) Fields() *gomock.Call {
+func (mr *MockKeyedRowMockRecorder) Fields() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockKeyedRow)(nil).Fields))
+}
+
+// Get mocks base method.
+func (m *MockKeyedRow) Get(arg0 string) (proto.Value, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Get", arg0)
+ ret0, _ := ret[0].(proto.Value)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Get indicates an expected call of Get.
+func (mr *MockKeyedRowMockRecorder) Get(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockKeyedRow)(nil).Get), arg0)
+}
+
+// IsBinary mocks base method.
+func (m *MockKeyedRow) IsBinary() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IsBinary")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IsBinary indicates an expected call of IsBinary.
+func (mr *MockKeyedRowMockRecorder) IsBinary() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBinary", reflect.TypeOf((*MockKeyedRow)(nil).IsBinary))
+}
+
+// Length mocks base method.
+func (m *MockKeyedRow) Length() int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Length")
+ ret0, _ := ret[0].(int)
+ return ret0
+}
+
+// Length indicates an expected call of Length.
+func (mr *MockKeyedRowMockRecorder) Length() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockKeyedRow)(nil).Length))
+}
+
+// Scan mocks base method.
+func (m *MockKeyedRow) Scan(arg0 []proto.Value) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Scan", arg0)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Scan indicates an expected call of Scan.
+func (mr *MockKeyedRowMockRecorder) Scan(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockKeyedRow)(nil).Scan), arg0)
+}
+
+// WriteTo mocks base method.
+func (m *MockKeyedRow) WriteTo(arg0 io.Writer) (int64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "WriteTo", arg0)
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// WriteTo indicates an expected call of WriteTo.
+func (mr *MockKeyedRowMockRecorder) WriteTo(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockKeyedRow)(nil).WriteTo), arg0)
+}
+
+// MockDataset is a mock of Dataset interface.
+type MockDataset struct {
+ ctrl *gomock.Controller
+ recorder *MockDatasetMockRecorder
+}
+
+// MockDatasetMockRecorder is the mock recorder for MockDataset.
+type MockDatasetMockRecorder struct {
+ mock *MockDataset
+}
+
+// NewMockDataset creates a new mock instance.
+func NewMockDataset(ctrl *gomock.Controller) *MockDataset {
+ mock := &MockDataset{ctrl: ctrl}
+ mock.recorder = &MockDatasetMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockDataset) EXPECT() *MockDatasetMockRecorder {
+ return m.recorder
+}
+
+// Close mocks base method.
+func (m *MockDataset) Close() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Close")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Close indicates an expected call of Close.
+func (mr *MockDatasetMockRecorder) Close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDataset)(nil).Close))
+}
+
+// Fields mocks base method.
+func (m *MockDataset) Fields() ([]proto.Field, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Fields")
+ ret0, _ := ret[0].([]proto.Field)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Fields indicates an expected call of Fields.
+func (mr *MockDatasetMockRecorder) Fields() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockDataset)(nil).Fields))
+}
+
+// Next mocks base method.
+func (m *MockDataset) Next() (proto.Row, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Next")
+ ret0, _ := ret[0].(proto.Row)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Next indicates an expected call of Next.
+func (mr *MockDatasetMockRecorder) Next() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockDataset)(nil).Next))
+}
+
+// MockResult is a mock of Result interface.
+type MockResult struct {
+ ctrl *gomock.Controller
+ recorder *MockResultMockRecorder
+}
+
+// MockResultMockRecorder is the mock recorder for MockResult.
+type MockResultMockRecorder struct {
+ mock *MockResult
+}
+
+// NewMockResult creates a new mock instance.
+func NewMockResult(ctrl *gomock.Controller) *MockResult {
+ mock := &MockResult{ctrl: ctrl}
+ mock.recorder = &MockResultMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockResult) EXPECT() *MockResultMockRecorder {
+ return m.recorder
+}
+
+// Dataset mocks base method.
+func (m *MockResult) Dataset() (proto.Dataset, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Dataset")
+ ret0, _ := ret[0].(proto.Dataset)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Dataset indicates an expected call of Dataset.
+func (mr *MockResultMockRecorder) Dataset() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dataset", reflect.TypeOf((*MockResult)(nil).Dataset))
+}
+
+// LastInsertId mocks base method.
+func (m *MockResult) LastInsertId() (uint64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LastInsertId")
+ ret0, _ := ret[0].(uint64)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LastInsertId indicates an expected call of LastInsertId.
+func (mr *MockResultMockRecorder) LastInsertId() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockRow)(nil).Fields))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastInsertId", reflect.TypeOf((*MockResult)(nil).LastInsertId))
}
-// GetColumnValue mocks base method.
-func (m *MockRow) GetColumnValue(column string) (interface{}, error) {
+// RowsAffected mocks base method.
+func (m *MockResult) RowsAffected() (uint64, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetColumnValue", column)
- ret0, _ := ret[0].(interface{})
+ ret := m.ctrl.Call(m, "RowsAffected")
+ ret0, _ := ret[0].(uint64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-// GetColumnValue indicates an expected call of GetColumnValue.
-func (mr *MockRowMockRecorder) GetColumnValue(column interface{}) *gomock.Call {
+// RowsAffected indicates an expected call of RowsAffected.
+func (mr *MockResultMockRecorder) RowsAffected() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetColumnValue", reflect.TypeOf((*MockRow)(nil).GetColumnValue), column)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RowsAffected", reflect.TypeOf((*MockResult)(nil).RowsAffected))
}
diff --git a/testdata/mock_runtime.go b/testdata/mock_runtime.go
index 4c8a84224..649792773 100644
--- a/testdata/mock_runtime.go
+++ b/testdata/mock_runtime.go
@@ -1,22 +1,5 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/arana-db/arana/pkg/proto (interfaces: VConn,Plan,Optimizer,DB,SchemaLoader)
+// Source: github.com/arana-db/arana/pkg/proto (interfaces: VConn,Plan,Optimizer,DB)
// Package testdata is a generated GoMock package.
package testdata
@@ -28,8 +11,6 @@ import (
)
import (
- ast "github.com/arana-db/parser/ast"
-
gomock "github.com/golang/mock/gomock"
)
@@ -176,23 +157,18 @@ func (m *MockOptimizer) EXPECT() *MockOptimizerMockRecorder {
}
// Optimize mocks base method.
-func (m *MockOptimizer) Optimize(arg0 context.Context, arg1 proto.VConn, arg2 ast.StmtNode, arg3 ...interface{}) (proto.Plan, error) {
+func (m *MockOptimizer) Optimize(arg0 context.Context) (proto.Plan, error) {
m.ctrl.T.Helper()
- varargs := []interface{}{arg0, arg1, arg2}
- for _, a := range arg3 {
- varargs = append(varargs, a)
- }
- ret := m.ctrl.Call(m, "Optimize", varargs...)
+ ret := m.ctrl.Call(m, "Optimize", arg0)
ret0, _ := ret[0].(proto.Plan)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Optimize indicates an expected call of Optimize.
-func (mr *MockOptimizerMockRecorder) Optimize(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call {
+func (mr *MockOptimizerMockRecorder) Optimize(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- varargs := append([]interface{}{arg0, arg1, arg2}, arg3...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), varargs...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), arg0)
}
// MockDB is a mock of DB interface.
@@ -393,40 +369,3 @@ func (mr *MockDBMockRecorder) Weight() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockDB)(nil).Weight))
}
-
-// MockSchemaLoader is a mock of SchemaLoader interface.
-type MockSchemaLoader struct {
- ctrl *gomock.Controller
- recorder *MockSchemaLoaderMockRecorder
-}
-
-// MockSchemaLoaderMockRecorder is the mock recorder for MockSchemaLoader.
-type MockSchemaLoaderMockRecorder struct {
- mock *MockSchemaLoader
-}
-
-// NewMockSchemaLoader creates a new mock instance.
-func NewMockSchemaLoader(ctrl *gomock.Controller) *MockSchemaLoader {
- mock := &MockSchemaLoader{ctrl: ctrl}
- mock.recorder = &MockSchemaLoaderMockRecorder{mock}
- return mock
-}
-
-// EXPECT returns an object that allows the caller to indicate expected use.
-func (m *MockSchemaLoader) EXPECT() *MockSchemaLoaderMockRecorder {
- return m.recorder
-}
-
-// Load mocks base method.
-func (m *MockSchemaLoader) Load(arg0 context.Context, arg1 proto.VConn, arg2 string, arg3 []string) map[string]*proto.TableMetadata {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2, arg3)
- ret0, _ := ret[0].(map[string]*proto.TableMetadata)
- return ret0
-}
-
-// Load indicates an expected call of Load.
-func (mr *MockSchemaLoaderMockRecorder) Load(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockSchemaLoader)(nil).Load), arg0, arg1, arg2, arg3)
-}
diff --git a/testdata/mock_schema.go b/testdata/mock_schema.go
new file mode 100644
index 000000000..b2aab27a6
--- /dev/null
+++ b/testdata/mock_schema.go
@@ -0,0 +1,56 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/arana-db/arana/pkg/proto (interfaces: SchemaLoader)
+
+// Package testdata is a generated GoMock package.
+package testdata
+
+import (
+ context "context"
+ reflect "reflect"
+)
+
+import (
+ gomock "github.com/golang/mock/gomock"
+)
+
+import (
+ proto "github.com/arana-db/arana/pkg/proto"
+)
+
+// MockSchemaLoader is a mock of SchemaLoader interface.
+type MockSchemaLoader struct {
+ ctrl *gomock.Controller
+ recorder *MockSchemaLoaderMockRecorder
+}
+
+// MockSchemaLoaderMockRecorder is the mock recorder for MockSchemaLoader.
+type MockSchemaLoaderMockRecorder struct {
+ mock *MockSchemaLoader
+}
+
+// NewMockSchemaLoader creates a new mock instance.
+func NewMockSchemaLoader(ctrl *gomock.Controller) *MockSchemaLoader {
+ mock := &MockSchemaLoader{ctrl: ctrl}
+ mock.recorder = &MockSchemaLoaderMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockSchemaLoader) EXPECT() *MockSchemaLoaderMockRecorder {
+ return m.recorder
+}
+
+// Load mocks base method.
+func (m *MockSchemaLoader) Load(arg0 context.Context, arg1 string, arg2 []string) (map[string]*proto.TableMetadata, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2)
+ ret0, _ := ret[0].(map[string]*proto.TableMetadata)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Load indicates an expected call of Load.
+func (mr *MockSchemaLoaderMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockSchemaLoader)(nil).Load), arg0, arg1, arg2)
+}
diff --git a/third_party/pools/resource_pool.go b/third_party/pools/resource_pool.go
index 50ec6e188..b531b8554 100644
--- a/third_party/pools/resource_pool.go
+++ b/third_party/pools/resource_pool.go
@@ -48,6 +48,7 @@ import (
)
import (
+ "github.com/arana-db/arana/pkg/util/log"
"github.com/arana-db/arana/third_party/sync2"
"github.com/arana-db/arana/third_party/timer"
)
@@ -154,6 +155,7 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur
r, err := rp.Get(ctx)
if err != nil {
+ log.Errorf("get connection resource error: %v", err)
return
}
rp.Put(r)