diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index 506fe2484..3062edddb 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -38,7 +38,7 @@ jobs: # If you want to matrix build , you can append the following list. matrix: go_version: - - 1.16 + - 1.18 os: - ubuntu-latest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b71067d33..091ac3268 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ name: Release on: release: - types: [created] + types: [ created ] jobs: releases-matrix: @@ -29,9 +29,9 @@ jobs: strategy: matrix: # build and publish in parallel: linux/386, linux/amd64, windows/386, windows/amd64, darwin/amd64 - goos: [linux, windows] - goarch: ["386", amd64, arm] - exclude: + goos: [ linux, windows ] + goarch: [ "386", amd64, arm ] + exclude: - goarch: "arm" goos: windows @@ -42,6 +42,6 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} goos: ${{ matrix.goos }} goarch: ${{ matrix.goarch }} - goversion: "https://golang.org/dl/go1.16.15.linux-amd64.tar.gz" + goversion: "https://go.dev/dl/go1.18.3.linux-amd64.tar.gz" project_path: "./cmd/arana" binary_name: "arana" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 39ac5fbff..6a7425b32 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -30,4 +30,5 @@ jobs: - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 with: - go_version: "1.16" + go_version: "1.18" + golangci_lint_version: "v1.46.2" # use latest version by default diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2845a2862..503044c5a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,6 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: http://github.com/golangci/golangci-lint - rev: v1.42.1 + rev: v1.46.2 hooks: - id: golangci-lint diff --git a/Dockerfile b/Dockerfile index 2b20788aa..f4dc11141 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # builder layer -FROM golang:1.16-alpine AS builder +FROM golang:1.18-alpine AS builder RUN apk add --no-cache upx diff --git a/Makefile b/Makefile index 2cd0ab391..c67702973 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,9 @@ docker-build: integration-test: @go clean -testcache go test -tags integration -v ./test/... + go test -tags integration-db_tbl -v ./integration_test/scene/db_tbl/... + go test -tags integration-db -v ./integration_test/scene/db/... + go test -tags integration-tbl -v ./integration_test/scene/tbl/... clean: @rm -rf coverage.txt @@ -50,4 +53,4 @@ prepareLic: .PHONY: license license: prepareLic - $(GO_LICENSE_CHECKER) -v -a -r -i vendor $(LICENSE_DIR)/license.txt . go && [[ -z `git status -s` ]] \ No newline at end of file + $(GO_LICENSE_CHECKER) -v -a -r -i vendor $(LICENSE_DIR)/license.txt . go && [[ -z `git status -s` ]] diff --git a/README.md b/README.md index a91e6d5cd..e8898d9ca 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,87 @@ -# arana -[![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE) +# Arana + +
+ +
+ +`Arana` is a Cloud Native Database Proxy. It can be deployed as a Database mesh sidecar. It provides transparent data access capabilities, +when using `arana`, user doesn't need to care about the `sharding` details of database, they can use it just like a single `MySQL` database. + +## Overview + +[![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE) [![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana) +[![Go Report Card](https://goreportcard.com/badge/github.com/arana-db/arana)](https://goreportcard.com/report/github.com/arana-db/arana) +[![Release](https://img.shields.io/github/v/release/arana-db/arana)](https://img.shields.io/github/v/release/arana-db/arana) +[![Docker Pulls](https://img.shields.io/docker/pulls/aranadb/arana)](https://img.shields.io/docker/pulls/aranadb/arana) +| **Stargazers Over Time** | **Contributors Over Time** | +|:-----------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| [![Stargazers over time](https://starchart.cc/arana-db/arana.svg)](https://starchart.cc/arana-db/arana) | [![Contributor over time](https://contributor-graph-api.apiseven.com/contributors-svg?chart=contributorOverTime&repo=arana-db/arana)](https://contributor-graph-api.apiseven.com/contributors-svg?chart=contributorOverTime&repo=arana-db/arana) | -![](./docs/pics/arana-logo.png) ## Introduction | [中文](https://github.com/arana-db/arana/blob/master/README_CN.md) -Arana is a db proxy. It can be deployed as a sidecar. +First, `Arana` is a Cloud Native Database Proxy. It provides transparent data access capabilities, when using `arana`, +user doesn't need to care about the `sharding` details of database, they can use it just like a single `MySQL` database. +`Arana` also provide abilities of `Multi Tenant`, `Distributed transaction`, `Shadow database`, `SQL Audit`, `Data encrypt / decrypt` +and so on. Through simple config, user can use these abilities provided by `arana` directly. + +Second, `Arana` can also be deployed as a Database mesh sidecar. As a Database mesh sidecar, arana switches data access from +client mode to proxy mode, which greatly optimizes the startup speed of applications. It provides the ability to manage database +traffic, it takes up very little container resources, doesn't affect the performance of application services in the container, but +provides all the capabilities of proxy. ## Architecture + + ## Features -| feature | complete | -| -- | -- | -| single db proxy | √ | -| read write splitting | × | -| tracing | × | -| metrics | × | -| sql audit | × | -| sharding | × | -| multi tenant | × | +| **Feature** | **Complete** | +|:-----------------------:|:------------:| +| Single DB Proxy | √ | +| Read Write Splitting | √ | +| Sharding | √ | +| Multi Tenant | √ | +| Distributed Primary Key | WIP | +| Distributed Transaction | WIP | +| Shadow Table | WIP | +| Database Mesh | WIP | +| Tracing / Metrics | WIP | +| SQL Audit | Roadmap | +| Data encrypt / decrypt | Roadmap | ## Getting started +Please reference this link [Getting Started](https://github.com/arana-db/arana/discussions/172) + ``` arana start -c ${configFilePath} ``` ### Prerequisites -+ MySQL server 5.7+ ++ Go 1.18+ ++ MySQL Server 5.7+ ## Design and implementation ## Roadmap ## Built With -- [tidb](https://github.com/pingcap/tidb) - The sql parser used + +- [TiDB](https://github.com/pingcap/tidb) - The SQL parser used ## Contact +Arana Chinese Community Meeting Time: **Every Saturday At 9:00PM GMT+8** + + + ## Contributing +Thanks for your help improving the project! We are so happy to have you! We have a contributing guide to help you get involved in the Arana project. + ## License Arana software is licenced under the Apache License Version 2.0. See the [LICENSE](https://github.com/arana-db/arana/blob/master/LICENSE) file for details. diff --git a/README_CN.md b/README_CN.md index ca05c93ba..a957e4604 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,7 +2,7 @@ [![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE) [![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana) -![](./docs/pics/arana-logo.png) +![](./docs/pics/arana-main.png) ## 简介 | [English](https://github.com/arana-db/arana/blob/master/README.md) diff --git a/README_ZH.md b/README_ZH.md index 96d9c4684..ca91e1874 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -2,7 +2,7 @@ [![LICENSE](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE) [![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana) -![](./docs/pics/arana-logo.png) +![](./docs/pics/arana-main.png) ## 简介 diff --git a/cmd/main.go b/cmd/main.go index e0da5f65c..ad35125d5 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -27,9 +27,7 @@ import ( _ "github.com/arana-db/arana/cmd/tools" ) -var ( - Version = "0.1.0" -) +var Version = "0.1.0" func main() { rootCommand := &cobra.Command{ diff --git a/cmd/start/start.go b/cmd/start/start.go index f6ee9de5f..28d232147 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -49,7 +49,7 @@ const slogan = ` / _ | / _ \/ _ | / |/ / _ | / __ |/ , _/ __ |/ / __ | /_/ |_/_/|_/_/ |_/_/|_/_/ |_| -High performance, powerful DB Mesh. +Arana, A High performance & Powerful DB Mesh sidecar. _____________________________________________ ` diff --git a/conf/config.yaml b/conf/config.yaml index 21721426d..5caeab8b1 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -40,40 +40,76 @@ data: type: mysql sql_max_limit: -1 tenant: arana - conn_props: - capacity: 10 - max_capacity: 20 - idle_timeout: 60 groups: - name: employees_0000 nodes: - - name: arana-node-1 + - name: node0 host: arana-mysql port: 3306 username: root password: "123456" - database: employees + database: employees_0000 weight: r10w10 - labels: - zone: shanghai - conn_props: - readTimeout: "1s" - writeTimeout: "1s" - parseTime: true - loc: Local - charset: utf8mb4,utf8 - + parameters: + maxAllowedPacket: 256M + - name: node0_r_0 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0000_r + weight: r0w0 + parameters: + maxAllowedPacket: 256M + - name: employees_0001 + nodes: + - name: node1 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0001 + weight: r10w10 + parameters: + maxAllowedPacket: 256M + - name: employees_0002 + nodes: + - name: node2 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0002 + weight: r10w10 + parameters: + maxAllowedPacket: 256M + - name: employees_0003 + nodes: + - name: node3 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0003 + weight: r10w10 + parameters: + maxAllowedPacket: 256M sharding_rule: tables: - name: employees.student allow_full_scan: true db_rules: + - column: uid + type: scriptExpr + expr: parseInt($value % 32 / 8) tbl_rules: - column: uid + type: scriptExpr expr: $value % 32 + step: 32 topology: - db_pattern: employees_0000 - tbl_pattern: student_${0000...0031} + db_pattern: employees_${0000..0003} + tbl_pattern: student_${0000..0031} attributes: sqlMaxLimit: -1 diff --git a/docker-compose.yaml b/docker-compose.yaml index c66cf692d..7a1548fc1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -18,7 +18,7 @@ version: "3" services: mysql: - image: mysql:8.0 + image: mysql:5.7 container_name: arana-mysql networks: - local @@ -41,7 +41,7 @@ services: arana: build: . container_name: arana - image: aranadb/arana:latest + image: aranadb/arana:master networks: - local ports: diff --git a/docs/pics/arana-architecture.png b/docs/pics/arana-architecture.png new file mode 100644 index 000000000..775fa8dcb Binary files /dev/null and b/docs/pics/arana-architecture.png differ diff --git a/docs/pics/arana-blue.png b/docs/pics/arana-blue.png new file mode 100644 index 000000000..90d31b8c4 Binary files /dev/null and b/docs/pics/arana-blue.png differ diff --git a/docs/pics/arana-db-blue.png b/docs/pics/arana-db-blue.png new file mode 100644 index 000000000..93f3d4014 Binary files /dev/null and b/docs/pics/arana-db-blue.png differ diff --git a/docs/pics/arana-db-v0.2.sketch b/docs/pics/arana-db-v0.2.sketch new file mode 100644 index 000000000..a97994827 Binary files /dev/null and b/docs/pics/arana-db-v0.2.sketch differ diff --git a/docs/pics/arana-main.png b/docs/pics/arana-main.png new file mode 100644 index 000000000..559c90dd4 Binary files /dev/null and b/docs/pics/arana-main.png differ diff --git a/docs/pics/dingtalk-group.jpeg b/docs/pics/dingtalk-group.jpeg new file mode 100644 index 000000000..638128143 Binary files /dev/null and b/docs/pics/dingtalk-group.jpeg differ diff --git a/go.mod b/go.mod index 7d19bd53a..23d592959 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,15 @@ module github.com/arana-db/arana -go 1.16 +go 1.18 require ( - github.com/arana-db/parser v0.2.1 + github.com/arana-db/parser v0.2.3 github.com/bwmarrin/snowflake v0.3.0 github.com/cespare/xxhash/v2 v2.1.2 github.com/dop251/goja v0.0.0-20220422102209-3faab1d8f20e github.com/dubbogo/gost v1.12.3 github.com/go-playground/validator/v10 v10.10.1 github.com/go-sql-driver/mysql v1.6.0 - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.5.0 github.com/hashicorp/golang-lru v0.5.4 github.com/lestrrat-go/strftime v1.0.5 @@ -19,21 +18,105 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.11.0 - github.com/prometheus/common v0.28.0 // indirect github.com/spf13/cobra v1.2.1 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.7.1 github.com/testcontainers/testcontainers-go v0.12.0 github.com/tidwall/gjson v1.14.0 go.etcd.io/etcd/api/v3 v3.5.1 go.etcd.io/etcd/client/v3 v3.5.0 go.etcd.io/etcd/server/v3 v3.5.0-alpha.0 + go.opentelemetry.io/otel v1.7.0 + go.opentelemetry.io/otel/trace v1.7.0 go.uber.org/atomic v1.9.0 - go.uber.org/multierr v1.7.0 go.uber.org/zap v1.19.1 - golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect + golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b +) + +require ( + github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect + github.com/Microsoft/go-winio v0.4.17-0.20210211115548-6eac466e5fa3 // indirect + github.com/Microsoft/hcsshim v0.8.16 // indirect + github.com/aliyun/alibaba-cloud-sdk-go v1.61.18 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/cenkalti/backoff v2.2.1+incompatible // indirect + github.com/containerd/cgroups v0.0.0-20210114181951-8a68de567b68 // indirect + github.com/containerd/containerd v1.5.0-beta.4 // indirect + github.com/coreos/go-semver v0.3.0 // indirect + github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91 // indirect + github.com/docker/distribution v2.7.1+incompatible // indirect + github.com/docker/docker v20.10.11+incompatible // indirect + github.com/docker/go-connections v0.4.0 // indirect + github.com/docker/go-units v0.4.0 // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect + github.com/go-errors/errors v1.0.1 // indirect + github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/google/btree v1.0.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/grpc-ecosystem/go-grpc-middleware v1.2.2 // indirect + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/json-iterator/go v1.1.11 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/magiconair/properties v1.8.5 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect + github.com/moby/sys/mount v0.2.0 // indirect + github.com/moby/sys/mountinfo v0.5.0 // indirect + github.com/moby/term v0.0.0-20201216013528-df9cb8a40635 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/morikuni/aec v0.0.0-20170113033406-39771216ff4c // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.1 // indirect + github.com/opencontainers/runc v1.0.2 // indirect + github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 // indirect + github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.2.0 // indirect + github.com/prometheus/common v0.28.0 // indirect + github.com/prometheus/procfs v0.6.0 // indirect + github.com/sirupsen/logrus v1.8.1 // indirect + github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect + github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect + go.etcd.io/bbolt v1.3.5 // indirect + go.etcd.io/etcd/client/pkg/v3 v3.5.0 // indirect + go.etcd.io/etcd/client/v2 v2.305.0 // indirect + go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0 // indirect + go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0 // indirect + go.opencensus.io v0.23.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect + golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 // indirect + golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect google.golang.org/genproto v0.0.0-20211104193956-4c6863e31247 // indirect google.golang.org/grpc v1.42.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b + google.golang.org/protobuf v1.27.1 // indirect + gopkg.in/ini.v1 v1.62.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + sigs.k8s.io/yaml v1.2.0 // indirect ) diff --git a/go.sum b/go.sum index a2d502ace..c3cc775e1 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,8 @@ github.com/aliyun/alibaba-cloud-sdk-go v1.61.18/go.mod h1:v8ESoHo4SyHmuB4b1tJqDH github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/arana-db/parser v0.2.1 h1:a885V+OIABmqYHYbLP2QWZbn+/TE0mZJd8dafWY7F6Y= -github.com/arana-db/parser v0.2.1/go.mod h1:y4hxIPieC5T26aoNd44XiWXNunC03kUQW0CI3NKaYTk= +github.com/arana-db/parser v0.2.3 h1:zLZcx0/oidlHnw/GZYE78NuvwQkHUv2Xtrm2IwyZasA= +github.com/arana-db/parser v0.2.3/go.mod h1:/XA29bplweWSEAjgoM557ZCzhBilSawUlHcZFjOeDAc= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -136,7 +136,6 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40 h1:xvUo53O5MRZhVMJAxWCJcS5HHrqAiAG9SJ1LpMu6aAI= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= -github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= @@ -244,7 +243,6 @@ github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmf github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20161114122254-48702e0da86b/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9/ZjdUKyjop4mf3Qdd+1TvvltAvM3m8= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= @@ -352,6 +350,11 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= @@ -444,8 +447,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -471,8 +474,9 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99 h1:twflg0XRTjwKpxb/jFExr4HGq6on2dEOmnL6FV+fgPw= +github.com/gopherjs/gopherjs v0.0.0-20190910122728-9d188e94fb99/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -536,8 +540,11 @@ github.com/j-keck/arping v0.0.0-20160618110441-2cf9dc699c56/go.mod h1:ymszkNOg6t github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869/go.mod h1:cJ6Cj7dQo+O6GJNiMx+Pa94qKj+TG8ONdKHgMNIyyag= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20160803190731-bd40a432e4c7/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= @@ -647,7 +654,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nacos-group/nacos-sdk-go v1.0.8 h1:8pEm05Cdav9sQgJSv5kyvlgfz0SzFUUGI3pWX6SiSnM= github.com/nacos-group/nacos-sdk-go v1.0.8/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA= github.com/nacos-group/nacos-sdk-go/v2 v2.0.1 h1:jEZjqdCDSt6ZFtl628UUwON21GxwJ+lEN/PDamQOzgU= github.com/nacos-group/nacos-sdk-go/v2 v2.0.1/go.mod h1:SlhyCAv961LcZ198XpKfPEQqlJWt2HkL1fDLas0uy/w= @@ -860,8 +866,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= @@ -915,7 +922,6 @@ go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= -go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489 h1:1JFLBqwIgdyHN1ZtgjTBwO+blA6gVOmZurpiMEsETKo= go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489/go.mod h1:yVHk9ub3CSBatqGNg7GRmsnfLWtoW60w4eDYfh7vHDg= go.etcd.io/etcd/api/v3 v3.5.0-alpha.0/go.mod h1:mPcW6aZJukV6Aa81LSKpBjQXTWlXB5r74ymPoSWa3Sw= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= @@ -946,6 +952,10 @@ go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= +go.opentelemetry.io/otel v1.7.0 h1:Z2lA3Tdch0iDcrhJXDIlC94XE+bxok1F9B+4Lz/lGsM= +go.opentelemetry.io/otel v1.7.0/go.mod h1:5BdUoMIz5WEs0vt0CUEMtSSaTSHBBVwrhnz7+nrD5xk= +go.opentelemetry.io/otel/trace v1.7.0 h1:O37Iogk1lEkMRXewVtZ1BBTVn5JEp8GrJvP92bJqC6o= +go.opentelemetry.io/otel/trace v1.7.0/go.mod h1:fzLSB9nqR2eXzxPXb2JW9IKE+ScyXA48yyE4TNvoHqU= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -960,9 +970,8 @@ go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpK go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= -go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= @@ -987,9 +996,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1002,6 +1010,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d h1:vtUKgx8dahOomfFzLREU8nSv25YHnTgLBn4rDnWZdU0= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1199,9 +1209,8 @@ golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211109184856-51b60fd695b3/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 h1:XDXtA5hveEEV8JB2l7nhMTp3t3cHp9ZpwcdjqyEWLlo= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32 h1:Js08h5hqB5xyWR789+QqueR6sDE8mk+YvpETZ+F6X9Y= -golang.org/x/sys v0.0.0-20220429233432-b5fbb4746d32/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1285,7 +1294,6 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= diff --git a/integration_test/config/db/config.yaml b/integration_test/config/db/config.yaml new file mode 100644 index 000000000..d9eaf3d96 --- /dev/null +++ b/integration_test/config/db/config.yaml @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: ConfigMap +apiVersion: "1.0" +metadata: + name: arana-config +data: + listeners: + - protocol_type: mysql + server_version: 5.7.0 + socket_address: + address: 0.0.0.0 + port: 13306 + + tenants: + - name: arana + users: + - username: arana + password: "123456" + - username: dksl + password: "123456" + + clusters: + - name: employees + type: mysql + sql_max_limit: -1 + tenant: arana + groups: + - name: employees_0000 + nodes: + - name: node0 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0000 + weight: r10w10 + - name: employees_0001 + nodes: + - name: node1 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0001 + weight: r10w10 + - name: employees_0002 + nodes: + - name: node2 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0002 + weight: r10w10 + - name: employees_0003 + nodes: + - name: node3 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0003 + weight: r10w10 + + sharding_rule: + tables: + - name: employees.student + allow_full_scan: true + db_rules: + - column: uid + type: scriptExpr + expr: parseInt($value % 32 / 8) + step: 32 + tbl_rules: + - column: uid + type: scriptExpr + expr: parseInt(0) + topology: + db_pattern: employees_${0000...0003} + tbl_pattern: student_0000 + attributes: + sqlMaxLimit: -1 diff --git a/integration_test/config/db/data.yaml b/integration_test/config/db/data.yaml new file mode 100644 index 000000000..112cb3c18 --- /dev/null +++ b/integration_test/config/db/data.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: DataSet +metadata: + tables: + - name: "order" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "order" + value: + - ["test", "test1"] diff --git a/integration_test/config/db/expected.yaml b/integration_test/config/db/expected.yaml new file mode 100644 index 000000000..e16b99a92 --- /dev/null +++ b/integration_test/config/db/expected.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: excepted +metadata: + tables: + - name: "order" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "order" + value: + - ["test", "test1"] diff --git a/integration_test/config/db_tbl/config.yaml b/integration_test/config/db_tbl/config.yaml new file mode 100644 index 000000000..4b421d229 --- /dev/null +++ b/integration_test/config/db_tbl/config.yaml @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: ConfigMap +apiVersion: "1.0" +metadata: + name: arana-config +data: + listeners: + - protocol_type: mysql + server_version: 5.7.0 + socket_address: + address: 0.0.0.0 + port: 13306 + + tenants: + - name: arana + users: + - username: arana + password: "123456" + - username: dksl + password: "123456" + + clusters: + - name: employees + type: mysql + sql_max_limit: -1 + tenant: arana + groups: + - name: employees_0000 + nodes: + - name: node0 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0000 + weight: r10w10 + - name: node0_r_0 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0000_r + weight: r0w0 + - name: employees_0001 + nodes: + - name: node1 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0001 + weight: r10w10 + - name: employees_0002 + nodes: + - name: node2 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0002 + weight: r10w10 + - name: employees_0003 + nodes: + - name: node3 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0003 + weight: r10w10 + + sharding_rule: + tables: + - name: employees.student + allow_full_scan: true + db_rules: + - column: uid + type: scriptExpr + expr: parseInt($value % 32 / 8) + tbl_rules: + - column: uid + type: scriptExpr + expr: $value % 32 + topology: + db_pattern: employees_${0000..0003} + tbl_pattern: student_${0000..0031} + attributes: + sqlMaxLimit: -1 diff --git a/integration_test/config/db_tbl/data.yaml b/integration_test/config/db_tbl/data.yaml new file mode 100644 index 000000000..112cb3c18 --- /dev/null +++ b/integration_test/config/db_tbl/data.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: DataSet +metadata: + tables: + - name: "order" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "order" + value: + - ["test", "test1"] diff --git a/integration_test/config/db_tbl/expected.yaml b/integration_test/config/db_tbl/expected.yaml new file mode 100644 index 000000000..542d37822 --- /dev/null +++ b/integration_test/config/db_tbl/expected.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: excepted +metadata: + tables: + - name: "sequence" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "sequence" + value: + - ["1", "2"] diff --git a/integration_test/config/tbl/config.yaml b/integration_test/config/tbl/config.yaml new file mode 100644 index 000000000..3eb3564bb --- /dev/null +++ b/integration_test/config/tbl/config.yaml @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: ConfigMap +apiVersion: "1.0" +metadata: + name: arana-config +data: + listeners: + - protocol_type: mysql + server_version: 5.7.0 + socket_address: + address: 0.0.0.0 + port: 13306 + + tenants: + - name: arana + users: + - username: arana + password: "123456" + - username: dksl + password: "123456" + + clusters: + - name: employees + type: mysql + sql_max_limit: -1 + tenant: arana + groups: + - name: employees_0000 + nodes: + - name: node0 + host: arana-mysql + port: 3306 + username: root + password: "123456" + database: employees_0000 + weight: r10w10 + + sharding_rule: + tables: + - name: employees.student + allow_full_scan: true + db_rules: + - column: uid + type: scriptExpr + expr: parseInt(0) + tbl_rules: + - column: uid + type: scriptExpr + expr: $value % 32 + topology: + db_pattern: employees_0000 + tbl_pattern: student_${0000..0031} + attributes: + sqlMaxLimit: -1 diff --git a/integration_test/config/tbl/data.yaml b/integration_test/config/tbl/data.yaml new file mode 100644 index 000000000..112cb3c18 --- /dev/null +++ b/integration_test/config/tbl/data.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: DataSet +metadata: + tables: + - name: "order" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "order" + value: + - ["test", "test1"] diff --git a/integration_test/config/tbl/expected.yaml b/integration_test/config/tbl/expected.yaml new file mode 100644 index 000000000..e16b99a92 --- /dev/null +++ b/integration_test/config/tbl/expected.yaml @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +kind: excepted +metadata: + tables: + - name: "order" + columns: + - name: "name" + type: "string" + - name: "value" + type: "string" +data: + - name: "order" + value: + - ["test", "test1"] diff --git a/integration_test/scene/db/integration_test.go b/integration_test/scene/db/integration_test.go new file mode 100644 index 000000000..c7188fcba --- /dev/null +++ b/integration_test/scene/db/integration_test.go @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "strings" + "testing" +) + +import ( + _ "github.com/go-sql-driver/mysql" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +import ( + "github.com/arana-db/arana/test" +) + +type IntegrationSuite struct { + *test.MySuite +} + +func TestSuite(t *testing.T) { + su := test.NewMySuite( + test.WithMySQLServerAuth("root", "123456"), + test.WithMySQLDatabase("employees"), + test.WithConfig("../integration_test/config/db/config.yaml"), + test.WithScriptPath("../integration_test/scripts/db"), + test.WithTestCasePath("../../testcase/casetest.yaml"), + // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!! + ) + suite.Run(t, &IntegrationSuite{su}) +} + +func (s *IntegrationSuite) TestDBScene() { + var ( + db = s.DB() + t = s.T() + ) + tx, err := db.Begin() + assert.NoError(t, err, "should begin a new tx") + + cases := s.TestCases() + for _, sqlCase := range cases.ExecCases { + for _, sense := range sqlCase.Sense { + if strings.TrimSpace(sense) == "db" { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + // Execute sql + result, err := tx.Exec(sqlCase.SQL, args...) + assert.NoError(t, err, "exec not right") + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } + + for _, sqlCase := range cases.QueryRowCases { + for _, sense := range sqlCase.Sense { + if strings.TrimSpace(sense) == "db" { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + result := tx.QueryRow(sqlCase.SQL, args...) + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } + + for _, sqlCase := range cases.QueryRowsCases { + s.LoadExpectedDataSetPath(sqlCase.ExpectedResult.Value) + for _, sense := range sqlCase.Sense { + if strings.TrimSpace(sense) == "db" { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + result, err := db.Query(sqlCase.SQL, args...) + assert.NoError(t, err, err) + err = sqlCase.ExpectedResult.CompareRows(result, s.ExpectedDataset()) + assert.NoError(t, err, err) + } + } + } +} diff --git a/integration_test/scene/db_tbl/integration_test.go b/integration_test/scene/db_tbl/integration_test.go new file mode 100644 index 000000000..e7878729a --- /dev/null +++ b/integration_test/scene/db_tbl/integration_test.go @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "strings" + "testing" +) + +import ( + _ "github.com/go-sql-driver/mysql" // register mysql + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +import ( + "github.com/arana-db/arana/test" +) + +type IntegrationSuite struct { + *test.MySuite +} + +func TestSuite(t *testing.T) { + su := test.NewMySuite( + test.WithMySQLServerAuth("root", "123456"), + test.WithMySQLDatabase("employees"), + test.WithConfig("../integration_test/config/db_tbl/config.yaml"), + test.WithScriptPath("../integration_test/scripts/db_tbl"), + test.WithTestCasePath("../../testcase/casetest.yaml"), + // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!! + ) + suite.Run(t, &IntegrationSuite{su}) +} + +func (s *IntegrationSuite) TestDBTBLScene() { + var ( + db = s.DB() + t = s.T() + ) + tx, err := db.Begin() + assert.NoError(t, err, "should begin a new tx") + + cases := s.TestCases() + for _, sqlCase := range cases.ExecCases { + for _, sense := range sqlCase.Sense { + if strings.Compare(strings.TrimSpace(sense), "db_tbl") == 1 { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + // Execute sql + result, err := tx.Exec(sqlCase.SQL, args...) + assert.NoError(t, err, "exec not right") + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } + + for _, sqlCase := range cases.QueryRowCases { + for _, sense := range sqlCase.Sense { + if strings.Compare(strings.TrimSpace(sense), "db_tbl") == 1 { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + result := tx.QueryRow(sqlCase.SQL, args...) + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } +} diff --git a/integration_test/scene/tbl/integration_test.go b/integration_test/scene/tbl/integration_test.go new file mode 100644 index 000000000..cc95a5853 --- /dev/null +++ b/integration_test/scene/tbl/integration_test.go @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "strings" + "testing" +) + +import ( + _ "github.com/go-sql-driver/mysql" // register mysql + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +import ( + "github.com/arana-db/arana/test" +) + +type IntegrationSuite struct { + *test.MySuite +} + +func TestSuite(t *testing.T) { + su := test.NewMySuite( + test.WithMySQLServerAuth("root", "123456"), + test.WithMySQLDatabase("employees"), + test.WithConfig("../integration_test/config/tbl/config.yaml"), + test.WithScriptPath("../integration_test/scripts/tbl"), + test.WithTestCasePath("../../testcase/casetest.yaml"), + // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!! + ) + suite.Run(t, &IntegrationSuite{su}) +} + +func (s *IntegrationSuite) TestDBTBLScene() { + var ( + db = s.DB() + t = s.T() + ) + tx, err := db.Begin() + assert.NoError(t, err, "should begin a new tx") + + cases := s.TestCases() + for _, sqlCase := range cases.ExecCases { + for _, sense := range sqlCase.Sense { + if strings.Compare(strings.TrimSpace(sense), "tbl") == 1 { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + // Execute sql + result, err := tx.Exec(sqlCase.SQL, args...) + assert.NoError(t, err, "exec not right") + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } + + for _, sqlCase := range cases.QueryRowCases { + for _, sense := range sqlCase.Sense { + if strings.Compare(strings.TrimSpace(sense), "tbl") == 1 { + params := strings.Split(sqlCase.Parameters, ",") + args := make([]interface{}, 0, len(params)) + for _, param := range params { + k, _ := test.GetValueByType(param) + args = append(args, k) + } + + result := tx.QueryRow(sqlCase.SQL, args...) + err = sqlCase.ExpectedResult.CompareRow(result) + assert.NoError(t, err, err) + } + } + } + +} diff --git a/integration_test/scripts/db/init.sql b/integration_test/scripts/db/init.sql new file mode 100644 index 000000000..744c67674 --- /dev/null +++ b/integration_test/scripts/db/init.sql @@ -0,0 +1,113 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +-- Sample employee database +-- See changelog table for details +-- Copyright (C) 2007,2008, MySQL AB +-- +-- Original data created by Fusheng Wang and Carlo Zaniolo +-- http://www.cs.aau.dk/TimeCenter/software.htm +-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip +-- +-- Current schema by Giuseppe Maxia +-- Data conversion from XML to relational by Patrick Crews +-- +-- This work is licensed under the +-- Creative Commons Attribution-Share Alike 3.0 Unported License. +-- To view a copy of this license, visit +-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to +-- Creative Commons, 171 Second Street, Suite 300, San Francisco, +-- California, 94105, USA. +-- +-- DISCLAIMER +-- To the best of our knowledge, this data is fabricated, and +-- it does not correspond to real people. +-- Any similarity to existing people is purely coincidental. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE employees_0000; + +SELECT 'CREATING DATABASE STRUCTURE' as 'INFO'; + +DROP TABLE IF EXISTS dept_emp, + dept_manager, + titles, + salaries, + employees, + departments; + +/*!50503 set default_storage_engine = InnoDB */; +/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */; + +CREATE TABLE employees ( + emp_no INT NOT NULL, + birth_date DATE NOT NULL, + first_name VARCHAR(14) NOT NULL, + last_name VARCHAR(16) NOT NULL, + gender ENUM ('M','F') NOT NULL, + hire_date DATE NOT NULL, + PRIMARY KEY (emp_no) +); + +CREATE TABLE departments ( + dept_no CHAR(4) NOT NULL, + dept_name VARCHAR(40) NOT NULL, + PRIMARY KEY (dept_no), + UNIQUE KEY (dept_name) +); + +CREATE TABLE dept_manager ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE dept_emp ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE titles ( + emp_no INT NOT NULL, + title VARCHAR(50) NOT NULL, + from_date DATE NOT NULL, + to_date DATE, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,title, from_date) +) +; + +CREATE TABLE salaries ( + emp_no INT NOT NULL, + salary INT NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no, from_date) +) +; diff --git a/integration_test/scripts/db/sequence.sql b/integration_test/scripts/db/sequence.sql new file mode 100644 index 000000000..0df0add1c --- /dev/null +++ b/integration_test/scripts/db/sequence.sql @@ -0,0 +1,31 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence` +( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `name` VARCHAR(64) NOT NULL, + `value` BIGINT NOT NULL, + `step` INT NOT NULL DEFAULT 10000, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_name` (`name`) +) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4; diff --git a/integration_test/scripts/db/sharding.sql b/integration_test/scripts/db/sharding.sql new file mode 100644 index 000000000..1f22d4036 --- /dev/null +++ b/integration_test/scripts/db/sharding.sql @@ -0,0 +1,83 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +INSERT INTO employees_0000.student_0000 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW()); diff --git a/integration_test/scripts/db_tbl/init.sql b/integration_test/scripts/db_tbl/init.sql new file mode 100644 index 000000000..744c67674 --- /dev/null +++ b/integration_test/scripts/db_tbl/init.sql @@ -0,0 +1,113 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +-- Sample employee database +-- See changelog table for details +-- Copyright (C) 2007,2008, MySQL AB +-- +-- Original data created by Fusheng Wang and Carlo Zaniolo +-- http://www.cs.aau.dk/TimeCenter/software.htm +-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip +-- +-- Current schema by Giuseppe Maxia +-- Data conversion from XML to relational by Patrick Crews +-- +-- This work is licensed under the +-- Creative Commons Attribution-Share Alike 3.0 Unported License. +-- To view a copy of this license, visit +-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to +-- Creative Commons, 171 Second Street, Suite 300, San Francisco, +-- California, 94105, USA. +-- +-- DISCLAIMER +-- To the best of our knowledge, this data is fabricated, and +-- it does not correspond to real people. +-- Any similarity to existing people is purely coincidental. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE employees_0000; + +SELECT 'CREATING DATABASE STRUCTURE' as 'INFO'; + +DROP TABLE IF EXISTS dept_emp, + dept_manager, + titles, + salaries, + employees, + departments; + +/*!50503 set default_storage_engine = InnoDB */; +/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */; + +CREATE TABLE employees ( + emp_no INT NOT NULL, + birth_date DATE NOT NULL, + first_name VARCHAR(14) NOT NULL, + last_name VARCHAR(16) NOT NULL, + gender ENUM ('M','F') NOT NULL, + hire_date DATE NOT NULL, + PRIMARY KEY (emp_no) +); + +CREATE TABLE departments ( + dept_no CHAR(4) NOT NULL, + dept_name VARCHAR(40) NOT NULL, + PRIMARY KEY (dept_no), + UNIQUE KEY (dept_name) +); + +CREATE TABLE dept_manager ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE dept_emp ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE titles ( + emp_no INT NOT NULL, + title VARCHAR(50) NOT NULL, + from_date DATE NOT NULL, + to_date DATE, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,title, from_date) +) +; + +CREATE TABLE salaries ( + emp_no INT NOT NULL, + salary INT NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no, from_date) +) +; diff --git a/integration_test/scripts/db_tbl/sequence.sql b/integration_test/scripts/db_tbl/sequence.sql new file mode 100644 index 000000000..0df0add1c --- /dev/null +++ b/integration_test/scripts/db_tbl/sequence.sql @@ -0,0 +1,31 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence` +( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `name` VARCHAR(64) NOT NULL, + `value` BIGINT NOT NULL, + `step` INT NOT NULL DEFAULT 10000, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_name` (`name`) +) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4; diff --git a/integration_test/scripts/db_tbl/sharding.sql b/integration_test/scripts/db_tbl/sharding.sql new file mode 100644 index 000000000..e98faa26f --- /dev/null +++ b/integration_test/scripts/db_tbl/sharding.sql @@ -0,0 +1,633 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE DATABASE IF NOT EXISTS employees_0000_r CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0008` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0009` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0010` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0011` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0012` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0013` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0014` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0015` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0016` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0017` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0018` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0019` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0020` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0021` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0022` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0023` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0024` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0025` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0026` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0027` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0028` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0029` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0030` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0031` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0001` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0002` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0003` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0004` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0005` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0006` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0007` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW()); diff --git a/integration_test/scripts/tbl/init.sql b/integration_test/scripts/tbl/init.sql new file mode 100644 index 000000000..744c67674 --- /dev/null +++ b/integration_test/scripts/tbl/init.sql @@ -0,0 +1,113 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +-- Sample employee database +-- See changelog table for details +-- Copyright (C) 2007,2008, MySQL AB +-- +-- Original data created by Fusheng Wang and Carlo Zaniolo +-- http://www.cs.aau.dk/TimeCenter/software.htm +-- http://www.cs.aau.dk/TimeCenter/Data/employeeTemporalDataSet.zip +-- +-- Current schema by Giuseppe Maxia +-- Data conversion from XML to relational by Patrick Crews +-- +-- This work is licensed under the +-- Creative Commons Attribution-Share Alike 3.0 Unported License. +-- To view a copy of this license, visit +-- http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to +-- Creative Commons, 171 Second Street, Suite 300, San Francisco, +-- California, 94105, USA. +-- +-- DISCLAIMER +-- To the best of our knowledge, this data is fabricated, and +-- it does not correspond to real people. +-- Any similarity to existing people is purely coincidental. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE employees_0000; + +SELECT 'CREATING DATABASE STRUCTURE' as 'INFO'; + +DROP TABLE IF EXISTS dept_emp, + dept_manager, + titles, + salaries, + employees, + departments; + +/*!50503 set default_storage_engine = InnoDB */; +/*!50503 select CONCAT('storage engine: ', @@default_storage_engine) as INFO */; + +CREATE TABLE employees ( + emp_no INT NOT NULL, + birth_date DATE NOT NULL, + first_name VARCHAR(14) NOT NULL, + last_name VARCHAR(16) NOT NULL, + gender ENUM ('M','F') NOT NULL, + hire_date DATE NOT NULL, + PRIMARY KEY (emp_no) +); + +CREATE TABLE departments ( + dept_no CHAR(4) NOT NULL, + dept_name VARCHAR(40) NOT NULL, + PRIMARY KEY (dept_no), + UNIQUE KEY (dept_name) +); + +CREATE TABLE dept_manager ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE dept_emp ( + emp_no INT NOT NULL, + dept_no CHAR(4) NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + FOREIGN KEY (dept_no) REFERENCES departments (dept_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,dept_no) +); + +CREATE TABLE titles ( + emp_no INT NOT NULL, + title VARCHAR(50) NOT NULL, + from_date DATE NOT NULL, + to_date DATE, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no,title, from_date) +) +; + +CREATE TABLE salaries ( + emp_no INT NOT NULL, + salary INT NOT NULL, + from_date DATE NOT NULL, + to_date DATE NOT NULL, + FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, + PRIMARY KEY (emp_no, from_date) +) +; diff --git a/integration_test/scripts/tbl/sequence.sql b/integration_test/scripts/tbl/sequence.sql new file mode 100644 index 000000000..0df0add1c --- /dev/null +++ b/integration_test/scripts/tbl/sequence.sql @@ -0,0 +1,31 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence` +( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `name` VARCHAR(64) NOT NULL, + `value` BIGINT NOT NULL, + `step` INT NOT NULL DEFAULT 10000, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_name` (`name`) +) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb4; diff --git a/integration_test/scripts/tbl/sharding.sql b/integration_test/scripts/tbl/sharding.sql new file mode 100644 index 000000000..8a00efcb7 --- /dev/null +++ b/integration_test/scripts/tbl/sharding.sql @@ -0,0 +1,500 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0008` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0009` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0010` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0011` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0012` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0013` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0014` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0015` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0016` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0017` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0018` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0019` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0020` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0021` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0022` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0023` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0024` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0025` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0026` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0027` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0028` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0029` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0030` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0031` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW()); diff --git a/integration_test/testcase/casetest.yaml b/integration_test/testcase/casetest.yaml new file mode 100644 index 000000000..0324d827a --- /dev/null +++ b/integration_test/testcase/casetest.yaml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +kind: dml +query_rows_cases: + - sql: "SELECT name, value FROM sequence WHERE name= ?" + parameters: "1:string" + sense: + - db + expected: + type: "file" + value: "../../config/db/expected.yaml" +query_row_cases: + - sql: "SELECT COUNT(1) FROM sequence WHERE name=?" + parameters: "1:string" + sense: + - db + - tbl + expected: + type: "valueInt" + value: "1" +exec_cases: + - sql: "INSERT INTO sequence(name,value,modified_at) VALUES(?,?,NOW())" + parameters: "1:string, 2:string" + sense: + - db + - tbl + expected: + type: "rowAffect" + value: 1 diff --git a/justfile b/justfile index 0dfff1acf..eb4f510f9 100644 --- a/justfile +++ b/justfile @@ -7,7 +7,7 @@ default: @just --list run: - @go run ./cmd/... start -c ./conf/bootstrap.yaml + @go run ./example/local_server cli: @mycli -h127.0.0.1 -P13306 -udksl employees -p123456 @@ -16,4 +16,4 @@ cli-raw: fix: @imports-formatter . - @license-eye header fix \ No newline at end of file + @license-eye header fix diff --git a/pkg/boot/boot.go b/pkg/boot/boot.go index 56126b4de..0b30863d5 100644 --- a/pkg/boot/boot.go +++ b/pkg/boot/boot.go @@ -30,7 +30,7 @@ import ( "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime" "github.com/arana-db/arana/pkg/runtime/namespace" - "github.com/arana-db/arana/pkg/runtime/optimize" + _ "github.com/arana-db/arana/pkg/schema" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" ) @@ -129,5 +129,5 @@ func buildNamespace(ctx context.Context, provider Discovery, cluster string) (*n } initCmds = append(initCmds, namespace.UpdateRule(&ru)) - return namespace.New(cluster, optimize.GetOptimizer(), initCmds...), nil + return namespace.New(cluster, initCmds...), nil } diff --git a/pkg/boot/discovery.go b/pkg/boot/discovery.go index d76331c5e..669dfad1b 100644 --- a/pkg/boot/discovery.go +++ b/pkg/boot/discovery.go @@ -176,7 +176,6 @@ func (fp *discovery) GetCluster(ctx context.Context, cluster string) (*Cluster, } func (fp *discovery) ListTenants(ctx context.Context) ([]string, error) { - cfg, err := fp.c.Load() if err != nil { return nil, err @@ -269,7 +268,7 @@ func (fp *discovery) ListTables(ctx context.Context, cluster string) ([]string, } var tables []string - for tb, _ := range fp.loadTables(cfg, cluster) { + for tb := range fp.loadTables(cfg, cluster) { tables = append(tables, tb) } sort.Strings(tables) @@ -325,6 +324,7 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (* var ( keys map[string]struct{} dbSharder, tbSharder map[string]rule.ShardComputer + dbSteps, tbSteps map[string]int ) for _, it := range table.DbRules { var shd rule.ShardComputer @@ -337,8 +337,12 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (* if keys == nil { keys = make(map[string]struct{}) } + if dbSteps == nil { + dbSteps = make(map[string]int) + } dbSharder[it.Column] = shd keys[it.Column] = struct{}{} + dbSteps[it.Column] = it.Step } for _, it := range table.TblRules { @@ -352,8 +356,12 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (* if keys == nil { keys = make(map[string]struct{}) } + if tbSteps == nil { + tbSteps = make(map[string]int) + } tbSharder[it.Column] = shd keys[it.Column] = struct{}{} + tbSteps[it.Column] = it.Step } for k := range keys { @@ -366,7 +374,9 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (* Computer: shd, Stepper: rule.DefaultNumberStepper, } - if dbBegin >= 0 && dbEnd >= 0 { + if s, ok := dbSteps[k]; ok && s > 0 { + dbMetadata.Steps = s + } else if dbBegin >= 0 && dbEnd >= 0 { dbMetadata.Steps = 1 + dbEnd - dbBegin } } @@ -375,14 +385,20 @@ func (fp *discovery) GetTable(ctx context.Context, cluster, tableName string) (* Computer: shd, Stepper: rule.DefaultNumberStepper, } - if tbBegin >= 0 && tbEnd >= 0 { + if s, ok := tbSteps[k]; ok && s > 0 { + tbMetadata.Steps = s + } else if tbBegin >= 0 && tbEnd >= 0 { tbMetadata.Steps = 1 + tbEnd - tbBegin } } vt.SetShardMetadata(k, dbMetadata, tbMetadata) tpRes := make(map[int][]int) - rng, _ := tbMetadata.Stepper.Ascend(0, tbMetadata.Steps) + step := tbMetadata.Steps + if dbMetadata.Steps > step { + step = dbMetadata.Steps + } + rng, _ := tbMetadata.Stepper.Ascend(0, step) for rng.HasNext() { var ( seed = rng.Next() @@ -472,7 +488,7 @@ var ( func getTopologyRegexp() *regexp.Regexp { _regexpTopologyOnce.Do(func() { - _regexpTopology = regexp.MustCompile(`\${(?P[0-9]+)\.\.\.(?P[0-9]+)}`) + _regexpTopology = regexp.MustCompile(`\${(?P\d+)\.{2,}(?P\d+)}`) }) return _regexpTopology } @@ -492,9 +508,7 @@ func parseTopology(input string) (format string, begin, end int, err error) { return } - var ( - beginStr, endStr string - ) + var beginStr, endStr string for i := 1; i < len(mats[0]); i++ { switch getTopologyRegexp().SubexpNames()[i] { case "begin": @@ -518,27 +532,29 @@ func parseTopology(input string) (format string, begin, end int, err error) { func toSharder(input *config.Rule) (rule.ShardComputer, error) { var ( computer rule.ShardComputer - method string mod int err error ) if mat := getRuleExprRegexp().FindStringSubmatch(input.Expr); len(mat) == 3 { - method = mat[1] mod, _ = strconv.Atoi(mat[2]) } - switch method { - case string(rrule.ModShard): + switch rrule.ShardType(input.Type) { + case rrule.ModShard: computer = rrule.NewModShard(mod) - case string(rrule.HashMd5Shard): + case rrule.HashMd5Shard: computer = rrule.NewHashMd5Shard(mod) - case string(rrule.HashBKDRShard): + case rrule.HashBKDRShard: computer = rrule.NewHashBKDRShard(mod) - case string(rrule.HashCrc32Shard): + case rrule.HashCrc32Shard: computer = rrule.NewHashCrc32Shard(mod) - default: + case rrule.FunctionExpr: + computer, err = rrule.NewExprShardComputer(input.Expr, input.Column) + case rrule.ScriptExpr: computer, err = rrule.NewJavascriptShardComputer(input.Expr) + default: + panic(fmt.Errorf("error config, unsupport shard type: %s", input.Type)) } return computer, err } diff --git a/pkg/config/api.go b/pkg/config/api.go index 7bd8bb591..08105ac0f 100644 --- a/pkg/config/api.go +++ b/pkg/config/api.go @@ -53,7 +53,7 @@ const ( ) var ( - slots map[string]StoreOperate = make(map[string]StoreOperate) + slots = make(map[string]StoreOperate) storeOperate StoreOperate ) @@ -66,7 +66,6 @@ func GetStoreOperate() (StoreOperate, error) { } func Init(name string, options map[string]interface{}) error { - s, exist := slots[name] if !exist { return fmt.Errorf("StoreOperate solt=[%s] not exist", name) @@ -77,7 +76,7 @@ func Init(name string, options map[string]interface{}) error { return storeOperate.Init(options) } -//Register register store plugin +// Register register store plugin func Register(s StoreOperate) { if _, ok := slots[s.Name()]; ok { panic(fmt.Errorf("StoreOperate=[%s] already exist", s.Name())) @@ -86,22 +85,22 @@ func Register(s StoreOperate) { slots[s.Name()] = s } -//StoreOperate config storage related plugins +// StoreOperate config storage related plugins type StoreOperate interface { io.Closer - //Init plugin initialization + // Init plugin initialization Init(options map[string]interface{}) error - //Save save a configuration data + // Save save a configuration data Save(key PathKey, val []byte) error - //Get get a configuration + // Get get a configuration Get(key PathKey) ([]byte, error) - //Watch Monitor changes of the key + // Watch Monitor changes of the key Watch(key PathKey) (<-chan []byte, error) - //Name plugin name + // Name plugin name Name() string } diff --git a/pkg/config/config.go b/pkg/config/config.go index 5531b5531..e1343975b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -36,6 +36,7 @@ import ( ) import ( + "github.com/arana-db/arana/pkg/util/env" "github.com/arana-db/arana/pkg/util/log" ) @@ -136,7 +137,9 @@ func (c *Center) LoadContext(ctx context.Context) (*Configuration, error) { c.confHolder.Store(cfg) out, _ := yaml.Marshal(cfg) - log.Debugf("load configuration:\n%s", string(out)) + if env.IsDevelopEnvironment() { + log.Infof("load configuration:\n%s", string(out)) + } } val = c.confHolder.Load() @@ -255,7 +258,6 @@ func (c *Center) PersistContext(ctx context.Context) error { } for k, v := range ConfigKeyMapping { - if err := c.storeOperate.Save(k, []byte(gjson.GetBytes(configJson, v).String())); err != nil { return err } diff --git a/pkg/config/etcd/etcd.go b/pkg/config/etcd/etcd.go index caf1fda31..5c8bbacef 100644 --- a/pkg/config/etcd/etcd.go +++ b/pkg/config/etcd/etcd.go @@ -35,6 +35,7 @@ import ( import ( "github.com/arana-db/arana/pkg/config" + "github.com/arana-db/arana/pkg/util/env" "github.com/arana-db/arana/pkg/util/log" ) @@ -79,6 +80,10 @@ func (c *storeOperate) Get(key config.PathKey) ([]byte, error) { return nil, err } + if env.IsDevelopEnvironment() { + log.Infof("[ConfigCenter][etcd] load config content : %#v", v) + } + return []byte(v), nil } diff --git a/pkg/config/file/file.go b/pkg/config/file/file.go index aa2bc276e..a8cac1abb 100644 --- a/pkg/config/file/file.go +++ b/pkg/config/file/file.go @@ -27,18 +27,19 @@ import ( import ( "github.com/pkg/errors" - "github.com/tidwall/gjson" - "gopkg.in/yaml.v3" ) import ( "github.com/arana-db/arana/pkg/config" "github.com/arana-db/arana/pkg/constants" + "github.com/arana-db/arana/pkg/util/env" "github.com/arana-db/arana/pkg/util/log" ) +var configFilenameList = []string{"config.yaml", "config.yml"} + func init() { config.Register(&storeOperate{}) } @@ -96,7 +97,9 @@ func (s *storeOperate) initCfgJsonMap(val string) { s.cfgJson[k] = gjson.Get(val, v).String() } - log.Debugf("[ConfigCenter][File] load config content : %#v", s.cfgJson) + if env.IsDevelopEnvironment() { + log.Infof("[ConfigCenter][File] load config content : %#v", s.cfgJson) + } } func (s *storeOperate) Save(key config.PathKey, val []byte) error { @@ -108,7 +111,7 @@ func (s *storeOperate) Get(key config.PathKey) ([]byte, error) { return val, nil } -//Watch TODO change notification through file inotify mechanism +// Watch TODO change notification through file inotify mechanism func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) { defer s.lock.Unlock() @@ -164,13 +167,11 @@ func (s *storeOperate) readFromFile(path string, cfg *config.Configuration) erro func (s *storeOperate) searchDefaultConfigFile() (string, bool) { var p string for _, it := range constants.GetConfigSearchPathList() { - p = filepath.Join(it, "config.yaml") - if _, err := os.Stat(p); err == nil { - return p, true - } - p = filepath.Join(it, "config.yml") - if _, err := os.Stat(p); err == nil { - return p, true + for _, filename := range configFilenameList { + p = filepath.Join(it, filename) + if _, err := os.Stat(p); err == nil { + return p, true + } } } return "", false diff --git a/pkg/config/model.go b/pkg/config/model.go index d1e77869e..d160ce285 100644 --- a/pkg/config/model.go +++ b/pkg/config/model.go @@ -20,10 +20,13 @@ package config import ( "bytes" "encoding/json" + "fmt" "io" "os" "regexp" "strconv" + "strings" + "time" ) import ( @@ -78,31 +81,25 @@ type ( Type DataSourceType `yaml:"type" json:"type"` SqlMaxLimit int `default:"-1" yaml:"sql_max_limit" json:"sql_max_limit,omitempty"` Tenant string `yaml:"tenant" json:"tenant"` - ConnProps *ConnProp `yaml:"conn_props" json:"conn_props,omitempty"` Groups []*Group `yaml:"groups" json:"groups"` } - ConnProp struct { - Capacity int `yaml:"capacity" json:"capacity,omitempty"` // connection pool capacity - MaxCapacity int `yaml:"max_capacity" json:"max_capacity,omitempty"` // max connection pool capacity - IdleTimeout int `yaml:"idle_timeout" json:"idle_timeout,omitempty"` // close backend direct connection after idle_timeout - } - Group struct { Name string `yaml:"name" json:"name"` Nodes []*Node `yaml:"nodes" json:"nodes"` } Node struct { - Name string `validate:"required" yaml:"name" json:"name"` - Host string `validate:"required" yaml:"host" json:"host"` - Port int `validate:"required" yaml:"port" json:"port"` - Username string `validate:"required" yaml:"username" json:"username"` - Password string `validate:"required" yaml:"password" json:"password"` - Database string `validate:"required" yaml:"database" json:"database"` - ConnProps map[string]string `yaml:"conn_props" json:"conn_props,omitempty"` - Weight string `default:"r10w10" yaml:"weight" json:"weight"` - Labels map[string]string `yaml:"labels" json:"labels,omitempty"` + Name string `validate:"required" yaml:"name" json:"name"` + Host string `validate:"required" yaml:"host" json:"host"` + Port int `validate:"required" yaml:"port" json:"port"` + Username string `validate:"required" yaml:"username" json:"username"` + Password string `validate:"required" yaml:"password" json:"password"` + Database string `validate:"required" yaml:"database" json:"database"` + Parameters ParametersMap `yaml:"parameters" json:"parameters"` + ConnProps map[string]interface{} `yaml:"conn_props" json:"conn_props,omitempty"` + Weight string `default:"r10w10" yaml:"weight" json:"weight"` + Labels map[string]string `yaml:"labels" json:"labels,omitempty"` } ShardingRule struct { @@ -163,7 +160,9 @@ type ( Rule struct { Column string `validate:"required" yaml:"column" json:"column"` + Type string `validate:"required" yaml:"type" json:"type"` Expr string `validate:"required" yaml:"expr" json:"expr"` + Step int `yaml:"step" json:"step"` } Topology struct { @@ -172,6 +171,19 @@ type ( } ) +type ParametersMap map[string]string + +func (pm *ParametersMap) String() string { + sBuff := strings.Builder{} + for k, v := range *pm { + sBuff.WriteString(k) + sBuff.WriteString("=") + sBuff.WriteString(v) + sBuff.WriteString("&") + } + return strings.TrimRight(sBuff.String(), "&") +} + // Decoder decodes configuration. type Decoder struct { reader io.Reader @@ -262,3 +274,62 @@ func Validate(cfg *Configuration) error { v := validator.New() return v.Struct(cfg) } + +// GetConnPropCapacity parses the capacity of backend connection pool, return default value if failed. +func GetConnPropCapacity(connProps map[string]interface{}, defaultValue int) int { + capacity, ok := connProps["capacity"] + if !ok { + return defaultValue + } + n, _ := strconv.Atoi(fmt.Sprint(capacity)) + if n < 1 { + return defaultValue + } + return n +} + +// GetConnPropMaxCapacity parses the max capacity of backend connection pool, return default value if failed. +func GetConnPropMaxCapacity(connProps map[string]interface{}, defaultValue int) int { + var ( + maxCapacity interface{} + ok bool + ) + + if maxCapacity, ok = connProps["max_capacity"]; !ok { + if maxCapacity, ok = connProps["maxCapacity"]; !ok { + return defaultValue + } + } + n, _ := strconv.Atoi(fmt.Sprint(maxCapacity)) + if n < 1 { + return defaultValue + } + return n +} + +// GetConnPropIdleTime parses the idle time of backend connection pool, return default value if failed. +func GetConnPropIdleTime(connProps map[string]interface{}, defaultValue time.Duration) time.Duration { + var ( + idleTime interface{} + ok bool + ) + + if idleTime, ok = connProps["idle_time"]; !ok { + if idleTime, ok = connProps["idleTime"]; !ok { + return defaultValue + } + } + + s := fmt.Sprint(idleTime) + d, _ := time.ParseDuration(s) + if d > 0 { + return d + } + + n, _ := strconv.Atoi(s) + if n < 1 { + return defaultValue + } + + return time.Duration(n) * time.Second +} diff --git a/pkg/config/model_test.go b/pkg/config/model_test.go index baba8a461..188d72931 100644 --- a/pkg/config/model_test.go +++ b/pkg/config/model_test.go @@ -64,10 +64,6 @@ func TestDataSourceClustersConf(t *testing.T) { assert.Equal(t, DBMySQL, dataSourceCluster.Type) assert.Equal(t, -1, dataSourceCluster.SqlMaxLimit) assert.Equal(t, "arana", dataSourceCluster.Tenant) - assert.NotNil(t, dataSourceCluster.ConnProps) - assert.Equal(t, 10, dataSourceCluster.ConnProps.Capacity) - assert.Equal(t, 20, dataSourceCluster.ConnProps.MaxCapacity) - assert.Equal(t, 60, dataSourceCluster.ConnProps.IdleTimeout) assert.Equal(t, 1, len(dataSourceCluster.Groups)) group := dataSourceCluster.Groups[0] @@ -112,22 +108,22 @@ func TestShardingRuleConf(t *testing.T) { func TestUnmarshalTextForProtocolTypeNil(t *testing.T) { var protocolType ProtocolType - var text = []byte("http") + text := []byte("http") err := protocolType.UnmarshalText(text) assert.Nil(t, err) assert.Equal(t, Http, protocolType) } func TestUnmarshalTextForUnrecognizedProtocolType(t *testing.T) { - var protocolType = Http - var text = []byte("PostgreSQL") + protocolType := Http + text := []byte("PostgreSQL") err := protocolType.UnmarshalText(text) assert.Error(t, err) } func TestUnmarshalText(t *testing.T) { - var protocolType = Http - var text = []byte("mysql") + protocolType := Http + text := []byte("mysql") err := protocolType.UnmarshalText(text) assert.Nil(t, err) assert.Equal(t, MySQL, protocolType) diff --git a/pkg/config/nacos/nacos.go b/pkg/config/nacos/nacos.go index b4aeb18ea..b9a354005 100644 --- a/pkg/config/nacos/nacos.go +++ b/pkg/config/nacos/nacos.go @@ -34,6 +34,8 @@ import ( import ( "github.com/arana-db/arana/pkg/config" + "github.com/arana-db/arana/pkg/util/env" + "github.com/arana-db/arana/pkg/util/log" ) const ( @@ -52,7 +54,7 @@ func init() { config.Register(&storeOperate{}) } -//StoreOperate config storage related plugins +// StoreOperate config storage related plugins type storeOperate struct { groupName string client config_client.IConfigClient @@ -63,7 +65,7 @@ type storeOperate struct { cancelList []context.CancelFunc } -//Init plugin initialization +// Init plugin initialization func (s *storeOperate) Init(options map[string]interface{}) error { s.lock = &sync.RWMutex{} s.cfgLock = &sync.RWMutex{} @@ -96,7 +98,6 @@ func (s *storeOperate) initNacosClient(options map[string]interface{}) error { ClientConfig: &clientConfig, }, ) - if err != nil { return err } @@ -170,9 +171,8 @@ func (s *storeOperate) loadDataFromServer() error { return nil } -//Save save a configuration data +// Save save a configuration data func (s *storeOperate) Save(key config.PathKey, val []byte) error { - _, err := s.client.PublishConfig(vo.ConfigParam{ Group: s.groupName, DataId: string(key), @@ -182,16 +182,20 @@ func (s *storeOperate) Save(key config.PathKey, val []byte) error { return err } -//Get get a configuration +// Get get a configuration func (s *storeOperate) Get(key config.PathKey) ([]byte, error) { defer s.cfgLock.RUnlock() s.cfgLock.RLock() val := []byte(s.confMap[key]) + + if env.IsDevelopEnvironment() { + log.Infof("[ConfigCenter][nacos] load config content : %#v", string(val)) + } return val, nil } -//Watch Monitor changes of the key +// Watch Monitor changes of the key func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) { defer s.lock.Unlock() s.lock.Lock() @@ -217,12 +221,12 @@ func (s *storeOperate) Watch(key config.PathKey) (<-chan []byte, error) { return rec, nil } -//Name plugin name +// Name plugin name func (s *storeOperate) Name() string { return "nacos" } -//Close do close storeOperate +// Close do close storeOperate func (s *storeOperate) Close() error { return nil } @@ -250,7 +254,6 @@ func (s *storeOperate) newWatcher(key config.PathKey, client config_client.IConf s.receivers[config.PathKey(dataId)].ch <- []byte(content) }, }) - if err != nil { return nil, err } diff --git a/pkg/config/nacos/nacos_test.go b/pkg/config/nacos/nacos_test.go index e693c0a1c..0e4a5442b 100644 --- a/pkg/config/nacos/nacos_test.go +++ b/pkg/config/nacos/nacos_test.go @@ -72,7 +72,6 @@ type TestNacosClient struct { } func newNacosClient() *TestNacosClient { - client := &TestNacosClient{ listeners: make(map[string][]vo.Listener), ch: make(chan vo.ConfigParam, 16), @@ -82,7 +81,6 @@ func newNacosClient() *TestNacosClient { go client.doLongPoll() return client - } func (client *TestNacosClient) doLongPoll() { @@ -151,7 +149,7 @@ func (client *TestNacosClient) ListenConfig(params vo.ConfigParam) (err error) { return nil } -//CancelListenConfig use to cancel listen config change +// CancelListenConfig use to cancel listen config change // dataId require // group require // tenant ==>nacos.namespace optional diff --git a/pkg/constants/env.go b/pkg/constants/env.go index 6830f0fac..60bd3a876 100644 --- a/pkg/constants/env.go +++ b/pkg/constants/env.go @@ -24,8 +24,9 @@ import ( // Environments const ( - EnvBootstrapPath = "ARANA_BOOTSTRAP_PATH" // bootstrap file path, eg: /etc/arana/bootstrap.yaml - EnvConfigPath = "ARANA_CONFIG_PATH" // config file path, eg: /etc/arana/config.yaml + EnvBootstrapPath = "ARANA_BOOTSTRAP_PATH" // bootstrap file path, eg: /etc/arana/bootstrap.yaml + EnvConfigPath = "ARANA_CONFIG_PATH" // config file path, eg: /etc/arana/config.yaml + EnvDevelopEnvironment = "ARANA_DEV" // config dev environment ) // GetConfigSearchPathList returns the default search path list of configuration. diff --git a/pkg/constants/mysql/type.go b/pkg/constants/mysql/type.go index 343379fce..b729cca38 100644 --- a/pkg/constants/mysql/type.go +++ b/pkg/constants/mysql/type.go @@ -60,14 +60,6 @@ const ( FieldTypeBit ) -const ( - FieldTypeUint8 FieldType = iota + 0x85 - FieldTypeUint16 - FieldTypeUint24 - FieldTypeUint32 - FieldTypeUint64 -) - const ( FieldTypeJSON FieldType = iota + 0xf5 FieldTypeNewDecimal @@ -118,48 +110,14 @@ var mysqlToType = map[int64]FieldType{ 255: FieldTypeGeometry, } -// modifyType modifies the vitess type based on the -// mysql flag. The function checks specific flags based -// on the type. This allows us to ignore stray flags -// that MySQL occasionally sets. -func modifyType(typ FieldType, flags int64) FieldType { - switch typ { - case FieldTypeTiny: - if uint(flags)&UnsignedFlag != 0 { - return FieldTypeUint8 - } - return FieldTypeTiny - case FieldTypeShort: - if uint(flags)&UnsignedFlag != 0 { - return FieldTypeUint16 - } - return FieldTypeShort - case FieldTypeLong: - if uint(flags)&UnsignedFlag != 0 { - return FieldTypeUint32 - } - return FieldTypeLong - case FieldTypeLongLong: - if uint(flags)&UnsignedFlag != 0 { - return FieldTypeUint64 - } - return FieldTypeLongLong - case FieldTypeInt24: - if uint(flags)&UnsignedFlag != 0 { - return FieldTypeUint24 - } - return FieldTypeInt24 - } - return typ -} - // MySQLToType computes the vitess type from mysql type and flags. -func MySQLToType(mysqlType, flags int64) (typ FieldType, err error) { +func MySQLToType(mysqlType, flags int64) (FieldType, error) { + _ = flags result, ok := mysqlToType[mysqlType] if !ok { return 0, fmt.Errorf("unsupported type: %d", mysqlType) } - return modifyType(result, flags), nil + return result, nil } // typeToMySQL is the reverse of MysqlToType. @@ -168,19 +126,14 @@ var typeToMySQL = map[FieldType]struct { flags int64 }{ FieldTypeTiny: {typ: 1}, - FieldTypeUint8: {typ: 1, flags: int64(UnsignedFlag)}, FieldTypeShort: {typ: 2}, - FieldTypeUint16: {typ: 2, flags: int64(UnsignedFlag)}, FieldTypeLong: {typ: 3}, - FieldTypeUint32: {typ: 3, flags: int64(UnsignedFlag)}, FieldTypeFloat: {typ: 4}, FieldTypeDouble: {typ: 5}, FieldTypeNULL: {typ: 6, flags: int64(BinaryFlag)}, FieldTypeTimestamp: {typ: 7}, FieldTypeLongLong: {typ: 8}, - FieldTypeUint64: {typ: 8, flags: int64(UnsignedFlag)}, FieldTypeInt24: {typ: 9}, - FieldTypeUint24: {typ: 9, flags: int64(UnsignedFlag)}, FieldTypeDate: {typ: 10, flags: int64(BinaryFlag)}, FieldTypeTime: {typ: 11, flags: int64(BinaryFlag)}, FieldTypeDateTime: {typ: 12, flags: int64(BinaryFlag)}, diff --git a/pkg/constants/mysql/type_test.go b/pkg/constants/mysql/type_test.go index fa46ba001..2ed3dbdea 100644 --- a/pkg/constants/mysql/type_test.go +++ b/pkg/constants/mysql/type_test.go @@ -34,11 +34,8 @@ func TestMySQLToType(t *testing.T) { {0, 0, FieldTypeDecimal}, {0, 32, FieldTypeDecimal}, {1, 0, FieldTypeTiny}, - {1, 32, FieldTypeUint8}, {2, 0, FieldTypeShort}, - {2, 32, FieldTypeUint16}, {3, 0, FieldTypeLong}, - {3, 32, FieldTypeUint32}, {4, 0, FieldTypeFloat}, {4, 32, FieldTypeFloat}, {5, 0, FieldTypeDouble}, @@ -48,9 +45,7 @@ func TestMySQLToType(t *testing.T) { {7, 0, FieldTypeTimestamp}, {7, 32, FieldTypeTimestamp}, {8, 0, FieldTypeLongLong}, - {8, 32, FieldTypeUint64}, {9, 0, FieldTypeInt24}, - {9, 32, FieldTypeUint24}, {10, 0, FieldTypeDate}, {10, 32, FieldTypeDate}, {11, 0, FieldTypeTime}, @@ -108,18 +103,14 @@ func TestTypeToMySQL(t *testing.T) { expectedFlags int64 }{ {FieldTypeTiny, int64(1), int64(0)}, - {FieldTypeUint8, int64(1), int64(UnsignedFlag)}, {FieldTypeShort, int64(2), int64(0)}, {FieldTypeLong, int64(3), int64(0)}, - {FieldTypeUint32, int64(3), int64(UnsignedFlag)}, {FieldTypeFloat, int64(4), int64(0)}, {FieldTypeDouble, int64(5), int64(0)}, {FieldTypeNULL, int64(6), int64(BinaryFlag)}, {FieldTypeTimestamp, int64(7), int64(0)}, {FieldTypeLongLong, int64(8), int64(0)}, - {FieldTypeUint64, int64(8), int64(UnsignedFlag)}, {FieldTypeInt24, int64(9), int64(0)}, - {FieldTypeUint24, int64(9), int64(UnsignedFlag)}, {FieldTypeDate, int64(10), int64(BinaryFlag)}, {FieldTypeTime, int64(11), int64(BinaryFlag)}, {FieldTypeDateTime, int64(12), int64(BinaryFlag)}, diff --git a/pkg/dataset/chain.go b/pkg/dataset/chain.go new file mode 100644 index 000000000..54fb6b1fd --- /dev/null +++ b/pkg/dataset/chain.go @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +type pipeOption []func(proto.Dataset) proto.Dataset + +func Filter(predicate PredicateFunc) Option { + return func(option *pipeOption) { + *option = append(*option, func(prev proto.Dataset) proto.Dataset { + return FilterDataset{ + Dataset: prev, + Predicate: predicate, + } + }) + } +} + +func Map(generateFields FieldsFunc, transform TransformFunc) Option { + return func(option *pipeOption) { + *option = append(*option, func(dataset proto.Dataset) proto.Dataset { + return &TransformDataset{ + Dataset: dataset, + FieldsGetter: generateFields, + Transform: transform, + } + }) + } +} + +func GroupReduce(groups []OrderByItem, generateFields FieldsFunc, reducer func() Reducer) Option { + return func(option *pipeOption) { + *option = append(*option, func(dataset proto.Dataset) proto.Dataset { + return &GroupDataset{ + Dataset: dataset, + keys: groups, + reducer: reducer, + fieldFunc: generateFields, + } + }) + } +} + +type Option func(*pipeOption) + +func Pipe(root proto.Dataset, options ...Option) proto.Dataset { + var o pipeOption + for _, it := range options { + it(&o) + } + + next := root + for _, it := range o { + next = it(next) + } + + return next +} diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go new file mode 100644 index 000000000..e93b2ce7d --- /dev/null +++ b/pkg/dataset/dataset.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +// PeekableDataset represents a peekable dataset. +type PeekableDataset interface { + proto.Dataset + // Peek peeks the next row, but will not consume it. + Peek() (proto.Row, error) +} + +type RandomAccessDataset interface { + PeekableDataset + // Len returns the length of sub-datasets. + Len() int + // PeekN peeks the next row with specified index. + PeekN(index int) (proto.Row, error) + // SetNextN force sets the next index of row. + SetNextN(index int) error +} diff --git a/pkg/dataset/filter.go b/pkg/dataset/filter.go new file mode 100644 index 000000000..4393209ba --- /dev/null +++ b/pkg/dataset/filter.go @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +var _ proto.Dataset = (*FilterDataset)(nil) + +type PredicateFunc func(proto.Row) bool + +type FilterDataset struct { + proto.Dataset + Predicate PredicateFunc +} + +func (f FilterDataset) Next() (proto.Row, error) { + if f.Predicate == nil { + return f.Dataset.Next() + } + + row, err := f.Dataset.Next() + if err != nil { + return nil, err + } + + if !f.Predicate(row) { + return f.Next() + } + + return row, nil +} diff --git a/pkg/dataset/filter_test.go b/pkg/dataset/filter_test.go new file mode 100644 index 000000000..6fac08c37 --- /dev/null +++ b/pkg/dataset/filter_test.go @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "database/sql" + "fmt" + "io" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" +) + +func TestFilter(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("name", consts.FieldTypeVarChar), + mysql.NewField("gender", consts.FieldTypeLong), + } + root := &VirtualDataset{ + Columns: fields, + } + + for i := 0; i < 10; i++ { + root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{ + int64(i), + fmt.Sprintf("fake-name-%d", i), + int64(i & 1), // 0=female,1=male + })) + } + + filtered := Pipe(root, Filter(func(row proto.Row) bool { + dest := make([]proto.Value, len(fields)) + _ = row.Scan(dest) + var gender sql.NullInt64 + _ = gender.Scan(dest[2]) + assert.True(t, gender.Valid) + return gender.Int64 == 1 + })) + + for { + next, err := filtered.Next() + if err == io.EOF { + break + } + assert.NoError(t, err) + + dest := make([]proto.Value, len(fields)) + _ = next.Scan(dest) + assert.Equal(t, "1", fmt.Sprint(dest[2])) + + t.Logf("id=%v, name=%v, gender=%v\n", dest[0], dest[1], dest[2]) + } +} diff --git a/pkg/dataset/fuse.go b/pkg/dataset/fuse.go new file mode 100644 index 000000000..820f92d54 --- /dev/null +++ b/pkg/dataset/fuse.go @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "io" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +var _ proto.Dataset = (*FuseableDataset)(nil) + +type GenerateFunc func() (proto.Dataset, error) + +type FuseableDataset struct { + fields []proto.Field + current proto.Dataset + generators []GenerateFunc +} + +func (fu *FuseableDataset) Close() error { + if fu.current == nil { + return nil + } + if err := fu.current.Close(); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (fu *FuseableDataset) Fields() ([]proto.Field, error) { + return fu.fields, nil +} + +func (fu *FuseableDataset) Next() (proto.Row, error) { + if fu.current == nil { + return nil, io.EOF + } + + var ( + next proto.Row + err error + ) + + if next, err = fu.current.Next(); errors.Is(err, io.EOF) { + if err = fu.nextDataset(); err != nil { + return nil, err + } + return fu.Next() + } + + if err != nil { + return nil, err + } + + return next, nil +} + +func (fu *FuseableDataset) nextDataset() error { + var err error + if err = fu.current.Close(); err != nil { + return errors.Wrap(err, "failed to close previous fused dataset") + } + fu.current = nil + + if len(fu.generators) < 1 { + return io.EOF + } + + gen := fu.generators[0] + fu.generators[0] = nil + fu.generators = fu.generators[1:] + + if fu.current, err = gen(); err != nil { + return errors.Wrap(err, "failed to close previous fused dataset") + } + + return nil +} + +func (fu *FuseableDataset) ToParallel() RandomAccessDataset { + generators := make([]GenerateFunc, len(fu.generators)+1) + copy(generators[1:], fu.generators) + streams := make([]*peekableDataset, len(fu.generators)+1) + streams[0] = &peekableDataset{Dataset: fu.current} + result := ¶llelDataset{ + fields: fu.fields, + generators: generators, + streams: streams, + } + return result +} + +func Fuse(first GenerateFunc, others ...GenerateFunc) (proto.Dataset, error) { + current, err := first() + if err != nil { + return nil, errors.Wrap(err, "failed to fuse datasets") + } + + fields, err := current.Fields() + if err != nil { + defer func() { + _ = current.Close() + }() + return nil, errors.WithStack(err) + } + + return &FuseableDataset{ + fields: fields, + current: current, + generators: others, + }, nil +} diff --git a/pkg/dataset/fuse_test.go b/pkg/dataset/fuse_test.go new file mode 100644 index 000000000..91cf862f1 --- /dev/null +++ b/pkg/dataset/fuse_test.go @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "fmt" + "io" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" +) + +func TestFuse(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("name", consts.FieldTypeVarChar), + } + + generate := func(offset, length int) proto.Dataset { + d := &VirtualDataset{ + Columns: fields, + } + + for i := offset; i < offset+length; i++ { + d.Rows = append(d.Rows, rows.NewTextVirtualRow(fields, []proto.Value{ + int64(i), + fmt.Sprintf("fake-name-%d", i), + })) + } + + return d + } + + fuse, err := Fuse( + func() (proto.Dataset, error) { + return generate(0, 3), nil + }, + func() (proto.Dataset, error) { + return generate(3, 4), nil + }, + func() (proto.Dataset, error) { + return generate(7, 3), nil + }, + ) + + assert.NoError(t, err) + + var seq int + for { + next, err := fuse.Next() + if err == io.EOF { + break + } + assert.NoError(t, err) + values := make([]proto.Value, len(fields)) + _ = next.Scan(values) + assert.Equal(t, fmt.Sprint(seq), fmt.Sprint(values[0])) + + t.Logf("next: id=%v, name=%v\n", values[0], values[1]) + + seq++ + } +} diff --git a/pkg/dataset/group_reduce.go b/pkg/dataset/group_reduce.go new file mode 100644 index 000000000..cc1a86962 --- /dev/null +++ b/pkg/dataset/group_reduce.go @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "fmt" + "io" + "reflect" + "strings" + "sync" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/merge" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/log" +) + +var _ proto.Dataset = (*GroupDataset)(nil) + +// Reducer represents the way to reduce rows. +type Reducer interface { + // Reduce reduces next row. + Reduce(next proto.Row) error + // Row returns the result row. + Row() proto.Row +} + +type AggregateItem struct { + agg merge.Aggregator + idx int +} + +type AggregateReducer struct { + AggItems map[int]merge.Aggregator + currentRow proto.Row + Fields []proto.Field +} + +func NewGroupReducer(aggFuncMap map[int]func() merge.Aggregator, fields []proto.Field) *AggregateReducer { + aggItems := make(map[int]merge.Aggregator) + for idx, f := range aggFuncMap { + aggItems[idx] = f() + } + return &AggregateReducer{ + AggItems: aggItems, + currentRow: nil, + Fields: fields, + } +} + +func (gr *AggregateReducer) Reduce(next proto.Row) error { + var ( + values = make([]proto.Value, len(gr.Fields)) + result = make([]proto.Value, len(gr.Fields)) + ) + err := next.Scan(values) + if err != nil { + return err + } + + for idx, aggregator := range gr.AggItems { + aggregator.Aggregate([]interface{}{values[idx]}) + } + + for i := 0; i < len(values); i++ { + if gr.AggItems[i] == nil { + result[i] = values[i] + } else { + aggResult, ok := gr.AggItems[i].GetResult() + if !ok { + return errors.New("can not aggregate value") + } + result[i] = aggResult + } + } + + if next.IsBinary() { + gr.currentRow = rows.NewBinaryVirtualRow(gr.Fields, result) + } else { + gr.currentRow = rows.NewTextVirtualRow(gr.Fields, result) + } + return nil +} + +func (gr *AggregateReducer) Row() proto.Row { + return gr.currentRow +} + +type GroupDataset struct { + // Should be an orderedDataset + proto.Dataset + keys []OrderByItem + + fieldFunc FieldsFunc + actualFieldsOnce sync.Once + actualFields []proto.Field + actualFieldsFailure error + + keyIndexes []int + keyIndexesOnce sync.Once + keyIndexesFailure error + + reducer func() Reducer + + buf proto.Row + eof bool +} + +func (gd *GroupDataset) Close() error { + return gd.Dataset.Close() +} + +func (gd *GroupDataset) Fields() ([]proto.Field, error) { + gd.actualFieldsOnce.Do(func() { + if gd.fieldFunc == nil { + gd.actualFields, gd.actualFieldsFailure = gd.Dataset.Fields() + return + } + + defer func() { + gd.fieldFunc = nil + }() + + fields, err := gd.Dataset.Fields() + if err != nil { + gd.actualFieldsFailure = err + return + } + gd.actualFields = gd.fieldFunc(fields) + }) + + return gd.actualFields, gd.actualFieldsFailure +} + +func (gd *GroupDataset) Next() (proto.Row, error) { + if gd.eof { + return nil, io.EOF + } + + indexes, err := gd.getKeyIndexes() + if err != nil { + return nil, errors.WithStack(err) + } + + var ( + rowsChan = make(chan proto.Row, 1) + errChan = make(chan error, 1) + ) + + go func() { + defer close(rowsChan) + gd.consumeUntilDifferent(indexes, rowsChan, errChan) + }() + + reducer := gd.reducer() + +L: + for { + select { + case next, ok := <-rowsChan: + if !ok { + break L + } + if err = reducer.Reduce(next); err != nil { + break L + } + case err = <-errChan: + break L + } + } + + if err != nil { + return nil, err + } + + return reducer.Row(), nil +} + +func (gd *GroupDataset) consumeUntilDifferent(indexes []int, rowsChan chan<- proto.Row, errChan chan<- error) { + var ( + next proto.Row + err error + ) + + for { + next, err = gd.Dataset.Next() + if errors.Is(err, io.EOF) { + gd.eof = true + + if buf, ok := gd.popBuf(); ok { + rowsChan <- buf + } + break + } + + if err != nil { + errChan <- err + break + } + + prev, ok := gd.popBuf() + gd.buf = next + if !ok { + log.Debugf("begin next group: %s", gd.toDebugStr(next)) + continue + } + + ok, err = gd.isSameGroup(indexes, prev, next) + + if err != nil { + errChan <- err + break + } + + rowsChan <- prev + + if !ok { + log.Debugf("begin next group: %s", gd.toDebugStr(next)) + break + } + } +} + +func (gd *GroupDataset) toDebugStr(next proto.Row) string { + var ( + display []string + fields, _ = gd.Dataset.Fields() + indexes, _ = gd.getKeyIndexes() + dest = make([]proto.Value, len(fields)) + ) + + _ = next.Scan(dest) + for _, it := range indexes { + display = append(display, fmt.Sprintf("%s:%v", fields[it].Name(), dest[it])) + } + + return fmt.Sprintf("[%s]", strings.Join(display, ",")) +} + +func (gd *GroupDataset) isSameGroup(indexes []int, prev, next proto.Row) (bool, error) { + var ( + fields, _ = gd.Dataset.Fields() + err error + ) + + // TODO: reduce scan times, maybe cache it. + var ( + dest0 = make([]proto.Value, len(fields)) + dest1 = make([]proto.Value, len(fields)) + ) + if err = prev.Scan(dest0); err != nil { + return false, errors.WithStack(err) + } + if err = next.Scan(dest1); err != nil { + return false, errors.WithStack(err) + } + + equal := true + for _, index := range indexes { + // TODO: how to compare equality more effectively? + if !reflect.DeepEqual(dest0[index], dest1[index]) { + equal = false + break + } + } + + return equal, nil +} + +func (gd *GroupDataset) popBuf() (ret proto.Row, ok bool) { + if gd.buf != nil { + ret, gd.buf, ok = gd.buf, nil, true + } + return +} + +// getKeyIndexes computes and holds the indexes of group keys. +func (gd *GroupDataset) getKeyIndexes() ([]int, error) { + gd.keyIndexesOnce.Do(func() { + var ( + fields []proto.Field + err error + ) + + if fields, err = gd.Dataset.Fields(); err != nil { + gd.keyIndexesFailure = err + return + } + gd.keyIndexes = make([]int, 0, len(gd.keys)) + for _, key := range gd.keys { + idx := -1 + for i := 0; i < len(fields); i++ { + if fields[i].Name() == key.Column { + idx = i + break + } + } + if idx == -1 { + gd.keyIndexesFailure = fmt.Errorf("cannot find group field '%+v'", key) + return + } + gd.keyIndexes = append(gd.keyIndexes, idx) + } + }) + + if gd.keyIndexesFailure != nil { + return nil, gd.keyIndexesFailure + } + + return gd.keyIndexes, nil +} diff --git a/pkg/dataset/group_reduce_test.go b/pkg/dataset/group_reduce_test.go new file mode 100644 index 000000000..4ae082f9c --- /dev/null +++ b/pkg/dataset/group_reduce_test.go @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "database/sql" + "fmt" + "sort" + "testing" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + vrows "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/rand2" +) + +type sortBy struct { + rows [][]proto.Value + less func(a, b []proto.Value) bool +} + +func (s sortBy) Len() int { + return len(s.rows) +} + +func (s sortBy) Less(i, j int) bool { + return s.less(s.rows[i], s.rows[j]) +} + +func (s sortBy) Swap(i, j int) { + s.rows[i], s.rows[j] = s.rows[j], s.rows[i] +} + +type fakeReducer struct { + fields []proto.Field + gender sql.NullInt64 + cnt int64 +} + +func (fa *fakeReducer) Reduce(next proto.Row) error { + if !fa.gender.Valid { + gender, _ := next.(proto.KeyedRow).Get("gender") + fa.gender.Int64, fa.gender.Valid = gender.(int64), true + } + fa.cnt++ + return nil +} + +func (fa *fakeReducer) Row() proto.Row { + return vrows.NewTextVirtualRow(fa.fields, []proto.Value{ + fa.gender, + fa.cnt, + }) +} + +func TestGroupReduce(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("name", consts.FieldTypeVarChar), + mysql.NewField("gender", consts.FieldTypeLong), + } + + var origin VirtualDataset + origin.Columns = fields + + var rows [][]proto.Value + for i := 0; i < 1000; i++ { + rows = append(rows, []proto.Value{ + int64(i), + fmt.Sprintf("Fake %d", i), + rand2.Int63n(2), + }) + } + + s := sortBy{ + rows: rows, + less: func(a, b []proto.Value) bool { + return a[2].(int64) < b[2].(int64) + }, + } + sort.Sort(s) + + for _, it := range rows { + origin.Rows = append(origin.Rows, vrows.NewTextVirtualRow(fields, it)) + } + + actualFields := []proto.Field{ + fields[2], + mysql.NewField("amount", consts.FieldTypeLong), + } + + // Simulate: SELECT gender,COUNT(*) AS amount FROM xxx WHERE ... GROUP BY gender + groups := []OrderByItem{{"gender", true}} + p := Pipe(&origin, + GroupReduce( + groups, + func(fields []proto.Field) []proto.Field { + return actualFields + }, + func() Reducer { + return &fakeReducer{ + fields: actualFields, + } + }, + ), + ) + + for { + next, err := p.Next() + if err != nil { + break + } + v := make([]proto.Value, len(actualFields)) + _ = next.Scan(v) + t.Logf("next: gender=%v, amount=%v\n", v[0], v[1]) + } +} diff --git a/pkg/dataset/ordered.go b/pkg/dataset/ordered.go new file mode 100644 index 000000000..262155808 --- /dev/null +++ b/pkg/dataset/ordered.go @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "container/heap" + "io" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +type orderedDataset struct { + dataset RandomAccessDataset + queue *PriorityQueue + firstRow bool +} + +func NewOrderedDataset(dataset RandomAccessDataset, items []OrderByItem) proto.Dataset { + return &orderedDataset{ + dataset: dataset, + queue: NewPriorityQueue(make([]*RowItem, 0), items), + firstRow: true, + } +} + +func (or *orderedDataset) Close() error { + return or.dataset.Close() +} + +func (or *orderedDataset) Fields() ([]proto.Field, error) { + return or.dataset.Fields() +} + +func (or *orderedDataset) Next() (proto.Row, error) { + if or.firstRow { + n := or.dataset.Len() + for i := 0; i < n; i++ { + or.dataset.SetNextN(i) + row, err := or.dataset.Next() + if err == io.EOF { + continue + } else if err != nil { + return nil, err + } + or.queue.Push(&RowItem{ + row: row.(proto.KeyedRow), + streamIdx: i, + }) + } + or.firstRow = false + } + if or.queue.Len() == 0 { + return nil, io.EOF + } + data := heap.Pop(or.queue) + + item := data.(*RowItem) + or.dataset.SetNextN(item.streamIdx) + nextRow, err := or.dataset.Next() + if err == io.EOF { + return item.row, nil + } else if err != nil { + return nil, err + } + heap.Push(or.queue, &RowItem{ + row: nextRow.(proto.KeyedRow), + streamIdx: item.streamIdx, + }) + + return item.row, nil +} diff --git a/pkg/dataset/ordered_test.go b/pkg/dataset/ordered_test.go new file mode 100644 index 000000000..bd25ee0fe --- /dev/null +++ b/pkg/dataset/ordered_test.go @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "testing" +) + +import ( + "github.com/golang/mock/gomock" + + "github.com/stretchr/testify/assert" +) + +func TestOrderedDataset(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pd := generateFakeParallelDataset(ctrl, 0, 2, 2, 2, 1, 1) + items := []OrderByItem{ + { + Column: "id", + Desc: false, + }, + { + Column: "gender", + Desc: true, + }, + } + + od := NewOrderedDataset(pd, items) + var pojo fakePojo + + row, err := od.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + t.Logf("next: %#v\n", pojo) + assert.Equal(t, int64(0), pojo.ID) + + row, err = od.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + t.Logf("next: %#v\n", pojo) + assert.Equal(t, int64(1), pojo.ID) + + row, err = od.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + t.Logf("next: %#v\n", pojo) + assert.Equal(t, int64(1), pojo.ID) + + row, err = od.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + t.Logf("next: %#v\n", pojo) + assert.Equal(t, int64(2), pojo.ID) + + row, err = od.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + t.Logf("next: %#v\n", pojo) + assert.Equal(t, int64(3), pojo.ID) +} diff --git a/pkg/dataset/parallel.go b/pkg/dataset/parallel.go new file mode 100644 index 000000000..c3202c191 --- /dev/null +++ b/pkg/dataset/parallel.go @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "io" + "sync" +) + +import ( + "github.com/pkg/errors" + + uatomic "go.uber.org/atomic" + + "golang.org/x/sync/errgroup" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/log" +) + +var ( + _ PeekableDataset = (*peekableDataset)(nil) + _ RandomAccessDataset = (*parallelDataset)(nil) +) + +type peekableDataset struct { + proto.Dataset + mu sync.Mutex + next proto.Row + err error +} + +func (pe *peekableDataset) Next() (proto.Row, error) { + var ( + next proto.Row + err error + ) + + pe.mu.Lock() + defer pe.mu.Unlock() + + if next, err, pe.next, pe.err = pe.next, pe.err, nil, nil; next != nil || err != nil { + return next, err + } + + return pe.Dataset.Next() +} + +func (pe *peekableDataset) Peek() (proto.Row, error) { + var ( + next proto.Row + err error + ) + + pe.mu.Lock() + defer pe.mu.Unlock() + + if next, err = pe.next, pe.err; next != nil || err != nil { + return next, err + } + + pe.next, pe.err = pe.Dataset.Next() + + return pe.next, pe.err +} + +type parallelDataset struct { + mu sync.RWMutex + fields []proto.Field + generators []GenerateFunc + streams []*peekableDataset + seq uatomic.Uint32 +} + +func (pa *parallelDataset) getStream(i int) (*peekableDataset, error) { + pa.mu.RLock() + if stream := pa.streams[i]; stream != nil { + pa.mu.RUnlock() + return stream, nil + } + + pa.mu.RUnlock() + + pa.mu.Lock() + defer pa.mu.Unlock() + + if stream := pa.streams[i]; stream != nil { + return stream, nil + } + + d, err := pa.generators[i]() + if err != nil { + return nil, errors.WithStack(err) + } + pa.streams[i] = &peekableDataset{Dataset: d} + pa.generators[i] = nil + return pa.streams[i], nil +} + +func (pa *parallelDataset) Peek() (proto.Row, error) { + i := pa.seq.Load() % uint32(pa.Len()) + s, err := pa.getStream(int(i)) + if err != nil { + return nil, errors.WithStack(err) + } + return s.Peek() +} + +func (pa *parallelDataset) PeekN(index int) (proto.Row, error) { + if index < 0 || index >= pa.Len() { + return nil, errors.Errorf("index out of range: index=%d, length=%d", index, pa.Len()) + } + s, err := pa.getStream(index) + if err != nil { + return nil, errors.WithStack(err) + } + return s.Peek() +} + +func (pa *parallelDataset) Close() error { + var g errgroup.Group + for i := 0; i < len(pa.streams); i++ { + i := i + g.Go(func() (err error) { + if pa.streams[i] != nil { + err = pa.streams[i].Close() + } + + if err != nil { + log.Errorf("failed to close dataset#%d: %v", i, err) + } + return + }) + } + return g.Wait() +} + +func (pa *parallelDataset) Fields() ([]proto.Field, error) { + return pa.fields, nil +} + +func (pa *parallelDataset) Next() (proto.Row, error) { + var ( + s *peekableDataset + next proto.Row + err error + ) + for j := 0; j < pa.Len(); j++ { + i := (pa.seq.Inc() - 1) % uint32(pa.Len()) + + if s, err = pa.getStream(int(i)); err != nil { + return nil, errors.WithStack(err) + } + + next, err = s.Next() + if err == nil { + break + } + if errors.Is(err, io.EOF) { + err = io.EOF + continue + } + + return nil, errors.WithStack(err) + } + + return next, err +} + +func (pa *parallelDataset) Len() int { + return len(pa.streams) +} + +func (pa *parallelDataset) SetNextN(index int) error { + if index < 0 || index >= pa.Len() { + return errors.Errorf("index out of range: index=%d, length=%d", index, pa.Len()) + } + pa.seq.Store(uint32(index)) + return nil +} + +// Parallel creates a thread-safe dataset, which can be random-accessed in parallel. +func Parallel(first GenerateFunc, others ...GenerateFunc) (RandomAccessDataset, error) { + current, err := first() + if err != nil { + return nil, errors.Wrap(err, "failed to create parallel datasets") + } + + fields, err := current.Fields() + if err != nil { + defer func() { + _ = current.Close() + }() + return nil, errors.WithStack(err) + } + + generators := make([]GenerateFunc, len(others)+1) + for i := 0; i < len(others); i++ { + if others[i] == nil { + return nil, errors.Errorf("nil dataset detected, index is %d", i+1) + } + generators[i+1] = others[i] + } + + streams := make([]*peekableDataset, len(others)+1) + streams[0] = &peekableDataset{Dataset: current} + + return ¶llelDataset{ + fields: fields, + generators: generators, + streams: streams, + }, nil +} + +// Peekable converts a dataset to a peekable one. +func Peekable(origin proto.Dataset) PeekableDataset { + return &peekableDataset{ + Dataset: origin, + } +} diff --git a/pkg/dataset/parallel_test.go b/pkg/dataset/parallel_test.go new file mode 100644 index 000000000..a8bc50f21 --- /dev/null +++ b/pkg/dataset/parallel_test.go @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "container/list" + "database/sql" + "fmt" + "io" + "math" + "sort" + "sync" + "testing" +) + +import ( + "github.com/golang/mock/gomock" + + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/rand2" + "github.com/arana-db/arana/testdata" +) + +type fakePojo struct { + ID int64 + Name string + Gender int64 +} + +func scanPojo(row proto.Row, dest *fakePojo) error { + s := make([]proto.Value, 3) + if err := row.Scan(s); err != nil { + return err + } + + var ( + id sql.NullInt64 + name sql.NullString + gender sql.NullInt64 + ) + _, _, _ = id.Scan(s[0]), name.Scan(s[1]), gender.Scan(s[2]) + + dest.ID = id.Int64 + dest.Name = name.String + dest.Gender = gender.Int64 + + return nil +} + +func generateFakeParallelDataset(ctrl *gomock.Controller, pairs ...int) RandomAccessDataset { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("name", consts.FieldTypeVarChar), + mysql.NewField("gender", consts.FieldTypeLong), + } + + getFakeDataset := func(offset, n int) proto.Dataset { + l := list.New() + for i := 0; i < n; i++ { + r := rows.NewTextVirtualRow(fields, []proto.Value{ + int64(offset + i), + fmt.Sprintf("Fake %d", offset+i), + rand2.Int63n(2), + }) + l.PushBack(r) + } + + ds := testdata.NewMockDataset(ctrl) + + ds.EXPECT().Close().Return(nil).AnyTimes() + ds.EXPECT().Fields().Return(fields, nil).AnyTimes() + ds.EXPECT().Next(). + DoAndReturn(func() (proto.Row, error) { + head := l.Front() + if head == nil { + return nil, io.EOF + } + l.Remove(head) + return head.Value.(proto.Row), nil + }). + AnyTimes() + + return ds + } + + var gens []GenerateFunc + + for i := 1; i < len(pairs); i += 2 { + offset := pairs[i-1] + n := pairs[i] + gens = append(gens, func() (proto.Dataset, error) { + return getFakeDataset(offset, n), nil + }) + } + + ret, _ := Parallel(gens[0], gens[1:]...) + return ret +} + +func TestParallelDataset(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pd := generateFakeParallelDataset(ctrl, 0, 3, 3, 3, 6, 3) + + var pojo fakePojo + + row, err := pd.PeekN(0) + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + assert.Equal(t, int64(0), pojo.ID) + + row, err = pd.PeekN(1) + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + assert.Equal(t, int64(3), pojo.ID) + + row, err = pd.PeekN(2) + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + assert.Equal(t, int64(6), pojo.ID) + + err = pd.SetNextN(1) + assert.NoError(t, err) + + row, err = pd.Peek() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + assert.Equal(t, int64(3), pojo.ID) + + row, err = pd.Next() + assert.NoError(t, err) + assert.NoError(t, scanPojo(row, &pojo)) + assert.Equal(t, int64(3), pojo.ID) + + var cnt int + + _ = scanPojo(row, &pojo) + t.Logf("first: %#v\n", pojo) + cnt++ + + for { + row, err = pd.Next() + if err != nil { + break + } + + _ = scanPojo(row, &pojo) + t.Logf("next: %#v\n", pojo) + cnt++ + } + + assert.Equal(t, 9, cnt) +} + +func TestParallelDataset_SortBy(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + pairs := []int{ + 5, 8, // offset, size + 0, 3, + 2, 4, + 6, 2, + } + + var ( + d = generateFakeParallelDataset(ctrl, pairs...) + + ids sort.IntSlice + ) + + for { + bingo := -1 + + var ( + wg sync.WaitGroup + lock sync.Mutex + min int64 = math.MaxInt64 + ) + + wg.Add(d.Len()) + + // search the next minimum value in parallel + for i := 0; i < d.Len(); i++ { + i := i + go func() { + defer wg.Done() + + r, err := d.PeekN(i) + if err == io.EOF { + return + } + + var pojo fakePojo + err = scanPojo(r, &pojo) + assert.NoError(t, err) + + lock.Lock() + defer lock.Unlock() + if pojo.ID < min { + min = pojo.ID + bingo = i + } + }() + } + + wg.Wait() + + if bingo == -1 { + break + } + + err := d.SetNextN(bingo) + assert.NoError(t, err) + + var pojo fakePojo + + r, err := d.Next() + assert.NoError(t, err) + err = scanPojo(r, &pojo) + assert.NoError(t, err) + + t.Logf("next: %#v\n", pojo) + ids = append(ids, int(pojo.ID)) + } + + assert.True(t, sort.IsSorted(ids), "values should be sorted") +} diff --git a/pkg/dataset/priority_queue.go b/pkg/dataset/priority_queue.go new file mode 100644 index 000000000..d3483e1c2 --- /dev/null +++ b/pkg/dataset/priority_queue.go @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "container/heap" + "fmt" + "strconv" + "time" +) + +import ( + "golang.org/x/exp/constraints" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +type OrderByValue struct { + OrderValues map[string]interface{} +} + +type RowItem struct { + row proto.KeyedRow + streamIdx int +} + +type OrderByItem struct { + Column string + Desc bool +} + +type PriorityQueue struct { + rows []*RowItem + orderByItems []OrderByItem +} + +func NewPriorityQueue(rows []*RowItem, orderByItems []OrderByItem) *PriorityQueue { + pq := &PriorityQueue{ + rows: rows, + orderByItems: orderByItems, + } + heap.Init(pq) + return pq +} + +func (pq *PriorityQueue) Len() int { + return len(pq.rows) +} + +func (pq *PriorityQueue) Less(i, j int) bool { + orderValues1 := &OrderByValue{ + OrderValues: make(map[string]interface{}), + } + orderValues2 := &OrderByValue{ + OrderValues: make(map[string]interface{}), + } + if i >= len(pq.rows) || j >= len(pq.rows) { + return false + } + row1 := pq.rows[i] + row2 := pq.rows[j] + for _, item := range pq.orderByItems { + val1, _ := row1.row.Get(item.Column) + val2, _ := row2.row.Get(item.Column) + orderValues1.OrderValues[item.Column] = val1 + orderValues2.OrderValues[item.Column] = val2 + } + return compare(orderValues1, orderValues2, pq.orderByItems) < 0 +} + +func (pq *PriorityQueue) Swap(i, j int) { + pq.rows[i], pq.rows[j] = pq.rows[j], pq.rows[i] +} + +func (pq *PriorityQueue) Push(x interface{}) { + item := x.(*RowItem) + pq.rows = append(pq.rows, item) + pq.update() +} + +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old.rows) + if n == 0 { + return nil + } + item := old.rows[n-1] + pq.rows = old.rows[0 : n-1] + return item +} + +func (pq *PriorityQueue) update() { + heap.Fix(pq, pq.Len()-1) +} + +func compare(a *OrderByValue, b *OrderByValue, orderByItems []OrderByItem) int { + for _, item := range orderByItems { + compare := compareTo(a.OrderValues[item.Column], b.OrderValues[item.Column], item.Desc) + if compare == 0 { + continue + } + return compare + } + return 0 +} + +func compareTo(a, b interface{}, desc bool) int { + if a == nil && b == nil { + return 0 + } + if a == nil { + return -1 + } + if b == nil { + return 1 + } + // TODO Deal with case sensitive. + var ( + result = 0 + ) + switch a.(type) { + case string: + result = compareValue(fmt.Sprintf("%v", a), fmt.Sprintf("%v", b)) + case int8, int16, int32, int64: + a, _ := strconv.ParseInt(fmt.Sprintf("%v", a), 10, 64) + b, _ := strconv.ParseInt(fmt.Sprintf("%v", b), 10, 64) + result = compareValue(a, b) + case uint8, uint16, uint32, uint64: + a, _ := strconv.ParseUint(fmt.Sprintf("%v", a), 10, 64) + b, _ := strconv.ParseUint(fmt.Sprintf("%v", b), 10, 64) + result = compareValue(a, b) + case float32, float64: + a, _ := strconv.ParseFloat(fmt.Sprintf("%v", a), 64) + b, _ := strconv.ParseFloat(fmt.Sprintf("%v", b), 64) + result = compareValue(a, b) + case time.Time: + result = compareTime(a.(time.Time), b.(time.Time)) + } + if desc { + return -1 * result + } + return result +} + +func compareValue[T constraints.Ordered](a, b T) int { + switch { + case a > b: + return 1 + case a < b: + return -1 + default: + return 0 + } +} + +func compareTime(a, b time.Time) int { + switch { + case a.After(b): + return 1 + case a.Before(b): + return -1 + default: + return 0 + } +} diff --git a/pkg/dataset/priority_queue_test.go b/pkg/dataset/priority_queue_test.go new file mode 100644 index 000000000..e41d3b77c --- /dev/null +++ b/pkg/dataset/priority_queue_test.go @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "container/heap" + "database/sql" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" +) + +func TestPriorityQueue(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("score", consts.FieldTypeLong), + } + items := []OrderByItem{ + {"id", false}, + {"score", true}, + } + + r1 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{ + int64(1), + int64(80), + }), 1} + r2 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{ + int64(2), + int64(75), + }), 1} + r3 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{ + int64(1), + int64(90), + }), 1} + r4 := &RowItem{rows.NewTextVirtualRow(fields, []proto.Value{ + int64(3), + int64(85), + }), 1} + pq := NewPriorityQueue([]*RowItem{ + r1, r2, r3, r4, + }, items) + + assertScorePojoEquals(t, fakeScorePojo{ + id: int64(1), + score: int64(90), + }, heap.Pop(pq).(*RowItem).row) + + assertScorePojoEquals(t, fakeScorePojo{ + id: int64(1), + score: int64(80), + }, heap.Pop(pq).(*RowItem).row) + + assertScorePojoEquals(t, fakeScorePojo{ + id: int64(2), + score: int64(75), + }, heap.Pop(pq).(*RowItem).row) + + assertScorePojoEquals(t, fakeScorePojo{ + id: int64(3), + score: int64(85), + }, heap.Pop(pq).(*RowItem).row) +} + +func assertScorePojoEquals(t *testing.T, expected fakeScorePojo, actual proto.Row) { + var pojo fakeScorePojo + err := scanScorePojo(actual, &pojo) + assert.NoError(t, err) + assert.Equal(t, expected, pojo) +} + +type fakeScorePojo struct { + id int64 + score int64 +} + +func scanScorePojo(row proto.Row, dest *fakeScorePojo) error { + s := make([]proto.Value, 2) + if err := row.Scan(s); err != nil { + return err + } + + var ( + id sql.NullInt64 + score sql.NullInt64 + ) + _, _ = id.Scan(s[0]), score.Scan(s[1]) + + dest.id = id.Int64 + dest.score = score.Int64 + + return nil +} diff --git a/pkg/dataset/transform.go b/pkg/dataset/transform.go new file mode 100644 index 000000000..43d4a2276 --- /dev/null +++ b/pkg/dataset/transform.go @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "sync" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +var _ proto.Dataset = (*TransformDataset)(nil) + +type ( + FieldsFunc func([]proto.Field) []proto.Field + TransformFunc func(proto.Row) (proto.Row, error) +) + +type TransformDataset struct { + proto.Dataset + FieldsGetter FieldsFunc + Transform TransformFunc + + actualFieldsOnce sync.Once + actualFields []proto.Field + actualFieldsFailure error +} + +func (td *TransformDataset) Fields() ([]proto.Field, error) { + td.actualFieldsOnce.Do(func() { + origin, err := td.Dataset.Fields() + if err != nil { + td.actualFieldsFailure = err + return + } + if td.FieldsGetter == nil { + td.actualFields = origin + return + } + + td.actualFields = td.FieldsGetter(origin) + }) + + return td.actualFields, td.actualFieldsFailure +} + +func (td *TransformDataset) Next() (proto.Row, error) { + if td.Transform == nil { + return td.Dataset.Next() + } + + var ( + row proto.Row + err error + ) + + if row, err = td.Dataset.Next(); err != nil { + return nil, err + } + + if row, err = td.Transform(row); err != nil { + return nil, errors.Wrap(err, "failed to transform dataset") + } + + return row, nil +} diff --git a/pkg/dataset/transform_test.go b/pkg/dataset/transform_test.go new file mode 100644 index 000000000..e5a714813 --- /dev/null +++ b/pkg/dataset/transform_test.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataset + +import ( + "fmt" + "io" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/rand2" +) + +func TestTransform(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("id", consts.FieldTypeLong), + mysql.NewField("name", consts.FieldTypeVarChar), + mysql.NewField("level", consts.FieldTypeLong), + } + + root := &VirtualDataset{ + Columns: fields, + } + + for i := 0; i < 10; i++ { + root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{ + int64(i), + fmt.Sprintf("fake-name-%d", i), + rand2.Int63n(10), + })) + } + + transformed := Pipe(root, Map(func(fields []proto.Field) []proto.Field { + return fields + }, func(row proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + _ = row.Scan(dest) + dest[2] = int64(100) + return rows.NewBinaryVirtualRow(fields, dest), nil + })) + + for { + next, err := transformed.Next() + if err == io.EOF { + break + } + + assert.NoError(t, err) + + dest := make([]proto.Value, len(fields)) + _ = next.Scan(dest) + + assert.Equal(t, "100", fmt.Sprint(dest[2])) + } +} diff --git a/pkg/mysql/result_test.go b/pkg/dataset/virtual.go similarity index 59% rename from pkg/mysql/result_test.go rename to pkg/dataset/virtual.go index ff3a287da..de73474da 100644 --- a/pkg/mysql/result_test.go +++ b/pkg/dataset/virtual.go @@ -15,36 +15,40 @@ * limitations under the License. */ -package mysql +package dataset import ( - "testing" + "io" ) import ( - "github.com/stretchr/testify/assert" + "github.com/arana-db/arana/pkg/proto" ) -func TestLastInsertId(t *testing.T) { - result := createResult() - insertId, err := result.LastInsertId() - assert.Equal(t, uint64(2000), insertId) - assert.Nil(t, err) +var _ proto.Dataset = (*VirtualDataset)(nil) + +type VirtualDataset struct { + Columns []proto.Field + Rows []proto.Row +} + +func (cu *VirtualDataset) Close() error { + return nil } -func TestRowsAffected(t *testing.T) { - result := createResult() - affectedRows, err := result.RowsAffected() - assert.Equal(t, uint64(10), affectedRows) - assert.Nil(t, err) +func (cu *VirtualDataset) Fields() ([]proto.Field, error) { + return cu.Columns, nil } -func createResult() *Result { - result := &Result{ - Fields: nil, - AffectedRows: uint64(10), - InsertId: uint64(2000), - Rows: nil, +func (cu *VirtualDataset) Next() (proto.Row, error) { + if len(cu.Rows) < 1 { + return nil, io.EOF } - return result + + next := cu.Rows[0] + + cu.Rows[0] = nil + cu.Rows = cu.Rows[1:] + + return next, nil } diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index 74fc62d35..3758c9f66 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -29,14 +29,18 @@ import ( "github.com/arana-db/parser/ast" "github.com/pkg/errors" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" ) import ( mConstants "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/metrics" - "github.com/arana-db/arana/pkg/mysql" mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/security" @@ -44,6 +48,8 @@ import ( ) var ( + Tracer = otel.Tracer("Executor") + errMissingTx = stdErrors.New("no transaction found") errNoDatabaseSelected = mysqlErrors.NewSQLError(mConstants.ERNoDb, mConstants.SSNoDatabaseSelected, "No database selected") ) @@ -121,6 +127,12 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto. return nil, errors.WithStack(err) } + if vt, ok := rt.Namespace().Rule().VTable(table); ok { + if _, atomTable, exist := vt.Topology().Render(0, 0); exist { + table = atomTable + } + } + db := rt.Namespace().DB0(ctx.Context) if db == nil { return nil, errors.New("cannot get physical backend connection") @@ -130,6 +142,10 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto. } func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Result, uint16, error) { + var span trace.Span + ctx.Context, span = Tracer.Start(ctx.Context, "ExecutorComQuery") + defer span.End() + var ( schemaless bool // true if schema is not specified err error @@ -138,10 +154,20 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re p := parser.New() query := ctx.GetQuery() start := time.Now() - act, err := p.ParseOneStmt(query, "", "") + act, hts, err := p.ParseOneStmtHints(query, "", "") if err != nil { - return nil, 0, err + return nil, 0, errors.WithStack(err) } + + var hints []*hint.Hint + for _, next := range hts { + var h *hint.Hint + if h, err = hint.Parse(next); err != nil { + return nil, 0, err + } + hints = append(hints, h) + } + metrics.ParserDuration.Observe(time.Since(start).Seconds()) log.Debugf("ComQuery: %s", query) @@ -157,6 +183,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re } ctx.Stmt = &proto.Stmt{ + Hints: hints, StmtNode: act, } @@ -181,7 +208,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re var tx proto.Tx if tx, err = rt.Begin(ctx); err == nil { executor.putTx(ctx, tx) - res = &mysql.Result{} + res = resultx.New() } } case *ast.CommitStmt: @@ -231,14 +258,12 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re } case *ast.ShowStmt: allowSchemaless := func(stmt *ast.ShowStmt) bool { - if stmt.Tp == ast.ShowDatabases { - return true - } - if stmt.Tp == ast.ShowVariables { + switch stmt.Tp { + case ast.ShowDatabases, ast.ShowVariables, ast.ShowTopology: return true + default: + return false } - - return false } if !schemaless || allowSchemaless(stmt) { // only SHOW DATABASES is allowed in schemaless mode @@ -246,24 +271,10 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re } else { err = errNoDatabaseSelected } - case *ast.TruncateTableStmt: - if schemaless { - err = errNoDatabaseSelected - } else { - res, warn, err = rt.Execute(ctx) - } - case *ast.DropTableStmt: - if schemaless { - err = errNoDatabaseSelected - } else { - res, warn, err = rt.Execute(ctx) - } - case *ast.ExplainStmt: - if schemaless { - err = errNoDatabaseSelected - } else { - res, warn, err = rt.Execute(ctx) - } + case *ast.TruncateTableStmt, *ast.DropTableStmt, *ast.ExplainStmt, *ast.DropIndexStmt, *ast.CreateIndexStmt: + res, warn, err = executeStmt(ctx, schemaless, rt) + case *ast.DropTriggerStmt: + res, warn, err = rt.Execute(ctx) default: if schemaless { err = errNoDatabaseSelected @@ -276,7 +287,6 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re res, warn, err = rt.Execute(ctx) } } - } executor.doPostFilter(ctx, res) @@ -284,6 +294,13 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re return res, warn, err } +func executeStmt(ctx *proto.Context, schemaless bool, rt runtime.Runtime) (proto.Result, uint16, error) { + if schemaless { + return nil, 0, errNoDatabaseSelected + } + return rt.Execute(ctx) +} + func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (proto.Result, uint16, error) { var ( executable proto.Executable diff --git a/pkg/merge/aggregator.go b/pkg/merge/aggregator.go index 984cda1be..169ab5836 100644 --- a/pkg/merge/aggregator.go +++ b/pkg/merge/aggregator.go @@ -23,6 +23,5 @@ import ( type Aggregator interface { Aggregate(values []interface{}) - GetResult() (*gxbig.Decimal, bool) } diff --git a/pkg/merge/aggregator/init.go b/pkg/merge/aggregator/init.go new file mode 100644 index 000000000..e68452a3c --- /dev/null +++ b/pkg/merge/aggregator/init.go @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aggregator + +import ( + "fmt" +) + +import ( + "github.com/arana-db/arana/pkg/merge" +) + +var aggregatorMap = make(map[string]func() merge.Aggregator) + +func init() { + aggregatorMap["MAX"] = func() merge.Aggregator { return &MaxAggregator{} } + aggregatorMap["MIN"] = func() merge.Aggregator { return &MinAggregator{} } + aggregatorMap["COUNT"] = func() merge.Aggregator { return &AddAggregator{} } + aggregatorMap["SUM"] = func() merge.Aggregator { return &AddAggregator{} } +} + +func GetAggFromName(name string) func() merge.Aggregator { + if agg, ok := aggregatorMap[name]; ok { + return agg + } + panic(fmt.Errorf("aggregator %s not support yet", name)) +} diff --git a/pkg/merge/aggregator/load_agg.go b/pkg/merge/aggregator/load_agg.go new file mode 100644 index 000000000..4c416044e --- /dev/null +++ b/pkg/merge/aggregator/load_agg.go @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aggregator + +import ( + "github.com/arana-db/arana/pkg/merge" + "github.com/arana-db/arana/pkg/runtime/ast" +) + +func LoadAggs(fields []ast.SelectElement) map[int]func() merge.Aggregator { + var aggMap = make(map[int]func() merge.Aggregator) + enter := func(i int, n *ast.AggrFunction) { + if n == nil { + return + } + aggMap[i] = GetAggFromName(n.Name()) + } + + for i, field := range fields { + if field == nil { + continue + } + if f, ok := field.(*ast.SelectElementFunction); ok { + enter(i, f.Function().(*ast.AggrFunction)) + } + } + + return aggMap +} diff --git a/pkg/merge/aggregator/max_aggregator.go b/pkg/merge/aggregator/max_aggregator.go index 2ae8c5de3..e37e5e6fb 100644 --- a/pkg/merge/aggregator/max_aggregator.go +++ b/pkg/merge/aggregator/max_aggregator.go @@ -22,7 +22,7 @@ import ( ) type MaxAggregator struct { - //max decimal.Decimal + // max decimal.Decimal max *gxbig.Decimal init bool } diff --git a/pkg/merge/impl/group_by/group_by_stream_merge_rows.go b/pkg/merge/impl/group/group_stream_merge_rows.go similarity index 92% rename from pkg/merge/impl/group_by/group_by_stream_merge_rows.go rename to pkg/merge/impl/group/group_stream_merge_rows.go index 8e52c530c..48ce59cfd 100644 --- a/pkg/merge/impl/group_by/group_by_stream_merge_rows.go +++ b/pkg/merge/impl/group/group_stream_merge_rows.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package group_by +package group import ( "container/heap" @@ -118,16 +118,16 @@ func (s *GroupByStreamMergeRows) merge() proto.Row { } } - row := testdata.NewMockRow(gomock.NewController(nil)) + row := testdata.NewMockKeyedRow(gomock.NewController(nil)) for _, sel := range s.stmt.Selects { if _, ok := aggrMap[sel.Column]; ok { res, _ := aggrMap[sel.Column].GetResult() // TODO use row encode() to build a new row result val, _ := res.ToInt() - row.EXPECT().GetColumnValue(sel.Column).Return(val, nil).AnyTimes() + row.EXPECT().Get(sel.Column).Return(val, nil).AnyTimes() } else { - res, _ := currentRow.GetColumnValue(sel.Column) - row.EXPECT().GetColumnValue(sel.Column).Return(res, nil).AnyTimes() + res, _ := currentRow.(proto.KeyedRow).Get(sel.Column) + row.EXPECT().Get(sel.Column).Return(res, nil).AnyTimes() } } return row @@ -136,7 +136,7 @@ func (s *GroupByStreamMergeRows) merge() proto.Row { // todo not support Avg method yet func (s *GroupByStreamMergeRows) aggregate(aggrMap map[string]merge.Aggregator, row proto.Row) { for k, v := range aggrMap { - val, err := row.GetColumnValue(k) + val, err := row.(proto.KeyedRow).Get(k) if err != nil { panic(err.Error()) } @@ -145,7 +145,7 @@ func (s *GroupByStreamMergeRows) aggregate(aggrMap map[string]merge.Aggregator, } func (s *GroupByStreamMergeRows) hasNext() bool { - //s.currentMergedRow = nil + // s.currentMergedRow = nil s.currentRow = nil if s.queue.Len() == 0 { return false diff --git a/pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go b/pkg/merge/impl/group/group_stream_merge_rows_test.go similarity index 87% rename from pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go rename to pkg/merge/impl/group/group_stream_merge_rows_test.go index 2ff735488..e589fda32 100644 --- a/pkg/merge/impl/group_by/group_by_stream_merge_rows_test.go +++ b/pkg/merge/impl/group/group_stream_merge_rows_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package group_by +package group import ( "testing" @@ -44,7 +44,6 @@ type ( ) func TestGroupByStreamMergeRows(t *testing.T) { - stmt := MergeRowStatement{ OrderBys: []merge.OrderByItem{ { @@ -77,13 +76,16 @@ func TestGroupByStreamMergeRows(t *testing.T) { if row == nil { break } - v1, _ := row.GetColumnValue(countScore) - v2, _ := row.GetColumnValue(age) + v1, _ := row.(proto.KeyedRow).Get(countScore) + v2, _ := row.(proto.KeyedRow).Get(age) res = append(res, student{countScore: v1.(int64), age: v2.(int64)}) } assert.Equal(t, []student{ - {countScore: 175, age: 81}, {countScore: 160, age: 70}, {countScore: 75, age: 68}, - {countScore: 143, age: 60}, {countScore: 70, age: 40}, + {countScore: 175, age: 81}, + {countScore: 160, age: 70}, + {countScore: 75, age: 68}, + {countScore: 143, age: 60}, + {countScore: 70, age: 40}, }, res) } @@ -98,9 +100,9 @@ func buildMergeRows(t *testing.T, vals [][]student) []*merge.MergeRows { func buildMergeRow(t *testing.T, vals []student) *merge.MergeRows { rows := make([]proto.Row, 0) for _, val := range vals { - row := testdata.NewMockRow(gomock.NewController(t)) + row := testdata.NewMockKeyedRow(gomock.NewController(t)) for k, v := range val { - row.EXPECT().GetColumnValue(k).Return(v, nil).AnyTimes() + row.EXPECT().Get(k).Return(v, nil).AnyTimes() } rows = append(rows, row) } diff --git a/pkg/merge/impl/group_by/group_by_value.go b/pkg/merge/impl/group/group_value.go similarity index 96% rename from pkg/merge/impl/group_by/group_by_value.go rename to pkg/merge/impl/group/group_value.go index 747583e5b..991c61422 100644 --- a/pkg/merge/impl/group_by/group_by_value.go +++ b/pkg/merge/impl/group/group_value.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package group_by +package group import ( "fmt" @@ -41,7 +41,7 @@ func buildGroupValues(groupByColumns []string, row proto.Row) []interface{} { values := make([]interface{}, 0) for _, column := range groupByColumns { - value, err := row.GetColumnValue(column) + value, err := row.(proto.KeyedRow).Get(column) if err != nil { panic("get column value error:" + err.Error()) } diff --git a/pkg/merge/merge_rows.go b/pkg/merge/merge_rows.go index 53bf60ed5..80c353ca7 100644 --- a/pkg/merge/merge_rows.go +++ b/pkg/merge/merge_rows.go @@ -47,6 +47,10 @@ func (s *MergeRows) Next() proto.Row { return result } +func (s *MergeRows) HasNext() bool { + return s.currentRowIndex < len(s.rows)-1 +} + func (s *MergeRows) GetCurrentRow() proto.Row { if len(s.rows) == 0 || s.currentRowIndex > len(s.rows) { return nil diff --git a/pkg/merge/merge_rows_test.go b/pkg/merge/merge_rows_test.go index 0bf705926..3ed14c374 100644 --- a/pkg/merge/merge_rows_test.go +++ b/pkg/merge/merge_rows_test.go @@ -51,8 +51,8 @@ func TestGetCurrentRow(t *testing.T) { if row == nil { break } - v1, _ := rows.GetCurrentRow().GetColumnValue(score) - v2, _ := rows.GetCurrentRow().GetColumnValue(age) + v1, _ := rows.GetCurrentRow().(proto.KeyedRow).Get(score) + v2, _ := rows.GetCurrentRow().(proto.KeyedRow).Get(age) res = append(res, student{score: v1.(int64), age: v2.(int64)}) } @@ -62,9 +62,9 @@ func TestGetCurrentRow(t *testing.T) { func buildMergeRow(t *testing.T, vals []student) *MergeRows { rows := make([]proto.Row, 0) for _, val := range vals { - row := testdata.NewMockRow(gomock.NewController(t)) + row := testdata.NewMockKeyedRow(gomock.NewController(t)) for k, v := range val { - row.EXPECT().GetColumnValue(k).Return(v, nil).AnyTimes() + row.EXPECT().Get(k).Return(v, nil).AnyTimes() } rows = append(rows, row) } diff --git a/pkg/merge/priority_queue.go b/pkg/merge/priority_queue.go index a6b75ae0d..014cc5a9c 100644 --- a/pkg/merge/priority_queue.go +++ b/pkg/merge/priority_queue.go @@ -22,6 +22,10 @@ import ( "fmt" ) +import ( + "github.com/arana-db/arana/pkg/proto" +) + type PriorityQueue struct { results []*MergeRows orderByItems []OrderByItem @@ -47,8 +51,11 @@ func (pq *PriorityQueue) Len() int { func (pq *PriorityQueue) Less(i, j int) bool { for _, item := range pq.orderByItems { - val1, _ := pq.results[i].GetCurrentRow().GetColumnValue(item.Column) - val2, _ := pq.results[j].GetCurrentRow().GetColumnValue(item.Column) + rowi := pq.results[i].GetCurrentRow().(proto.KeyedRow) + rowj := pq.results[j].GetCurrentRow().(proto.KeyedRow) + + val1, _ := rowi.Get(item.Column) + val2, _ := rowj.Get(item.Column) if val1 != val2 { if item.Desc { return fmt.Sprintf("%v", val1) > fmt.Sprintf("%v", val2) diff --git a/pkg/merge/priority_queue_test.go b/pkg/merge/priority_queue_test.go index caadf636c..9e4840f2e 100644 --- a/pkg/merge/priority_queue_test.go +++ b/pkg/merge/priority_queue_test.go @@ -26,6 +26,10 @@ import ( "github.com/stretchr/testify/assert" ) +import ( + "github.com/arana-db/arana/pkg/proto" +) + func TestPriorityQueue(t *testing.T) { rows := buildMergeRows(t, [][]student{ {{score: 85, age: 72}, {score: 75, age: 70}, {score: 65, age: 50}}, @@ -43,18 +47,25 @@ func TestPriorityQueue(t *testing.T) { break } row := heap.Pop(&queue).(*MergeRows) - v1, _ := row.GetCurrentRow().GetColumnValue(score) - v2, _ := row.GetCurrentRow().GetColumnValue(age) + v1, _ := row.GetCurrentRow().(proto.KeyedRow).Get(score) + v2, _ := row.GetCurrentRow().(proto.KeyedRow).Get(age) res = append(res, student{score: v1.(int64), age: v2.(int64)}) if row.Next() != nil { queue.Push(row) } } assert.Equal(t, []student{ - {score: 90, age: 72}, {score: 85, age: 72}, {score: 85, age: 70}, - {score: 78, age: 60}, {score: 75, age: 80}, {score: 75, age: 70}, - {score: 75, age: 68}, {score: 70, age: 40}, {score: 65, age: 80}, - {score: 65, age: 50}, {score: 60, age: 40}, + {score: 90, age: 72}, + {score: 85, age: 72}, + {score: 85, age: 70}, + {score: 78, age: 60}, + {score: 75, age: 80}, + {score: 75, age: 70}, + {score: 75, age: 68}, + {score: 70, age: 40}, + {score: 65, age: 80}, + {score: 65, age: 50}, + {score: 60, age: 40}, }, res) } @@ -75,18 +86,26 @@ func TestPriorityQueue2(t *testing.T) { break } row := heap.Pop(&queue).(*MergeRows) - v1, _ := row.GetCurrentRow().GetColumnValue(score) - v2, _ := row.GetCurrentRow().GetColumnValue(age) + + v1, _ := row.GetCurrentRow().(proto.KeyedRow).Get(score) + v2, _ := row.GetCurrentRow().(proto.KeyedRow).Get(age) res = append(res, student{score: v1.(int64), age: v2.(int64)}) if row.Next() != nil { queue.Push(row) } } assert.Equal(t, []student{ - {score: 60, age: 40}, {score: 65, age: 50}, {score: 65, age: 80}, - {score: 70, age: 40}, {score: 75, age: 68}, {score: 75, age: 70}, - {score: 75, age: 80}, {score: 78, age: 60}, {score: 85, age: 70}, - {score: 85, age: 72}, {score: 90, age: 72}, + {score: 60, age: 40}, + {score: 65, age: 50}, + {score: 65, age: 80}, + {score: 70, age: 40}, + {score: 75, age: 68}, + {score: 75, age: 70}, + {score: 75, age: 80}, + {score: 78, age: 60}, + {score: 85, age: 70}, + {score: 85, age: 72}, + {score: 90, age: 72}, }, res) } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 691898ad4..97a1a8463 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -27,7 +27,7 @@ var ( Subsystem: "parser", Name: "duration_seconds", Help: "histogram of processing time (s) in parse SQL.", - Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), //40us ~ 11min + Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), // 40us ~ 11min }) OptimizeDuration = prometheus.NewHistogram(prometheus.HistogramOpts{ @@ -35,7 +35,7 @@ var ( Subsystem: "optimizer", Name: "duration_seconds", Help: "histogram of processing time (s) in optimizer.", - Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), //40us ~ 11min + Buckets: prometheus.ExponentialBuckets(0.00004, 2, 25), // 40us ~ 11min }) ExecuteDuration = prometheus.NewHistogram(prometheus.HistogramOpts{ @@ -43,7 +43,7 @@ var ( Subsystem: "executor", Name: "duration_seconds", Help: "histogram of processing time (s) in execute.", - Buckets: prometheus.ExponentialBuckets(0.0001, 2, 30), //100us ~ 15h, + Buckets: prometheus.ExponentialBuckets(0.0001, 2, 30), // 100us ~ 15h, }) ) diff --git a/pkg/mysql/client.go b/pkg/mysql/client.go index cb5679360..307df3839 100644 --- a/pkg/mysql/client.go +++ b/pkg/mysql/client.go @@ -28,7 +28,6 @@ import ( "math/big" "net" "net/url" - "strconv" "strings" "time" ) @@ -37,6 +36,7 @@ import ( "github.com/arana-db/arana/pkg/constants/mysql" err2 "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/bytefmt" "github.com/arana-db/arana/pkg/util/log" "github.com/arana-db/arana/third_party/pools" ) @@ -414,10 +414,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } case "maxAllowedPacket": - cfg.MaxAllowedPacket, err = strconv.Atoi(value) + byteSize, err := bytefmt.ToBytes(value) if err != nil { - return + return err } + cfg.MaxAllowedPacket = int(byteSize) default: // lazy init if cfg.Params == nil { @@ -642,7 +643,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32, if !ok { return 0, nil, "", err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (lower 2 bytes)") } - var capabilities = uint32(capLower) + capabilities := uint32(capLower) // The packet can end here. if pos == len(data) { @@ -727,7 +728,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32, // Returns a SQLError. func (conn *BackendConnection) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, plugin string) error { // Build our flags. - var flags = mysql.CapabilityClientLongPassword | + flags := mysql.CapabilityClientLongPassword | mysql.CapabilityClientLongFlag | mysql.CapabilityClientProtocol41 | mysql.CapabilityClientTransactions | @@ -747,16 +748,15 @@ func (conn *BackendConnection) writeHandshakeResponse41(capabilities uint32, scr // FIXME(alainjobart) add multi statement. - length := - 4 + // Client capability flags. - 4 + // Max-packet size. - 1 + // Character set. - 23 + // Reserved. - lenNullString(conn.conf.User) + - // length of scrambled password is handled below. - len(scrambledPassword) + - 21 + // "mysql_native_password" string. - 1 // terminating zero. + length := 4 + // Client capability flags. + 4 + // Max-packet size. + 1 + // Character set. + 23 + // Reserved. + lenNullString(conn.conf.User) + + // length of scrambled password is handled below. + len(scrambledPassword) + + 21 + // "mysql_native_password" string. + 1 // terminating zero. // Add the DB name if the server supports it. if conn.conf.DBName != "" && (capabilities&mysql.CapabilityClientConnectWithDB != 0) { @@ -858,6 +858,7 @@ func (conn *BackendConnection) WriteComSetOption(operation uint16) error { return nil } +// WriteComFieldList https://dev.mysql.com/doc/internals/en/com-field-list.html func (conn *BackendConnection) WriteComFieldList(table string, wildcard string) error { conn.c.sequence = 0 length := lenNullString(table) + lenNullString(wildcard) @@ -871,10 +872,9 @@ func (conn *BackendConnection) WriteComFieldList(table string, wildcard string) pos = writeByte(data, 0, mysql.ComFieldList) if len(wildcard) > 0 { pos = writeNullString(data, pos, table) - writeNullString(data, pos, wildcard) - } else { - pos = writeEOFString(data, pos, table) writeEOFString(data, pos, wildcard) + } else { + writeNullString(data, pos, table) } if err := conn.c.writeEphemeralPacket(); err != nil { @@ -886,7 +886,7 @@ func (conn *BackendConnection) WriteComFieldList(table string, wildcard string) func (conn *BackendConnection) readResultSetHeaderPacket() (affectedRows, lastInsertID uint64, colNumber int, more bool, warnings uint16, err error) { // Get the result. - affectedRows, lastInsertID, colNumber, more, warning, err := conn.ReadComQueryResponse() + affectedRows, lastInsertID, colNumber, more, warning, err := conn.readComQueryResponse() if err != nil { return affectedRows, lastInsertID, colNumber, more, warning, err } @@ -908,174 +908,18 @@ func (conn *BackendConnection) readResultSetColumnsPacket(colNumber int) (column } // ReadQueryRow returns iterator, and the line reads the results set -func (conn *BackendConnection) ReadQueryRow() (iter *IterRow, affectedRows, lastInsertID uint64, more bool, warnings uint16, err error) { - iterRow := &IterRow{BackendConnection: conn, - Row: &Row{ - Content: []byte{}, ResultSet: &ResultSet{ - Columns: make([]proto.Field, 0), - }}, - hasNext: true, - } - affectedRows, lastInsertID, colNumber, more, warning, err := conn.readResultSetHeaderPacket() - if err != nil { - return iterRow, affectedRows, lastInsertID, more, warning, err - } - if colNumber == 0 { - // OK packet, means no results. Just use the numbers. - return iterRow, affectedRows, lastInsertID, more, warning, nil - } - - // Read column headers. One packet per column. - // Build the fields. - columns := make([]proto.Field, colNumber) - for i := 0; i < colNumber; i++ { - field := &Field{} - columns[i] = field - if err = conn.ReadColumnDefinition(field, i); err != nil { - return - } - } - - iterRow.Row.ResultSet.Columns = columns - - if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 { - // EOF is only present here if it's not deprecated. - data, err := conn.c.readEphemeralPacket() - if err != nil { - return nil, affectedRows, lastInsertID, more, warning, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err) - } - if isEOFPacket(data) { - - // This is what we expect. - // Warnings and status flags are ignored. - conn.c.recycleReadPacket() - // goto: read row loop - - } else if isErrorPacket(data) { - defer conn.c.recycleReadPacket() - return nil, affectedRows, lastInsertID, more, warning, ParseErrorPacket(data) - } else { - defer conn.c.recycleReadPacket() - return nil, affectedRows, lastInsertID, more, warning, fmt.Errorf("unexpected packet after fields: %v", data) - } - } - - return iterRow, affectedRows, lastInsertID, more, warnings, err +func (conn *BackendConnection) ReadQueryRow() *RawResult { + return newResult(conn) } // ReadQueryResult gets the result from the last written query. -func (conn *BackendConnection) ReadQueryResult(wantFields bool) (result *Result, more bool, warnings uint16, err error) { - // Get the result. - affectedRows, lastInsertID, colNumber, more, warnings, err := conn.ReadComQueryResponse() - if err != nil { - return nil, false, 0, err - } - - if colNumber == 0 { - // OK packet, means no results. Just use the numbers. - return &Result{ - AffectedRows: affectedRows, - InsertId: lastInsertID, - }, more, warnings, nil - } - - result = &Result{ - Fields: make([]proto.Field, colNumber), - } - - // Read column headers. One packet per column. - // Build the fields. - for i := 0; i < colNumber; i++ { - field := &Field{} - result.Fields[i] = field - - if wantFields { - if err := conn.ReadColumnDefinition(field, i); err != nil { - return nil, false, 0, err - } - } else { - if err := conn.ReadColumnDefinitionType(field, i); err != nil { - return nil, false, 0, err - } - } - } - - if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 { - // EOF is only present here if it's not deprecated. - data, err := conn.c.readEphemeralPacket() - if err != nil { - return nil, false, 0, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err) - } - if isEOFPacket(data) { - - // This is what we expect. - // Warnings and status flags are ignored. - conn.c.recycleReadPacket() - // goto: read row loop - - } else if isErrorPacket(data) { - defer conn.c.recycleReadPacket() - return nil, false, 0, ParseErrorPacket(data) - } else { - defer conn.c.recycleReadPacket() - return nil, false, 0, fmt.Errorf("unexpected packet after fields: %v", data) - } - } - - // read each row until EOF or OK packet. - for { - data, err := conn.c.ReadPacket() - if err != nil { - return nil, false, 0, err - } - - if isEOFPacket(data) { - // Strip the partial Fields before returning. - if !wantFields { - result.Fields = nil - } - result.AffectedRows = uint64(len(result.Rows)) - - // The deprecated EOF packets change means that this is either an - // EOF packet or an OK packet with the EOF type code. - if conn.capabilities&mysql.CapabilityClientDeprecateEOF == 0 { - warnings, more, err = parseEOFPacket(data) - if err != nil { - return nil, false, 0, err - } - } else { - var statusFlags uint16 - _, _, statusFlags, warnings, err = parseOKPacket(data) - if err != nil { - return nil, false, 0, err - } - more = (statusFlags & mysql.ServerMoreResultsExists) != 0 - } - return result, more, warnings, nil - - } else if isErrorPacket(data) { - // Error packet. - return nil, false, 0, ParseErrorPacket(data) - } - - //// Check we're not over the limit before we add more. - //if len(result.Rows) == maxrows { - // if err := conn.DrainResults(); err != nil { - // return nil, false, 0, err - // } - // return nil, false, 0, err2.NewSQLError(mysql.ERVitessMaxRowsExceeded, mysql.SSUnknownSQLState, "Row count exceeded %d") - //} - - // Regular row. - row, err := conn.parseRow(data, result.Fields) - if err != nil { - return nil, false, 0, err - } - result.Rows = append(result.Rows, row) - } +func (conn *BackendConnection) ReadQueryResult(wantFields bool) proto.Result { + ret := conn.ReadQueryRow() + ret.setWantFields(wantFields) + return ret } -func (conn *BackendConnection) ReadComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) { +func (conn *BackendConnection) readComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) { data, err := conn.c.readEphemeralPacket() if err != nil { return 0, 0, 0, false, 0, err2.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", err) @@ -1107,7 +951,7 @@ func (conn *BackendConnection) ReadComQueryResponse() (affectedRows uint64, last } // ReadColumnDefinition reads the next Column Definition packet. -// Returns a SQLError. +// Returns a SQLError. https://dev.mysql.com/doc/internals/en/com-query-response.html#column-definition func (conn *BackendConnection) ReadColumnDefinition(field *Field, index int) error { colDef, err := conn.c.readEphemeralPacket() if err != nil { @@ -1188,20 +1032,20 @@ func (conn *BackendConnection) ReadColumnDefinition(field *Field, index int) err } field.decimals = decimals - //if more Content, command was field list - if len(colDef) > pos+8 { - //length of default value lenenc-int - field.defaultValueLength, pos, ok = readUint64(colDef, pos) - if !ok { - return err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "extracting col %v default value failed", index) - } + // filter [0x00][0x00] + pos += 2 + + // if more Content, command was field list + if len(colDef) > pos { + // length of default value lenenc-int + field.defaultValueLength, pos = readComFieldListDefaultValueLength(colDef, pos) if pos+int(field.defaultValueLength) > len(colDef) { return err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "extracting col %v default value failed", index) } - //default value string[$len] - field.defaultValue = colDef[pos:(pos + int(field.defaultValueLength))] + // default value string[$len] + field.defaultValue = append(field.defaultValue, colDef[pos:(pos+int(field.defaultValueLength))]...) } return nil } @@ -1281,18 +1125,6 @@ func (conn *BackendConnection) ReadColumnDefinitionType(field *Field, index int) return nil } -// parseRow parses an individual row. -// Returns a SQLError. -func (conn *BackendConnection) parseRow(data []byte, fields []proto.Field) (proto.Row, error) { - row := &Row{ - Content: data, - ResultSet: &ResultSet{ - Columns: fields, - }, - } - return row, nil -} - // DrainResults will read all packets for a result set and ignore them. func (conn *BackendConnection) DrainResults() error { for { @@ -1303,9 +1135,12 @@ func (conn *BackendConnection) DrainResults() error { if isEOFPacket(data) { conn.c.recycleReadPacket() return nil - } else if isErrorPacket(data) { - defer conn.c.recycleReadPacket() - return ParseErrorPacket(data) + } + + if isErrorPacket(data) { + err = ParseErrorPacket(data) + conn.c.recycleReadPacket() + return err } conn.c.recycleReadPacket() } @@ -1317,7 +1152,7 @@ func (conn *BackendConnection) ReadColumnDefinitions() ([]proto.Field, error) { for { field := &Field{} err := conn.ReadColumnDefinition(field, i) - if err == io.EOF { + if errors.Is(err, io.EOF) { return result, nil } if err != nil { @@ -1328,59 +1163,9 @@ func (conn *BackendConnection) ReadColumnDefinitions() ([]proto.Field, error) { } } -// Execute executes a query and returns the result. -// Returns a SQLError. Depending on the transport used, the error -// returned might be different for the same condition: -// -// 1. if the server closes the connection when no command is in flight: -// -// 1.1 unix: WriteComQuery will fail with a 'broken pipe', and we'll -// return CRServerGone(2006). -// -// 1.2 tcp: WriteComQuery will most likely work, but ReadComQueryResponse -// will fail, and we'll return CRServerLost(2013). -// -// This is because closing a TCP socket on the server side sends -// a FIN to the client (telling the client the server is done -// writing), but on most platforms doesn't send a RST. So the -// client has no idea it can't write. So it succeeds writing Content, which -// *then* triggers the server to send a RST back, received a bit -// later. By then, the client has already started waiting for -// the response, and will just return a CRServerLost(2013). -// So CRServerGone(2006) will almost never be seen with TCP. -// -// 2. if the server closes the connection when a command is in flight, -// ReadComQueryResponse will fail, and we'll return CRServerLost(2013). -func (conn *BackendConnection) Execute(query string, wantFields bool) (result *Result, err error) { - result, _, err = conn.ExecuteMulti(query, wantFields) - return -} - -// ExecuteMulti is for fetching multiple results from a multi-statement result. -// It returns an additional 'more' flag. If it is set, you must fetch the additional -// results using ReadQueryResult. -func (conn *BackendConnection) ExecuteMulti(query string, wantFields bool) (result *Result, more bool, err error) { - defer func() { - if err != nil { - if sqlerr, ok := err.(*err2.SQLError); ok { - sqlerr.Query = query - } - } - }() - - // Send the query as a COM_QUERY packet. - if err = conn.WriteComQuery(query); err != nil { - return nil, false, err - } - - result, more, _, err = conn.ReadQueryResult(wantFields) - return -} - // ExecuteWithWarningCountIterRow is for fetching results and a warning count -// Note: In a future iteration this should be abolished and merged into the -// Execute API. -func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (res *Result, warnings uint16, err error) { +// Note: In a future iteration this should be abolished and merged into the Execute API. +func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (result proto.Result, err error) { defer func() { if err != nil { if sqlErr, ok := err.(*err2.SQLError); ok { @@ -1391,26 +1176,22 @@ func (conn *BackendConnection) ExecuteWithWarningCountIterRow(query string) (res // Send the query as a COM_QUERY packet. if err = conn.WriteComQuery(query); err != nil { - return nil, 0, err + return } - iterTextRow, affectedRows, lastInsertID, _, warnings, err := conn.ReadQueryRow() - iterRow := &TextIterRow{iterTextRow} + res := conn.ReadQueryRow() + res.setTextProtocol() + res.setWantFields(true) + + result = res - res = &Result{ - AffectedRows: affectedRows, - InsertId: lastInsertID, - Fields: iterRow.Fields(), - Rows: []proto.Row{iterRow}, - DataChan: make(chan proto.Row, 1), - } return } // ExecuteWithWarningCount is for fetching results and a warning count // Note: In a future iteration this should be abolished and merged into the // Execute API. -func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields bool) (result *Result, warnings uint16, err error) { +func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields bool) (result proto.Result, err error) { defer func() { if err != nil { if sqlErr, ok := err.(*err2.SQLError); ok { @@ -1421,49 +1202,42 @@ func (conn *BackendConnection) ExecuteWithWarningCount(query string, wantFields // Send the query as a COM_QUERY packet. if err = conn.WriteComQuery(query); err != nil { - return nil, 0, err + return } - result, _, warnings, err = conn.ReadQueryResult(wantFields) + result = conn.ReadQueryResult(wantFields) + return } -func (conn *BackendConnection) PrepareExecuteArgs(query string, args []interface{}) (result *Result, warnings uint16, err error) { +func (conn *BackendConnection) PrepareExecuteArgs(query string, args []interface{}) (proto.Result, error) { stmt, err := conn.prepare(query) if err != nil { - return nil, 0, err + return nil, err } return stmt.execArgs(args) } -func (conn *BackendConnection) PrepareQueryArgsIterRow(query string, data []interface{}) (result *Result, warnings uint16, err error) { - stmt, err := conn.prepare(query) - if err != nil { - return nil, 0, err - } - return stmt.queryArgsIterRow(data) -} - -func (conn *BackendConnection) PrepareQueryArgs(query string, data []interface{}) (result *Result, warnings uint16, err error) { +func (conn *BackendConnection) PrepareQueryArgs(query string, data []interface{}) (proto.Result, error) { stmt, err := conn.prepare(query) if err != nil { - return nil, 0, err + return nil, err } return stmt.queryArgs(data) } -func (conn *BackendConnection) PrepareExecute(query string, data []byte) (result *Result, warnings uint16, err error) { +func (conn *BackendConnection) PrepareExecute(query string, data []byte) (proto.Result, error) { stmt, err := conn.prepare(query) if err != nil { - return nil, 0, err + return nil, err } return stmt.exec(data) } -func (conn *BackendConnection) PrepareQuery(query string, data []byte) (Result *Result, warnings uint16, err error) { +func (conn *BackendConnection) PrepareQuery(query string, data []byte) (proto.Result, error) { stmt, err := conn.prepare(query) if err != nil { - return nil, 0, err + return nil, err } return stmt.query(data) } diff --git a/pkg/mysql/client_test.go b/pkg/mysql/client_test.go index aeb11e9cd..4eea2c377 100644 --- a/pkg/mysql/client_test.go +++ b/pkg/mysql/client_test.go @@ -304,35 +304,38 @@ func TestPrepare(t *testing.T) { assert.Equal(t, 0, stmt.paramCount) } -func TestReadComQueryResponse(t *testing.T) { - dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true" - cfg, _ := ParseDSN(dsn) - conn := &BackendConnection{conf: cfg} - conn.c = newConn(new(mockConn)) - buf := make([]byte, 13) - buf[0] = 9 - buf[4] = mysql.OKPacket - buf[5] = 1 - buf[6] = 1 - conn.c.conn.(*mockConn).data = buf - affectedRows, lastInsertID, _, _, _, err := conn.ReadComQueryResponse() - assert.NoError(t, err) - assert.Equal(t, uint64(0x1), affectedRows) - assert.Equal(t, uint64(0x1), lastInsertID) -} +//func TestReadComQueryResponse(t *testing.T) { +// dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true" +// cfg, _ := ParseDSN(dsn) +// conn := &BackendConnection{conf: cfg} +// conn.c = newConn(new(mockConn)) +// buf := make([]byte, 13) +// buf[0] = 9 +// buf[4] = mysql.OKPacket +// buf[5] = 1 +// buf[6] = 1 +// conn.c.conn.(*mockConn).data = buf +// affectedRows, lastInsertID, _, _, _, err := conn.ReadComQueryResponse() +// assert.NoError(t, err) +// assert.Equal(t, uint64(0x1), affectedRows) +// assert.Equal(t, uint64(0x1), lastInsertID) +//} func TestReadColumnDefinition(t *testing.T) { dsn := "admin:123456@tcp(127.0.0.1:3306)/pass?allowAllFiles=true&allowCleartextPasswords=true" cfg, _ := ParseDSN(dsn) conn := &BackendConnection{conf: cfg} conn.c = newConn(new(mockConn)) - buf := make([]byte, 100) - buf[0] = 96 - buf[4] = 3 + buf := make([]byte, 80) + buf[0] = 0x4d + buf[1] = 0x00 + buf[2] = 0x00 + buf[3] = 0x00 + buf[4] = 0x03 // catalog buf[5] = 'd' buf[6] = 'e' buf[7] = 'f' - buf[8] = 8 + buf[8] = 0x08 // schema buf[9] = 't' buf[10] = 'e' buf[11] = 's' @@ -341,7 +344,7 @@ func TestReadColumnDefinition(t *testing.T) { buf[14] = 'a' buf[15] = 's' buf[16] = 'e' - buf[17] = 9 + buf[17] = 0x09 // table buf[18] = 't' buf[19] = 'e' buf[20] = 's' @@ -351,18 +354,60 @@ func TestReadColumnDefinition(t *testing.T) { buf[24] = 'b' buf[25] = 'l' buf[26] = 'e' - buf[28] = 4 - buf[29] = 'n' - buf[30] = 'a' - buf[31] = 'm' - buf[32] = 'e' - buf[37] = 255 - buf[41] = 15 - buf[45] = 4 - buf[53] = 'u' - buf[54] = 's' - buf[55] = 'e' - buf[56] = 'r' + buf[27] = 0x09 // org table + buf[28] = 't' + buf[29] = 'e' + buf[30] = 's' + buf[31] = 't' + buf[32] = 't' + buf[33] = 'a' + buf[34] = 'b' + buf[35] = 'l' + buf[36] = 'e' + buf[37] = 0x04 // name + buf[38] = 'n' + buf[39] = 'a' + buf[40] = 'm' + buf[41] = 'e' + buf[42] = 0x04 // org name + buf[43] = 'n' + buf[44] = 'a' + buf[45] = 'm' + buf[46] = 'e' + buf[47] = 0x0c + buf[48] = 0x3f // character set + buf[49] = 0x00 + buf[50] = 0x13 // column length + buf[51] = 0x00 + buf[52] = 0x00 + buf[53] = 0x00 + buf[54] = 0x0c // field type + buf[55] = 0x81 // flags + buf[56] = 0x00 + buf[57] = 0x00 // decimals + buf[58] = 0x00 // decimals + buf[59] = 0x00 // decimals + buf[60] = 0x13 // default + buf[61] = '0' + buf[62] = '0' + buf[63] = '0' + buf[64] = '0' + buf[65] = '-' + buf[66] = '0' + buf[67] = '0' + buf[68] = '-' + buf[69] = '0' + buf[70] = '0' + buf[71] = ' ' + buf[72] = '0' + buf[73] = '0' + buf[74] = ':' + buf[75] = '0' + buf[76] = '0' + buf[77] = ':' + buf[78] = '0' + buf[79] = '0' + conn.c.conn.(*mockConn).data = buf field := &Field{} err := conn.ReadColumnDefinition(field, 0) @@ -370,10 +415,10 @@ func TestReadColumnDefinition(t *testing.T) { assert.Equal(t, "testtable", field.table) assert.Equal(t, "testbase", field.database) assert.Equal(t, "name", field.name) - assert.Equal(t, mysql.FieldTypeVarChar, field.fieldType) - assert.Equal(t, uint64(0x4), field.defaultValueLength) - assert.Equal(t, "user", string(field.defaultValue)) - assert.Equal(t, uint32(255), field.columnLength) + assert.Equal(t, mysql.FieldTypeDateTime, field.fieldType) + assert.Equal(t, uint64(0x13), field.defaultValueLength) + assert.Equal(t, "0000-00-00 00:00:00", string(field.defaultValue)) + assert.Equal(t, uint32(19), field.columnLength) } func TestReadColumnDefinitionType(t *testing.T) { diff --git a/pkg/mysql/conn.go b/pkg/mysql/conn.go index 5f91fc7f7..e5a9c015f 100644 --- a/pkg/mysql/conn.go +++ b/pkg/mysql/conn.go @@ -48,6 +48,9 @@ const ( // connBufferSize is how much we buffer for reading and // writing. It is also how much we allocate for ephemeral buffers. connBufferSize = 16 * 1024 + + // packetHeaderSize is the first four bytes of a packet + packetHeaderSize int = 4 ) // Constants for how ephemeral buffers were used for reading / writing. @@ -426,6 +429,53 @@ func (c *Conn) ReadPacket() ([]byte, error) { return result, err } +func (c *Conn) writePacketForFieldList(data []byte) error { + index := 0 + dataLength := len(data) - packetHeaderSize + + w, unget := c.getWriter() + defer unget() + + var header [packetHeaderSize]byte + for { + // toBeSent is capped to MaxPacketSize. + toBeSent := dataLength + if toBeSent > mysql.MaxPacketSize { + toBeSent = mysql.MaxPacketSize + } + + // Write the body. + if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil { + return errors.Wrapf(err, "Write(header) failed") + } else if n != (toBeSent + packetHeaderSize) { + return errors.Wrapf(err, "Write(packet) returned a short write: %v < %v", n, toBeSent+packetHeaderSize) + } + + // Update our state. + c.sequence++ + dataLength -= toBeSent + if dataLength == 0 { + if toBeSent == mysql.MaxPacketSize { + // The packet we just sent had exactly + // MaxPacketSize size, we need to + // send a zero-size packet too. + header[0] = 0 + header[1] = 0 + header[2] = 0 + header[3] = c.sequence + if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil { + return errors.Wrapf(err, "Write(header) failed") + } else if n != (toBeSent + packetHeaderSize) { + return errors.Wrapf(err, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize)) + } + c.sequence++ + } + return nil + } + index += toBeSent + } +} + // writePacket writes a packet, possibly cutting it into multiple // chunks. Note this is not very efficient, as the client probably // has to build the []byte and that makes a memory copy. @@ -475,7 +525,7 @@ func (c *Conn) writePacket(data []byte) error { if packetLength == mysql.MaxPacketSize { // The packet we just sent had exactly // MaxPacketSize size, we need to - // sent a zero-size packet too. + // send a zero-size packet too. header[0] = 0 header[1] = 0 header[2] = 0 @@ -512,7 +562,7 @@ func (c *Conn) writeEphemeralPacket() error { switch c.currentEphemeralPolicy { case ephemeralWrite: if err := c.writePacket(*c.currentEphemeralBuffer); err != nil { - return errors.Wrapf(err, "conn %v", c.ID()) + return errors.WithStack(errors.Wrapf(err, "conn %v", c.ID())) } case ephemeralUnused, ephemeralRead: // Programming error. @@ -522,7 +572,7 @@ func (c *Conn) writeEphemeralPacket() error { return nil } -// recycleWritePacket recycles the write packet. It needs to be called +// recycleWritePacket recycles write packet. It needs to be called // after writeEphemeralPacket was called. func (c *Conn) recycleWritePacket() { if c.currentEphemeralPolicy != ephemeralWrite { @@ -660,6 +710,18 @@ func (c *Conn) writeErrorPacketFromError(err error) error { return c.writeErrorPacket(mysql.ERUnknownError, mysql.SSUnknownSQLState, "unknown error: %v", err) } +func (c *Conn) buildEOFPacket(flags uint16, warnings uint16) []byte { + data := make([]byte, 9) + pos := 0 + data[pos] = 0x05 + pos += 3 + pos = writeLenEncInt(data, pos, uint64(c.sequence)) + pos = writeByte(data, pos, mysql.EOFPacket) + pos = writeUint16(data, pos, flags) + _ = writeUint16(data, pos, warnings) + return data +} + // writeEOFPacket writes an EOF packet, through the buffer, and // doesn't flush (as it is used as part of a query result). func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { @@ -677,7 +739,7 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { // Packet parsing methods, for generic packets. // -// isEOFPacket determines whether or not a Content packet is a "true" EOF. DO NOT blindly compare the +// isEOFPacket determines whether a Content packet is a "true" EOF. DO NOT blindly compare the // first byte of a packet to EOFPacket as you might do for other packet types, as 0xfe is overloaded // as a first byte. // diff --git a/pkg/mysql/conn_test.go b/pkg/mysql/conn_test.go index 0b49c1eca..214688126 100644 --- a/pkg/mysql/conn_test.go +++ b/pkg/mysql/conn_test.go @@ -67,6 +67,7 @@ func (m *mockConn) Read(b []byte) (n int, err error) { m.read += n return } + func (m *mockConn) Write(b []byte) (n int, err error) { if m.closed { return 0, errConnClosed @@ -86,22 +87,28 @@ func (m *mockConn) Write(b []byte) (n int, err error) { } return } + func (m *mockConn) Close() error { m.closed = true return nil } + func (m *mockConn) LocalAddr() net.Addr { return m.laddr } + func (m *mockConn) RemoteAddr() net.Addr { return m.raddr } + func (m *mockConn) SetDeadline(t time.Time) error { return nil } + func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } + func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/pkg/mysql/execute_handle.go b/pkg/mysql/execute_handle.go index 5d1974cfd..f150b5d57 100644 --- a/pkg/mysql/execute_handle.go +++ b/pkg/mysql/execute_handle.go @@ -29,6 +29,7 @@ import ( "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" ) @@ -70,28 +71,38 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { c.startWriterBuffering() defer func() { if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err) } }() c.recycleReadPacket() - result, warn, err := l.executor.ExecutorComQuery(ctx) - if err != nil { + + var ( + result proto.Result + err error + warn uint16 + ) + + if result, warn, err = l.executor.ExecutorComQuery(ctx); err != nil { + log.Errorf("executor com_query error %v: %v", ctx.ConnectionID, err) if wErr := c.writeErrorPacketFromError(err); wErr != nil { - log.Errorf("Error writing query error to client %v: %v", l.connectionID, wErr) + log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, wErr) return wErr } return nil } - if cr, ok := result.(*proto.CloseableResult); ok { - result = cr.Result - defer func() { - _ = cr.Close() - }() + var ds proto.Dataset + if ds, err = result.Dataset(); err != nil { + log.Errorf("get dataset error %v: %v", ctx.ConnectionID, err) + if wErr := c.writeErrorPacketFromError(err); wErr != nil { + log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, wErr) + return wErr + } + return nil } - if len(result.GetFields()) == 0 { + if ds == nil { // A successful callback with no fields means that this was a // DML or other write-only operation. // @@ -104,10 +115,15 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { ) return c.writeOKPacket(affected, insertId, c.StatusFlags, warn) } - if err = c.writeFields(l.capabilities, result); err != nil { + + fields, _ := ds.Fields() + + if err = c.writeFields(l.capabilities, fields); err != nil { + log.Errorf("write fields error %v: %v", ctx.ConnectionID, err) return err } - if err = c.writeRowChan(result); err != nil { + if err = c.writeDataset(ds); err != nil { + log.Errorf("write dataset error %v: %v", ctx.ConnectionID, err) return err } if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil { @@ -128,17 +144,37 @@ func (l *Listener) handleFieldList(c *Conn, ctx *proto.Context) error { return wErr } } - return c.writeFields(l.capabilities, &Result{Fields: fields}) + + // Combine the fields into a package to send + var des []byte + for _, field := range fields { + fld := field.(*Field) + des = append(des, c.DefColumnDefinition(fld)...) + } + + des = append(des, c.buildEOFPacket(0, 2)...) + + if err = c.writePacketForFieldList(des); err != nil { + return err + } + + return nil } func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { c.startWriterBuffering() defer func() { if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err) } }() - stmtID, _, err := c.parseComStmtExecute(&l.stmts, ctx.Data) + + var ( + stmtID uint32 + err error + ) + + stmtID, _, err = c.parseComStmtExecute(&l.stmts, ctx.Data) c.recycleReadPacket() if stmtID != uint32(0) { @@ -154,7 +190,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { if err != nil { if wErr := c.writeErrorPacketFromError(err); wErr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", l.connectionID, wErr) + log.Error("Error writing query error to client %v: %v", ctx.ConnectionID, wErr) return wErr } return nil @@ -163,23 +199,29 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { prepareStmt, _ := l.stmts.Load(stmtID) ctx.Stmt = prepareStmt.(*proto.Stmt) - result, warn, err := l.executor.ExecutorComStmtExecute(ctx) - if err != nil { + var ( + result proto.Result + warn uint16 + ) + + if result, warn, err = l.executor.ExecutorComStmtExecute(ctx); err != nil { if wErr := c.writeErrorPacketFromError(err); wErr != nil { - log.Errorf("Error writing query error to client %v: %v, executor error: %v", l.connectionID, wErr, err) + log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err) return wErr } return nil } - if cr, ok := result.(*proto.CloseableResult); ok { - result = cr.Result - defer func() { - _ = cr.Close() - }() + var ds proto.Dataset + if ds, err = result.Dataset(); err != nil { + if wErr := c.writeErrorPacketFromError(err); wErr != nil { + log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err) + return wErr + } + return nil } - if len(result.GetFields()) == 0 { + if ds == nil { // A successful callback with no fields means that this was a // DML or other write-only operation. // @@ -190,10 +232,17 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { lastInsertId, _ := result.LastInsertId() return c.writeOKPacket(affected, lastInsertId, c.StatusFlags, warn) } - if err = c.writeFields(l.capabilities, result); err != nil { + + defer func() { + _ = ds.Close() + }() + + fields, _ := ds.Fields() + + if err = c.writeFields(l.capabilities, fields); err != nil { return err } - if err = c.writeBinaryRowChan(result); err != nil { + if err = c.writeDatasetBinary(ds); err != nil { return err } if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil { @@ -207,7 +256,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error { query := string(ctx.Data[1:]) c.recycleReadPacket() - // Popoulate PrepareData + // Populate PrepareData statementID := l.statementID.Inc() stmt := &proto.Stmt{ @@ -215,7 +264,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error { PrepareStmt: query, } p := parser.New() - act, err := p.ParseOneStmt(stmt.PrepareStmt, "", "") + act, hts, err := p.ParseOneStmtHints(stmt.PrepareStmt, "", "") if err != nil { log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err) if wErr := c.writeErrorPacketFromError(err); wErr != nil { @@ -223,6 +272,18 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error { return wErr } } + + for _, it := range hts { + var h *hint.Hint + if h, err = hint.Parse(it); err != nil { + if wErr := c.writeErrorPacketFromError(err); wErr != nil { + log.Errorf("Conn %v: Error writing prepared statement error: %v", c, wErr) + return wErr + } + } + stmt.Hints = append(stmt.Hints, h) + } + stmt.StmtNode = act paramsCount := uint16(strings.Count(query, "?")) @@ -260,7 +321,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error { case 1: l.capabilities &^= mysql.CapabilityClientMultiStatements default: - log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", l.connectionID, ctx.Data) + log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data) if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err @@ -271,7 +332,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error { return err } } - log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", l.connectionID, ctx.Data) + log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data) if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err diff --git a/pkg/mysql/fields.go b/pkg/mysql/fields.go index 0eab340e9..987d3111c 100644 --- a/pkg/mysql/fields.go +++ b/pkg/mysql/fields.go @@ -19,6 +19,7 @@ package mysql import ( "database/sql" + "math" "reflect" ) @@ -61,19 +62,62 @@ type Field struct { defaultValue []byte } +func (mf *Field) FieldType() mysql.FieldType { + return mf.fieldType +} + +func (mf *Field) DecimalSize() (int64, int64, bool) { + decimals := int64(mf.decimals) + switch mf.fieldType { + case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal: + if decimals > 0 { + return int64(mf.length) - 2, decimals, true + } + return int64(mf.length) - 1, decimals, true + case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, mysql.FieldTypeTime: + return decimals, decimals, true + case mysql.FieldTypeFloat, mysql.FieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (mf *Field) Length() (length int64, ok bool) { + length = int64(mf.length) + ok = true + return +} + +func (mf *Field) Nullable() (nullable, ok bool) { + nullable, ok = mf.flags&mysql.NotNullFlag == 0, true + return +} + func NewField(name string, filedType mysql.FieldType) *Field { return &Field{name: name, fieldType: filedType} } +func (mf *Field) Name() string { + return mf.name +} + +func (mf *Field) OriginName() string { + return mf.orgName +} + func (mf *Field) TableName() string { return mf.table } -func (mf *Field) DataBaseName() string { +func (mf *Field) DatabaseName() string { return mf.database } -func (mf *Field) TypeDatabaseName() string { +func (mf *Field) DatabaseTypeName() string { switch mf.fieldType { case mysql.FieldTypeBit: return "BIT" @@ -157,7 +201,7 @@ func (mf *Field) TypeDatabaseName() string { } } -func (mf *Field) scanType() reflect.Type { +func (mf *Field) ScanType() reflect.Type { switch mf.fieldType { case mysql.FieldTypeTiny: if mf.flags&mysql.NotNullFlag != 0 { diff --git a/pkg/mysql/fields_test.go b/pkg/mysql/fields_test.go index 1d2651034..fecdeec37 100644 --- a/pkg/mysql/fields_test.go +++ b/pkg/mysql/fields_test.go @@ -36,7 +36,7 @@ func TestTableName(t *testing.T) { func TestDataBaseName(t *testing.T) { field := createDefaultField() - assert.Equal(t, "db_arana", field.DataBaseName()) + assert.Equal(t, "db_arana", field.DatabaseName()) } func TestTypeDatabaseName(t *testing.T) { @@ -83,7 +83,7 @@ func TestTypeDatabaseName(t *testing.T) { } for _, unit := range unitTests { field := createField(unit.field, mysql.Collations[unit.collation]) - assert.Equal(t, unit.expected, field.TypeDatabaseName()) + assert.Equal(t, unit.expected, field.DatabaseTypeName()) } } diff --git a/pkg/mysql/iterator.go b/pkg/mysql/iterator.go deleted file mode 100644 index 27b50c750..000000000 --- a/pkg/mysql/iterator.go +++ /dev/null @@ -1,412 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mysql - -import ( - "encoding/binary" - "fmt" - "math" - "time" -) - -import ( - "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/mysql/errors" - "github.com/arana-db/arana/pkg/proto" -) - -// Iter is used to iterate output results -type Iter interface { - Next() (bool, error) -} - -// IterRow implementation of Iter -type IterRow struct { - *BackendConnection - *Row - hasNext bool -} - -// TextIterRow is iterator for text protocol result set -type TextIterRow struct { - *IterRow -} - -// BinaryIterRow is iterator for binary protocol result set -type BinaryIterRow struct { - *IterRow -} - -func (iterRow *IterRow) Next() (bool, error) { - // read one row - data, err := iterRow.c.ReadPacket() - if err != nil { - iterRow.hasNext = false - return iterRow.hasNext, err - } - - if isEOFPacket(data) { - iterRow.hasNext = false - // The deprecated EOF packets change means that this is either an - // EOF packet or an OK packet with the EOF type code. - if iterRow.capabilities&mysql.CapabilityClientDeprecateEOF == 0 { - _, _, err = parseEOFPacket(data) - if err != nil { - return iterRow.hasNext, err - } - } else { - _, _, _, _, err = parseOKPacket(data) - if err != nil { - return iterRow.hasNext, err - } - } - return iterRow.hasNext, nil - - } else if isErrorPacket(data) { - // Error packet. - iterRow.hasNext = false - return iterRow.hasNext, ParseErrorPacket(data) - } - - iterRow.Content = data - return iterRow.hasNext, nil -} - -func (iterRow *IterRow) Decode() ([]*proto.Value, error) { - return nil, nil -} - -func (rows *TextIterRow) Decode() ([]*proto.Value, error) { - dest := make([]*proto.Value, len(rows.ResultSet.Columns)) - - // RowSet Packet - var val []byte - var isNull bool - var n int - var err error - pos := 0 - - for i := 0; i < len(rows.ResultSet.Columns); i++ { - field := rows.ResultSet.Columns[i].(*Field) - - // Read bytes and convert to string - val, isNull, n, err = readLengthEncodedString(rows.Content[pos:]) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: val, - } - pos += n - if err == nil { - if !isNull { - switch field.fieldType { - case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, - mysql.FieldTypeDate, mysql.FieldTypeNewDate: - dest[i].Val, err = parseDateTime( - val, - time.Local, - ) - if err == nil { - continue - } - default: - continue - } - } else { - dest[i].Val = nil - continue - } - } - return nil, err // err != nil - } - - return dest, nil -} - -func (rows *BinaryIterRow) Decode() ([]*proto.Value, error) { - dest := make([]*proto.Value, len(rows.ResultSet.Columns)) - - if rows.Content[0] != mysql.OKPacket { - return nil, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", rows) - } - - // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] - pos := 1 + (len(dest)+7+2)>>3 - nullMask := rows.Content[1:pos] - - for i := 0; i < len(rows.ResultSet.Columns); i++ { - // Field is NULL - // (byte >> bit-pos) % 2 == 1 - if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { - dest[i] = nil - continue - } - - field := rows.ResultSet.Columns[i].(*Field) - // Convert to byte-coded string - // TODO Optimize storage space based on the length of data types - mysqlType, _ := mysql.TypeToMySQL(field.fieldType) - switch mysql.FieldType(mysqlType) { - case mysql.FieldTypeNULL: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: nil, - Raw: nil, - } - continue - - // Numeric Types - case mysql.FieldTypeTiny: - if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: int64(rows.Content[pos]), - Raw: rows.Content[pos : pos+1], - } - } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: int64(int8(rows.Content[pos])), - Raw: rows.Content[pos : pos+1], - } - } - pos++ - continue - - case mysql.FieldTypeShort, mysql.FieldTypeYear: - if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 2, - Val: int64(binary.LittleEndian.Uint16(rows.Content[pos : pos+2])), - Raw: rows.Content[pos : pos+1], - } - } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 2, - Val: int64(int16(binary.LittleEndian.Uint16(rows.Content[pos : pos+2]))), - Raw: rows.Content[pos : pos+1], - } - } - pos += 2 - continue - - case mysql.FieldTypeInt24, mysql.FieldTypeLong: - if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: int64(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])), - Raw: rows.Content[pos : pos+4], - } - } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: int64(int32(binary.LittleEndian.Uint32(rows.Content[pos : pos+4]))), - Raw: rows.Content[pos : pos+4], - } - } - pos += 4 - continue - - case mysql.FieldTypeLongLong: - if field.flags&mysql.UnsignedFlag != 0 { - val := binary.LittleEndian.Uint64(rows.Content[pos : pos+8]) - if val > math.MaxInt64 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: uint64ToString(val), - Raw: rows.Content[pos : pos+8], - } - } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: int64(val), - Raw: rows.Content[pos : pos+8], - } - } - } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: int64(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])), - Raw: rows.Content[pos : pos+8], - } - } - pos += 8 - continue - - case mysql.FieldTypeFloat: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: math.Float32frombits(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])), - Raw: rows.Content[pos : pos+4], - } - pos += 4 - continue - - case mysql.FieldTypeDouble: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: math.Float64frombits(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])), - Raw: rows.Content[pos : pos+8], - } - pos += 8 - continue - - // Length coded Binary Strings - case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar, - mysql.FieldTypeBit, mysql.FieldTypeEnum, mysql.FieldTypeSet, mysql.FieldTypeTinyBLOB, - mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB, - mysql.FieldTypeVarString, mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON: - var val interface{} - var isNull bool - var n int - var err error - val, isNull, n, err = readLengthEncodedString(rows.Content[pos:]) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], - } - pos += n - if err == nil { - if !isNull { - continue - } else { - dest[i].Val = nil - continue - } - } - return nil, err - - case - mysql.FieldTypeDate, mysql.FieldTypeNewDate, // Date YYYY-MM-DD - mysql.FieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] - mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] - - num, isNull, n := readLengthEncodedInteger(rows.Content[pos:]) - pos += n - - var val interface{} - var err error - switch { - case isNull: - dest[i] = nil - continue - case field.fieldType == mysql.FieldTypeTime: - // database/sql does not support an equivalent to TIME, return a string - var dstlen uint8 - switch decimals := field.decimals; decimals { - case 0x00, 0x1f: - dstlen = 8 - case 1, 2, 3, 4, 5, 6: - dstlen = 8 + 1 + decimals - default: - return nil, fmt.Errorf( - "protocol error, illegal decimals architecture.Value %d", - field.decimals, - ) - } - val, err = formatBinaryTime(rows.Content[pos:pos+int(num)], dstlen) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], - } - default: - val, err = parseBinaryDateTime(num, rows.Content[pos:], time.Local) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: int(num), - Val: val, - Raw: rows.Content[pos : pos+int(num)], - } - if err == nil { - break - } - - var dstlen uint8 - if field.fieldType == mysql.FieldTypeDate { - dstlen = 10 - } else { - switch decimals := field.decimals; decimals { - case 0x00, 0x1f: - dstlen = 19 - case 1, 2, 3, 4, 5, 6: - dstlen = 19 + 1 + decimals - default: - return nil, fmt.Errorf( - "protocol error, illegal decimals architecture.Value %d", - field.decimals, - ) - } - } - val, err = formatBinaryDateTime(rows.Content[pos:pos+int(num)], dstlen) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: int(num), - Val: val, - Raw: rows.Content[pos : pos+n], - } - } - - if err == nil { - pos += int(num) - continue - } else { - return nil, err - } - - // Please report if this happens! - default: - return nil, fmt.Errorf("unknown field type %d", field.fieldType) - } - } - - return dest, nil -} diff --git a/pkg/mysql/result.go b/pkg/mysql/result.go index 99d86a9c9..359e96dbb 100644 --- a/pkg/mysql/result.go +++ b/pkg/mysql/result.go @@ -18,33 +18,324 @@ package mysql import ( + "io" + "sync" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/constants/mysql" + merrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/log" +) + +const ( + _flagTextMode uint8 = 1 << iota + _flagWantFields ) -type Result struct { - Fields []proto.Field // Columns information - AffectedRows uint64 - InsertId uint64 - Rows []proto.Row - DataChan chan proto.Row +var ( + _ proto.Result = (*RawResult)(nil) + _ proto.Dataset = (*RawResult)(nil) +) + +type RawResult struct { + flag uint8 + + c *BackendConnection + + lastInsertID, affectedRows uint64 + colNumber int + more bool + warnings uint16 + fields []proto.Field + + preflightOnce sync.Once + preflightFailure error + flightOnce sync.Once + flightFailure error + postFlightOnce sync.Once + postFlightFailure error + + closeFunc func() error + closeOnce sync.Once + closeFailure error + + eof bool +} + +func (rr *RawResult) Discard() (err error) { + defer func() { + _ = rr.Close() + }() + + if err = rr.postFlight(); err != nil { + return errors.Wrapf(err, "failed to discard mysql result") + } + + if rr.colNumber < 1 { + return + } + + for { + _, err = rr.nextRowData() + if err != nil { + return + } + } +} + +func (rr *RawResult) Warn() (uint16, error) { + if err := rr.preflight(); err != nil { + return 0, err + } + return rr.warnings, nil +} + +func (rr *RawResult) Close() error { + rr.closeOnce.Do(func() { + if rr.closeFunc == nil { + return + } + rr.closeFailure = rr.closeFunc() + if rr.closeFailure != nil { + log.Errorf("failed to close mysql result: %v", rr.closeFailure) + } + }) + return rr.closeFailure +} + +func (rr *RawResult) Fields() ([]proto.Field, error) { + if err := rr.flight(); err != nil { + return nil, err + } + return rr.fields, nil +} + +func (rr *RawResult) Next() (row proto.Row, err error) { + defer func() { + if err != nil { + _ = rr.Close() + } + }() + + var data []byte + if data, err = rr.nextRowData(); err != nil { + return + } + + if rr.flag&_flagTextMode != 0 { + row = TextRow{ + fields: rr.fields, + raw: data, + } + } else { + row = BinaryRow{ + fields: rr.fields, + raw: data, + } + } + return +} + +func (rr *RawResult) Dataset() (proto.Dataset, error) { + if err := rr.postFlight(); err != nil { + return nil, err + } + + if len(rr.fields) < 1 { + return nil, nil + } + + return rr, nil +} + +func (rr *RawResult) preflight() (err error) { + defer func() { + if err != nil || rr.colNumber < 1 { + _ = rr.Close() + } + }() + + rr.preflightOnce.Do(func() { + rr.affectedRows, rr.lastInsertID, rr.colNumber, rr.more, rr.warnings, rr.preflightFailure = rr.c.readResultSetHeaderPacket() + }) + err = rr.preflightFailure + return +} + +func (rr *RawResult) flight() (err error) { + if err = rr.preflight(); err != nil { + return err + } + + defer func() { + if err != nil { + _ = rr.Close() + } + }() + + rr.flightOnce.Do(func() { + if rr.colNumber < 1 { + return + } + columns := make([]proto.Field, rr.colNumber) + for i := 0; i < rr.colNumber; i++ { + var field Field + + if rr.flag&_flagWantFields != 0 { + rr.flightFailure = rr.c.ReadColumnDefinition(&field, i) + } else { + rr.flightFailure = rr.c.ReadColumnDefinitionType(&field, i) + } + + if rr.flightFailure != nil { + return + } + columns[i] = &field + } + + rr.fields = columns + }) + + err = rr.flightFailure + + return +} + +func (rr *RawResult) postFlight() (err error) { + if err = rr.flight(); err != nil { + return err + } + + defer func() { + if err != nil { + _ = rr.Close() + } + }() + + rr.postFlightOnce.Do(func() { + if rr.colNumber < 1 { + return + } + + if rr.c.capabilities&mysql.CapabilityClientDeprecateEOF != 0 { + return + } + + var data []byte + + // EOF is only present here if it's not deprecated. + if data, rr.postFlightFailure = rr.c.c.readEphemeralPacket(); rr.postFlightFailure != nil { + rr.postFlightFailure = merrors.NewSQLError(mysql.CRServerLost, mysql.SSUnknownSQLState, "%v", rr.postFlightFailure) + return + } + + if isEOFPacket(data) { + // This is what we expect. + // Warnings and status flags are ignored. + rr.c.c.recycleReadPacket() + // goto: read row loop + return + } + + defer rr.c.c.recycleReadPacket() + + if isErrorPacket(data) { + rr.postFlightFailure = ParseErrorPacket(data) + } else { + rr.postFlightFailure = errors.Errorf("unexpected packet after fields: %v", data) + } + }) + + err = rr.postFlightFailure + + return +} + +func (rr *RawResult) LastInsertId() (uint64, error) { + if err := rr.preflight(); err != nil { + return 0, err + } + return rr.lastInsertID, nil +} + +func (rr *RawResult) RowsAffected() (uint64, error) { + if err := rr.preflight(); err != nil { + return 0, err + } + return rr.affectedRows, nil +} + +func (rr *RawResult) nextRowData() (data []byte, err error) { + if rr.eof { + err = io.EOF + return + } + + if data, err = rr.c.c.readPacket(); err != nil { + return + } + + switch { + case isEOFPacket(data): + rr.eof = true + // The deprecated EOF packets change means that this is either an + // EOF packet or an OK packet with the EOF type code. + if rr.c.capabilities&mysql.CapabilityClientDeprecateEOF == 0 { + if _, rr.more, err = parseEOFPacket(data); err != nil { + return + } + } else { + var statusFlags uint16 + if _, _, statusFlags, _, err = parseOKPacket(data); err != nil { + return + } + rr.more = statusFlags&mysql.ServerMoreResultsExists != 0 + } + data = nil + err = io.EOF + case isErrorPacket(data): + rr.eof = true + data = nil + err = ParseErrorPacket(data) + } + + // TODO: Check we're not over the limit before we add more. + //if len(result.Rows) == maxrows { + // if err := conn.DrainResults(); err != nil { + // return nil, false, 0, err + // } + // return nil, false, 0, err2.NewSQLError(mysql.ERVitessMaxRowsExceeded, mysql.SSUnknownSQLState, "Row count exceeded %d") + //} + + return } -func (res *Result) GetFields() []proto.Field { - return res.Fields +func (rr *RawResult) setTextProtocol() { + rr.flag |= _flagTextMode } -func (res *Result) GetRows() []proto.Row { - return res.Rows +func (rr *RawResult) setBinaryProtocol() { + rr.flag &= ^_flagTextMode } -func (res *Result) LastInsertId() (uint64, error) { - return res.InsertId, nil +func (rr *RawResult) setWantFields(b bool) { + if b { + rr.flag |= _flagWantFields + } else { + rr.flag &= ^_flagWantFields + } } -func (res *Result) RowsAffected() (uint64, error) { - return res.AffectedRows, nil +func (rr *RawResult) SetCloser(closer func() error) { + rr.closeFunc = closer } -func (res *Result) GetDataChan() chan proto.Row { - return res.DataChan +func newResult(c *BackendConnection) *RawResult { + return &RawResult{c: c} } diff --git a/pkg/mysql/rows.go b/pkg/mysql/rows.go index ecf3b3875..edd02cc5c 100644 --- a/pkg/mysql/rows.go +++ b/pkg/mysql/rows.go @@ -18,243 +18,92 @@ package mysql import ( - "bytes" "encoding/binary" - "fmt" + "io" "math" + "strconv" "time" ) +import ( + "github.com/pkg/errors" +) + import ( "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/mysql/errors" + mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/util/log" ) -type ResultSet struct { - Columns []proto.Field - ColumnNames []string -} - -type Row struct { - Content []byte - ResultSet *ResultSet -} +var ( + _ proto.KeyedRow = (*BinaryRow)(nil) + _ proto.KeyedRow = (*TextRow)(nil) +) type BinaryRow struct { - Row + fields []proto.Field + raw []byte } -type TextRow struct { - Row -} - -func (row *Row) Columns() []string { - if row.ResultSet.ColumnNames != nil { - return row.ResultSet.ColumnNames - } - - columns := make([]string, len(row.ResultSet.Columns)) - if row.Content != nil { - for i := range columns { - field := row.ResultSet.Columns[i].(*Field) - if tableName := field.table; len(tableName) > 0 { - columns[i] = tableName + "." + field.name - } else { - columns[i] = field.name - } - } - } else { - for i := range columns { - field := row.ResultSet.Columns[i].(*Field) - columns[i] = field.name +func (bi BinaryRow) Get(name string) (proto.Value, error) { + idx := -1 + for i, it := range bi.fields { + if it.Name() == name { + idx = i + break } } - - row.ResultSet.ColumnNames = columns - return columns -} - -func (row *Row) Fields() []proto.Field { - return row.ResultSet.Columns -} - -func (row *Row) Data() []byte { - return row.Content -} - -func (row *Row) Encode(values []*proto.Value, columns []proto.Field, columnNames []string) proto.Row { - var bf bytes.Buffer - row.ResultSet = &ResultSet{ - Columns: columns, - ColumnNames: columnNames, + if idx == -1 { + return nil, errors.Errorf("no such field '%s' found", name) } - for _, val := range values { - bf.Write(val.Raw) + dest := make([]proto.Value, len(bi.fields)) + if err := bi.Scan(dest); err != nil { + return nil, errors.WithStack(err) } - row.Content = bf.Bytes() - return row -} -func (row *Row) Decode() ([]*proto.Value, error) { - return nil, nil + return dest[idx], nil } -func (row *Row) GetColumnValue(column string) (interface{}, error) { - values, err := row.Decode() - if err != nil { - return nil, err - } - for _, value := range values { - if string(value.Raw) == column { - return value.Val, nil - } - } - return nil, nil +func (bi BinaryRow) Fields() []proto.Field { + return bi.fields } -func (rows *TextRow) Encode(row []*proto.Value, fields []proto.Field, columnNames []string) proto.Row { - var val []byte - - for i := 0; i < len(fields); i++ { - field := fields[i].(*Field) - switch field.fieldType { - case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, - mysql.FieldTypeDate, mysql.FieldTypeNewDate: - data, err := appendDateTime(row[i].Raw, row[i].Val.(time.Time)) - if err != nil { - log.Errorf("appendDateTime fail, val=%+v, err=%v", &row[i], err) - return nil - } - val = append(val, data...) - default: - val = PutLengthEncodedString(row[i].Raw) - } - } - - rows.ResultSet = &ResultSet{ - Columns: fields, - ColumnNames: columnNames, +func NewBinaryRow(fields []proto.Field, raw []byte) BinaryRow { + return BinaryRow{ + fields: fields, + raw: raw, } - - rows.Content = val - return rows } -func (rows *TextRow) Decode() ([]*proto.Value, error) { - dest := make([]*proto.Value, len(rows.ResultSet.Columns)) - - // RowSet Packet - var val []byte - var isNull bool - var n int - var err error - pos := 0 - - for i := 0; i < len(rows.ResultSet.Columns); i++ { - field := rows.ResultSet.Columns[i].(*Field) - - // Read bytes and convert to string - val, isNull, n, err = readLengthEncodedString(rows.Content[pos:]) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: val, - } - pos += n - if err == nil { - if !isNull { - switch field.fieldType { - case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, - mysql.FieldTypeDate, mysql.FieldTypeNewDate: - dest[i].Val, err = parseDateTime( - val, - time.Local, - ) - if err == nil { - continue - } - default: - continue - } - } else { - dest[i].Val = nil - continue - } - } - return nil, err // err != nil - } - - return dest, nil +func (bi BinaryRow) IsBinary() bool { + return true } -func (rows *BinaryRow) Encode(row []*proto.Value, fields []proto.Field, columnNames []string) proto.Row { - length := 0 - nullBitMapLen := (len(fields) + 7 + 2) / 8 - for _, val := range row { - if val != nil && val.Val != nil { - l, err := val2MySQLLen(val) - if err != nil { - return nil - } - length += l - } - } - - length += nullBitMapLen + 1 - - Data := *bufPool.Get(length) - pos := 0 - - pos = writeByte(Data, pos, 0x00) - - for i := 0; i < nullBitMapLen; i++ { - pos = writeByte(Data, pos, 0x00) - } - - for i, val := range row { - if val == nil || val.Val == nil { - bytePos := (i+2)/8 + 1 - bitPos := (i + 2) % 8 - Data[bytePos] |= 1 << uint(bitPos) - } else { - v, err := val2MySQL(val) - if err != nil { - return nil - } - pos += copy(Data[pos:], v) - } - } - - if pos != length { - return nil - } +func (bi BinaryRow) Length() int { + return len(bi.raw) +} - rows.ResultSet = &ResultSet{ - Columns: fields, - ColumnNames: columnNames, +func (bi BinaryRow) WriteTo(w io.Writer) (n int64, err error) { + var wrote int + wrote, err = w.Write(bi.raw) + if err != nil { + return } - - rows.Content = Data - return rows + n += int64(wrote) + return } -func (rows *BinaryRow) Decode() ([]*proto.Value, error) { - dest := make([]*proto.Value, len(rows.ResultSet.Columns)) - - if rows.Content[0] != mysql.OKPacket { - return nil, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", rows) +func (bi BinaryRow) Scan(dest []proto.Value) error { + if bi.raw[0] != mysql.OKPacket { + return mysqlErrors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "read binary rows (%v) failed", bi) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] pos := 1 + (len(dest)+7+2)>>3 - nullMask := rows.Content[1:pos] + nullMask := bi.raw[1:pos] - for i := 0; i < len(rows.ResultSet.Columns); i++ { + for i := 0; i < len(bi.fields); i++ { // Field is NULL // (byte >> bit-pos) % 2 == 1 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { @@ -262,134 +111,64 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) { continue } - field := rows.ResultSet.Columns[i].(*Field) + field := bi.fields[i].(*Field) // Convert to byte-coded string - switch field.fieldType { + // TODO Optimize storage space based on the length of data types + mysqlType, _ := mysql.TypeToMySQL(field.fieldType) + switch mysql.FieldType(mysqlType) { case mysql.FieldTypeNULL: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: nil, - Raw: nil, - } + dest[i] = nil continue // Numeric Types case mysql.FieldTypeTiny: if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: int64(rows.Content[pos]), - Raw: rows.Content[pos : pos+1], - } + dest[i] = int64(bi.raw[pos]) } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 1, - Val: int64(int8(rows.Content[pos])), - Raw: rows.Content[pos : pos+1], - } + dest[i] = int64(int8(bi.raw[pos])) } pos++ continue case mysql.FieldTypeShort, mysql.FieldTypeYear: if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 2, - Val: int64(binary.LittleEndian.Uint16(rows.Content[pos : pos+2])), - Raw: rows.Content[pos : pos+1], - } + dest[i] = int64(binary.LittleEndian.Uint16(bi.raw[pos : pos+2])) } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 2, - Val: int64(int16(binary.LittleEndian.Uint16(rows.Content[pos : pos+2]))), - Raw: rows.Content[pos : pos+1], - } + dest[i] = int64(int16(binary.LittleEndian.Uint16(bi.raw[pos : pos+2]))) } pos += 2 continue case mysql.FieldTypeInt24, mysql.FieldTypeLong: if field.flags&mysql.UnsignedFlag != 0 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: int64(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])), - Raw: rows.Content[pos : pos+4], - } + dest[i] = int64(binary.LittleEndian.Uint32(bi.raw[pos : pos+4])) } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: int64(int32(binary.LittleEndian.Uint32(rows.Content[pos : pos+4]))), - Raw: rows.Content[pos : pos+4], - } + dest[i] = int64(int32(binary.LittleEndian.Uint32(bi.raw[pos : pos+4]))) } pos += 4 continue case mysql.FieldTypeLongLong: if field.flags&mysql.UnsignedFlag != 0 { - val := binary.LittleEndian.Uint64(rows.Content[pos : pos+8]) + val := binary.LittleEndian.Uint64(bi.raw[pos : pos+8]) if val > math.MaxInt64 { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: uint64ToString(val), - Raw: rows.Content[pos : pos+8], - } + dest[i] = uint64ToString(val) } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: int64(val), - Raw: rows.Content[pos : pos+8], - } + dest[i] = int64(val) } } else { - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: int64(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])), - Raw: rows.Content[pos : pos+8], - } + dest[i] = int64(binary.LittleEndian.Uint64(bi.raw[pos : pos+8])) } pos += 8 continue case mysql.FieldTypeFloat: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 4, - Val: math.Float32frombits(binary.LittleEndian.Uint32(rows.Content[pos : pos+4])), - Raw: rows.Content[pos : pos+4], - } + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(bi.raw[pos : pos+4])) pos += 4 continue case mysql.FieldTypeDouble: - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: 8, - Val: math.Float64frombits(binary.LittleEndian.Uint64(rows.Content[pos : pos+8])), - Raw: rows.Content[pos : pos+8], - } + dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(bi.raw[pos : pos+8])) pos += 8 continue @@ -402,31 +181,25 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) { var isNull bool var n int var err error - val, isNull, n, err = readLengthEncodedString(rows.Content[pos:]) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], - } + val, isNull, n, err = readLengthEncodedString(bi.raw[pos:]) + dest[i] = val pos += n if err == nil { if !isNull { continue } else { - dest[i].Val = nil + dest[i] = nil continue } } - return nil, err + return err case mysql.FieldTypeDate, mysql.FieldTypeNewDate, // Date YYYY-MM-DD mysql.FieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] - num, isNull, n := readLengthEncodedInteger(rows.Content[pos:]) + num, isNull, n := readLengthEncodedInteger(bi.raw[pos:]) pos += n var val interface{} @@ -444,29 +217,14 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) { case 1, 2, 3, 4, 5, 6: dstlen = 8 + 1 + decimals default: - return nil, fmt.Errorf( - "protocol error, illegal decimals architecture.Value %d", - field.decimals, - ) - } - val, err = formatBinaryTime(rows.Content[pos:pos+int(num)], dstlen) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], + return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals) } + val, err = formatBinaryTime(bi.raw[pos:pos+int(num)], dstlen) + dest[i] = val default: - val, err = parseBinaryDateTime(num, rows.Content[pos:], time.Local) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], - } + val, err = parseBinaryDateTime(num, bi.raw[pos:], time.Local) if err == nil { + dest[i] = val break } @@ -480,34 +238,141 @@ func (rows *BinaryRow) Decode() ([]*proto.Value, error) { case 1, 2, 3, 4, 5, 6: dstlen = 19 + 1 + decimals default: - return nil, fmt.Errorf( - "protocol error, illegal decimals architecture.Value %d", - field.decimals, - ) + return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals) } } - val, err = formatBinaryDateTime(rows.Content[pos:pos+int(num)], dstlen) - dest[i] = &proto.Value{ - Typ: field.fieldType, - Flags: field.flags, - Len: n, - Val: val, - Raw: rows.Content[pos : pos+n], + val, err = formatBinaryDateTime(bi.raw[pos:pos+int(num)], dstlen) + if err != nil { + return errors.WithStack(err) } + dest[i] = val } if err == nil { pos += int(num) continue - } else { - return nil, err } + return err + // Please report if this happens! default: - return nil, fmt.Errorf("unknown field type %d", field.fieldType) + return errors.Errorf("unknown field type %d", field.fieldType) + } + } + + return nil +} + +type TextRow struct { + fields []proto.Field + raw []byte +} + +func (te TextRow) Get(name string) (proto.Value, error) { + idx := -1 + for i, it := range te.fields { + if it.Name() == name { + idx = i + break + } + } + if idx == -1 { + return nil, errors.Errorf("no such field '%s' found", name) + } + + dest := make([]proto.Value, len(te.fields)) + if err := te.Scan(dest); err != nil { + return nil, errors.WithStack(err) + } + + return dest[idx], nil +} + +func (te TextRow) Fields() []proto.Field { + return te.fields +} + +func NewTextRow(fields []proto.Field, raw []byte) TextRow { + return TextRow{ + fields: fields, + raw: raw, + } +} + +func (te TextRow) IsBinary() bool { + return false +} + +func (te TextRow) Length() int { + return len(te.raw) +} + +func (te TextRow) WriteTo(w io.Writer) (n int64, err error) { + var wrote int + wrote, err = w.Write(te.raw) + if err != nil { + return + } + n += int64(wrote) + return +} + +func (te TextRow) Scan(dest []proto.Value) error { + // RowSet Packet + // var val []byte + // var isNull bool + var ( + n int + err error + pos int + isNull bool + ) + + // TODO: support parseTime + + var ( + b []byte + loc = time.Local + ) + + for i := 0; i < len(te.fields); i++ { + b, isNull, n, err = readLengthEncodedString(te.raw[pos:]) + if err != nil { + return errors.WithStack(err) + } + pos += n + + if isNull { + dest[i] = nil + continue + } + + switch te.fields[i].(*Field).fieldType { + case mysql.FieldTypeString, mysql.FieldTypeVarString, mysql.FieldTypeVarChar: + dest[i] = string(b) + case mysql.FieldTypeTiny, mysql.FieldTypeShort, mysql.FieldTypeLong, + mysql.FieldTypeInt24, mysql.FieldTypeLongLong, mysql.FieldTypeYear: + if te.fields[i].(*Field).flags&mysql.UnsignedFlag > 0 { + dest[i], err = strconv.ParseUint(string(b), 10, 64) + } else { + dest[i], err = strconv.ParseInt(string(b), 10, 64) + } + if err != nil { + return errors.WithStack(err) + } + case mysql.FieldTypeFloat, mysql.FieldTypeDouble, mysql.FieldTypeNewDecimal, mysql.FieldTypeDecimal: + if dest[i], err = strconv.ParseFloat(string(b), 64); err != nil { + return errors.WithStack(err) + } + case mysql.FieldTypeTimestamp, mysql.FieldTypeDateTime, mysql.FieldTypeDate, mysql.FieldTypeNewDate: + if dest[i], err = parseDateTime(b, loc); err != nil { + return errors.WithStack(err) + } + default: + dest[i] = b } } - return dest, nil + return nil } diff --git a/pkg/mysql/rows/codec.go b/pkg/mysql/rows/codec.go new file mode 100644 index 000000000..2dfdb8456 --- /dev/null +++ b/pkg/mysql/rows/codec.go @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rows + +import ( + "bytes" + "math" + "time" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/util/bytesconv" +) + +const _day = time.Hour * 24 + +type ValueWriter interface { + WriteString(s string) (int64, error) + WriteUint64(v uint64) (int64, error) + WriteUint32(v uint32) (int64, error) + WriteUint16(v uint16) (int64, error) + WriteUint8(v uint8) (int64, error) + WriteFloat64(f float64) (int64, error) + WriteFloat32(f float32) (int64, error) + WriteDate(t time.Time) (int64, error) + WriteDateTime(t time.Time) (int64, error) + WriteDuration(d time.Duration) (int64, error) +} + +var _ ValueWriter = (*BinaryValueWriter)(nil) + +type BinaryValueWriter bytes.Buffer + +func (bw *BinaryValueWriter) Bytes() []byte { + return (*bytes.Buffer)(bw).Bytes() +} + +func (bw *BinaryValueWriter) Reset() { + (*bytes.Buffer)(bw).Reset() +} + +func (bw *BinaryValueWriter) WriteDuration(d time.Duration) (n int64, err error) { + b := bw.buffer() + if d == 0 { + if err = b.WriteByte(0); err != nil { + return + } + n++ + return + } + + var ( + length uint8 + neg byte + ) + + if d < 0 { + neg = 1 + d *= -1 + } + + var ( + days = uint32(d / _day) + hours = uint8(d % _day / time.Hour) + minutes = uint8(d % time.Hour / time.Minute) + secs = uint8(d % time.Minute / time.Second) + msec = uint32(d % time.Millisecond / time.Microsecond) + ) + + if msec == 0 { + length = 8 + } else { + length = 12 + } + + if err = b.WriteByte(length); err != nil { + return + } + if err = b.WriteByte(neg); err != nil { + return + } + if _, err = bw.WriteUint32(days); err != nil { + return + } + if err = b.WriteByte(hours); err != nil { + return + } + if err = b.WriteByte(minutes); err != nil { + return + } + if err = b.WriteByte(secs); err != nil { + return + } + + if length == 12 { + if _, err = bw.WriteUint32(msec); err != nil { + return + } + } + + n = int64(length) + 1 + + return +} + +func (bw *BinaryValueWriter) WriteDate(t time.Time) (int64, error) { + if t.IsZero() { + return bw.writeTimeN(t, 0) + } + return bw.writeTimeN(t, 4) +} + +func (bw *BinaryValueWriter) WriteDateTime(t time.Time) (int64, error) { + if t.IsZero() { + return bw.writeTimeN(t, 0) + } + + if t.Nanosecond()/int(time.Microsecond) == 0 { + return bw.writeTimeN(t, 7) + } + + return bw.writeTimeN(t, 11) +} + +func (bw *BinaryValueWriter) writeTimeN(t time.Time, l int) (n int64, err error) { + var ( + year = uint16(t.Year()) + month = uint8(t.Month()) + day = uint8(t.Day()) + hour = uint8(t.Hour()) + minute = uint8(t.Minute()) + sec = uint8(t.Second()) + ) + + if _, err = bw.WriteUint8(byte(l)); err != nil { + return + } + + switch l { + case 0: + case 4: + if _, err = bw.WriteUint16(year); err != nil { + return + } + if _, err = bw.WriteUint8(month); err != nil { + return + } + if _, err = bw.WriteUint8(day); err != nil { + return + } + case 7: + if _, err = bw.WriteUint16(year); err != nil { + return + } + if _, err = bw.WriteUint8(month); err != nil { + return + } + if _, err = bw.WriteUint8(day); err != nil { + return + } + if _, err = bw.WriteUint8(hour); err != nil { + return + } + if _, err = bw.WriteUint8(minute); err != nil { + return + } + if _, err = bw.WriteUint8(sec); err != nil { + return + } + case 11: + if _, err = bw.WriteUint16(year); err != nil { + return + } + if _, err = bw.WriteUint8(month); err != nil { + return + } + if _, err = bw.WriteUint8(day); err != nil { + return + } + if _, err = bw.WriteUint8(hour); err != nil { + return + } + if _, err = bw.WriteUint8(minute); err != nil { + return + } + if _, err = bw.WriteUint8(sec); err != nil { + return + } + msec := uint32(int64(t.Nanosecond()) / int64(time.Microsecond)) + if _, err = bw.WriteUint32(msec); err != nil { + return + } + default: + err = errors.Errorf("illegal time length %d", l) + return + } + + n = int64(l) + 1 + return +} + +func (bw *BinaryValueWriter) WriteFloat32(f float32) (int64, error) { + n := math.Float32bits(f) + return bw.WriteUint32(n) +} + +func (bw *BinaryValueWriter) WriteFloat64(f float64) (int64, error) { + n := math.Float64bits(f) + return bw.WriteUint64(n) +} + +func (bw *BinaryValueWriter) WriteUint8(v uint8) (int64, error) { + if err := bw.buffer().WriteByte(v); err != nil { + return 0, err + } + return 1, nil +} + +func (bw *BinaryValueWriter) WriteString(s string) (int64, error) { + return bw.writeLenEncString(s) +} + +func (bw *BinaryValueWriter) WriteUint16(v uint16) (n int64, err error) { + if err = bw.buffer().WriteByte(byte(v)); err != nil { + return + } + n++ + if err = bw.buffer().WriteByte(byte(v >> 8)); err != nil { + return + } + n++ + return +} + +func (bw *BinaryValueWriter) WriteUint32(v uint32) (n int64, err error) { + b := bw.buffer() + if err = b.WriteByte(byte(v)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 8)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 16)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 24)); err != nil { + return + } + n += 4 + return +} + +func (bw *BinaryValueWriter) WriteUint64(v uint64) (n int64, err error) { + b := bw.buffer() + if err = b.WriteByte(byte(v)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 8)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 16)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 24)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 32)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 40)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 48)); err != nil { + return + } + if err = b.WriteByte(byte(v >> 56)); err != nil { + return + } + n += 8 + return +} + +func (bw *BinaryValueWriter) writeLenEncString(s string) (n int64, err error) { + var wrote int64 + + if wrote, err = bw.writeLenEncInt(uint64(len(s))); err != nil { + return + } + n += wrote + + if wrote, err = bw.writeEOFString(s); err != nil { + return + } + n += wrote + + return +} + +func (bw *BinaryValueWriter) writeEOFString(s string) (n int64, err error) { + var wrote int + if wrote, err = bw.buffer().Write(bytesconv.StringToBytes(s)); err != nil { + return + } + n += int64(wrote) + return +} + +func (bw *BinaryValueWriter) writeLenEncInt(i uint64) (n int64, err error) { + b := bw.buffer() + switch { + case i < 251: + if err = b.WriteByte(byte(i)); err == nil { + n++ + } + case i < 1<<16: + if err = b.WriteByte(0xfc); err != nil { + return + } + if err = b.WriteByte(byte(i)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 8)); err != nil { + return + } + n += 3 + case i < 1<<24: + if err = b.WriteByte(0xfd); err != nil { + return + } + if err = b.WriteByte(byte(i)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 8)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 16)); err != nil { + return + } + n += 4 + default: + if err = b.WriteByte(0xfe); err != nil { + return + } + if err = b.WriteByte(byte(i)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 8)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 16)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 24)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 32)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 40)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 48)); err != nil { + return + } + if err = b.WriteByte(byte(i >> 56)); err != nil { + return + } + n += 9 + } + + return +} + +func (bw *BinaryValueWriter) buffer() *bytes.Buffer { + return (*bytes.Buffer)(bw) +} diff --git a/pkg/mysql/rows/codec_test.go b/pkg/mysql/rows/codec_test.go new file mode 100644 index 000000000..ed3eab371 --- /dev/null +++ b/pkg/mysql/rows/codec_test.go @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rows + +import ( + "testing" + "time" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestBinaryValueWriter(t *testing.T) { + var w BinaryValueWriter + + n, err := w.WriteString("foo") + assert.NoError(t, err) + assert.Equal(t, int64(4), n) + assert.Equal(t, []byte{0x03, 0x66, 0x6f, 0x6f}, w.Bytes()) + + w.Reset() + n, err = w.WriteUint64(1) + assert.NoError(t, err) + assert.Equal(t, int64(8), n) + assert.Equal(t, []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, w.Bytes()) + + w.Reset() + n, err = w.WriteUint32(1) + assert.NoError(t, err) + assert.Equal(t, int64(4), n) + assert.Equal(t, []byte{0x01, 0x00, 0x00, 0x00}, w.Bytes()) + + w.Reset() + n, err = w.WriteUint16(1) + assert.NoError(t, err) + assert.Equal(t, int64(2), n) + assert.Equal(t, []byte{0x01, 0x00}, w.Bytes()) + + w.Reset() + n, err = w.WriteFloat64(10.2) + assert.NoError(t, err) + assert.Equal(t, int64(8), n) + assert.Equal(t, []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x24, 0x40}, w.Bytes()) + + w.Reset() + n, err = w.WriteFloat32(10.2) + assert.NoError(t, err) + assert.Equal(t, int64(4), n) + assert.Equal(t, []byte{0x33, 0x33, 0x23, 0x41}, w.Bytes()) + + // -120d 19:27:30.000 001 + w.Reset() + n, err = w.WriteDuration(-(120*24*time.Hour + 19*time.Hour + 27*time.Minute + 30*time.Second + 1*time.Microsecond)) + assert.NoError(t, err) + assert.Equal(t, int64(13), n) + assert.Equal(t, []byte{0x0c, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, w.Bytes()) + + // -120d 19:27:30 + w.Reset() + n, err = w.WriteDuration(-(120*24*time.Hour + 19*time.Hour + 27*time.Minute + 30*time.Second)) + assert.NoError(t, err) + assert.Equal(t, int64(9), n) + assert.Equal(t, []byte{0x08, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e}, w.Bytes()) + + // 0d 00:00:00 + w.Reset() + n, err = w.WriteDuration(0) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.Equal(t, []byte{0x00}, w.Bytes()) +} diff --git a/pkg/mysql/rows/virtual_row.go b/pkg/mysql/rows/virtual_row.go new file mode 100644 index 000000000..67198fc9c --- /dev/null +++ b/pkg/mysql/rows/virtual_row.go @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rows + +import ( + "errors" + "fmt" + "io" + "sync" + "time" +) + +import ( + gxbig "github.com/dubbogo/gost/math/big" + + perrors "github.com/pkg/errors" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/util/bufferpool" + "github.com/arana-db/arana/pkg/util/bytesconv" +) + +var ( + _ VirtualRow = (*binaryVirtualRow)(nil) + _ VirtualRow = (*textVirtualRow)(nil) +) + +var errScanDiffSize = errors.New("cannot scan to dest with different length") + +// IsScanDiffSizeErr returns true if target error is caused by scanning values with different size. +func IsScanDiffSizeErr(err error) bool { + return perrors.Is(err, errScanDiffSize) +} + +// VirtualRow represents virtual row which is created manually. +type VirtualRow interface { + proto.KeyedRow + // Values returns all values of current row. + Values() []proto.Value +} + +type baseVirtualRow struct { + lengthOnce sync.Once + length int64 + + fields []proto.Field + cells []proto.Value +} + +func (b *baseVirtualRow) Get(name string) (proto.Value, error) { + idx := -1 + for i, it := range b.fields { + if it.Name() == name { + idx = i + break + } + } + + if idx == -1 { + return nil, perrors.Errorf("no such field '%s' found", name) + } + + return b.cells[idx], nil +} + +func (b *baseVirtualRow) Scan(dest []proto.Value) error { + if len(dest) != len(b.cells) { + return perrors.WithStack(errScanDiffSize) + } + copy(dest, b.cells) + return nil +} + +func (b *baseVirtualRow) Values() []proto.Value { + return b.cells +} + +type binaryVirtualRow baseVirtualRow + +func (vi *binaryVirtualRow) Get(name string) (proto.Value, error) { + return (*baseVirtualRow)(vi).Get(name) +} + +func (vi *binaryVirtualRow) Fields() []proto.Field { + return vi.fields +} + +func (vi *binaryVirtualRow) Values() []proto.Value { + return (*baseVirtualRow)(vi).Values() +} + +func (vi *binaryVirtualRow) IsBinary() bool { + return true +} + +func (vi *binaryVirtualRow) WriteTo(w io.Writer) (n int64, err error) { + // https://dev.mysql.com/doc/internals/en/null-bitmap.html + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + const offset = 2 + nullBitmapLen := (len(vi.cells) + 7 + offset) >> 3 + + b := bufferpool.Get() + defer bufferpool.Put(b) + + b.Grow(1 + nullBitmapLen) + + b.WriteByte(0x00) + for i := 0; i < nullBitmapLen; i++ { + b.WriteByte(0x00) + } + + bw := (*BinaryValueWriter)(b) + + var next proto.Value + for i := 0; i < len(vi.cells); i++ { + next = vi.cells[i] + if next == nil { + var ( + bytePos = (i+2)/8 + 1 + bitPos = (i + 2) % 8 + ) + b.Bytes()[bytePos] |= 1 << bitPos + } + + switch val := next.(type) { + case uint64: + _, err = bw.WriteUint64(val) + case int64: + _, err = bw.WriteUint64(uint64(val)) + case uint32: + _, err = bw.WriteUint32(val) + case int32: + _, err = bw.WriteUint32(uint32(val)) + case uint16: + _, err = bw.WriteUint16(val) + case int16: + _, err = bw.WriteUint16(uint16(val)) + case string: + _, err = bw.WriteString(val) + case float32: + _, err = bw.WriteFloat32(val) + case float64: + _, err = bw.WriteFloat64(val) + case time.Time: + _, err = bw.WriteDateTime(val) + case time.Duration: + _, err = bw.WriteDuration(val) + case *gxbig.Decimal: + _, err = bw.WriteString(val.String()) + default: + err = perrors.Errorf("unknown value type %T", val) + } + + if err != nil { + return + } + } + + return b.WriteTo(w) +} + +func (vi *binaryVirtualRow) Length() int { + vi.lengthOnce.Do(func() { + n, _ := vi.WriteTo(io.Discard) + vi.length = n + }) + return int(vi.length) +} + +func (vi *binaryVirtualRow) Scan(dest []proto.Value) error { + return (*baseVirtualRow)(vi).Scan(dest) +} + +type textVirtualRow baseVirtualRow + +func (t *textVirtualRow) Get(name string) (proto.Value, error) { + return (*baseVirtualRow)(t).Get(name) +} + +func (t *textVirtualRow) Fields() []proto.Field { + return t.fields +} + +func (t *textVirtualRow) WriteTo(w io.Writer) (n int64, err error) { + bf := bufferpool.Get() + defer bufferpool.Put(bf) + + for i := 0; i < len(t.fields); i++ { + var ( + field = t.fields[i].(*mysql.Field) + cell = t.cells[i] + ) + + if cell == nil { + if err = bf.WriteByte(0xfb); err != nil { + err = perrors.WithStack(err) + return + } + continue + } + + switch field.FieldType() { + case consts.FieldTypeTimestamp, consts.FieldTypeDateTime, consts.FieldTypeDate, consts.FieldTypeNewDate: + var b []byte + if b, err = mysql.AppendDateTime(b, t.cells[i].(time.Time)); err != nil { + err = perrors.WithStack(err) + return + } + if _, err = (*BinaryValueWriter)(bf).WriteString(bytesconv.BytesToString(b)); err != nil { + err = perrors.WithStack(err) + } + default: + var s string + switch val := cell.(type) { + case fmt.Stringer: + s = val.String() + default: + s = fmt.Sprint(cell) + } + if _, err = (*BinaryValueWriter)(bf).WriteString(s); err != nil { + err = perrors.WithStack(err) + return + } + } + } + + n, err = bf.WriteTo(w) + + return +} + +func (t *textVirtualRow) IsBinary() bool { + return false +} + +func (t *textVirtualRow) Length() int { + t.lengthOnce.Do(func() { + t.length, _ = t.WriteTo(io.Discard) + }) + return int(t.length) +} + +func (t *textVirtualRow) Scan(dest []proto.Value) error { + return (*baseVirtualRow)(t).Scan(dest) +} + +func (t *textVirtualRow) Values() []proto.Value { + return (*baseVirtualRow)(t).Values() +} + +// NewBinaryVirtualRow creates a virtual row with binary-protocol. +func NewBinaryVirtualRow(fields []proto.Field, cells []proto.Value) VirtualRow { + return (*binaryVirtualRow)(newBaseVirtualRow(fields, cells)) +} + +// NewTextVirtualRow creates a virtual row with text-protocol. +func NewTextVirtualRow(fields []proto.Field, cells []proto.Value) VirtualRow { + return (*textVirtualRow)(newBaseVirtualRow(fields, cells)) +} + +func newBaseVirtualRow(fields []proto.Field, cells []proto.Value) *baseVirtualRow { + if len(fields) != len(cells) { + panic(fmt.Sprintf("the lengths of fields and cells are doesn't match!")) + } + return &baseVirtualRow{ + fields: fields, + cells: cells, + } +} diff --git a/pkg/mysql/rows/virtual_row_test.go b/pkg/mysql/rows/virtual_row_test.go new file mode 100644 index 000000000..602250c82 --- /dev/null +++ b/pkg/mysql/rows/virtual_row_test.go @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rows + +import ( + "bytes" + "database/sql" + "testing" + "time" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" +) + +func TestNew(t *testing.T) { + fields := []proto.Field{ + mysql.NewField("name", consts.FieldTypeString), + mysql.NewField("uid", consts.FieldTypeLongLong), + mysql.NewField("created_at", consts.FieldTypeDateTime), + } + + var ( + row VirtualRow + b bytes.Buffer + now = time.Unix(time.Now().Unix(), 0) + values = []proto.Value{"foobar", int64(1), now} + ) + + t.Run("Binary", func(t *testing.T) { + b.Reset() + + row = NewBinaryVirtualRow(fields, values) + _, err := row.WriteTo(&b) + assert.NoError(t, err) + + br := mysql.NewBinaryRow(fields, b.Bytes()) + cells := make([]proto.Value, len(fields)) + err = br.Scan(cells) + assert.NoError(t, err) + + var ( + name sql.NullString + uid sql.NullInt64 + createdAt sql.NullTime + ) + + _ = name.Scan(cells[0]) + _ = uid.Scan(cells[1]) + _ = createdAt.Scan(cells[2]) + + t.Log("name:", name.String) + t.Log("uid:", uid.Int64) + t.Log("created_at:", createdAt.Time) + + assert.Equal(t, "foobar", name.String) + assert.Equal(t, int64(1), uid.Int64) + assert.Equal(t, now, createdAt.Time) + }) + + t.Run("Text", func(t *testing.T) { + b.Reset() + + row = NewTextVirtualRow(fields, values) + _, err := row.WriteTo(&b) + assert.NoError(t, err) + + cells := make([]proto.Value, len(fields)) + + tr := mysql.NewTextRow(fields, b.Bytes()) + err = tr.Scan(cells) + assert.NoError(t, err) + + var ( + name sql.NullString + uid sql.NullInt64 + createdAt sql.NullTime + ) + + _ = name.Scan(cells[0]) + _ = uid.Scan(cells[1]) + _ = createdAt.Scan(cells[2]) + + t.Log("name:", name.String) + t.Log("uid:", uid.Int64) + t.Log("created_at:", createdAt.Time) + + assert.Equal(t, "foobar", name.String) + assert.Equal(t, int64(1), uid.Int64) + assert.Equal(t, now, createdAt.Time) + }) +} diff --git a/pkg/mysql/rows_test.go b/pkg/mysql/rows_test.go deleted file mode 100644 index 5e7f5a57b..000000000 --- a/pkg/mysql/rows_test.go +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mysql - -import ( - "testing" -) - -import ( - "github.com/stretchr/testify/assert" -) - -import ( - "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/proto" -) - -func TestFields(t *testing.T) { - row := &Row{ - Content: createContent(), - ResultSet: createResultSet(), - } - fields := row.Fields() - assert.Equal(t, 3, len(fields)) - assert.Equal(t, "db_arana", fields[0].DataBaseName()) - assert.Equal(t, "t_order", fields[1].TableName()) - assert.Equal(t, "DECIMAL", fields[2].TypeDatabaseName()) -} - -func TestData(t *testing.T) { - row := &Row{ - Content: createContent(), - ResultSet: createResultSet(), - } - content := row.Data() - assert.Equal(t, 3, len(content)) - assert.Equal(t, byte('1'), content[0]) - assert.Equal(t, byte('2'), content[1]) - assert.Equal(t, byte('3'), content[2]) -} - -func TestColumnsForColumnNames(t *testing.T) { - row := &Row{ - Content: createContent(), - ResultSet: createResultSet(), - } - columns := row.Columns() - assert.Equal(t, "t_order.id", columns[0]) - assert.Equal(t, "t_order.order_id", columns[1]) - assert.Equal(t, "t_order.order_amount", columns[2]) -} - -func TestColumnsForColumns(t *testing.T) { - row := &Row{ - Content: createContent(), - ResultSet: &ResultSet{ - Columns: createColumns(), - ColumnNames: nil, - }, - } - columns := row.Columns() - assert.Equal(t, "t_order.id", columns[0]) - assert.Equal(t, "t_order.order_id", columns[1]) - assert.Equal(t, "t_order.order_amount", columns[2]) -} - -func TestDecodeForRow(t *testing.T) { - row := &Row{ - Content: createContent(), - ResultSet: createResultSet(), - } - val, err := row.Decode() - //TODO row.Decode() is empty. - assert.Nil(t, val) - assert.Nil(t, err) -} - -func TestEncodeForRow(t *testing.T) { - values := make([]*proto.Value, 0, 3) - names := []string{ - "id", "order_id", "order_amount", - } - fields := createColumns() - for i := 0; i < 3; i++ { - values = append(values, &proto.Value{Raw: []byte{byte(i)}, Len: 1}) - } - - row := &Row{} - r := row.Encode(values, fields, names) - assert.Equal(t, 3, len(r.Columns())) - assert.Equal(t, 3, len(r.Fields())) - assert.Equal(t, []byte{byte(0), byte(1), byte(2)}, r.Data()) - -} - -func createContent() []byte { - result := []byte{ - '1', '2', '3', - } - return result -} - -func createResultSet() *ResultSet { - - result := &ResultSet{ - Columns: createColumns(), - ColumnNames: createColumnNames(), - } - - return result -} - -func createColumns() []proto.Field { - result := []proto.Field{ - &Field{ - database: "db_arana", - table: "t_order", - name: "id", - fieldType: mysql.FieldTypeLong, - }, &Field{ - database: "db_arana", - table: "t_order", - name: "order_id", - fieldType: mysql.FieldTypeLong, - }, &Field{ - database: "db_arana", - table: "t_order", - name: "order_amount", - fieldType: mysql.FieldTypeDecimal, - }, - } - return result -} - -func createColumnNames() []string { - result := []string{ - "t_order.id", "t_order.order_id", "t_order.order_amount", - } - return result -} diff --git a/pkg/mysql/server.go b/pkg/mysql/server.go index b41af0614..a79c6dedf 100644 --- a/pkg/mysql/server.go +++ b/pkg/mysql/server.go @@ -33,7 +33,7 @@ import ( import ( _ "github.com/arana-db/parser/test_driver" - err2 "github.com/pkg/errors" + perrors "github.com/pkg/errors" "go.uber.org/atomic" ) @@ -157,11 +157,10 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) { if x := recover(); x != nil { log.Errorf("mysql_server caught panic:\n%v", x) } - conn.Close() l.executor.ConnectionClose(&proto.Context{ Context: context.Background(), - ConnectionID: l.connectionID, + ConnectionID: c.ConnectionID, }) }() @@ -182,8 +181,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) { for { c.sequence = 0 - data, err := c.readEphemeralPacket() - if err != nil { + var data []byte + if data, err = c.readEphemeralPacket(); err != nil { // Don't log EOF errors. They cause too much spam. if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { log.Errorf("Error reading packet from %s: %v", c, err) @@ -197,10 +196,16 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) { Context: context.Background(), Schema: c.Schema, Tenant: c.Tenant, - ConnectionID: l.connectionID, + ConnectionID: c.ConnectionID, Data: content, } + if err = l.ExecuteCommand(c, ctx); err != nil { + if err == io.EOF { + log.Debugf("the connection#%d of remote client %s requests quit", c.ConnectionID, c.conn.(*net.TCPConn).RemoteAddr()) + } else { + log.Errorf("failed to execute command: %v", err) + } return } } @@ -273,20 +278,19 @@ func (l *Listener) writeHandshakeV10(c *Conn, enableTLS bool, salt []byte) error capabilities |= mysql.CapabilityClientSSL } - length := - 1 + // protocol version - lenNullString(l.conf.ServerVersion) + - 4 + // connection ID - 8 + // first part of salt Content - 1 + // filler byte - 2 + // capability flags (lower 2 bytes) - 1 + // character set - 2 + // status flag - 2 + // capability flags (upper 2 bytes) - 1 + // length of auth plugin Content - 10 + // reserved (0) - 13 + // auth-plugin-Content - lenNullString(mysql.MysqlNativePassword) // auth-plugin-name + length := 1 + // protocol version + lenNullString(l.conf.ServerVersion) + + 4 + // connection ID + 8 + // first part of salt Content + 1 + // filler byte + 2 + // capability flags (lower 2 bytes) + 1 + // character set + 2 + // status flag + 2 + // capability flags (upper 2 bytes) + 1 + // length of auth plugin Content + 10 + // reserved (0) + 13 + // auth-plugin-Content + lenNullString(mysql.MysqlNativePassword) // auth-plugin-name data := c.startEphemeralPacket(length) pos := 0 @@ -334,7 +338,7 @@ func (l *Listener) writeHandshakeV10(c *Conn, enableTLS bool, salt []byte) error // Sanity check. if pos != len(data) { - return err2.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data)) + return perrors.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data)) } if err := c.writeEphemeralPacket(); err != nil { @@ -359,10 +363,10 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han // Client flags, 4 bytes. clientFlags, pos, ok := readUint32(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read client flags") + return nil, perrors.New("parseClientHandshakePacket: can't read client flags") } if clientFlags&mysql.CapabilityClientProtocol41 == 0 { - return nil, err2.New("parseClientHandshakePacket: only support protocol 4.1") + return nil, perrors.New("parseClientHandshakePacket: only support protocol 4.1") } // Remember a subset of the capabilities, so we can use them @@ -381,13 +385,13 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han // See doc.go for more information. _, pos, ok = readUint32(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read maxPacketSize") + return nil, perrors.New("parseClientHandshakePacket: can't read maxPacketSize") } // Character set. Need to handle it. characterSet, pos, ok := readByte(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read characterSet") + return nil, perrors.New("parseClientHandshakePacket: can't read characterSet") } l.characterSet = characterSet @@ -407,7 +411,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han // username username, pos, ok := readNullString(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read username") + return nil, perrors.New("parseClientHandshakePacket: can't read username") } // auth-response can have three forms. @@ -416,29 +420,29 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han var l uint64 l, pos, ok = readLenEncInt(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read auth-response variable length") + return nil, perrors.New("parseClientHandshakePacket: can't read auth-response variable length") } authResponse, pos, ok = readBytesCopy(data, pos, int(l)) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read auth-response") + return nil, perrors.New("parseClientHandshakePacket: can't read auth-response") } } else if clientFlags&mysql.CapabilityClientSecureConnection != 0 { var l byte l, pos, ok = readByte(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read auth-response length") + return nil, perrors.New("parseClientHandshakePacket: can't read auth-response length") } authResponse, pos, ok = readBytesCopy(data, pos, int(l)) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read auth-response") + return nil, perrors.New("parseClientHandshakePacket: can't read auth-response") } } else { a := "" a, pos, ok = readNullString(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read auth-response") + return nil, perrors.New("parseClientHandshakePacket: can't read auth-response") } authResponse = []byte(a) } @@ -449,7 +453,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han dbname := "" dbname, pos, ok = readNullString(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read dbname") + return nil, perrors.New("parseClientHandshakePacket: can't read dbname") } schemaName = dbname } @@ -459,7 +463,7 @@ func (l *Listener) parseClientHandshakePacket(firstTime bool, data []byte) (*han if clientFlags&mysql.CapabilityClientPluginAuth != 0 { authMethod, pos, ok = readNullString(data, pos) if !ok { - return nil, err2.New("parseClientHandshakePacket: can't read authMethod") + return nil, perrors.New("parseClientHandshakePacket: can't read authMethod") } } @@ -542,7 +546,7 @@ func (l *Listener) ExecuteCommand(c *Conn, ctx *proto.Context) error { case mysql.ComQuit: // https://dev.mysql.com/doc/internals/en/com-quit.html c.recycleReadPacket() - return err2.New("ComQuit") + return io.EOF case mysql.ComInitDB: return l.handleInitDB(c, ctx) case mysql.ComQuery: @@ -582,7 +586,7 @@ func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { attrLen, pos, ok := readLenEncInt(data, pos) if !ok { - return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attributes variable length") + return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attributes variable length") } var attrLenRead uint64 @@ -593,27 +597,27 @@ func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { var keyLen byte keyLen, pos, ok = readByte(data, pos) if !ok { - return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute key length") + return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute key length") } attrLenRead += uint64(keyLen) + 1 var connAttrKey []byte connAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen)) if !ok { - return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute key") + return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute key") } var valLen byte valLen, pos, ok = readByte(data, pos) if !ok { - return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute value length") + return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute value length") } attrLenRead += uint64(valLen) + 1 var connAttrVal []byte connAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen)) if !ok { - return nil, 0, err2.Errorf("parseClientHandshakePacket: can't read connection attribute value") + return nil, 0, perrors.Errorf("parseClientHandshakePacket: can't read connection attribute value") } attrs[string(connAttrKey[:])] = string(connAttrVal[:]) @@ -657,7 +661,7 @@ func (c *Conn) parseComStmtExecute(stmts *sync.Map, data []byte) (uint32, byte, if !ok { return 0, 0, errors.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "reading statement ID failed") } - //prepare, ok := stmts[stmtID] + // prepare, ok := stmts[stmtID] prepare, ok := stmts.Load(stmtID) if !ok { return 0, 0, errors.NewSQLError(mysql.CRCommandsOutOfSync, mysql.SSUnknownSQLState, "statement ID is not found from record") @@ -740,27 +744,15 @@ func (c *Conn) parseStmtArgs(data []byte, typ mysql.FieldType, pos int) (interfa case mysql.FieldTypeTiny: val, pos, ok := readByte(data, pos) return int64(int8(val)), pos, ok - case mysql.FieldTypeUint8: - val, pos, ok := readByte(data, pos) - return int64(int8(val)), pos, ok - case mysql.FieldTypeUint16: - val, pos, ok := readUint16(data, pos) - return int64(int16(val)), pos, ok case mysql.FieldTypeShort, mysql.FieldTypeYear: val, pos, ok := readUint16(data, pos) return int64(int16(val)), pos, ok - case mysql.FieldTypeUint24, mysql.FieldTypeUint32: - val, pos, ok := readUint32(data, pos) - return int64(val), pos, ok case mysql.FieldTypeInt24, mysql.FieldTypeLong: val, pos, ok := readUint32(data, pos) return int64(int32(val)), pos, ok case mysql.FieldTypeFloat: val, pos, ok := readUint32(data, pos) - return math.Float32frombits(uint32(val)), pos, ok - case mysql.FieldTypeUint64: - val, pos, ok := readUint64(data, pos) - return val, pos, ok + return math.Float32frombits(val), pos, ok case mysql.FieldTypeLongLong: val, pos, ok := readUint64(data, pos) return int64(val), pos, ok @@ -962,6 +954,57 @@ func (c *Conn) parseStmtArgs(data []byte, typ mysql.FieldType, pos int) (interfa } } +func (c *Conn) DefColumnDefinition(field *Field) []byte { + length := 4 + + lenEncStringSize("def") + + lenEncStringSize(field.database) + + lenEncStringSize(field.table) + + lenEncStringSize(field.orgTable) + + lenEncStringSize(field.name) + + lenEncStringSize(field.orgName) + + 1 + // length of fixed length fields + 2 + // character set + 4 + // column length + 1 + // type + 2 + // flags + 1 + // decimals + 2 + // filler + lenEncStringSize(string(field.defaultValue)) // default value + + // Get the type and the flags back. If the Field contains + // non-zero flags, we use them. Otherwise, use the flags we + // derive from the type. + typ, flags := mysql.TypeToMySQL(field.fieldType) + if field.flags != 0 { + flags = int64(field.flags) + } + + data := make([]byte, length) + writeLenEncInt(data, 0, uint64(length-4)) + writeLenEncInt(data, 3, uint64(c.sequence)) + c.sequence++ + pos := 4 + + pos = writeLenEncString(data, pos, "def") // Always same. + pos = writeLenEncString(data, pos, field.database) + pos = writeLenEncString(data, pos, field.table) + pos = writeLenEncString(data, pos, field.orgTable) + pos = writeLenEncString(data, pos, field.name) + pos = writeLenEncString(data, pos, field.orgName) + pos = writeByte(data, pos, 0x0c) + pos = writeUint16(data, pos, field.charSet) + pos = writeUint32(data, pos, field.columnLength) + pos = writeByte(data, pos, byte(typ)) + pos = writeUint16(data, pos, uint16(flags)) + pos = writeByte(data, pos, byte(field.decimals)) + pos = writeUint16(data, pos, uint16(0x0000)) + if len(field.defaultValue) > 0 { + writeLenEncString(data, pos, string(field.defaultValue)) + } + + return data +} + func (c *Conn) writeColumnDefinition(field *Field) error { length := 4 + // lenEncStringSize("def") lenEncStringSize(field.database) + @@ -1011,11 +1054,7 @@ func (c *Conn) writeColumnDefinition(field *Field) error { // writeFields writes the fields of a Result. It should be called only // if there are valid Columns in the result. -func (c *Conn) writeFields(capabilities uint32, result proto.Result) error { - var ( - fields = result.GetFields() - ) - +func (c *Conn) writeFields(capabilities uint32, fields []proto.Field) error { // Send the number of fields first. if err := c.sendColumnCount(uint64(len(fields))); err != nil { return err @@ -1039,86 +1078,35 @@ func (c *Conn) writeFields(capabilities uint32, result proto.Result) error { return nil } -func (c *Conn) writeRow(row []*proto.Value) error { - length := 0 - for _, val := range row { - if val == nil || val.Val == nil { - length++ - } else { - l := len(val.Raw) - length += lenEncIntSize(uint64(l)) + l - } - } - - data := c.startEphemeralPacket(length) - pos := 0 - for _, val := range row { - if val == nil || val.Val == nil { - pos = writeByte(data, pos, mysql.NullValue) - } else { - l := len(val.Raw) - pos = writeLenEncInt(data, pos, uint64(l)) - pos += copy(data[pos:], val.Raw) - } - } - - if pos != length { - return err2.Errorf("packet row: got %v bytes but expected %v", pos, length) +func (c *Conn) writeRow(row proto.Row) error { + var bf bytes.Buffer + n, err := row.WriteTo(&bf) + if err != nil { + return perrors.WithStack(err) } + data := c.startEphemeralPacket(int(n)) + copy(data, bf.Bytes()) return c.writeEphemeralPacket() } -// writeRowIter sends the rows of a Result. -func (c *Conn) writeRowIter(result proto.Result) error { - row := result.GetRows()[0] - rowIter := row.(*TextIterRow) +func (c *Conn) writeDataset(ds proto.Dataset) error { var ( - has bool - err error - values []*proto.Value + row proto.Row + err error ) - for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - if values, err = rowIter.Decode(); err != nil { - return err - } - if err = c.writeRow(values); err != nil { - return err - } - } - return err -} - -// writeRowChan sends the rows of a Result. -func (c *Conn) writeRowChan(result proto.Result) error { - res := result.(*Result) - for row := range res.GetDataChan() { - textRow := row.(*TextIterRow) - values, err := textRow.Decode() - if err != nil { - return err - } - if err = c.writeRow(values); err != nil { - return err + for { + row, err = ds.Next() + if perrors.Is(err, io.EOF) { + return nil } - } - return nil -} - -// writeRows sends the rows of a Result. -func (c *Conn) writeRows(result proto.Result) error { - for _, row := range result.GetRows() { - r := row.(*Row) - textRow := TextRow{*r} - values, err := textRow.Decode() if err != nil { return err } - if err := c.writeRow(values); err != nil { + if err = c.writeRow(row); err != nil { return err } } - return nil } // writeEndResult concludes the sending of a Result. @@ -1144,7 +1132,7 @@ func (c *Conn) writeEndResult(capabilities uint32, more bool, affectedRows, last return nil } -// writePrepare writes a prepare query response to the wire. +// writePrepare writes a prepared query response to the wire. func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error { paramsCount := prepare.ParamsCount @@ -1186,473 +1174,22 @@ func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error { return nil } -func (c *Conn) writeBinaryRow(fields []proto.Field, row []*proto.Value) error { - length := 0 - nullBitMapLen := (len(fields) + 7 + 2) / 8 - for _, val := range row { - if val != nil && val.Val != nil { - l, err := val2MySQLLen(val) - if err != nil { - return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err) - } - length += l - } - } - - length += nullBitMapLen + 1 - - Data := c.startEphemeralPacket(length) - pos := 0 - - pos = writeByte(Data, pos, 0x00) - - for i := 0; i < nullBitMapLen; i++ { - pos = writeByte(Data, pos, 0x00) - } - - for i, val := range row { - if val == nil || val.Val == nil { - bytePos := (i+2)/8 + 1 - bitPos := (i + 2) % 8 - Data[bytePos] |= 1 << uint(bitPos) - } else { - v, err := val2MySQL(val) - if err != nil { - c.recycleWritePacket() - return fmt.Errorf("internal value %v to MySQL value error: %v", val, err) - } - pos += copy(Data[pos:], v) - } - } - - if pos != length { - return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length) - } - - return c.writeEphemeralPacket() -} - -func (c *Conn) writeBinaryRowIter(result proto.Result) error { - row := result.GetRows()[0] - rowIter := row.(*BinaryIterRow) +func (c *Conn) writeDatasetBinary(result proto.Dataset) error { var ( - has bool - err error - values []*proto.Value + row proto.Row + err error ) - for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - if values, err = rowIter.Decode(); err != nil { - return err - } - if err = c.writeBinaryRow(rowIter.Fields(), values); err != nil { - return err - } - } - return err -} -// writeRowChan sends the rows of a Result. -func (c *Conn) writeBinaryRowChan(result proto.Result) error { - res := result.(*Result) - for row := range res.GetDataChan() { - r := row.(*BinaryIterRow) - if err := c.writePacket(r.Data()); err != nil { - return err + for { + row, err = result.Next() + if perrors.Is(err, io.EOF) { + return nil } - } - return nil -} - -// writeTextToBinaryRows sends the rows of a Result with binary form. -func (c *Conn) writeTextToBinaryRows(result proto.Result) error { - for _, row := range result.GetRows() { - r := row.(*Row) - textRow := TextRow{*r} - values, err := textRow.Decode() if err != nil { return err } - if err := c.writeBinaryRow(result.GetFields(), values); err != nil { + if err = c.writeRow(row); err != nil { return err } } - return nil -} - -func val2MySQL(v *proto.Value) ([]byte, error) { - var out []byte - pos := 0 - if v == nil { - return out, nil - } - - switch v.Typ { - case mysql.FieldTypeNULL: - // no-op - case mysql.FieldTypeTiny: - val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 8) - if err != nil { - return []byte{}, err - } - out = make([]byte, 1) - writeByte(out, pos, uint8(val)) - case mysql.FieldTypeUint8: - val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 8) - if err != nil { - return []byte{}, err - } - out = make([]byte, 1) - writeByte(out, pos, uint8(val)) - case mysql.FieldTypeUint16: - val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 16) - if err != nil { - return []byte{}, err - } - out = make([]byte, 2) - writeUint16(out, pos, uint16(val)) - case mysql.FieldTypeShort, mysql.FieldTypeYear: - val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 16) - if err != nil { - return []byte{}, err - } - out = make([]byte, 2) - writeUint16(out, pos, uint16(val)) - case mysql.FieldTypeUint24, mysql.FieldTypeUint32: - val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 32) - if err != nil { - return []byte{}, err - } - out = make([]byte, 4) - writeUint32(out, pos, uint32(val)) - case mysql.FieldTypeInt24, mysql.FieldTypeLong: - val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 32) - if err != nil { - return []byte{}, err - } - out = make([]byte, 4) - writeUint32(out, pos, uint32(val)) - case mysql.FieldTypeFloat: - val, err := strconv.ParseFloat(fmt.Sprint(v.Val), 32) - if err != nil { - return []byte{}, err - } - bits := math.Float32bits(float32(val)) - out = make([]byte, 4) - writeUint32(out, pos, bits) - case mysql.FieldTypeUint64: - val, err := strconv.ParseUint(fmt.Sprint(v.Val), 10, 64) - if err != nil { - return []byte{}, err - } - out = make([]byte, 8) - writeUint64(out, pos, uint64(val)) - case mysql.FieldTypeLongLong: - val, err := strconv.ParseInt(fmt.Sprint(v.Val), 10, 64) - if err != nil { - return []byte{}, err - } - out = make([]byte, 8) - writeUint64(out, pos, uint64(val)) - case mysql.FieldTypeDouble: - val, err := strconv.ParseFloat(fmt.Sprint(v.Val), 64) - if err != nil { - return []byte{}, err - } - bits := math.Float64bits(val) - out = make([]byte, 8) - writeUint64(out, pos, bits) - case mysql.FieldTypeTimestamp, mysql.FieldTypeDate, mysql.FieldTypeDateTime: - if len(v.Raw) > 19 { - out = make([]byte, 1+11) - out[pos] = 0x0b - pos++ - year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16) - if err != nil { - return []byte{}, err - } - month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8) - if err != nil { - return []byte{}, err - } - day, err := strconv.ParseUint(string(v.Raw[8:10]), 10, 8) - if err != nil { - return []byte{}, err - } - hour, err := strconv.ParseUint(string(v.Raw[11:13]), 10, 8) - if err != nil { - return []byte{}, err - } - minute, err := strconv.ParseUint(string(v.Raw[14:16]), 10, 8) - if err != nil { - return []byte{}, err - } - second, err := strconv.ParseUint(string(v.Raw[17:19]), 10, 8) - if err != nil { - return []byte{}, err - } - val := make([]byte, 6) - count := copy(val, v.Raw[20:]) - for i := 0; i < (6 - count); i++ { - val[count+i] = 0x30 - } - microSecond, err := strconv.ParseUint(string(val), 10, 32) - if err != nil { - return []byte{}, err - } - pos = writeUint16(out, pos, uint16(year)) - pos = writeByte(out, pos, byte(month)) - pos = writeByte(out, pos, byte(day)) - pos = writeByte(out, pos, byte(hour)) - pos = writeByte(out, pos, byte(minute)) - pos = writeByte(out, pos, byte(second)) - writeUint32(out, pos, uint32(microSecond)) - } else if len(v.Raw) > 10 { - out = make([]byte, 1+7) - out[pos] = 0x07 - pos++ - year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16) - if err != nil { - return []byte{}, err - } - month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8) - if err != nil { - return []byte{}, err - } - day, err := strconv.ParseUint(string(v.Raw[8:10]), 10, 8) - if err != nil { - return []byte{}, err - } - hour, err := strconv.ParseUint(string(v.Raw[11:13]), 10, 8) - if err != nil { - return []byte{}, err - } - minute, err := strconv.ParseUint(string(v.Raw[14:16]), 10, 8) - if err != nil { - return []byte{}, err - } - second, err := strconv.ParseUint(string(v.Raw[17:]), 10, 8) - if err != nil { - return []byte{}, err - } - pos = writeUint16(out, pos, uint16(year)) - pos = writeByte(out, pos, byte(month)) - pos = writeByte(out, pos, byte(day)) - pos = writeByte(out, pos, byte(hour)) - pos = writeByte(out, pos, byte(minute)) - writeByte(out, pos, byte(second)) - } else if len(v.Raw) > 0 { - out = make([]byte, 1+4) - out[pos] = 0x04 - pos++ - year, err := strconv.ParseUint(string(v.Raw[0:4]), 10, 16) - if err != nil { - return []byte{}, err - } - month, err := strconv.ParseUint(string(v.Raw[5:7]), 10, 8) - if err != nil { - return []byte{}, err - } - day, err := strconv.ParseUint(string(v.Raw[8:]), 10, 8) - if err != nil { - return []byte{}, err - } - pos = writeUint16(out, pos, uint16(year)) - pos = writeByte(out, pos, byte(month)) - writeByte(out, pos, byte(day)) - } else { - out = make([]byte, 1) - out[pos] = 0x00 - } - case mysql.FieldTypeTime: - if string(v.Raw) == "00:00:00" { - out = make([]byte, 1) - out[pos] = 0x00 - } else if strings.Contains(string(v.Raw), ".") { - out = make([]byte, 1+12) - out[pos] = 0x0c - pos++ - - sub1 := strings.Split(string(v.Raw), ":") - if len(sub1) != 3 { - err := fmt.Errorf("incorrect time value, ':' is not found") - return []byte{}, err - } - sub2 := strings.Split(sub1[2], ".") - if len(sub2) != 2 { - err := fmt.Errorf("incorrect time value, '.' is not found") - return []byte{}, err - } - - var total []byte - if strings.HasPrefix(sub1[0], "-") { - out[pos] = 0x01 - total = []byte(sub1[0]) - total = total[1:] - } else { - out[pos] = 0x00 - total = []byte(sub1[0]) - } - pos++ - - h, err := strconv.ParseUint(string(total), 10, 32) - if err != nil { - return []byte{}, err - } - - days := uint32(h) / 24 - hours := uint32(h) % 24 - minute := sub1[1] - second := sub2[0] - microSecond := sub2[1] - - minutes, err := strconv.ParseUint(minute, 10, 8) - if err != nil { - return []byte{}, err - } - - seconds, err := strconv.ParseUint(second, 10, 8) - if err != nil { - return []byte{}, err - } - pos = writeUint32(out, pos, uint32(days)) - pos = writeByte(out, pos, byte(hours)) - pos = writeByte(out, pos, byte(minutes)) - pos = writeByte(out, pos, byte(seconds)) - - val := make([]byte, 6) - count := copy(val, microSecond) - for i := 0; i < (6 - count); i++ { - val[count+i] = 0x30 - } - microSeconds, err := strconv.ParseUint(string(val), 10, 32) - if err != nil { - return []byte{}, err - } - writeUint32(out, pos, uint32(microSeconds)) - } else if len(v.Raw) > 0 { - out = make([]byte, 1+8) - out[pos] = 0x08 - pos++ - - sub1 := strings.Split(string(v.Raw), ":") - if len(sub1) != 3 { - err := fmt.Errorf("incorrect time value, ':' is not found") - return []byte{}, err - } - - var total []byte - if strings.HasPrefix(sub1[0], "-") { - out[pos] = 0x01 - total = []byte(sub1[0]) - total = total[1:] - } else { - out[pos] = 0x00 - total = []byte(sub1[0]) - } - pos++ - - h, err := strconv.ParseUint(string(total), 10, 32) - if err != nil { - return []byte{}, err - } - - days := uint32(h) / 24 - hours := uint32(h) % 24 - minute := sub1[1] - second := sub1[2] - - minutes, err := strconv.ParseUint(minute, 10, 8) - if err != nil { - return []byte{}, err - } - - seconds, err := strconv.ParseUint(second, 10, 8) - if err != nil { - return []byte{}, err - } - pos = writeUint32(out, pos, uint32(days)) - pos = writeByte(out, pos, byte(hours)) - pos = writeByte(out, pos, byte(minutes)) - writeByte(out, pos, byte(seconds)) - } else { - err := fmt.Errorf("incorrect time value") - return []byte{}, err - } - case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar, mysql.FieldTypeTinyBLOB, - mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB, mysql.FieldTypeVarString, - mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON, mysql.FieldTypeBit, - mysql.FieldTypeEnum, mysql.FieldTypeSet: - l := len(v.Raw) - length := lenEncIntSize(uint64(l)) + l - out = make([]byte, length) - pos = writeLenEncInt(out, pos, uint64(l)) - copy(out[pos:], v.Raw) - default: - out = make([]byte, len(v.Raw)) - copy(out, v.Raw) - } - return out, nil -} - -func val2MySQLLen(v *proto.Value) (int, error) { - var length int - var err error - if v == nil { - return 0, nil - } - - switch v.Typ { - case mysql.FieldTypeNULL: - length = 0 - case mysql.FieldTypeTiny, mysql.FieldTypeUint8: - length = 1 - case mysql.FieldTypeUint16, mysql.FieldTypeShort, mysql.FieldTypeYear: - length = 2 - case mysql.FieldTypeUint24, mysql.FieldTypeUint32, mysql.FieldTypeInt24, mysql.FieldTypeLong, mysql.FieldTypeFloat: - length = 4 - case mysql.FieldTypeUint64, mysql.FieldTypeLongLong, mysql.FieldTypeDouble: - length = 8 - case mysql.FieldTypeTimestamp, mysql.FieldTypeDate, mysql.FieldTypeDateTime: - if len(v.Raw) > 19 { - length = 12 - } else if len(v.Raw) > 10 { - length = 8 - } else if len(v.Raw) > 0 { - length = 5 - } else { - length = 1 - } - case mysql.FieldTypeTime: - if string(v.Raw) == "00:00:00" { - length = 1 - } else if strings.Contains(string(v.Raw), ".") { - length = 13 - } else if len(v.Raw) > 0 { - length = 9 - } else { - err = fmt.Errorf("incorrect time value") - } - case mysql.FieldTypeDecimal, mysql.FieldTypeNewDecimal, mysql.FieldTypeVarChar, mysql.FieldTypeTinyBLOB, - mysql.FieldTypeMediumBLOB, mysql.FieldTypeLongBLOB, mysql.FieldTypeBLOB, mysql.FieldTypeVarString, - mysql.FieldTypeString, mysql.FieldTypeGeometry, mysql.FieldTypeJSON, mysql.FieldTypeBit, - mysql.FieldTypeEnum, mysql.FieldTypeSet: - l := len(v.Raw) - length = lenEncIntSize(uint64(l)) + l - default: - length = len(v.Raw) - } - if err != nil { - return 0, err - } - return length, nil -} - -func (c *Conn) writeBinaryRows(result proto.Result) error { - for _, row := range result.GetRows() { - r := row.(*Row) - if err := c.writePacket(r.Data()); err != nil { - return err - } - } - return nil } diff --git a/pkg/mysql/statement.go b/pkg/mysql/statement.go index fe2a3ef5a..398523c35 100644 --- a/pkg/mysql/statement.go +++ b/pkg/mysql/statement.go @@ -323,12 +323,12 @@ func (stmt *BackendStatement) writeExecutePacket(args []interface{}) error { paramTypes[i+i+1] = 0x00 var a [64]byte - var b = a[:0] + b := a[:0] if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b, err = appendDateTime(b, v.In(bc.conf.Loc)) + b, err = AppendDateTime(b, v.In(bc.conf.Loc)) if err != nil { return err } @@ -357,65 +357,40 @@ func (stmt *BackendStatement) writeExecutePacket(args []interface{}) error { return bc.c.writePacket(data[4:]) } -func (stmt *BackendStatement) execArgs(args []interface{}) (*Result, uint16, error) { +func (stmt *BackendStatement) execArgs(args []interface{}) (proto.Result, error) { err := stmt.writeExecutePacket(args) if err != nil { - return nil, 0, err + return nil, err } - affectedRows, lastInsertID, colNumber, _, warnings, err := stmt.conn.ReadComQueryResponse() - if err != nil { - return nil, 0, err - } - - if colNumber > 0 { - // columns - if err = stmt.conn.DrainResults(); err != nil { - return nil, 0, err - } - // rows - if err = stmt.conn.DrainResults(); err != nil { - return nil, 0, err - } - } - - return &Result{ - AffectedRows: affectedRows, - InsertId: lastInsertID, - }, warnings, nil + //if colNumber > 0 { + // // columns + // if err = stmt.conn.DrainResults(); err != nil { + // return nil, 0, err + // } + // // rows + // if err = stmt.conn.DrainResults(); err != nil { + // return nil, 0, err + // } + //} + + return stmt.conn.ReadQueryRow(), nil } -// queryArgsIterRow is iterator for binary protocol result set -func (stmt *BackendStatement) queryArgsIterRow(args []interface{}) (*Result, uint16, error) { +func (stmt *BackendStatement) queryArgs(args []interface{}) (proto.Result, error) { err := stmt.writeExecutePacket(args) if err != nil { - return nil, 0, err + return nil, err } - result, affectedRows, lastInsertID, _, warnings, err := stmt.conn.ReadQueryRow() - - iterRow := &BinaryIterRow{result} + res := stmt.conn.ReadQueryRow() + res.setWantFields(true) + res.setBinaryProtocol() - return &Result{ - AffectedRows: affectedRows, - InsertId: lastInsertID, - Fields: iterRow.Fields(), - Rows: []proto.Row{iterRow}, - DataChan: make(chan proto.Row, 1), - }, warnings, err + return res, nil } -func (stmt *BackendStatement) queryArgs(args []interface{}) (*Result, uint16, error) { - err := stmt.writeExecutePacket(args) - if err != nil { - return nil, 0, err - } - - result, _, warnings, err := stmt.conn.ReadQueryResult(true) - return result, warnings, err -} - -func (stmt *BackendStatement) exec(args []byte) (*Result, uint16, error) { +func (stmt *BackendStatement) exec(args []byte) (proto.Result, error) { args[1] = byte(stmt.id) args[2] = byte(stmt.id >> 8) args[3] = byte(stmt.id >> 16) @@ -426,33 +401,29 @@ func (stmt *BackendStatement) exec(args []byte) (*Result, uint16, error) { err := stmt.conn.c.writePacket(args) if err != nil { - return nil, 0, err + return nil, err } stmt.conn.c.recycleWritePacket() - affectedRows, lastInsertID, colNumber, _, warnings, err := stmt.conn.ReadComQueryResponse() - if err != nil { - return nil, 0, err - } - - if colNumber > 0 { - // columns - if err = stmt.conn.DrainResults(); err != nil { - return nil, 0, err - } - // rows - if err = stmt.conn.DrainResults(); err != nil { - return nil, 0, err - } - } - - return &Result{ - AffectedRows: affectedRows, - InsertId: lastInsertID, - }, warnings, nil + res := stmt.conn.ReadQueryRow() + + //if colNumber > 0 { + // // columns + // if err = stmt.conn.DrainResults(); err != nil { + // return nil, 0, err + // } + // // rows + // if err = stmt.conn.DrainResults(); err != nil { + // return nil, 0, err + // } + //} + + res.setBinaryProtocol() + res.setWantFields(true) + return res, nil } -func (stmt *BackendStatement) query(args []byte) (*Result, uint16, error) { +func (stmt *BackendStatement) query(args []byte) (proto.Result, error) { args[1] = byte(stmt.id) args[2] = byte(stmt.id >> 8) args[3] = byte(stmt.id >> 16) @@ -463,9 +434,8 @@ func (stmt *BackendStatement) query(args []byte) (*Result, uint16, error) { err := stmt.conn.c.writePacket(args) if err != nil { - return nil, 0, err + return nil, err } - result, _, warnings, err := stmt.conn.ReadQueryResult(true) - return result, warnings, err + return stmt.conn.ReadQueryResult(true), nil } diff --git a/pkg/mysql/thead/thead.go b/pkg/mysql/thead/thead.go new file mode 100644 index 000000000..824a2066b --- /dev/null +++ b/pkg/mysql/thead/thead.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thead + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" +) + +var ( + Topology = Thead{ + Col{Name: "id", FieldType: consts.FieldTypeLongLong}, + Col{Name: "group_name", FieldType: consts.FieldTypeVarString}, + Col{Name: "table_name", FieldType: consts.FieldTypeVarString}, + } + Database = Thead{ + Col{Name: "Database", FieldType: consts.FieldTypeVarString}, + } +) + +type Col struct { + Name string + FieldType consts.FieldType +} + +type Thead []Col + +func (t Thead) ToFields() []proto.Field { + columns := make([]proto.Field, len(t)) + for i := 0; i < len(t); i++ { + columns[i] = mysql.NewField(t[i].Name, t[i].FieldType) + } + return columns +} diff --git a/pkg/mysql/utils.go b/pkg/mysql/utils.go index 63c0bb76c..2f6ec435e 100644 --- a/pkg/mysql/utils.go +++ b/pkg/mysql/utils.go @@ -295,7 +295,7 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } -func appendDateTime(buf []byte, t time.Time) ([]byte, error) { +func AppendDateTime(buf []byte, t time.Time) ([]byte, error) { year, month, day := t.Date() hour, min, sec := t.Clock() nsec := t.Nanosecond() @@ -335,14 +335,11 @@ func appendDateTime(buf []byte, t time.Time) ([]byte, error) { localBuf[19] = '.' // milli second - localBuf[20], localBuf[21], localBuf[22] = - digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] + localBuf[20], localBuf[21], localBuf[22] = digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] // micro second - localBuf[23], localBuf[24], localBuf[25] = - digits10[nsec10000], digits01[nsec10000], digits10[nsec100] + localBuf[23], localBuf[24], localBuf[25] = digits10[nsec10000], digits01[nsec10000], digits10[nsec100] // nano second - localBuf[26], localBuf[27], localBuf[28] = - digits01[nsec100], digits10[nsec1], digits01[nsec1] + localBuf[26], localBuf[27], localBuf[28] = digits01[nsec100], digits10[nsec1], digits01[nsec1] // trim trailing zeros n := len(localBuf) @@ -600,6 +597,11 @@ func skipLengthEncodedString(b []byte) (int, error) { return n, io.EOF } +func readComFieldListDefaultValueLength(data []byte, pos int) (uint64, int) { + l, _, n := readLengthEncodedInteger(data[pos:]) + return l, pos + n +} + // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // See issue #349 @@ -632,7 +634,7 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { return uint64(b[0]), false, 1 } -// encodes a uint64 value and appends it to the given bytes slice +// appendLengthEncodedInteger encodes an uint64 value and appends it to the given bytes slice func appendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: @@ -1079,7 +1081,7 @@ func convertAssignRows(dest, src interface{}) error { i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetInt(i64) return nil @@ -1091,7 +1093,7 @@ func convertAssignRows(dest, src interface{}) error { u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetUint(u64) return nil @@ -1103,7 +1105,7 @@ func convertAssignRows(dest, src interface{}) error { f64, err := strconv.ParseFloat(s, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetFloat(f64) return nil @@ -1121,7 +1123,7 @@ func convertAssignRows(dest, src interface{}) error { } } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.V type %T into type %T", src, dest) } func cloneBytes(b []byte) []byte { @@ -1361,8 +1363,10 @@ func PutLengthEncodedInt(n uint64) []byte { return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} case n <= 0xffffffffffffffff: - return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), - byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)} + return []byte{ + 0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), + byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), + } } return nil } diff --git a/pkg/proto/data.go b/pkg/proto/data.go index 074ec4e8d..f54d3a6da 100644 --- a/pkg/proto/data.go +++ b/pkg/proto/data.go @@ -15,82 +15,87 @@ * limitations under the License. */ +//go:generate mockgen -destination=../../testdata/mock_data.go -package=testdata . Field,Row,KeyedRow,Dataset,Result package proto import ( - "sync" + "io" + "reflect" ) -import ( - "github.com/arana-db/parser/ast" -) - -import ( - "github.com/arana-db/arana/pkg/constants/mysql" -) - -// CloseableResult is a temporary solution for chan-based result. -// Deprecated: TODO, should be removed in the future -type CloseableResult struct { - once sync.Once - Result - Closer func() error -} - -func (c *CloseableResult) Close() error { - var err error - c.once.Do(func() { - if c.Closer != nil { - err = c.Closer() - } - }) - return err -} - type ( - Value struct { - Typ mysql.FieldType - Flags uint - Len int - Val interface{} - Raw []byte + // Field contains the name and type of column, it follows sql.ColumnType. + Field interface { + // Name returns the name or alias of the column. + Name() string + + // DecimalSize returns the scale and precision of a decimal type. + // If not applicable or if not supported ok is false. + DecimalSize() (precision, scale int64, ok bool) + + // ScanType returns a Go type suitable for scanning into using Rows.Scan. + // If a driver does not support this property ScanType will return + // the type of empty interface. + ScanType() reflect.Type + + // Length returns the column type length for variable length column types such + // as text and binary field types. If the type length is unbounded the value will + // be math.MaxInt64 (any database limits will still apply). + // If the column type is not variable length, such as an int, or if not supported + // by the driver ok is false. + Length() (length int64, ok bool) + + // Nullable reports whether the column may be null. + // If a driver does not support this property ok will be false. + Nullable() (nullable, ok bool) + + // DatabaseTypeName returns the database system name of the column type. If an empty + // string is returned, then the driver type name is not supported. + // Consult your driver documentation for a list of driver data types. Length specifiers + // are not included. + // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", + // "INT", and "BIGINT". + DatabaseTypeName() string } - Field interface { - TableName() string + // Value represents the cell value of Row. + Value interface{} - DataBaseName() string + // Row represents a row data from a result set. + Row interface { + io.WriterTo + IsBinary() bool - TypeDatabaseName() string - } + // Length returns the length of Row. + Length() int - // Row is an iterator over an executed query's results. - Row interface { - // Columns returns the names of the columns. The number of - // columns of the result is inferred from the length of the - // slice. If a particular column name isn't known, an empty - // string should be returned for that entry. - Columns() []string + // Scan scans the Row to values. + Scan(dest []Value) error + } + // KeyedRow represents row with fields. + KeyedRow interface { + Row + // Fields returns the fields of row. Fields() []Field + // Get returns the value of column name. + Get(name string) (Value, error) + } - // Data returns the result in bytes. - Data() []byte - - Decode() ([]*Value, error) + Dataset interface { + io.Closer - GetColumnValue(column string) (interface{}, error) + // Fields returns the fields of Dataset. + Fields() ([]Field, error) - Encode(values []*Value, columns []Field, columnNames []string) Row + // Next returns the next row. + Next() (Row, error) } // Result is the result of a query execution. Result interface { - // GetFields returns the fields. - GetFields() []Field - - // GetRows returns the rows. - GetRows() []Row + // Dataset returns the Dataset. + Dataset() (Dataset, error) // LastInsertId returns the database's auto-generated ID // after, for example, an INSERT into a table with primary @@ -101,15 +106,4 @@ type ( // query. RowsAffected() (uint64, error) } - - // Stmt is a buffer used for store prepare statement meta data - Stmt struct { - StatementID uint32 - PrepareStmt string - ParamsCount uint16 - ParamsType []int32 - ColumnNames []string - BindVars map[string]interface{} - StmtNode ast.StmtNode - } ) diff --git a/pkg/proto/hint/hint.go b/pkg/proto/hint/hint.go new file mode 100644 index 000000000..5a1adb877 --- /dev/null +++ b/pkg/proto/hint/hint.go @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hint + +import ( + "bufio" + "bytes" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/runtime/misc" +) + +const ( + _ Type = iota + TypeMaster // force route to master node + TypeSlave // force route to slave node + TypeRoute // custom route + TypeFullScan // enable full-scan + TypeDirect // direct route +) + +var _hintTypes = [...]string{ + TypeMaster: "MASTER", + TypeSlave: "SLAVE", + TypeRoute: "ROUTE", + TypeFullScan: "FULLSCAN", + TypeDirect: "DIRECT", +} + +// KeyValue represents a pair of key and value. +type KeyValue struct { + K string // key (optional) + V string // value +} + +// Type represents the type of Hint. +type Type uint8 + +// String returns the display string. +func (tp Type) String() string { + return _hintTypes[tp] +} + +// Hint represents a Hint, a valid Hint should include type and input kv pairs. +// +// Follow the format below: +// - without inputs: YOUR_HINT() +// - with non-keyed inputs: YOUR_HINT(foo,bar,quz) +// - with keyed inputs: YOUR_HINT(x=foo,y=bar,z=quz) +// +type Hint struct { + Type Type + Inputs []KeyValue +} + +// String returns the display string. +func (h Hint) String() string { + var sb strings.Builder + sb.WriteString(h.Type.String()) + + if len(h.Inputs) < 1 { + sb.WriteString("()") + return sb.String() + } + + sb.WriteByte('(') + + writeKv := func(p KeyValue) { + if key := p.K; len(key) > 0 { + sb.WriteString(key) + sb.WriteByte('=') + } + sb.WriteString(p.V) + } + + writeKv(h.Inputs[0]) + for i := 1; i < len(h.Inputs); i++ { + sb.WriteByte(',') + writeKv(h.Inputs[i]) + } + + sb.WriteByte(')') + return sb.String() +} + +// Parse parses Hint from an input string. +func Parse(s string) (*Hint, error) { + var ( + tpStr string + tp Type + ) + + offset := strings.Index(s, "(") + if offset == -1 { + tpStr = s + } else { + tpStr = s[:offset] + } + + for i, v := range _hintTypes { + if strings.EqualFold(tpStr, v) { + tp = Type(i) + break + } + } + + if tp == 0 { + return nil, errors.Errorf("hint: invalid input '%s'", s) + } + + if offset == -1 { + return &Hint{Type: tp}, nil + } + + end := strings.LastIndex(s, ")") + if end == -1 { + return nil, errors.Errorf("hint: invalid input '%s'", s) + } + + s = s[offset+1 : end] + + scanner := bufio.NewScanner(strings.NewReader(s)) + scanner.Split(scanComma) + + var kvs []KeyValue + + for scanner.Scan() { + text := scanner.Text() + + // split kv by '=' + i := strings.Index(text, "=") + if i == -1 { + // omit blank text + if misc.IsBlank(text) { + continue + } + kvs = append(kvs, KeyValue{V: strings.TrimSpace(text)}) + } else { + var ( + k = strings.TrimSpace(text[:i]) + v = strings.TrimSpace(text[i+1:]) + ) + // omit blank key/value + if misc.IsBlank(k) || misc.IsBlank(v) { + continue + } + kvs = append(kvs, KeyValue{K: k, V: v}) + } + } + + if err := scanner.Err(); err != nil { + return nil, errors.Wrapf(err, "hint: invalid input '%s'", s) + } + + return &Hint{Type: tp, Inputs: kvs}, nil +} + +func scanComma(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, ','); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/pkg/proto/hint/hint_test.go b/pkg/proto/hint/hint_test.go new file mode 100644 index 000000000..aa812cc68 --- /dev/null +++ b/pkg/proto/hint/hint_test.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hint + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + type tt struct { + input string + output string + pass bool + } + + for _, next := range []tt{ + {"route( foo , bar , qux )", "ROUTE(foo,bar,qux)", true}, + {"master", "MASTER()", true}, + {"slave", "SLAVE()", true}, + {"not_exist_hint(1,2,3)", "", false}, + {"route(,,,)", "ROUTE()", true}, + {"fullscan()", "FULLSCAN()", true}, + {"route(foo=111,bar=222,qux=333,)", "ROUTE(foo=111,bar=222,qux=333)", true}, + } { + t.Run(next.input, func(t *testing.T) { + res, err := Parse(next.input) + if next.pass { + assert.NoError(t, err) + assert.Equal(t, next.output, res.String()) + } else { + assert.Error(t, err) + } + }) + } +} diff --git a/pkg/proto/interface.go b/pkg/proto/interface.go index 3c9d70673..677abc05a 100644 --- a/pkg/proto/interface.go +++ b/pkg/proto/interface.go @@ -20,6 +20,7 @@ package proto import ( "context" "encoding/json" + "sort" ) import ( @@ -45,9 +46,7 @@ type ( Listener interface { SetExecutor(executor Executor) - Listen() - Close() } @@ -74,35 +73,22 @@ type ( // Executor Executor interface { AddPreFilter(filter PreFilter) - AddPostFilter(filter PostFilter) - GetPreFilters() []PreFilter - GetPostFilters() []PostFilter - ProcessDistributedTransaction() bool - InLocalTransaction(ctx *Context) bool - InGlobalTransaction(ctx *Context) bool - ExecuteUseDB(ctx *Context) error - ExecuteFieldList(ctx *Context) ([]Field, error) - ExecutorComQuery(ctx *Context) (Result, uint16, error) - ExecutorComStmtExecute(ctx *Context) (Result, uint16, error) - ConnectionClose(ctx *Context) } ResourceManager interface { GetMasterResourcePool(name string) *pools.ResourcePool - GetSlaveResourcePool(name string) *pools.ResourcePool - GetMetaResourcePool(name string) *pools.ResourcePool } ) @@ -118,3 +104,23 @@ func (c Context) GetQuery() string { } return bytesconv.BytesToString(c.Data[1:]) } + +func (c Context) GetArgs() []interface{} { + if c.Stmt == nil || len(c.Stmt.BindVars) < 1 { + return nil + } + + var ( + keys = make([]string, 0, len(c.Stmt.BindVars)) + args = make([]interface{}, 0, len(c.Stmt.BindVars)) + ) + + for k := range c.Stmt.BindVars { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + args = append(args, c.Stmt.BindVars[k]) + } + return args +} diff --git a/pkg/proto/rule/range.go b/pkg/proto/rule/range.go index 2a1ce8c29..b884918a2 100644 --- a/pkg/proto/rule/range.go +++ b/pkg/proto/rule/range.go @@ -42,9 +42,7 @@ const ( Ustr ) -var ( - DefaultNumberStepper = Stepper{U: Unum, N: 1} -) +var DefaultNumberStepper = Stepper{U: Unum, N: 1} // Range represents a value range. type Range interface { diff --git a/pkg/proto/rule/range_test.go b/pkg/proto/rule/range_test.go index ab57a8b93..0489e12b7 100644 --- a/pkg/proto/rule/range_test.go +++ b/pkg/proto/rule/range_test.go @@ -37,7 +37,7 @@ func TestStepper_Date_After(t *testing.T) { U: Uday, } t.Log(daySt.String()) - testTime := time.Date(2021, 1, 17, 17, 45, 04, 0, time.UTC) + testTime := time.Date(2021, 1, 17, 17, 45, 0o4, 0, time.UTC) hour, err := hourSt.After(testTime) assert.NoError(t, err) assert.Equal(t, 19, hour.(time.Time).Hour()) diff --git a/pkg/proto/rule/topology.go b/pkg/proto/rule/topology.go index df15bc5a6..e3f77f436 100644 --- a/pkg/proto/rule/topology.go +++ b/pkg/proto/rule/topology.go @@ -18,62 +18,154 @@ package rule import ( + "math" "sort" + "sync" ) // Topology represents the topology of databases and tables. type Topology struct { + mu sync.RWMutex dbRender, tbRender func(int) string - idx map[int][]int + idx sync.Map // map[int][]int } // Len returns the length of database and table. func (to *Topology) Len() (dbLen int, tblLen int) { - dbLen = len(to.idx) - for _, v := range to.idx { - tblLen += len(v) - } + to.idx.Range(func(_, value any) bool { + dbLen++ + tblLen += len(value.([]int)) + return true + }) return } // SetTopology sets the topology. func (to *Topology) SetTopology(db int, tables ...int) { - if to.idx == nil { - to.idx = make(map[int][]int) - } - if len(tables) < 1 { - delete(to.idx, db) + to.idx.Delete(db) return } clone := make([]int, len(tables)) copy(clone, tables) sort.Ints(clone) - to.idx[db] = clone + to.idx.Store(db, clone) } // SetRender sets the database/table name render. func (to *Topology) SetRender(dbRender, tbRender func(int) string) { + to.mu.Lock() to.dbRender, to.tbRender = dbRender, tbRender + to.mu.Unlock() } // Render renders the name of database and table from indexes. func (to *Topology) Render(dbIdx, tblIdx int) (string, string, bool) { + to.mu.RLock() + defer to.mu.RUnlock() + if to.tbRender == nil || to.dbRender == nil { return "", "", false } return to.dbRender(dbIdx), to.tbRender(tblIdx), true } +func (to *Topology) EnumerateDatabases() []string { + to.mu.RLock() + render := to.dbRender + to.mu.RUnlock() + + var keys []string + + to.idx.Range(func(key, _ any) bool { + keys = append(keys, render(key.(int))) + return true + }) + + sort.Strings(keys) + + return keys +} + +func (to *Topology) Enumerate() DatabaseTables { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + dt := make(DatabaseTables) + to.Each(func(dbIdx, tbIdx int) bool { + d := dbRender(dbIdx) + t := tbRender(tbIdx) + dt[d] = append(dt[d], t) + return true + }) + + return dt +} + // Each enumerates items in current Topology. func (to *Topology) Each(onEach func(dbIdx, tbIdx int) (ok bool)) bool { - for d, v := range to.idx { + done := true + to.idx.Range(func(key, value any) bool { + var ( + d = key.(int) + v = value.([]int) + ) for _, t := range v { if !onEach(d, t) { + done = false return false } } + return true + }) + + return done +} + +func (to *Topology) Smallest() (db, tb string, ok bool) { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + smallest := [2]int{math.MaxInt64, math.MaxInt64} + to.idx.Range(func(key, value any) bool { + if d := key.(int); d < smallest[0] { + smallest[0] = d + if t := value.([]int); len(t) > 0 { + smallest[1] = t[0] + } + } + return true + }) + + if smallest[0] != math.MaxInt64 || smallest[1] != math.MaxInt64 { + db, tb, ok = dbRender(smallest[0]), tbRender(smallest[1]), true + } + + return +} + +func (to *Topology) Largest() (db, tb string, ok bool) { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + largest := [2]int{math.MinInt64, math.MinInt64} + to.idx.Range(func(key, value any) bool { + if d := key.(int); d > largest[0] { + largest[0] = d + if t := value.([]int); len(t) > 0 { + largest[1] = t[len(t)-1] + } + } + return true + }) + + if largest[0] != math.MinInt64 || largest[1] != math.MinInt64 { + db, tb, ok = dbRender(largest[0]), tbRender(largest[1]), true } - return true + + return } diff --git a/pkg/proto/rule/topology_test.go b/pkg/proto/rule/topology_test.go index 637f438f6..f62dcfe14 100644 --- a/pkg/proto/rule/topology_test.go +++ b/pkg/proto/rule/topology_test.go @@ -33,24 +33,18 @@ func TestLen(t *testing.T) { assert.Equal(t, 6, tblLen) } -func TestSetTopologyForIdxNil(t *testing.T) { - topology := &Topology{ - idx: nil, - } +func TestSetTopology(t *testing.T) { + var topology Topology + topology.SetTopology(2, 2, 3, 4) - for each := range topology.idx { - assert.Equal(t, 2, each) - assert.Equal(t, 3, len(topology.idx[each])) - } dbLen, tblLen := topology.Len() assert.Equal(t, 1, dbLen) assert.Equal(t, 3, tblLen) } -func TestSetTopologyForIdxNotNil(t *testing.T) { - topology := &Topology{ - idx: map[int][]int{0: []int{1, 2, 3}}, - } +func TestSetTopologyNoConflict(t *testing.T) { + var topology Topology + topology.SetTopology(0, 1, 2, 3) topology.SetTopology(1, 4, 5, 6) dbLen, tblLen := topology.Len() assert.Equal(t, 2, dbLen) @@ -58,14 +52,12 @@ func TestSetTopologyForIdxNotNil(t *testing.T) { } func TestSetTopologyForTablesLessThanOne(t *testing.T) { - topology := &Topology{ - idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}}, - } + var topology Topology + + topology.SetTopology(0, 1, 2, 3) + topology.SetTopology(1, 4, 5, 6) topology.SetTopology(1) - for each := range topology.idx { - assert.Equal(t, 0, each) - assert.Equal(t, 3, len(topology.idx[each])) - } + dbLen, tblLen := topology.Len() assert.Equal(t, 1, dbLen) assert.Equal(t, 3, tblLen) @@ -104,6 +96,36 @@ func TestTopology_Each(t *testing.T) { t.Logf("on each: %d,%d\n", dbIdx, tbIdx) return true }) + + assert.False(t, topology.Each(func(dbIdx, tbIdx int) bool { + return false + })) +} + +func TestTopology_Enumerate(t *testing.T) { + topology := createTopology() + shards := topology.Enumerate() + assert.Greater(t, shards.Len(), 0) +} + +func TestTopology_EnumerateDatabases(t *testing.T) { + topology := createTopology() + dbs := topology.EnumerateDatabases() + assert.Greater(t, len(dbs), 0) +} + +func TestTopology_Largest(t *testing.T) { + topology := createTopology() + db, tb, ok := topology.Largest() + assert.True(t, ok) + t.Logf("largest: %s.%s\n", db, tb) +} + +func TestTopology_Smallest(t *testing.T) { + topology := createTopology() + db, tb, ok := topology.Smallest() + assert.True(t, ok) + t.Logf("smallest: %s.%s\n", db, tb) } func createTopology() *Topology { @@ -114,7 +136,10 @@ func createTopology() *Topology { tbRender: func(i int) string { return fmt.Sprintf("%s:%d", "tbRender", i) }, - idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}}, } + + result.SetTopology(0, 1, 2, 3) + result.SetTopology(1, 4, 5, 6) + return result } diff --git a/pkg/proto/runtime.go b/pkg/proto/runtime.go index db87fb475..58d387c43 100644 --- a/pkg/proto/runtime.go +++ b/pkg/proto/runtime.go @@ -15,7 +15,7 @@ * limitations under the License. */ -//go:generate mockgen -destination=../../testdata/mock_runtime.go -package=testdata . VConn,Plan,Optimizer,DB,SchemaLoader +//go:generate mockgen -destination=../../testdata/mock_runtime.go -package=testdata . VConn,Plan,Optimizer,DB package proto import ( @@ -24,10 +24,6 @@ import ( "time" ) -import ( - "github.com/arana-db/parser/ast" -) - const ( PlanTypeQuery PlanType = iota // QUERY PlanTypeExec // EXEC @@ -56,7 +52,7 @@ type ( // Optimizer represents a sql statement optimizer which can be used to create QueryPlan or ExecPlan. Optimizer interface { // Optimize optimizes the sql with arguments then returns a Plan. - Optimize(ctx context.Context, conn VConn, stmt ast.StmtNode, args ...interface{}) (Plan, error) + Optimize(ctx context.Context) (Plan, error) } // Weight represents the read/write weight info. @@ -106,6 +102,7 @@ type ( // Tx represents transaction. Tx interface { Executable + VConn // ID returns the unique transaction id. ID() int64 // Commit commits current transaction. @@ -113,8 +110,4 @@ type ( // Rollback rollbacks current transaction. Rollback(ctx context.Context) (Result, uint16, error) } - - SchemaLoader interface { - Load(ctx context.Context, conn VConn, schema string, tables []string) map[string]*TableMetadata - } ) diff --git a/pkg/proto/metadata.go b/pkg/proto/schema.go similarity index 74% rename from pkg/proto/metadata.go rename to pkg/proto/schema.go index 3da651c42..6ab374661 100644 --- a/pkg/proto/metadata.go +++ b/pkg/proto/schema.go @@ -15,9 +15,11 @@ * limitations under the License. */ +//go:generate mockgen -destination=../../testdata/mock_schema.go -package=testdata . SchemaLoader package proto import ( + "context" "strings" ) @@ -66,3 +68,29 @@ type ColumnMetadata struct { type IndexMetadata struct { Name string } + +var _defaultSchemaLoader SchemaLoader + +func RegisterSchemaLoader(l SchemaLoader) { + _defaultSchemaLoader = l +} + +func LoadSchemaLoader() SchemaLoader { + cur := _defaultSchemaLoader + if cur == nil { + return noopSchemaLoader{} + } + return cur +} + +// SchemaLoader represents a schema discovery. +type SchemaLoader interface { + // Load loads the schema. + Load(ctx context.Context, schema string, table []string) (map[string]*TableMetadata, error) +} + +type noopSchemaLoader struct{} + +func (n noopSchemaLoader) Load(_ context.Context, _ string, _ []string) (map[string]*TableMetadata, error) { + return nil, nil +} diff --git a/pkg/proto/schema_manager/loader.go b/pkg/proto/schema_manager/loader.go deleted file mode 100644 index 6cef03e50..000000000 --- a/pkg/proto/schema_manager/loader.go +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package schema_manager - -import ( - "context" - "fmt" - "io" - "strings" -) - -import ( - "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/proto" - rcontext "github.com/arana-db/arana/pkg/runtime/context" - "github.com/arana-db/arana/pkg/util/log" -) - -const ( - orderByOrdinalPosition = " ORDER BY ORDINAL_POSITION" - tableMetadataNoOrder = "SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE, COLUMN_KEY, EXTRA, COLLATION_NAME, ORDINAL_POSITION FROM information_schema.columns WHERE TABLE_SCHEMA=database()" - tableMetadataSQL = tableMetadataNoOrder + orderByOrdinalPosition - tableMetadataSQLInTables = tableMetadataNoOrder + " AND TABLE_NAME IN (%s)" + orderByOrdinalPosition - indexMetadataSQL = "SELECT TABLE_NAME, INDEX_NAME FROM information_schema.statistics WHERE TABLE_SCHEMA=database() AND TABLE_NAME IN (%s)" -) - -type SimpleSchemaLoader struct{} - -func (l *SimpleSchemaLoader) Load(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string]*proto.TableMetadata { - ctx = rcontext.WithRead(rcontext.WithDirect(ctx)) - var ( - tableMetadataMap = make(map[string]*proto.TableMetadata, len(tables)) - indexMetadataMap map[string][]*proto.IndexMetadata - columnMetadataMap map[string][]*proto.ColumnMetadata - ) - columnMetadataMap = l.LoadColumnMetadataMap(ctx, conn, schema, tables) - if columnMetadataMap != nil { - indexMetadataMap = l.LoadIndexMetadata(ctx, conn, schema, tables) - } - - for tableName, columns := range columnMetadataMap { - tableMetadataMap[tableName] = proto.NewTableMetadata(tableName, columns, indexMetadataMap[tableName]) - } - - return tableMetadataMap -} - -func (l *SimpleSchemaLoader) LoadColumnMetadataMap(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string][]*proto.ColumnMetadata { - resultSet, err := conn.Query(ctx, schema, getColumnMetadataSQL(tables)) - if err != nil { - return nil - } - if closer, ok := resultSet.(io.Closer); ok { - defer func() { - _ = closer.Close() - }() - } - - result := make(map[string][]*proto.ColumnMetadata, 0) - if err != nil { - log.Errorf("Load ColumnMetadata error when call db: %v", err) - return nil - } - if resultSet == nil { - log.Error("Load ColumnMetadata error because the result is nil") - return nil - } - - row := resultSet.GetRows()[0] - var rowIter mysql.Iter - switch r := row.(type) { - case *mysql.BinaryIterRow: - rowIter = r - case *mysql.TextIterRow: - rowIter = r - } - - var ( - has bool - rowValues []*proto.Value - ) - for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - if rowValues, err = row.Decode(); err != nil { - return nil - } - tableName := convertInterfaceToStrNullable(rowValues[0].Val) - columnName := convertInterfaceToStrNullable(rowValues[1].Val) - dataType := convertInterfaceToStrNullable(rowValues[2].Val) - columnKey := convertInterfaceToStrNullable(rowValues[3].Val) - extra := convertInterfaceToStrNullable(rowValues[4].Val) - collationName := convertInterfaceToStrNullable(rowValues[5].Val) - ordinalPosition := convertInterfaceToStrNullable(rowValues[6].Val) - result[tableName] = append(result[tableName], &proto.ColumnMetadata{ - Name: columnName, - DataType: dataType, - Ordinal: ordinalPosition, - PrimaryKey: strings.EqualFold("PRI", columnKey), - Generated: strings.EqualFold("auto_increment", extra), - CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"), - }) - } - - //for _, row := range resultSet.GetRows() { - // var innerRow mysql.Row - // switch r := row.(type) { - // case *mysql.BinaryRow: - // innerRow = r.Row - // case *mysql.Row: - // innerRow = *r - // case *mysql.TextRow: - // innerRow = r.Row - // } - // textRow := mysql.TextRow{Row: innerRow} - // rowValues, err := textRow.Decode() - // if err != nil { - // //logger.Errorf("Load ColumnMetadata error when decode text row: %v", err) - // return nil - // } - // tableName := convertInterfaceToStrNullable(rowValues[0].Val) - // columnName := convertInterfaceToStrNullable(rowValues[1].Val) - // dataType := convertInterfaceToStrNullable(rowValues[2].Val) - // columnKey := convertInterfaceToStrNullable(rowValues[3].Val) - // extra := convertInterfaceToStrNullable(rowValues[4].Val) - // collationName := convertInterfaceToStrNullable(rowValues[5].Val) - // ordinalPosition := convertInterfaceToStrNullable(rowValues[6].Val) - // result[tableName] = append(result[tableName], &proto.ColumnMetadata{ - // Name: columnName, - // DataType: dataType, - // Ordinal: ordinalPosition, - // PrimaryKey: strings.EqualFold("PRI", columnKey), - // Generated: strings.EqualFold("auto_increment", extra), - // CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"), - // }) - //} - return result -} - -func convertInterfaceToStrNullable(value interface{}) string { - if value != nil { - return string(value.([]byte)) - } - return "" -} - -func (l *SimpleSchemaLoader) LoadIndexMetadata(ctx context.Context, conn proto.VConn, schema string, tables []string) map[string][]*proto.IndexMetadata { - resultSet, err := conn.Query(ctx, schema, getIndexMetadataSQL(tables)) - if err != nil { - return nil - } - - if closer, ok := resultSet.(io.Closer); ok { - defer func() { - _ = closer.Close() - }() - } - - result := make(map[string][]*proto.IndexMetadata, 0) - - row := resultSet.GetRows()[0] - var rowIter mysql.Iter - switch r := row.(type) { - case *mysql.BinaryIterRow: - rowIter = r - case *mysql.TextIterRow: - rowIter = r - } - - var ( - has bool - rowValues []*proto.Value - ) - for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - if rowValues, err = row.Decode(); err != nil { - return nil - } - tableName := convertInterfaceToStrNullable(rowValues[0].Val) - indexName := convertInterfaceToStrNullable(rowValues[1].Val) - result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName}) - } - - //for _, row := range resultSet.GetRows() { - // var innerRow mysql.Row - // switch r := row.(type) { - // case *mysql.BinaryRow: - // innerRow = r.Row - // case *mysql.Row: - // innerRow = *r - // case *mysql.TextRow: - // innerRow = r.Row - // } - // textRow := mysql.TextRow{Row: innerRow} - // rowValues, err := textRow.Decode() - // if err != nil { - // log.Errorf("Load ColumnMetadata error when decode text row: %v", err) - // return nil - // } - // tableName := convertInterfaceToStrNullable(rowValues[0].Val) - // indexName := convertInterfaceToStrNullable(rowValues[1].Val) - // result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName}) - //} - - return result -} - -func getIndexMetadataSQL(tables []string) string { - tableParamList := make([]string, 0, len(tables)) - for _, table := range tables { - tableParamList = append(tableParamList, "'"+table+"'") - } - return fmt.Sprintf(indexMetadataSQL, strings.Join(tableParamList, ",")) -} - -func getColumnMetadataSQL(tables []string) string { - if len(tables) == 0 { - return tableMetadataSQL - } - tableParamList := make([]string, len(tables)) - for i, table := range tables { - tableParamList[i] = "'" + table + "'" - } - // TODO use strings.Builder in the future - return fmt.Sprintf(tableMetadataSQLInTables, strings.Join(tableParamList, ",")) -} diff --git a/pkg/proto/stmt.go b/pkg/proto/stmt.go new file mode 100644 index 000000000..d37c14e19 --- /dev/null +++ b/pkg/proto/stmt.go @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package proto + +import ( + "github.com/arana-db/parser/ast" +) + +import ( + "github.com/arana-db/arana/pkg/proto/hint" +) + +// Stmt is a buffer used for store prepare statement metadata. +type Stmt struct { + StatementID uint32 + PrepareStmt string + ParamsCount uint16 + ParamsType []int32 + ColumnNames []string + BindVars map[string]interface{} + Hints []*hint.Hint + StmtNode ast.StmtNode +} diff --git a/pkg/resultx/resultx.go b/pkg/resultx/resultx.go new file mode 100644 index 000000000..d6b6266d6 --- /dev/null +++ b/pkg/resultx/resultx.go @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resultx + +import ( + "github.com/arana-db/arana/pkg/proto" +) + +var ( + _ proto.Result = (*emptyResult)(nil) // contains nothing + _ proto.Result = (*slimResult)(nil) // only contains rows-affected and last-insert-id, design for exec + _ proto.Result = (*dsResult)(nil) // only contains dataset, design for query + _ proto.Result = (*fullResult)(nil) // contains all +) + +type option struct { + ds proto.Dataset + id, affected uint64 +} + +// Option represents the option to create a result. +type Option func(*option) + +// WithLastInsertID specify the last-insert-id for the result to be created. +func WithLastInsertID(id uint64) Option { + return func(o *option) { + o.id = id + } +} + +// WithRowsAffected specify the rows-affected for the result to be created. +func WithRowsAffected(n uint64) Option { + return func(o *option) { + o.affected = n + } +} + +// WithDataset specify the dataset for the result to be created. +func WithDataset(d proto.Dataset) Option { + return func(o *option) { + o.ds = d + } +} + +// New creates a result from some options. +func New(options ...Option) proto.Result { + var o option + for _, it := range options { + it(&o) + } + + // When execute EXEC, no need to specify dataset. + if o.ds == nil { + if o.id == 0 && o.affected == 0 { + return emptyResult{} + } + return slimResult{o.id, o.affected} + } + + // When execute QUERY, only dataset is required. + if o.id == 0 && o.affected == 0 { + return dsResult{ds: o.ds} + } + + // should never happen + return fullResult{ + ds: o.ds, + id: o.id, + affected: o.affected, + } +} + +type emptyResult struct{} + +func (n emptyResult) Dataset() (proto.Dataset, error) { + return nil, nil +} + +func (n emptyResult) LastInsertId() (uint64, error) { + return 0, nil +} + +func (n emptyResult) RowsAffected() (uint64, error) { + return 0, nil +} + +type slimResult [2]uint64 // [lastInsertId,rowsAffected] + +func (h slimResult) Dataset() (proto.Dataset, error) { + return nil, nil +} + +func (h slimResult) LastInsertId() (uint64, error) { + return h[0], nil +} + +func (h slimResult) RowsAffected() (uint64, error) { + return h[1], nil +} + +type fullResult struct { + ds proto.Dataset + id uint64 + affected uint64 +} + +func (f fullResult) Dataset() (proto.Dataset, error) { + return f.ds, nil +} + +func (f fullResult) LastInsertId() (uint64, error) { + return f.id, nil +} + +func (f fullResult) RowsAffected() (uint64, error) { + return f.affected, nil +} + +type dsResult struct { + ds proto.Dataset +} + +func (d dsResult) Dataset() (proto.Dataset, error) { + return d.ds, nil +} + +func (d dsResult) LastInsertId() (uint64, error) { + return 0, nil +} + +func (d dsResult) RowsAffected() (uint64, error) { + return 0, nil +} + +func Drain(result proto.Result) { + if d, _ := result.Dataset(); d != nil { + defer func() { + _ = d.Close() + }() + } + return +} diff --git a/pkg/runtime/ast/alter_table.go b/pkg/runtime/ast/alter_table.go index a60f1596c..eab6fd8dd 100644 --- a/pkg/runtime/ast/alter_table.go +++ b/pkg/runtime/ast/alter_table.go @@ -179,5 +179,5 @@ func (at *AlterTableStatement) CntParams() int { } func (at *AlterTableStatement) Mode() SQLType { - return SalterTable + return SQLTypeAlterTable } diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 5c6f7e897..6a02d7da5 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -34,20 +34,19 @@ import ( ) import ( + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/runtime/cmp" "github.com/arana-db/arana/pkg/runtime/logical" ) -var ( - _opcode2comparison = map[opcode.Op]cmp.Comparison{ - opcode.EQ: cmp.Ceq, - opcode.NE: cmp.Cne, - opcode.LT: cmp.Clt, - opcode.GT: cmp.Cgt, - opcode.LE: cmp.Clte, - opcode.GE: cmp.Cgte, - } -) +var _opcode2comparison = map[opcode.Op]cmp.Comparison{ + opcode.EQ: cmp.Ceq, + opcode.NE: cmp.Cne, + opcode.LT: cmp.Clt, + opcode.GT: cmp.Cgt, + opcode.LE: cmp.Clte, + opcode.GE: cmp.Cgte, +} type ( parseOption struct { @@ -98,7 +97,7 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { } switch tgt := result.(type) { case *ShowColumns: - return &DescribeStatement{Table: tgt.tableName}, nil + return &DescribeStatement{Table: tgt.TableName, Column: tgt.Column}, nil default: return &ExplainStatement{tgt: tgt}, nil } @@ -108,11 +107,52 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { return cc.convDropTableStmt(stmt), nil case *ast.AlterTableStmt: return cc.convAlterTableStmt(stmt), nil + case *ast.DropIndexStmt: + return cc.convDropIndexStmt(stmt), nil + case *ast.DropTriggerStmt: + return cc.convDropTrigger(stmt), nil + case *ast.CreateIndexStmt: + return cc.convCreateIndexStmt(stmt), nil default: return nil, errors.Errorf("unimplement: stmt type %T!", stmt) } } +func (cc *convCtx) convDropIndexStmt(stmt *ast.DropIndexStmt) *DropIndexStatement { + var tableName TableName + if db := stmt.Table.Schema.O; len(db) > 0 { + tableName = append(tableName, db) + } + tableName = append(tableName, stmt.Table.Name.O) + return &DropIndexStatement{ + IfExists: stmt.IfExists, + IndexName: stmt.IndexName, + Table: tableName, + } +} + +func (cc *convCtx) convCreateIndexStmt(stmt *ast.CreateIndexStmt) *CreateIndexStatement { + var tableName TableName + if db := stmt.Table.Schema.O; len(db) > 0 { + tableName = append(tableName, db) + } + tableName = append(tableName, stmt.Table.Name.O) + + keys := make([]*IndexPartSpec, len(stmt.IndexPartSpecifications)) + for i, k := range stmt.IndexPartSpecifications { + keys[i] = &IndexPartSpec{ + Column: cc.convColumn(k.Column), + Expr: toExpressionNode(cc.convExpr(k.Expr)), + } + } + + return &CreateIndexStatement{ + Table: tableName, + IndexName: stmt.IndexName, + Keys: keys, + } +} + func (cc *convCtx) convAlterTableStmt(stmt *ast.AlterTableStmt) *AlterTableStatement { var tableName TableName if db := stmt.Table.Schema.O; len(db) > 0 { @@ -282,7 +322,7 @@ func (cc *convCtx) convConstraint(c *ast.Constraint) *Constraint { } func (cc *convCtx) convDropTableStmt(stmt *ast.DropTableStmt) *DropTableStatement { - var tables = make([]*TableName, len(stmt.Tables)) + tables := make([]*TableName, len(stmt.Tables)) for i, table := range stmt.Tables { tables[i] = &TableName{ table.Name.String(), @@ -506,6 +546,23 @@ func (cc *convCtx) convInsertStmt(stmt *ast.InsertStmt) Statement { } } + if stmt.Select != nil { + switch v := stmt.Select.(type) { + case *ast.SelectStmt: + return &InsertSelectStatement{ + baseInsertStatement: &bi, + sel: cc.convSelectStmt(v), + duplicatedUpdates: updates, + } + case *ast.SetOprStmt: + return &InsertSelectStatement{ + baseInsertStatement: &bi, + unionSel: cc.convUnionStmt(v), + duplicatedUpdates: updates, + } + } + } + return &InsertStatement{ baseInsertStatement: &bi, values: values, @@ -520,12 +577,30 @@ func (cc *convCtx) convTruncateTableStmt(node *ast.TruncateTableStmt) Statement } func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement { + toIn := func(node *ast.ShowStmt) (string, bool) { + if node.DBName == "" { + return "", false + } + return node.DBName, true + } + toFrom := func(node *ast.ShowStmt) (FromTable, bool) { + if node.Table == nil { + return "", false + } + return FromTable(node.Table.Name.String()), true + } toWhere := func(node *ast.ShowStmt) (ExpressionNode, bool) { if node.Where == nil { return nil, false } return toExpressionNode(cc.convExpr(node.Where)), true } + toShowLike := func(node *ast.ShowStmt) (PredicateNode, bool) { + if node.Pattern == nil { + return nil, false + } + return cc.convPatternLikeExpr(node.Pattern), true + } toLike := func(node *ast.ShowStmt) (string, bool) { if node.Pattern == nil { return "", false @@ -535,15 +610,23 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement { toBaseShow := func() *baseShow { var bs baseShow - if like, ok := toLike(node); ok { + if like, ok := toShowLike(node); ok { bs.filter = like } else if where, ok := toWhere(node); ok { bs.filter = where + } else if in, ok := toIn(node); ok { + bs.filter = in + } else if from, ok := toFrom(node); ok { + bs.filter = from } return &bs } switch node.Tp { + case ast.ShowTopology: + return &ShowTopology{baseShow: toBaseShow()} + case ast.ShowOpenTables: + return &ShowOpenTables{baseShow: toBaseShow()} case ast.ShowTables: return &ShowTables{baseShow: toBaseShow()} case ast.ShowDatabases: @@ -555,7 +638,7 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement { } case ast.ShowIndex: ret := &ShowIndex{ - tableName: []string{node.Table.Name.O}, + TableName: []string{node.Table.Name.O}, } if where, ok := toWhere(node); ok { ret.where = where @@ -563,7 +646,10 @@ func (cc *convCtx) convShowStmt(node *ast.ShowStmt) Statement { return ret case ast.ShowColumns: ret := &ShowColumns{ - tableName: []string{node.Table.Name.O}, + TableName: []string{node.Table.Name.O}, + } + if node.Column != nil { + ret.Column = node.Column.Name.O } if node.Extended { ret.flag |= scFlagExtended @@ -599,28 +685,46 @@ func convInsertColumns(columnNames []*ast.ColumnName) []string { } // Parse parses the SQL string to Statement. -func Parse(sql string, options ...ParseOption) (Statement, error) { +func Parse(sql string, options ...ParseOption) ([]*hint.Hint, Statement, error) { var o parseOption for _, it := range options { it(&o) } p := parser.New() - s, err := p.ParseOneStmt(sql, o.charset, o.collation) + s, hintStrs, err := p.ParseOneStmtHints(sql, o.charset, o.collation) + if err != nil { + return nil, nil, err + } + + stmt, err := FromStmtNode(s) if err != nil { - return nil, err + return nil, nil, err + } + + if len(hintStrs) < 1 { + return nil, stmt, nil + } + + hints := make([]*hint.Hint, 0, len(hintStrs)) + for _, it := range hintStrs { + var h *hint.Hint + if h, err = hint.Parse(it); err != nil { + return nil, nil, errors.WithStack(err) + } + hints = append(hints, h) } - return FromStmtNode(s) + return hints, stmt, nil } // MustParse parses the SQL string to Statement, panic if failed. -func MustParse(sql string) Statement { - stmt, err := Parse(sql) +func MustParse(sql string) ([]*hint.Hint, Statement) { + hints, stmt, err := Parse(sql) if err != nil { panic(err.Error()) } - return stmt + return hints, stmt } type convCtx struct { @@ -981,9 +1085,7 @@ func (cc *convCtx) convAggregateFuncExpr(node *ast.AggregateFuncExpr) PredicateN } func (cc *convCtx) convFuncCallExpr(expr *ast.FuncCallExpr) PredicateNode { - var ( - fnName = strings.ToUpper(expr.FnName.O) - ) + fnName := strings.ToUpper(expr.FnName.O) // NOTICE: tidb-parser cannot process CONVERT('foobar' USING utf8). // It should be a CastFunc, but now will be parsed as a FuncCall. @@ -1101,13 +1203,13 @@ func (cc *convCtx) toArg(arg ast.ExprNode) *FunctionArg { func (cc *convCtx) convPatternLikeExpr(expr *ast.PatternLikeExpr) PredicateNode { var ( - left = cc.convExpr(expr.Expr) - right = cc.convExpr(expr.Pattern) + left, _ = cc.convExpr(expr.Expr).(PredicateNode) + right, _ = cc.convExpr(expr.Pattern).(PredicateNode) ) return &LikePredicateNode{ Not: expr.Not, - Left: left.(PredicateNode), - Right: right.(PredicateNode), + Left: left, + Right: right, } } @@ -1199,7 +1301,6 @@ func (cc *convCtx) convValueExpr(expr ast.ValueExpr) PredicateNode { default: if val == nil { atom = &ConstantExpressionAtom{Inner: Null{}} - } else { atom = &ConstantExpressionAtom{Inner: val} } @@ -1283,9 +1384,7 @@ func (cc *convCtx) convBinaryOperationExpr(expr *ast.BinaryOperationExpr) interf Right: right.(*AtomPredicateNode).A, }} case opcode.EQ, opcode.NE, opcode.GT, opcode.GE, opcode.LT, opcode.LE: - var ( - op = _opcode2comparison[expr.Op] - ) + op := _opcode2comparison[expr.Op] if !isColumnAtom(left.(PredicateNode)) && isColumnAtom(right.(PredicateNode)) { // do reverse: @@ -1389,6 +1488,15 @@ func (cc *convCtx) convTableName(val *ast.TableName, tgt *TableSourceNode) { tgt.partitions = partitions } +func (cc *convCtx) convDropTrigger(stmt *ast.DropTriggerStmt) *DropTriggerStatement { + var tableName TableName + if db := stmt.Trigger.Schema.O; len(db) > 0 { + tableName = append(tableName, db) + } + tableName = append(tableName, stmt.Trigger.Name.O) + return &DropTriggerStatement{Table: tableName, IfExists: stmt.IfExists} +} + func toExpressionNode(src interface{}) ExpressionNode { if src == nil { return nil diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go index 90b4a5651..50a10a51d 100644 --- a/pkg/runtime/ast/ast_test.go +++ b/pkg/runtime/ast/ast_test.go @@ -46,39 +46,39 @@ func TestParse(t *testing.T) { "select * from student where uid = !0", } { t.Run(sql, func(t *testing.T) { - stmt, err = Parse(sql) + _, stmt, err = Parse(sql) assert.NoError(t, err) t.Log("stmt:", stmt) }) } // 1. select statement - stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age") + _, stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", stmt) // 2. delete statement - deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)") + _, deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", deleteStmt) // 3. insert statements - insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'") + _, insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithSetClause) - insertStmtWithValues, err := Parse("insert into sink values(1, '2')") + _, insertStmtWithValues, err := Parse("insert into sink values(1, '2')") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithValues) - insertStmtWithOnDuplicateUpdates, err := Parse( + _, insertStmtWithOnDuplicateUpdates, err := Parse( "insert into sink (a, b) values(1, '2') on duplicate key update a=a+1", ) assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithOnDuplicateUpdates) // 4. update statement - updateStmt, err := Parse( + _, updateStmt, err := Parse( "update source set a=a+1, b=b+2 where a>1 order by a limit 5", ) assert.NoError(t, err, "parse+conv ast failed") @@ -98,7 +98,7 @@ func TestParse_UnionStmt(t *testing.T) { {"select id,uid,name,nickname from student where uid in (?,?,?) union all select id,uid,name,nickname from tb_user where uid in (?,?,?)", "SELECT `id`,`uid`,`name`,`nickname` FROM `student` WHERE `uid` IN (?,?,?) UNION ALL SELECT `id`,`uid`,`name`,`nickname` FROM `tb_user` WHERE `uid` IN (?,?,?)"}, } { t.Run(next.input, func(t *testing.T) { - stmt, err := Parse(next.input) + _, stmt, err := Parse(next.input) assert.NoError(t, err, "should parse ok") assert.IsType(t, (*UnionSelectStatement)(nil), stmt, "should be union statement") @@ -106,9 +106,7 @@ func TestParse_UnionStmt(t *testing.T) { assert.NoError(t, err, "should restore ok") assert.Equal(t, next.expect, actual) }) - } - } func TestParse_SelectStmt(t *testing.T) { @@ -162,7 +160,7 @@ func TestParse_SelectStmt(t *testing.T) { {"select null as pkid", "SELECT NULL AS `pkid`"}, } { t.Run(next.input, func(t *testing.T) { - stmt, err := Parse(next.input) + _, stmt, err := Parse(next.input) assert.NoError(t, err, "should parse ok") assert.IsType(t, (*SelectStatement)(nil), stmt, "should be select statement") @@ -171,7 +169,6 @@ func TestParse_SelectStmt(t *testing.T) { assert.Equal(t, next.expect, actual) }) } - } func TestParse_DeleteStmt(t *testing.T) { @@ -185,7 +182,7 @@ func TestParse_DeleteStmt(t *testing.T) { {"delete low_priority quick ignore from student where id = 1", "DELETE LOW_PRIORITY QUICK IGNORE FROM `student` WHERE `id` = 1"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsType(t, (*DeleteStatement)(nil), stmt, "should be delete statement") @@ -206,7 +203,7 @@ func TestParse_DescribeStatement(t *testing.T) { {"desc foobar", "DESC `foobar`"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsType(t, (*DescribeStatement)(nil), stmt, "should be describe statement") @@ -228,7 +225,12 @@ func TestParse_ShowStatement(t *testing.T) { {"show databases", (*ShowDatabases)(nil), "SHOW DATABASES"}, {"show databases like '%foo%'", (*ShowDatabases)(nil), "SHOW DATABASES LIKE '%foo%'"}, {"show databases where name = 'foobar'", (*ShowDatabases)(nil), "SHOW DATABASES WHERE `name` = 'foobar'"}, + {"show open tables", (*ShowOpenTables)(nil), "SHOW OPEN TABLES"}, + {"show open tables in foobar", (*ShowOpenTables)(nil), "SHOW OPEN TABLES IN `foobar`"}, + {"show open tables like '%foo%'", (*ShowOpenTables)(nil), "SHOW OPEN TABLES LIKE '%foo%'"}, + {"show open tables where name = 'foo'", (*ShowOpenTables)(nil), "SHOW OPEN TABLES WHERE `name` = 'foo'"}, {"show tables", (*ShowTables)(nil), "SHOW TABLES"}, + {"show tables in foobar", (*ShowTables)(nil), "SHOW TABLES IN `foobar`"}, {"show tables like '%foo%'", (*ShowTables)(nil), "SHOW TABLES LIKE '%foo%'"}, {"show tables where name = 'foo'", (*ShowTables)(nil), "SHOW TABLES WHERE `name` = 'foo'"}, {"sHow indexes from foo", (*ShowIndex)(nil), "SHOW INDEXES FROM `foo`"}, @@ -238,7 +240,7 @@ func TestParse_ShowStatement(t *testing.T) { {"show create table `foo`", (*ShowCreate)(nil), "SHOW CREATE TABLE `foo`"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, it.expectTyp, stmt, "should be %T", it.expectTyp) @@ -247,11 +249,10 @@ func TestParse_ShowStatement(t *testing.T) { assert.Equal(t, it.expect, actual) }) } - } func TestParse_ExplainStmt(t *testing.T) { - stmt, err := Parse("explain select * from student where uid = 1") + _, stmt, err := Parse("explain select * from student where uid = 1") assert.NoError(t, err) assert.IsType(t, (*ExplainStatement)(nil), stmt) s := MustRestoreToString(RestoreDefault, stmt) @@ -286,7 +287,7 @@ func TestParseMore(t *testing.T) { for _, sql := range tbls { t.Run(sql, func(t *testing.T) { - _, err := Parse(sql) + _, _, err := Parse(sql) assert.NoError(t, err) }) } @@ -303,7 +304,7 @@ func TestParse_UpdateStmt(t *testing.T) { {"update low_priority student set nickname = ? where id = 1 limit 1", "UPDATE LOW_PRIORITY `student` SET `nickname` = ? WHERE `id` = 1 LIMIT 1"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*UpdateStatement)(nil), stmt, "should be update statement") @@ -311,7 +312,6 @@ func TestParse_UpdateStmt(t *testing.T) { assert.NoError(t, err, "should restore ok") assert.Equal(t, it.expect, actual) }) - } } @@ -337,7 +337,7 @@ func TestParse_InsertStmt(t *testing.T) { }, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*InsertStatement)(nil), stmt, "should be insert statement") @@ -347,10 +347,38 @@ func TestParse_InsertStmt(t *testing.T) { }) } + for _, it := range []tt{ + { + "insert into student select * from student_tmp", + "INSERT INTO `student` SELECT * FROM `student_tmp`", + }, + { + "insert into student(id,name) select emp_no, name from employees limit 10,2", + "INSERT INTO `student`(`id`, `name`) SELECT `emp_no`,`name` FROM `employees` LIMIT 10,2", + }, + { + "insert into student(id,name) select emp_no, name from employees on duplicate key update version=version+1,modified_at=NOW()", + "INSERT INTO `student`(`id`, `name`) SELECT `emp_no`,`name` FROM `employees` ON DUPLICATE KEY UPDATE `version` = `version`+1, `modified_at` = NOW()", + }, + { + "insert student select id, score from student_tmp union select id * 10, score * 10 from student_tmp", + "INSERT INTO `student` SELECT `id`,`score` FROM `student_tmp` UNION SELECT `id`*10,`score`*10 FROM `student_tmp`", + }, + } { + t.Run(it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*InsertSelectStatement)(nil), stmt, "should be insert-select statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }) + } } func TestRestoreCount(t *testing.T) { - stmt := MustParse("select count(1)") + _, stmt := MustParse("select count(1)") sel := stmt.(*SelectStatement) var sb strings.Builder _ = sel.Restore(RestoreDefault, &sb, nil) @@ -358,7 +386,7 @@ func TestRestoreCount(t *testing.T) { } func TestQuote(t *testing.T) { - stmt := MustParse("select `a``bc`") + _, stmt := MustParse("select `a``bc`") sel := stmt.(*SelectStatement) var sb strings.Builder _ = sel.Restore(RestoreDefault, &sb, nil) @@ -403,7 +431,7 @@ func TestParse_AlterTableStmt(t *testing.T) { }, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement") @@ -412,5 +440,15 @@ func TestParse_AlterTableStmt(t *testing.T) { assert.Equal(t, it.expect, actual) }) } +} +func TestParse_DescStmt(t *testing.T) { + _, stmt := MustParse("desc student id") + // In MySQL, the case of "desc student 'id'" will be parsed successfully, + // but in arana, it will get an error by tidb parser. + desc := stmt.(*DescribeStatement) + var sb strings.Builder + _ = desc.Restore(RestoreDefault, &sb, nil) + t.Logf(sb.String()) + assert.Equal(t, "DESC `student` `id`", sb.String()) } diff --git a/pkg/runtime/ast/create_index.go b/pkg/runtime/ast/create_index.go new file mode 100644 index 000000000..050713bc0 --- /dev/null +++ b/pkg/runtime/ast/create_index.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +var ( + _ Statement = (*CreateIndexStatement)(nil) + _ Restorer = (*CreateIndexStatement)(nil) +) + +type CreateIndexStatement struct { + IndexName string + Table TableName + Keys []*IndexPartSpec +} + +func (c *CreateIndexStatement) CntParams() int { + return 0 +} + +func (c *CreateIndexStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("CREATE INDEX ") + sb.WriteString(c.IndexName) + if len(c.Table) == 0 { + return nil + } + sb.WriteString(" ON ") + if err := c.Table.Restore(flag, sb, args); err != nil { + return err + } + + sb.WriteString(" (") + for i, k := range c.Keys { + if i != 0 { + sb.WriteString(", ") + } + if err := k.Restore(flag, sb, args); err != nil { + return err + } + } + sb.WriteString(")") + return nil +} + +func (c *CreateIndexStatement) Validate() error { + return nil +} + +func (c *CreateIndexStatement) Mode() SQLType { + return SQLTypeCreateIndex +} diff --git a/pkg/runtime/ast/delete.go b/pkg/runtime/ast/delete.go index 90987637a..b8b38f05c 100644 --- a/pkg/runtime/ast/delete.go +++ b/pkg/runtime/ast/delete.go @@ -109,7 +109,7 @@ func (ds *DeleteStatement) CntParams() int { } func (ds *DeleteStatement) Mode() SQLType { - return Sdelete + return SQLTypeDelete } func (ds *DeleteStatement) IsLowPriority() bool { diff --git a/pkg/runtime/ast/describe.go b/pkg/runtime/ast/describe.go index 3107e2831..3523ff482 100644 --- a/pkg/runtime/ast/describe.go +++ b/pkg/runtime/ast/describe.go @@ -33,7 +33,7 @@ var ( // DescribeStatement represents mysql describe statement. see https://dev.mysql.com/doc/refman/8.0/en/describe.html type DescribeStatement struct { Table TableName - column string + Column string } // Restore implements Restorer. @@ -42,9 +42,9 @@ func (d *DescribeStatement) Restore(flag RestoreFlag, sb *strings.Builder, args if err := d.Table.Restore(flag, sb, args); err != nil { return errors.WithStack(err) } - if len(d.column) > 0 { + if len(d.Column) > 0 { sb.WriteByte(' ') - WriteID(sb, d.column) + WriteID(sb, d.Column) } return nil @@ -59,14 +59,7 @@ func (d *DescribeStatement) CntParams() int { } func (d *DescribeStatement) Mode() SQLType { - return Squery -} - -func (d *DescribeStatement) Column() (string, bool) { - if len(d.column) > 0 { - return d.column, true - } - return "", false + return SQLTypeDescribe } // ExplainStatement represents mysql explain statement. see https://dev.mysql.com/doc/refman/8.0/en/explain.html @@ -95,5 +88,5 @@ func (e *ExplainStatement) CntParams() int { } func (e *ExplainStatement) Mode() SQLType { - return Squery + return SQLTypeSelect } diff --git a/pkg/runtime/ast/drop_index.go b/pkg/runtime/ast/drop_index.go new file mode 100644 index 000000000..58918a21b --- /dev/null +++ b/pkg/runtime/ast/drop_index.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +var ( + _ Statement = (*DropIndexStatement)(nil) + _ Restorer = (*DropIndexStatement)(nil) +) + +type DropIndexStatement struct { + IfExists bool + IndexName string + Table TableName +} + +func (d *DropIndexStatement) CntParams() int { + return 0 +} + +func (d *DropIndexStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("DROP INDEX ") + if d.IfExists { + sb.WriteString("IF EXISTS") + } + sb.WriteString(d.IndexName) + if len(d.Table) == 0 { + return nil + } + sb.WriteString(" ON ") + return d.Table.Restore(flag, sb, args) +} + +func (d *DropIndexStatement) Validate() error { + return nil +} + +func (d *DropIndexStatement) Mode() SQLType { + return SQLTypeDropIndex +} diff --git a/pkg/runtime/ast/drop_table.go b/pkg/runtime/ast/drop_table.go index 1bf3b63cb..1b2778711 100644 --- a/pkg/runtime/ast/drop_table.go +++ b/pkg/runtime/ast/drop_table.go @@ -55,5 +55,5 @@ func (d DropTableStatement) Validate() error { } func (d DropTableStatement) Mode() SQLType { - return SdropTable + return SQLTypeDropTable } diff --git a/pkg/runtime/ast/function.go b/pkg/runtime/ast/function.go index 6c7402983..c297d8fd6 100644 --- a/pkg/runtime/ast/function.go +++ b/pkg/runtime/ast/function.go @@ -48,8 +48,10 @@ const ( Fpasswd ) -type FunctionArgType uint8 -type FunctionType uint8 +type ( + FunctionArgType uint8 + FunctionType uint8 +) func (f FunctionType) String() string { switch f { @@ -535,9 +537,7 @@ func (cd *ConvertDataType) Parse(s string) error { return errors.Errorf("invalid cast string '%s'", s) } - var ( - name, first, second, suffix string - ) + var name, first, second, suffix string for i := 1; i < len(keys); i++ { sub := subs[i] switch keys[i] { diff --git a/pkg/runtime/ast/insert.go b/pkg/runtime/ast/insert.go index e173a12f8..683a14af1 100644 --- a/pkg/runtime/ast/insert.go +++ b/pkg/runtime/ast/insert.go @@ -146,7 +146,7 @@ type ReplaceStatement struct { } func (r *ReplaceStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { - //TODO implement me + // TODO implement me panic("implement me") } @@ -159,7 +159,7 @@ func (r *ReplaceStatement) Values() [][]ExpressionNode { } func (r *ReplaceStatement) Mode() SQLType { - return Sreplace + return SQLTypeReplace } func (r *ReplaceStatement) CntParams() int { @@ -343,7 +343,7 @@ func (is *InsertStatement) CntParams() int { } func (is *InsertStatement) Mode() SQLType { - return Sinsert + return SQLTypeInsert } type ReplaceSelectStatement struct { @@ -352,7 +352,7 @@ type ReplaceSelectStatement struct { } func (r *ReplaceSelectStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { - //TODO implement me + // TODO implement me panic("implement me") } @@ -369,21 +369,84 @@ func (r *ReplaceSelectStatement) CntParams() int { } func (r *ReplaceSelectStatement) Mode() SQLType { - return Sreplace + return SQLTypeReplace } type InsertSelectStatement struct { *baseInsertStatement - sel *SelectStatement + duplicatedUpdates []*UpdateElement + sel *SelectStatement + unionSel *UnionSelectStatement } func (is *InsertSelectStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { - //TODO implement me - panic("implement me") + sb.WriteString("INSERT ") + + // write priority + if is.IsLowPriority() { + sb.WriteString("LOW_PRIORITY ") + } else if is.IsHighPriority() { + sb.WriteString("HIGH_PRIORITY ") + } else if is.IsDelayed() { + sb.WriteString("DELAYED ") + } + + if is.IsIgnore() { + sb.WriteString("IGNORE ") + } + + sb.WriteString("INTO ") + + if err := is.Table().Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + + if len(is.columns) > 0 { + sb.WriteByte('(') + WriteID(sb, is.columns[0]) + for i := 1; i < len(is.columns); i++ { + sb.WriteString(", ") + WriteID(sb, is.columns[i]) + } + sb.WriteString(") ") + } else { + sb.WriteByte(' ') + } + + if is.sel != nil { + if err := is.sel.Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + } + + if is.unionSel != nil { + if err := is.unionSel.Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + } + + if len(is.duplicatedUpdates) > 0 { + sb.WriteString(" ON DUPLICATE KEY UPDATE ") + + if err := is.duplicatedUpdates[0].Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + for i := 1; i < len(is.duplicatedUpdates); i++ { + sb.WriteString(", ") + if err := is.duplicatedUpdates[i].Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + } + } + + return nil } func (is *InsertSelectStatement) Validate() error { - return nil + if is.unionSel != nil { + return is.unionSel.Validate() + } + return is.sel.Validate() } func (is *InsertSelectStatement) Select() *SelectStatement { @@ -391,9 +454,12 @@ func (is *InsertSelectStatement) Select() *SelectStatement { } func (is *InsertSelectStatement) CntParams() int { + if is.unionSel != nil { + return is.unionSel.CntParams() + } return is.sel.CntParams() } func (is *InsertSelectStatement) Mode() SQLType { - return Sinsert + return SQLTypeInsertSelect } diff --git a/pkg/runtime/ast/model.go b/pkg/runtime/ast/model.go index f6bfcb5b9..cac08dcb0 100644 --- a/pkg/runtime/ast/model.go +++ b/pkg/runtime/ast/model.go @@ -310,10 +310,22 @@ func (ln *LimitNode) SetLimitVar() { ln.flag |= flagLimitLimitVar } +func (ln *LimitNode) UnsetOffsetVar() { + ln.flag &= ^flagLimitOffsetVar +} + +func (ln *LimitNode) UnsetLimitVar() { + ln.flag &= ^flagLimitLimitVar +} + func (ln *LimitNode) SetHasOffset() { ln.flag |= flagLimitHasOffset } +func (ln *LimitNode) UnsetHasOffset() { + ln.flag &= ^flagLimitHasOffset +} + func (ln *LimitNode) HasOffset() bool { return ln.flag&flagLimitHasOffset != 0 } diff --git a/pkg/runtime/ast/predicate.go b/pkg/runtime/ast/predicate.go index 64dd379a4..5a26ef993 100644 --- a/pkg/runtime/ast/predicate.go +++ b/pkg/runtime/ast/predicate.go @@ -74,8 +74,10 @@ func (l *LikePredicateNode) InTables(tables map[string]struct{}) error { } func (l *LikePredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { - if err := l.Left.Restore(flag, sb, args); err != nil { - return errors.WithStack(err) + if l.Left != nil { + if err := l.Left.Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } } if l.Not { @@ -83,9 +85,10 @@ func (l *LikePredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args } else { sb.WriteString(" LIKE ") } - - if err := l.Right.Restore(flag, sb, args); err != nil { - return errors.WithStack(err) + if l.Right != nil { + if err := l.Right.Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } } return nil diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index c5109dcdb..4923a714b 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -22,15 +22,29 @@ import ( ) const ( - _ SQLType = iota - Squery // QUERY - Sdelete // DELETE - Supdate // UPDATE - Sinsert // INSERT - Sreplace // REPLACE - Struncate // TRUNCATE - SdropTable // DROP TABLE - SalterTable // ALTER TABLE + _ SQLType = iota + SQLTypeSelect // SELECT + SQLTypeDelete // DELETE + SQLTypeUpdate // UPDATE + SQLTypeInsert // INSERT + SQLTypeInsertSelect // INSERT SELECT + SQLTypeReplace // REPLACE + SQLTypeTruncate // TRUNCATE + SQLTypeDropTable // DROP TABLE + SQLTypeAlterTable // ALTER TABLE + SQLTypeDropIndex // DROP INDEX + SQLTypeShowDatabases // SHOW DATABASES + SQLTypeShowTables // SHOW TABLES + SQLTypeShowOpenTables // SHOW OPEN TABLES + SQLTypeShowIndex // SHOW INDEX + SQLTypeShowColumns // SHOW COLUMNS + SQLTypeShowCreate // SHOW CREATE + SQLTypeShowVariables // SHOW VARIABLES + SQLTypeShowTopology // SHOW TOPOLOGY + SQLTypeDescribe // DESCRIBE + SQLTypeUnion // UNION + SQLTypeDropTrigger // DROP TRIGGER + SQLTypeCreateIndex // CREATE INDEX ) type RestoreFlag uint32 @@ -45,14 +59,27 @@ type Restorer interface { } var _sqlTypeNames = [...]string{ - Squery: "QUERY", - Sdelete: "DELETE", - Supdate: "UPDATE", - Sinsert: "INSERT", - Sreplace: "REPLACE", - Struncate: "TRUNCATE", - SdropTable: "DROP TABLE", - SalterTable: "ALTER TABLE", + SQLTypeSelect: "SELECT", + SQLTypeDelete: "DELETE", + SQLTypeUpdate: "UPDATE", + SQLTypeInsert: "INSERT", + SQLTypeInsertSelect: "INSERT SELECT", + SQLTypeReplace: "REPLACE", + SQLTypeTruncate: "TRUNCATE", + SQLTypeDropTable: "DROP TABLE", + SQLTypeAlterTable: "ALTER TABLE", + SQLTypeDropIndex: "DROP INDEX", + SQLTypeShowDatabases: "SHOW DATABASES", + SQLTypeShowTables: "SHOW TABLES", + SQLTypeShowOpenTables: "SHOW OPEN TABLES", + SQLTypeShowIndex: "SHOW INDEX", + SQLTypeShowColumns: "SHOW COLUMNS", + SQLTypeShowCreate: "SHOW CREATE", + SQLTypeShowVariables: "SHOW VARIABLES", + SQLTypeDescribe: "DESCRIBE", + SQLTypeUnion: "UNION", + SQLTypeDropTrigger: "DROP TRIGGER", + SQLTypeCreateIndex: "CREATE INDEX", } // SQLType represents the type of SQL. diff --git a/pkg/runtime/ast/select.go b/pkg/runtime/ast/select.go index af0ebff1a..4366bffff 100644 --- a/pkg/runtime/ast/select.go +++ b/pkg/runtime/ast/select.go @@ -252,7 +252,7 @@ func (ss *SelectStatement) Validate() error { } func (ss *SelectStatement) Mode() SQLType { - return Squery + return SQLTypeSelect } func (ss *SelectStatement) CntParams() int { diff --git a/pkg/runtime/ast/show.go b/pkg/runtime/ast/show.go index 85b3f33e1..fdde805a0 100644 --- a/pkg/runtime/ast/show.go +++ b/pkg/runtime/ast/show.go @@ -28,12 +28,20 @@ import ( var ( _ Statement = (*ShowTables)(nil) + _ Statement = (*ShowOpenTables)(nil) _ Statement = (*ShowCreate)(nil) _ Statement = (*ShowDatabases)(nil) _ Statement = (*ShowColumns)(nil) _ Statement = (*ShowIndex)(nil) + _ Statement = (*ShowTopology)(nil) ) +type FromTable string + +func (f FromTable) String() string { + return string(f) +} + type baseShow struct { filter interface{} // ExpressionNode or string } @@ -41,11 +49,16 @@ type baseShow struct { func (bs *baseShow) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { switch val := bs.filter.(type) { case string: - sb.WriteString(" LIKE ") - sb.WriteByte('\'') + sb.WriteString(" IN ") + sb.WriteByte('`') sb.WriteString(val) - sb.WriteByte('\'') + sb.WriteByte('`') return nil + case FromTable: + sb.WriteString(val.String()) + return nil + case PredicateNode: + return val.Restore(flag, sb, nil) case ExpressionNode: sb.WriteString(" WHERE ") return val.Restore(flag, sb, args) @@ -58,6 +71,7 @@ func (bs *baseShow) Like() (string, bool) { v, ok := bs.filter.(string) return v, ok } + func (bs *baseShow) Where() (ExpressionNode, bool) { v, ok := bs.filter.(ExpressionNode) return v, ok @@ -67,14 +81,14 @@ func (bs *baseShow) CntParams() int { return 0 } -func (bs *baseShow) Mode() SQLType { - return Squery -} - type ShowDatabases struct { *baseShow } +func (s ShowDatabases) Mode() SQLType { + return SQLTypeShowDatabases +} + func (s ShowDatabases) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW DATABASES") if err := s.baseShow.Restore(flag, sb, args); err != nil { @@ -91,6 +105,10 @@ type ShowTables struct { *baseShow } +func (s ShowTables) Mode() SQLType { + return SQLTypeShowTables +} + func (s ShowTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW TABLES") if err := s.baseShow.Restore(flag, sb, args); err != nil { @@ -103,6 +121,42 @@ func (s ShowTables) Validate() error { return nil } +type ShowTopology struct { + *baseShow +} + +func (s ShowTopology) Mode() SQLType { + return SQLTypeShowTopology +} + +func (s ShowTopology) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + return s.baseShow.Restore(flag, sb, args) +} + +func (s ShowTopology) Validate() error { + return nil +} + +type ShowOpenTables struct { + *baseShow +} + +func (s ShowOpenTables) Mode() SQLType { + return SQLTypeShowOpenTables +} + +func (s ShowOpenTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("SHOW OPEN TABLES") + if err := s.baseShow.Restore(flag, sb, args); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (s ShowOpenTables) Validate() error { + return nil +} + const ( _ ShowCreateType = iota ShowCreateTypeTable @@ -139,6 +193,13 @@ type ShowCreate struct { tgt string } +func (s *ShowCreate) ResetTable(table string) *ShowCreate { + ret := new(ShowCreate) + *ret = *s + ret.tgt = table + return ret +} + func (s *ShowCreate) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW CREATE ") sb.WriteString(s.typ.String()) @@ -164,11 +225,11 @@ func (s *ShowCreate) CntParams() int { } func (s *ShowCreate) Mode() SQLType { - return Squery + return SQLTypeShowCreate } type ShowIndex struct { - tableName TableName + TableName TableName where ExpressionNode } @@ -179,7 +240,7 @@ func (s *ShowIndex) Validate() error { func (s *ShowIndex) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW INDEXES FROM ") - _ = s.tableName.Restore(flag, sb, args) + _ = s.TableName.Restore(flag, sb, args) if where, ok := s.Where(); ok { sb.WriteString(" WHERE ") @@ -191,10 +252,6 @@ func (s *ShowIndex) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) return nil } -func (s *ShowIndex) TableName() TableName { - return s.tableName -} - func (s *ShowIndex) Where() (ExpressionNode, bool) { if s.where != nil { return s.where, true @@ -210,7 +267,7 @@ func (s *ShowIndex) CntParams() int { } func (s *ShowIndex) Mode() SQLType { - return Squery + return SQLTypeShowIndex } type showColumnsFlag uint8 @@ -224,8 +281,9 @@ const ( type ShowColumns struct { flag showColumnsFlag - tableName TableName + TableName TableName like sql.NullString + Column string } func (sh *ShowColumns) IsFull() bool { @@ -247,7 +305,7 @@ func (sh *ShowColumns) Restore(flag RestoreFlag, sb *strings.Builder, args *[]in } sb.WriteString("COLUMNS FROM ") - if err := sh.tableName.Restore(flag, sb, args); err != nil { + if err := sh.TableName.Restore(flag, sb, args); err != nil { return errors.WithStack(err) } @@ -274,7 +332,7 @@ func (sh *ShowColumns) Validate() error { } func (sh *ShowColumns) Table() TableName { - return sh.tableName + return sh.TableName } func (sh *ShowColumns) CntParams() int { @@ -282,7 +340,7 @@ func (sh *ShowColumns) CntParams() int { } func (sh *ShowColumns) Mode() SQLType { - return Squery + return SQLTypeShowColumns } func (sh *ShowColumns) Full() bool { @@ -342,5 +400,5 @@ func (s *ShowVariables) CntParams() int { } func (s *ShowVariables) Mode() SQLType { - return Squery + return SQLTypeShowVariables } diff --git a/pkg/runtime/ast/table_source.go b/pkg/runtime/ast/table_source.go index 48ed4c121..e8d9351dd 100644 --- a/pkg/runtime/ast/table_source.go +++ b/pkg/runtime/ast/table_source.go @@ -37,9 +37,7 @@ type TableSourceNode struct { func (t *TableSourceNode) ResetTableName(newTableName string) bool { switch source := t.source.(type) { case TableName: - var ( - newSource = make(TableName, len(source)) - ) + newSource := make(TableName, len(source)) copy(newSource, source) newSource[len(newSource)-1] = newTableName t.source = newSource diff --git a/pkg/runtime/ast/trigger.go b/pkg/runtime/ast/trigger.go new file mode 100644 index 000000000..31c5b40b1 --- /dev/null +++ b/pkg/runtime/ast/trigger.go @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +var ( + _ Statement = (*DropTriggerStatement)(nil) +) + +type DropTriggerStatement struct { + IfExists bool + Table TableName +} + +func (d DropTriggerStatement) CntParams() int { + return 0 +} + +func (d DropTriggerStatement) Mode() SQLType { + return SQLTypeDropTrigger +} + +func (d DropTriggerStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("DROP TRIGGER ") + if d.IfExists { + sb.WriteString("IF EXISTS ") + } + return d.Table.Restore(flag, sb, args) +} + +func (d DropTriggerStatement) Validate() error { return nil } diff --git a/pkg/runtime/ast/truncate.go b/pkg/runtime/ast/truncate.go index aa5ea87c9..86db2114c 100644 --- a/pkg/runtime/ast/truncate.go +++ b/pkg/runtime/ast/truncate.go @@ -25,9 +25,7 @@ import ( "github.com/pkg/errors" ) -var ( - _ Statement = (*TruncateStatement)(nil) -) +var _ Statement = (*TruncateStatement)(nil) // TruncateStatement represents mysql describe statement. see https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html type TruncateStatement struct { @@ -64,5 +62,5 @@ func (stmt *TruncateStatement) CntParams() int { } func (stmt *TruncateStatement) Mode() SQLType { - return Struncate + return SQLTypeTruncate } diff --git a/pkg/runtime/ast/union.go b/pkg/runtime/ast/union.go index 57ce0cbc0..61a7b07f3 100644 --- a/pkg/runtime/ast/union.go +++ b/pkg/runtime/ast/union.go @@ -112,7 +112,7 @@ func (u *UnionSelectStatement) OrderBy() OrderByNode { } func (u *UnionSelectStatement) Mode() SQLType { - return Squery + return SQLTypeUnion } func (u *UnionSelectStatement) First() *SelectStatement { diff --git a/pkg/runtime/ast/update.go b/pkg/runtime/ast/update.go index 93c855f3e..26a2038b7 100644 --- a/pkg/runtime/ast/update.go +++ b/pkg/runtime/ast/update.go @@ -143,5 +143,5 @@ func (u *UpdateStatement) CntParams() int { } func (u *UpdateStatement) Mode() SQLType { - return Supdate + return SQLTypeUpdate } diff --git a/pkg/runtime/context/context.go b/pkg/runtime/context/context.go index bd7b11284..67894172d 100644 --- a/pkg/runtime/context/context.go +++ b/pkg/runtime/context/context.go @@ -23,7 +23,7 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/proto/hint" ) const ( @@ -34,12 +34,13 @@ const ( type ( keyFlag struct{} - keyRule struct{} keySequence struct{} keySql struct{} keyNodeLabel struct{} keySchema struct{} keyDefaultDBGroup struct{} + keyTenant struct{} + keyHints struct{} ) type cFlag uint8 @@ -60,14 +61,9 @@ func WithSQL(ctx context.Context, sql string) context.Context { return context.WithValue(ctx, keySql{}, sql) } -// WithDBGroup binds the default db. -func WithDBGroup(ctx context.Context, group string) context.Context { - return context.WithValue(ctx, keyDefaultDBGroup{}, group) -} - -// WithRule binds a rule. -func WithRule(ctx context.Context, ru *rule.Rule) context.Context { - return context.WithValue(ctx, keyRule{}, ru) +// WithTenant binds the tenant. +func WithTenant(ctx context.Context, tenant string) context.Context { + return context.WithValue(ctx, keyTenant{}, tenant) } func WithSchema(ctx context.Context, data string) context.Context { @@ -89,6 +85,11 @@ func WithRead(ctx context.Context) context.Context { return context.WithValue(ctx, keyFlag{}, _flagRead|getFlag(ctx)) } +// WithHints binds the hints. +func WithHints(ctx context.Context, hints []*hint.Hint) context.Context { + return context.WithValue(ctx, keyHints{}, hints) +} + // Sequencer extracts the sequencer. func Sequencer(ctx context.Context) proto.Sequencer { s, ok := ctx.Value(keySequence{}).(proto.Sequencer) @@ -98,24 +99,15 @@ func Sequencer(ctx context.Context) proto.Sequencer { return s } -// DBGroup extracts the db. -func DBGroup(ctx context.Context) string { - db, ok := ctx.Value(keyDefaultDBGroup{}).(string) +// Tenant extracts the tenant. +func Tenant(ctx context.Context) string { + db, ok := ctx.Value(keyTenant{}).(string) if !ok { return "" } return db } -// Rule extracts the rule. -func Rule(ctx context.Context) *rule.Rule { - ru, ok := ctx.Value(keyRule{}).(*rule.Rule) - if !ok { - return nil - } - return ru -} - // IsRead returns true if this is a read operation func IsRead(ctx context.Context) bool { return hasFlag(ctx, _flagRead) @@ -154,6 +146,15 @@ func NodeLabel(ctx context.Context) string { return "" } +// Hints extracts the hints. +func Hints(ctx context.Context) []*hint.Hint { + hints, ok := ctx.Value(keyHints{}).([]*hint.Hint) + if !ok { + return nil + } + return hints +} + func hasFlag(ctx context.Context, flag cFlag) bool { return getFlag(ctx)&flag != 0 } diff --git a/pkg/runtime/function/function_test.go b/pkg/runtime/function/function_test.go index 64ff195da..b17326a98 100644 --- a/pkg/runtime/function/function_test.go +++ b/pkg/runtime/function/function_test.go @@ -104,7 +104,7 @@ func BenchmarkEval(b *testing.B) { } func mustGetMathAtom() *ast.MathExpressionAtom { - stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)") + _, stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)") if err != nil { panic(err.Error()) } diff --git a/pkg/runtime/function/vm.go b/pkg/runtime/function/vm.go index 847a401e4..7fecaf34d 100644 --- a/pkg/runtime/function/vm.go +++ b/pkg/runtime/function/vm.go @@ -52,9 +52,7 @@ var scripts embed.FS var _decimalRegex = regexp.MustCompile(`^(?P[+\-])?(?P[0-9])+(?P\.[0-9]+)$`) -var ( - freeList = make(chan *VM, 16) -) +var freeList = make(chan *VM, 16) const FuncUnary = "__unary" diff --git a/pkg/runtime/function/vm_test.go b/pkg/runtime/function/vm_test.go index 3b41e65e4..ff9388e9f 100644 --- a/pkg/runtime/function/vm_test.go +++ b/pkg/runtime/function/vm_test.go @@ -58,7 +58,6 @@ func TestScripts_Time(t *testing.T) { v, err = testFunc("DATE_FORMAT", now, "%Y-%m-%d") assert.NoError(t, err) t.Log("DATE_FORMAT:", v) - } func TestTime_Month(t *testing.T) { diff --git a/pkg/runtime/namespace/namespace.go b/pkg/runtime/namespace/namespace.go index fe0b08aaa..ceca48015 100644 --- a/pkg/runtime/namespace/namespace.go +++ b/pkg/runtime/namespace/namespace.go @@ -76,8 +76,7 @@ type ( name string // the name of Namespace - rule atomic.Value // *rule.Rule - optimizer proto.Optimizer + rule atomic.Value // *rule.Rule // datasource map, eg: employee_0001 -> [mysql-a,mysql-b,mysql-c], ... employee_0007 -> [mysql-x,mysql-y,mysql-z] dss atomic.Value // map[string][]proto.DB @@ -91,12 +90,11 @@ type ( ) // New creates a Namespace. -func New(name string, optimizer proto.Optimizer, commands ...Command) *Namespace { +func New(name string, commands ...Command) *Namespace { ns := &Namespace{ - name: name, - optimizer: optimizer, - cmds: make(chan Command, 1), - done: make(chan struct{}), + name: name, + cmds: make(chan Command, 1), + done: make(chan struct{}), } ns.dss.Store(make(map[string][]proto.DB)) // init empty map ns.rule.Store(&rule.Rule{}) // init empty rule @@ -153,7 +151,6 @@ func (ns *Namespace) DB(ctx context.Context, group string) proto.DB { for _, db := range exist { wrList = append(wrList, int(db.Weight().R)) } - } else if rcontext.IsWrite(ctx) { for _, db := range exist { wrList = append(wrList, int(db.Weight().W)) @@ -166,9 +163,54 @@ func (ns *Namespace) DB(ctx context.Context, group string) proto.DB { return exist[target] } -// Optimizer returns the optimizer. -func (ns *Namespace) Optimizer() proto.Optimizer { - return ns.optimizer +// DBMaster returns a master DB, returns nil if nothing selected. +func (ns *Namespace) DBMaster(_ context.Context, group string) proto.DB { + // use weight manager to select datasource + dss := ns.dss.Load().(map[string][]proto.DB) + exist, ok := dss[group] + if !ok { + return nil + } + // master weight w>0 && r>0 + for _, db := range exist { + if db.Weight().W > 0 && db.Weight().R > 0 { + return db + } + } + return nil +} + +// DBSlave returns a slave DB, returns nil if nothing selected. +func (ns *Namespace) DBSlave(_ context.Context, group string) proto.DB { + // use weight manager to select datasource + dss := ns.dss.Load().(map[string][]proto.DB) + exist, ok := dss[group] + if !ok { + return nil + } + var ( + target = 0 + wrList = make([]int, 0, len(exist)) + readDBList = make([]proto.DB, 0, len(exist)) + ) + // slave weight w==0 && r>=0 + for _, db := range exist { + if db.Weight().W != 0 { + continue + } + // r==0 has high priority + if db.Weight().R == 0 { + return db + } + if db.Weight().R > 0 { + wrList = append(wrList, int(db.Weight().R)) + readDBList = append(readDBList, db) + } + } + if len(wrList) != 0 { + target = selector.NewWeightRandomSelector(wrList).GetDataSourceNo() + } + return readDBList[target] } // Rule returns the sharding rule. diff --git a/pkg/runtime/namespace/namespace_test.go b/pkg/runtime/namespace/namespace_test.go index d579c5ecd..5fb72f347 100644 --- a/pkg/runtime/namespace/namespace_test.go +++ b/pkg/runtime/namespace/namespace_test.go @@ -44,11 +44,7 @@ func TestRegister(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) - - const ( - name = "employees" - ) + const name = "employees" getDB := func(i int) proto.DB { db := testdata.NewMockDB(ctrl) @@ -58,7 +54,7 @@ func TestRegister(t *testing.T) { return db } - err := Register(New(name, opt, UpsertDB(getGroup(0), getDB(1)))) + err := Register(New(name, UpsertDB(getGroup(0), getDB(1)))) assert.NoError(t, err, "should register namespace ok") defer func() { @@ -89,8 +85,6 @@ func TestGetDBByWeight(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) - const ( name = "account" ) @@ -104,7 +98,7 @@ func TestGetDBByWeight(t *testing.T) { } // when doing read operation, db 3 is the max // when doing write operation, db 2 is the max - err := Register(New(name, opt, + err := Register(New(name, UpsertDB(getGroup(0), getDB(1, 9, 1)), UpsertDB(getGroup(0), getDB(2, 10, 5)), UpsertDB(getGroup(0), getDB(3, 3, 10)), diff --git a/pkg/runtime/optimize/dal/show_columns.go b/pkg/runtime/optimize/dal/show_columns.go new file mode 100644 index 000000000..ec39d30f2 --- /dev/null +++ b/pkg/runtime/optimize/dal/show_columns.go @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowColumns, optimizeShowColumns) +} + +func optimizeShowColumns(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.ShowColumns) + + vts := o.Rule.VTables() + vtName := []string(stmt.TableName)[0] + ret := &dal.ShowColumnsPlan{Stmt: stmt} + ret.BindArgs(o.Args) + + if vTable, ok := vts[vtName]; ok { + _, tblName, _ := vTable.Topology().Smallest() + ret.Table = tblName + } + + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_create.go b/pkg/runtime/optimize/dal/show_create.go new file mode 100644 index 000000000..c926187a7 --- /dev/null +++ b/pkg/runtime/optimize/dal/show_create.go @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowCreate, optimizeShowCreate) +} + +func optimizeShowCreate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.ShowCreate) + + if stmt.Type() != ast.ShowCreateTypeTable { + return nil, errors.Errorf("not support SHOW CREATE %s", stmt.Type().String()) + } + + var ( + ret = dal.NewShowCreatePlan(stmt) + table = stmt.Target() + ) + ret.BindArgs(o.Args) + + if vt, ok := o.Rule.VTable(table); ok { + // sharding + topology := vt.Topology() + if d, t, ok := topology.Render(0, 0); ok { + ret.Database = d + ret.Table = t + } else { + return nil, errors.Errorf("failed to render table:%s ", table) + } + } else { + ret.Table = table + } + + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_databases.go b/pkg/runtime/optimize/dal/show_databases.go new file mode 100644 index 000000000..cc59628da --- /dev/null +++ b/pkg/runtime/optimize/dal/show_databases.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowDatabases, optimizeShowDatabases) +} + +func optimizeShowDatabases(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + ret := &dal.ShowDatabasesPlan{Stmt: o.Stmt.(*ast.ShowDatabases)} + ret.BindArgs(o.Args) + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_index.go b/pkg/runtime/optimize/dal/show_index.go new file mode 100644 index 000000000..e75d93d1a --- /dev/null +++ b/pkg/runtime/optimize/dal/show_index.go @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowIndex, optimizeShowIndex) +} + +func optimizeShowIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.ShowIndex) + + ret := &dal.ShowIndexPlan{Stmt: stmt} + ret.BindArgs(o.Args) + + vt, ok := o.Rule.VTable(stmt.TableName.Suffix()) + if !ok { + return ret, nil + } + + shards := rule.DatabaseTables{} + + topology := vt.Topology() + if d, t, ok := topology.Render(0, 0); ok { + shards[d] = append(shards[d], t) + } + ret.Shards = shards + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_open_tables.go b/pkg/runtime/optimize/dal/show_open_tables.go new file mode 100644 index 000000000..e60e1c226 --- /dev/null +++ b/pkg/runtime/optimize/dal/show_open_tables.go @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/namespace" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" + "github.com/arana-db/arana/pkg/runtime/plan/dml" + "github.com/arana-db/arana/pkg/security" + "github.com/arana-db/arana/pkg/transformer" +) + +func init() { + optimize.Register(ast.SQLTypeShowOpenTables, optimizeShowOpenTables) +} + +func optimizeShowOpenTables(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + var invertedIndex map[string]string + for logicalTable, v := range o.Rule.VTables() { + t := v.Topology() + t.Each(func(x, y int) bool { + if _, phyTable, ok := t.Render(x, y); ok { + if invertedIndex == nil { + invertedIndex = make(map[string]string) + } + invertedIndex[phyTable] = logicalTable + } + return true + }) + } + + stmt := o.Stmt.(*ast.ShowOpenTables) + + clusters := security.DefaultTenantManager().GetClusters(rcontext.Tenant(ctx)) + plans := make([]proto.Plan, 0, len(clusters)) + for _, cluster := range clusters { + ns := namespace.Load(cluster) + // 配置里原子库 都需要执行一次 + groups := ns.DBGroups() + for i := 0; i < len(groups); i++ { + ret := dal.NewShowOpenTablesPlan(stmt) + ret.BindArgs(o.Args) + ret.SetInvertedShards(invertedIndex) + ret.SetDatabase(groups[i]) + plans = append(plans, ret) + } + } + + unionPlan := &dml.UnionPlan{ + Plans: plans, + } + + aggregate := &dml.AggregatePlan{ + Plan: unionPlan, + Combiner: transformer.NewCombinerManager(), + AggrLoader: transformer.LoadAggrs(nil), + } + + return aggregate, nil +} diff --git a/pkg/runtime/optimize/dal/show_tables.go b/pkg/runtime/optimize/dal/show_tables.go new file mode 100644 index 000000000..778a7f784 --- /dev/null +++ b/pkg/runtime/optimize/dal/show_tables.go @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowTables, optimizeShowTables) +} + +func optimizeShowTables(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.ShowTables) + var invertedIndex map[string]string + for logicalTable, v := range o.Rule.VTables() { + t := v.Topology() + t.Each(func(x, y int) bool { + if _, phyTable, ok := t.Render(x, y); ok { + if invertedIndex == nil { + invertedIndex = make(map[string]string) + } + invertedIndex[phyTable] = logicalTable + } + return true + }) + } + + ret := dal.NewShowTablesPlan(stmt) + ret.BindArgs(o.Args) + ret.SetInvertedShards(invertedIndex) + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_topology.go b/pkg/runtime/optimize/dal/show_topology.go new file mode 100644 index 000000000..a03d0d5d6 --- /dev/null +++ b/pkg/runtime/optimize/dal/show_topology.go @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowTopology, optimizeShowTopology) +} + +func optimizeShowTopology(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + rule := o.Rule + stmt := o.Stmt.(*ast.ShowTopology) + ret := dal.NewShowTopologyPlan(stmt) + ret.BindArgs(o.Args) + ret.SetRule(rule) + return ret, nil +} diff --git a/pkg/runtime/optimize/dal/show_variables.go b/pkg/runtime/optimize/dal/show_variables.go new file mode 100644 index 000000000..5cd6aedbf --- /dev/null +++ b/pkg/runtime/optimize/dal/show_variables.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dal" +) + +func init() { + optimize.Register(ast.SQLTypeShowVariables, optimizeShowVariables) +} + +func optimizeShowVariables(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + ret := dal.NewShowVariablesPlan(o.Stmt.(*ast.ShowVariables)) + ret.BindArgs(o.Args) + return ret, nil +} diff --git a/pkg/runtime/optimize/ddl/alter_table.go b/pkg/runtime/optimize/ddl/alter_table.go new file mode 100644 index 000000000..9891e9fe7 --- /dev/null +++ b/pkg/runtime/optimize/ddl/alter_table.go @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" +) + +func init() { + optimize.Register(ast.SQLTypeAlterTable, optimizeAlterTable) +} + +func optimizeAlterTable(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + var ( + stmt = o.Stmt.(*ast.AlterTableStatement) + ret = ddl.NewAlterTablePlan(stmt) + table = stmt.Table + vt *rule.VTable + ok bool + ) + ret.BindArgs(o.Args) + + // non-sharding update + if vt, ok = o.Rule.VTable(table.Suffix()); !ok { + return ret, nil + } + + //TODO alter table table or column to new name , should update sharding info + + // exit if full-scan is disabled + if !vt.AllowFullScan() { + return nil, optimize.ErrDenyFullScan + } + + // sharding + ret.Shards = vt.Topology().Enumerate() + return ret, nil +} diff --git a/pkg/runtime/optimize/ddl/create_index.go b/pkg/runtime/optimize/ddl/create_index.go new file mode 100644 index 000000000..f29a17db8 --- /dev/null +++ b/pkg/runtime/optimize/ddl/create_index.go @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" +) + +func init() { + optimize.Register(ast.SQLTypeCreateIndex, optimizeCreateIndex) +} + +func optimizeCreateIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.CreateIndexStatement) + ret := ddl.NewCreateIndexPlan(stmt) + vt, ok := o.Rule.VTable(stmt.Table.Suffix()) + + // table shard + if !ok { + return ret, nil + } + + ret.SetShard(vt.Topology().Enumerate()) + return ret, nil +} diff --git a/pkg/runtime/optimize/ddl/drop_index.go b/pkg/runtime/optimize/ddl/drop_index.go new file mode 100644 index 000000000..e5dd7d0ce --- /dev/null +++ b/pkg/runtime/optimize/ddl/drop_index.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" +) + +func init() { + optimize.Register(ast.SQLTypeDropIndex, optimizeDropIndex) +} + +func optimizeDropIndex(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.DropIndexStatement) + //table shard + + shard, err := o.ComputeShards(stmt.Table, nil, o.Args) + if err != nil { + return nil, err + } + if len(shard) == 0 { + return plan.Transparent(stmt, o.Args), nil + } + + shardPlan := ddl.NewDropIndexPlan(stmt) + shardPlan.SetShard(shard) + shardPlan.BindArgs(o.Args) + return shardPlan, nil +} diff --git a/pkg/runtime/optimize/ddl/drop_table.go b/pkg/runtime/optimize/ddl/drop_table.go new file mode 100644 index 000000000..1d179b8d4 --- /dev/null +++ b/pkg/runtime/optimize/ddl/drop_table.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" + "github.com/arana-db/arana/pkg/runtime/plan/dml" +) + +func init() { + optimize.Register(ast.SQLTypeDropTable, optimizeDropTable) +} + +func optimizeDropTable(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.DropTableStatement) + //table shard + var shards []rule.DatabaseTables + //tables not shard + noShardStmt := ast.NewDropTableStatement() + for _, table := range stmt.Tables { + shard, err := o.ComputeShards(*table, nil, o.Args) + if err != nil { + return nil, err + } + if shard == nil { + noShardStmt.Tables = append(noShardStmt.Tables, table) + continue + } + shards = append(shards, shard) + } + + shardPlan := ddl.NewDropTablePlan(stmt) + shardPlan.BindArgs(o.Args) + shardPlan.SetShards(shards) + + if len(noShardStmt.Tables) == 0 { + return shardPlan, nil + } + + noShardPlan := plan.Transparent(noShardStmt, o.Args) + + return &dml.UnionPlan{ + Plans: []proto.Plan{ + noShardPlan, shardPlan, + }, + }, nil +} diff --git a/pkg/runtime/optimize/ddl/drop_trigger.go b/pkg/runtime/optimize/ddl/drop_trigger.go new file mode 100644 index 000000000..74e1d1de4 --- /dev/null +++ b/pkg/runtime/optimize/ddl/drop_trigger.go @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" +) + +func init() { + optimize.Register(ast.SQLTypeDropTrigger, optimizeTrigger) +} + +func optimizeTrigger(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + shards := rule.DatabaseTables{} + for _, table := range o.Rule.VTables() { + shards = table.Topology().Enumerate() + break + } + + ret := &ddl.DropTriggerPlan{Stmt: o.Stmt.(*ast.DropTriggerStatement), Shards: shards} + ret.BindArgs(o.Args) + return ret, nil +} diff --git a/pkg/runtime/optimize/ddl/truncate.go b/pkg/runtime/optimize/ddl/truncate.go new file mode 100644 index 000000000..6e3cbbc77 --- /dev/null +++ b/pkg/runtime/optimize/ddl/truncate.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" +) + +func init() { + optimize.Register(ast.SQLTypeTruncate, optimizeTruncate) +} + +func optimizeTruncate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.TruncateStatement) + shards, err := o.ComputeShards(stmt.Table, nil, o.Args) + if err != nil { + return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement") + } + + if shards == nil { + return plan.Transparent(stmt, o.Args), nil + } + + ret := ddl.NewTruncatePlan(stmt) + ret.BindArgs(o.Args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/optimize/dml/delete.go b/pkg/runtime/optimize/dml/delete.go new file mode 100644 index 000000000..c9ee9c2ce --- /dev/null +++ b/pkg/runtime/optimize/dml/delete.go @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/runtime/plan/dml" +) + +func init() { + optimize.Register(ast.SQLTypeDelete, optimizeDelete) +} + +func optimizeDelete(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + var ( + stmt = o.Stmt.(*ast.DeleteStatement) + ) + + shards, err := o.ComputeShards(stmt.Table, stmt.Where, o.Args) + if err != nil { + return nil, errors.Wrap(err, "failed to optimize DELETE statement") + } + + // TODO: delete from a child sharding-table directly + + if shards == nil { + return plan.Transparent(stmt, o.Args), nil + } + + ret := dml.NewSimpleDeletePlan(stmt) + ret.BindArgs(o.Args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/optimize/dml/insert.go b/pkg/runtime/optimize/dml/insert.go new file mode 100644 index 000000000..3684dcd67 --- /dev/null +++ b/pkg/runtime/optimize/dml/insert.go @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/cmp" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dml" +) + +func init() { + optimize.Register(ast.SQLTypeInsert, optimizeInsert) + optimize.Register(ast.SQLTypeInsertSelect, optimizeInsertSelect) +} + +func optimizeInsert(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + ret := dml.NewSimpleInsertPlan() + ret.BindArgs(o.Args) + + var ( + stmt = o.Stmt.(*ast.InsertStatement) + vt *rule.VTable + ok bool + ) + + if vt, ok = o.Rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table + ret.Put("", stmt) + return ret, nil + } + + // TODO: handle multiple shard keys. + + bingo := -1 + // check existing shard columns + for i, col := range stmt.Columns() { + if _, _, ok = vt.GetShardMetadata(col); ok { + bingo = i + break + } + } + + if bingo < 0 { + return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert") + } + + //check on duplicated key update + for _, upd := range stmt.DuplicatedUpdates() { + if upd.Column.Suffix() == stmt.Columns()[bingo] { + return nil, errors.New("do not support update sharding key") + } + } + + var ( + sharder = (*optimize.Sharder)(o.Rule) + left = ast.ColumnNameExpressionAtom(make([]string, 1)) + filter = &ast.PredicateExpressionNode{ + P: &ast.BinaryComparisonPredicateNode{ + Left: &ast.AtomPredicateNode{ + A: left, + }, + Op: cmp.Ceq, + }, + } + slots = make(map[string]map[string][]int) // (db,table,valuesIndex) + ) + + // reset filter + resetFilter := func(column string, value ast.ExpressionNode) { + left[0] = column + filter.P.(*ast.BinaryComparisonPredicateNode).Right = value.(*ast.PredicateExpressionNode).P + } + + for i, values := range stmt.Values() { + value := values[bingo] + resetFilter(stmt.Columns()[bingo], value) + + shards, _, err := sharder.Shard(stmt.Table(), filter, o.Args...) + + if err != nil { + return nil, errors.WithStack(err) + } + + if shards.Len() != 1 { + return nil, errors.Wrap(optimize.ErrNoShardKeyFound, "failed to insert") + } + + var ( + db string + table string + ) + + for k, v := range shards { + db = k + table = v[0] + break + } + + if _, ok = slots[db]; !ok { + slots[db] = make(map[string][]int) + } + slots[db][table] = append(slots[db][table], i) + } + + _, tb0, _ := vt.Topology().Smallest() + + for db, slot := range slots { + for table, indexes := range slot { + // clone insert stmt without values + newborn := ast.NewInsertStatement(ast.TableName{table}, stmt.Columns()) + newborn.SetFlag(stmt.Flag()) + newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates()) + + // collect values with same table + values := make([][]ast.ExpressionNode, 0, len(indexes)) + for _, i := range indexes { + values = append(values, stmt.Values()[i]) + } + newborn.SetValues(values) + + rewriteInsertStatement(ctx, newborn, tb0) + ret.Put(db, newborn) + } + } + + return ret, nil +} + +func optimizeInsertSelect(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.InsertSelectStatement) + + ret := dml.NewInsertSelectPlan() + + ret.BindArgs(o.Args) + + if _, ok := o.Rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table + ret.Batch[""] = stmt + return ret, nil + } + + // TODO: handle shard keys. + + return nil, errors.New("not support insert-select into sharding table") +} + +func rewriteInsertStatement(ctx context.Context, stmt *ast.InsertStatement, tb string) error { + metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb}) + if err != nil { + return errors.WithStack(err) + } + metadata := metadatas[tb] + if metadata == nil || len(metadata.ColumnNames) == 0 { + return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) + } + + if len(metadata.ColumnNames) == len(stmt.Columns()) { + // User had explicitly specified every value + return nil + } + columnsMetadata := metadata.Columns + + for _, colName := range stmt.Columns() { + if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated { + // User had explicitly specified auto-generated primary key column + return nil + } + } + + pkColName := "" + for name, column := range columnsMetadata { + if column.PrimaryKey && column.Generated { + pkColName = name + break + } + } + if len(pkColName) < 1 { + // There's no auto-generated primary key column + return nil + } + + // TODO rewrite columns and add distributed primary key + //stmt.SetColumns(append(stmt.Columns(), pkColName)) + // append value of distributed primary key + //newValues := stmt.Values() + //for _, newValue := range newValues { + // newValue = append(newValue, ) + //} + return nil +} diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go new file mode 100644 index 000000000..d9cf57eb9 --- /dev/null +++ b/pkg/runtime/optimize/dml/select.go @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/merge/aggregator" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/dml" + "github.com/arana-db/arana/pkg/transformer" + "github.com/arana-db/arana/pkg/util/log" +) + +const ( + _bypass uint32 = 1 << iota + _supported +) + +func init() { + optimize.Register(ast.SQLTypeSelect, optimizeSelect) +} + +func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.SelectStatement) + + // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be + // `select * from student offset 0 limit 100+5` + originOffset, newLimit := overwriteLimit(stmt, &o.Args) + if stmt.HasJoin() { + return optimizeJoin(o, stmt) + } + flag := getSelectFlag(o.Rule, stmt) + if flag&_supported == 0 { + return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) + } + + if flag&_bypass != 0 { + if len(stmt.From) > 0 { + err := rewriteSelectStatement(ctx, stmt, stmt.From[0].TableName().Suffix()) + if err != nil { + return nil, err + } + } + ret := &dml.SimpleQueryPlan{Stmt: stmt} + ret.BindArgs(o.Args) + return ret, nil + } + + var ( + shards rule.DatabaseTables + fullScan bool + err error + vt = o.Rule.MustVTable(stmt.From[0].TableName().Suffix()) + ) + + if shards, fullScan, err = (*optimize.Sharder)(o.Rule).Shard(stmt.From[0].TableName(), stmt.Where, o.Args...); err != nil { + return nil, errors.Wrap(err, "calculate shards failed") + } + + log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) + + // return error if full-scan is disabled + if fullScan && !vt.AllowFullScan() { + return nil, errors.WithStack(optimize.ErrDenyFullScan) + } + + toSingle := func(db, tbl string) (proto.Plan, error) { + _, tb0, _ := vt.Topology().Smallest() + if err := rewriteSelectStatement(ctx, stmt, tb0); err != nil { + return nil, err + } + ret := &dml.SimpleQueryPlan{ + Stmt: stmt, + Database: db, + Tables: []string{tbl}, + } + ret.BindArgs(o.Args) + + return ret, nil + } + + // Go through first table if no shards matched. + // For example: + // SELECT ... FROM xxx WHERE a > 8 and a < 4 + if shards.IsEmpty() { + var ( + db0, tbl0 string + ok bool + ) + if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok { + return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix()) + } + + return toSingle(db0, tbl0) + } + + // Handle single shard + if shards.Len() == 1 { + var db, tbl string + for k, v := range shards { + db = k + tbl = v[0] + } + return toSingle(db, tbl) + } + + // Handle multiple shards + + if shards.IsFullScan() { // expand all shards if all shards matched + shards = vt.Topology().Enumerate() + } + + plans := make([]proto.Plan, 0, len(shards)) + for k, v := range shards { + next := &dml.SimpleQueryPlan{ + Database: k, + Tables: v, + Stmt: stmt, + } + next.BindArgs(o.Args) + plans = append(plans, next) + } + + if len(plans) > 0 { + _, tb, _ := vt.Topology().Smallest() + if err = rewriteSelectStatement(ctx, stmt, tb); err != nil { + return nil, errors.WithStack(err) + } + } + + var tmpPlan proto.Plan + tmpPlan = &dml.UnionPlan{ + Plans: plans, + } + + if stmt.Limit != nil { + tmpPlan = &dml.LimitPlan{ + ParentPlan: tmpPlan, + OriginOffset: originOffset, + OverwriteLimit: newLimit, + } + } + + orderByItems := optimizeOrderBy(stmt) + + if stmt.OrderBy != nil { + tmpPlan = &dml.OrderPlan{ + ParentPlan: tmpPlan, + OrderByItems: orderByItems, + } + } + + convertOrderByItems := func(origins []*ast.OrderByItem) []dataset.OrderByItem { + var result = make([]dataset.OrderByItem, 0, len(origins)) + for _, origin := range origins { + var columnName string + if cn, ok := origin.Expr.(ast.ColumnNameExpressionAtom); ok { + columnName = cn.Suffix() + } + result = append(result, dataset.OrderByItem{ + Column: columnName, + Desc: origin.Desc, + }) + } + return result + } + if stmt.GroupBy != nil { + return &dml.GroupPlan{ + Plan: tmpPlan, + AggItems: aggregator.LoadAggs(stmt.Select), + OrderByItems: convertOrderByItems(stmt.OrderBy), + }, nil + } else { + // TODO: refactor groupby/orderby/aggregate plan to a unified plan + return &dml.AggregatePlan{ + Plan: tmpPlan, + Combiner: transformer.NewCombinerManager(), + AggrLoader: transformer.LoadAggrs(stmt.Select), + }, nil + } +} + +//optimizeJoin ony support a join b in one db +func optimizeJoin(o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { + join := stmt.From[0].Source().(*ast.JoinNode) + + compute := func(tableSource *ast.TableSourceNode) (database, alias string, shardList []string, err error) { + table := tableSource.TableName() + if table == nil { + err = errors.New("must table, not statement or join node") + return + } + alias = tableSource.Alias() + database = table.Prefix() + + shards, err := o.ComputeShards(table, nil, o.Args) + if err != nil { + return + } + //table no shard + if shards == nil { + shardList = append(shardList, table.Suffix()) + return + } + //table shard more than one db + if len(shards) > 1 { + err = errors.New("not support more than one db") + return + } + + for k, v := range shards { + database = k + shardList = v + } + + if alias == "" { + alias = table.Suffix() + } + + return + } + + dbLeft, aliasLeft, shardLeft, err := compute(join.Left) + if err != nil { + return nil, err + } + dbRight, aliasRight, shardRight, err := compute(join.Right) + + if err != nil { + return nil, err + } + + if dbLeft != "" && dbRight != "" && dbLeft != dbRight { + return nil, errors.New("not support more than one db") + } + + joinPan := &dml.SimpleJoinPlan{ + Left: &dml.JoinTable{ + Tables: shardLeft, + Alias: aliasLeft, + }, + Join: join, + Right: &dml.JoinTable{ + Tables: shardRight, + Alias: aliasRight, + }, + Stmt: o.Stmt.(*ast.SelectStatement), + } + joinPan.BindArgs(o.Args) + + return joinPan, nil +} + +func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) { + switch len(stmt.From) { + case 1: + from := stmt.From[0] + tn := from.TableName() + + if tn == nil { // only FROM table supported now + return + } + + flag |= _supported + + if len(tn) > 1 { + switch strings.ToLower(tn.Prefix()) { + case "mysql", "information_schema": + flag |= _bypass + return + } + } + if !ru.Has(tn.Suffix()) { + flag |= _bypass + } + case 0: + flag |= _bypass + flag |= _supported + } + return +} + +func optimizeOrderBy(stmt *ast.SelectStatement) []dataset.OrderByItem { + if stmt == nil || stmt.OrderBy == nil { + return nil + } + result := make([]dataset.OrderByItem, 0, len(stmt.OrderBy)) + for _, node := range stmt.OrderBy { + column, _ := node.Expr.(ast.ColumnNameExpressionAtom) + item := dataset.OrderByItem{ + Column: column[0], + Desc: node.Desc, + } + result = append(result, item) + } + return result +} + +func overwriteLimit(stmt *ast.SelectStatement, args *[]interface{}) (originOffset, overwriteLimit int64) { + if stmt == nil || stmt.Limit == nil { + return 0, 0 + } + + offset := stmt.Limit.Offset() + limit := stmt.Limit.Limit() + + // SELECT * FROM student where uid = ? limit ? offset ? + var offsetIndex int64 + var limitIndex int64 + + if stmt.Limit.IsOffsetVar() { + offsetIndex = offset + offset = (*args)[offsetIndex].(int64) + + if !stmt.Limit.IsLimitVar() { + limit = stmt.Limit.Limit() + *args = append(*args, limit) + limitIndex = int64(len(*args) - 1) + } + } + originOffset = offset + + if stmt.Limit.IsLimitVar() { + limitIndex = limit + limit = (*args)[limitIndex].(int64) + + if !stmt.Limit.IsOffsetVar() { + *args = append(*args, int64(0)) + offsetIndex = int64(len(*args) - 1) + } + } + + if stmt.Limit.IsLimitVar() || stmt.Limit.IsOffsetVar() { + if !stmt.Limit.IsLimitVar() { + stmt.Limit.SetLimitVar() + stmt.Limit.SetLimit(limitIndex) + } + if !stmt.Limit.IsOffsetVar() { + stmt.Limit.SetOffsetVar() + stmt.Limit.SetOffset(offsetIndex) + } + + newLimitVar := limit + offset + overwriteLimit = newLimitVar + (*args)[limitIndex] = newLimitVar + (*args)[offsetIndex] = int64(0) + return + } + + stmt.Limit.SetOffset(0) + stmt.Limit.SetLimit(offset + limit) + overwriteLimit = offset + limit + return +} + +func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb string) error { + // todo db 计算逻辑&tb shard 的计算逻辑 + var starExpand = false + if len(stmt.Select) == 1 { + if _, ok := stmt.Select[0].(*ast.SelectElementAll); ok { + starExpand = true + } + } + + if !starExpand { + return nil + } + + if len(tb) < 1 { + tb = stmt.From[0].TableName().Suffix() + } + metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb}) + if err != nil { + return errors.WithStack(err) + } + metadata := metadatas[tb] + if metadata == nil || len(metadata.ColumnNames) == 0 { + return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) + } + selectElements := make([]ast.SelectElement, len(metadata.Columns)) + for i, column := range metadata.ColumnNames { + selectElements[i] = ast.NewSelectElementColumn([]string{column}, "") + } + stmt.Select = selectElements + + return nil +} diff --git a/pkg/runtime/optimize/dml/update.go b/pkg/runtime/optimize/dml/update.go new file mode 100644 index 000000000..53bdf4632 --- /dev/null +++ b/pkg/runtime/optimize/dml/update.go @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/runtime/plan/dml" +) + +func init() { + optimize.Register(ast.SQLTypeUpdate, optimizeUpdate) +} + +func optimizeUpdate(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + var ( + stmt = o.Stmt.(*ast.UpdateStatement) + table = stmt.Table + vt *rule.VTable + ok bool + ) + + // non-sharding update + if vt, ok = o.Rule.VTable(table.Suffix()); !ok { + ret := dml.NewUpdatePlan(stmt) + ret.BindArgs(o.Args) + return ret, nil + } + + //check update sharding key + for _, element := range stmt.Updated { + if _, _, ok := vt.GetShardMetadata(element.Column.Suffix()); ok { + return nil, errors.New("do not support update sharding key") + } + } + + var ( + shards rule.DatabaseTables + fullScan = true + err error + ) + + // compute shards + if where := stmt.Where; where != nil { + sharder := (*optimize.Sharder)(o.Rule) + if shards, fullScan, err = sharder.Shard(table, where, o.Args...); err != nil { + return nil, errors.Wrap(err, "failed to update") + } + } + + // exit if full-scan is disabled + if fullScan && !vt.AllowFullScan() { + return nil, optimize.ErrDenyFullScan + } + + // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1) + if shards.IsEmpty() { + return plan.AlwaysEmptyExecPlan{}, nil + } + + // compute all sharding tables + if shards.IsFullScan() { + // compute all tables + shards = vt.Topology().Enumerate() + } + + ret := dml.NewUpdatePlan(stmt) + ret.BindArgs(o.Args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/optimize/optimizer.go b/pkg/runtime/optimize/optimizer.go index 45c341368..a23ed7f58 100644 --- a/pkg/runtime/optimize/optimizer.go +++ b/pkg/runtime/optimize/optimizer.go @@ -19,700 +19,106 @@ package optimize import ( "context" - stdErrors "errors" - "strings" + "errors" ) import ( "github.com/arana-db/parser/ast" - "github.com/pkg/errors" + perrors "github.com/pkg/errors" + + "go.opentelemetry.io/otel" ) import ( "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/proto/rule" - "github.com/arana-db/arana/pkg/proto/schema_manager" rast "github.com/arana-db/arana/pkg/runtime/ast" - "github.com/arana-db/arana/pkg/runtime/cmp" rcontext "github.com/arana-db/arana/pkg/runtime/context" - "github.com/arana-db/arana/pkg/runtime/plan" - "github.com/arana-db/arana/pkg/transformer" "github.com/arana-db/arana/pkg/util/log" ) -var _ proto.Optimizer = (*optimizer)(nil) +var _ proto.Optimizer = (*Optimizer)(nil) + +var Tracer = otel.Tracer("optimize") // errors group var ( - errNoRuleFound = stdErrors.New("no rule found") - errDenyFullScan = stdErrors.New("the full-scan query is not allowed") - errNoShardKeyFound = stdErrors.New("no shard key found") + ErrNoRuleFound = errors.New("optimize: no rule found") + ErrDenyFullScan = errors.New("optimize: the full-scan query is not allowed") + ErrNoShardKeyFound = errors.New("optimize: no shard key found") ) // IsNoShardKeyFoundErr returns true if target error is caused by NO-SHARD-KEY-FOUND func IsNoShardKeyFoundErr(err error) bool { - return errors.Is(err, errNoShardKeyFound) + return perrors.Is(err, ErrNoShardKeyFound) } // IsNoRuleFoundErr returns true if target error is caused by NO-RULE-FOUND. func IsNoRuleFoundErr(err error) bool { - return errors.Is(err, errNoRuleFound) + return perrors.Is(err, ErrNoRuleFound) } // IsDenyFullScanErr returns true if target error is caused by DENY-FULL-SCAN. func IsDenyFullScanErr(err error) bool { - return errors.Is(err, errDenyFullScan) -} - -func GetOptimizer() proto.Optimizer { - return optimizer{ - schemaLoader: &schema_manager.SimpleSchemaLoader{}, - } -} - -type optimizer struct { - schemaLoader proto.SchemaLoader -} - -func (o *optimizer) SetSchemaLoader(schemaLoader proto.SchemaLoader) { - o.schemaLoader = schemaLoader -} - -func (o *optimizer) SchemaLoader() proto.SchemaLoader { - return o.schemaLoader + return perrors.Is(err, ErrDenyFullScan) } -func (o optimizer) Optimize(ctx context.Context, conn proto.VConn, stmt ast.StmtNode, args ...interface{}) (plan proto.Plan, err error) { - defer func() { - if rec := recover(); rec != nil { - err = errors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx)) - log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec) - } - }() - - var rstmt rast.Statement - if rstmt, err = rast.FromStmtNode(stmt); err != nil { - return nil, errors.Wrap(err, "optimize failed") - } - return o.doOptimize(ctx, conn, rstmt, args...) -} - -func (o optimizer) doOptimize(ctx context.Context, conn proto.VConn, stmt rast.Statement, args ...interface{}) (proto.Plan, error) { - switch t := stmt.(type) { - case *rast.ShowDatabases: - return o.optimizeShowDatabases(ctx, t, args) - case *rast.SelectStatement: - return o.optimizeSelect(ctx, conn, t, args) - case *rast.InsertStatement: - return o.optimizeInsert(ctx, conn, t, args) - case *rast.DeleteStatement: - return o.optimizeDelete(ctx, t, args) - case *rast.UpdateStatement: - return o.optimizeUpdate(ctx, conn, t, args) - case *rast.ShowTables: - return o.optimizeShowTables(ctx, t, args) - case *rast.TruncateStatement: - return o.optimizeTruncate(ctx, t, args) - case *rast.DropTableStatement: - return o.optimizeDropTable(ctx, t, args) - case *rast.ShowVariables: - return o.optimizeShowVariables(ctx, t, args) - case *rast.DescribeStatement: - return o.optimizeDescribeStatement(ctx, t, args) - case *rast.AlterTableStatement: - return o.optimizeAlterTable(ctx, t, args) - } - - //TODO implement all statements - panic("implement me") -} - -const ( - _bypass uint32 = 1 << iota - _supported +var ( + _handlers = make(map[rast.SQLType]Processor) ) -func (o optimizer) optimizeAlterTable(ctx context.Context, stmt *rast.AlterTableStatement, args []interface{}) (proto.Plan, error) { - var ( - ret = plan.NewAlterTablePlan(stmt) - ru = rcontext.Rule(ctx) - table = stmt.Table - vt *rule.VTable - ok bool - ) - ret.BindArgs(args) - - // non-sharding update - if vt, ok = ru.VTable(table.Suffix()); !ok { - return ret, nil - } - - //TODO alter table table or column to new name , should update sharding info - - // exit if full-scan is disabled - if !vt.AllowFullScan() { - return nil, errDenyFullScan - } - - // sharding - shards := rule.DatabaseTables{} - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - ret.Shards = shards - return ret, nil -} - -func (o optimizer) optimizeDropTable(ctx context.Context, stmt *rast.DropTableStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - //table shard - var shards []rule.DatabaseTables - //tables not shard - noShardStmt := rast.NewDropTableStatement() - for _, table := range stmt.Tables { - shard, err := o.computeShards(ru, *table, nil, args) - if err != nil { - return nil, err - } - if shard == nil { - noShardStmt.Tables = append(noShardStmt.Tables, table) - continue - } - shards = append(shards, shard) - } - - shardPlan := plan.NewDropTablePlan(stmt) - shardPlan.BindArgs(args) - shardPlan.SetShards(shards) - - if len(noShardStmt.Tables) == 0 { - return shardPlan, nil - } - - noShardPlan := plan.Transparent(noShardStmt, args) - - return &plan.UnionPlan{ - Plans: []proto.Plan{ - noShardPlan, shardPlan, - }, - }, nil +func Register(t rast.SQLType, h Processor) { + _handlers[t] = h } -func (o optimizer) getSelectFlag(ctx context.Context, stmt *rast.SelectStatement) (flag uint32) { - switch len(stmt.From) { - case 1: - from := stmt.From[0] - tn := from.TableName() +type Processor = func(ctx context.Context, o *Optimizer) (proto.Plan, error) - if tn == nil { // only FROM table supported now - return - } - - flag |= _supported - - if len(tn) > 1 { - switch strings.ToLower(tn.Prefix()) { - case "mysql", "information_schema": - flag |= _bypass - return - } - } - if !rcontext.Rule(ctx).Has(tn.Suffix()) { - flag |= _bypass - } - case 0: - flag |= _bypass - flag |= _supported - } - return -} - -func (o optimizer) optimizeShowDatabases(ctx context.Context, stmt *rast.ShowDatabases, args []interface{}) (proto.Plan, error) { - ret := &plan.ShowDatabasesPlan{Stmt: stmt} - ret.BindArgs(args) - return ret, nil +type Optimizer struct { + Rule *rule.Rule + Hints []*hint.Hint + Stmt rast.Statement + Args []interface{} } -func (o optimizer) optimizeSelect(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) { - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - if stmt.HasJoin() { - return o.optimizeJoin(ctx, conn, stmt, args) - } - flag := o.getSelectFlag(ctx, stmt) - if flag&_supported == 0 { - return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) - } - - if flag&_bypass != 0 { - if len(stmt.From) > 0 { - err := o.rewriteSelectStatement(ctx, conn, stmt, rcontext.DBGroup(ctx), stmt.From[0].TableName().Suffix()) - if err != nil { - return nil, err - } - } - ret := &plan.SimpleQueryPlan{Stmt: stmt} - ret.BindArgs(args) - return ret, nil - } - +func NewOptimizer(rule *rule.Rule, hints []*hint.Hint, stmt ast.StmtNode, args []interface{}) (proto.Optimizer, error) { var ( - shards rule.DatabaseTables - fullScan bool - err error - vt = ru.MustVTable(stmt.From[0].TableName().Suffix()) + rstmt rast.Statement + err error ) - - if shards, fullScan, err = (*Sharder)(ru).Shard(stmt.From[0].TableName(), stmt.Where, args...); err != nil { - return nil, errors.Wrap(err, "calculate shards failed") - } - - log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) - - // return error if full-scan is disabled - if fullScan && !vt.AllowFullScan() { - return nil, errors.WithStack(errDenyFullScan) - } - - toSingle := func(db, tbl string) (proto.Plan, error) { - if err := o.rewriteSelectStatement(ctx, conn, stmt, db, tbl); err != nil { - return nil, err - } - ret := &plan.SimpleQueryPlan{ - Stmt: stmt, - Database: db, - Tables: []string{tbl}, - } - ret.BindArgs(args) - - return ret, nil - } - - // Go through first table if no shards matched. - // For example: - // SELECT ... FROM xxx WHERE a > 8 and a < 4 - if shards.IsEmpty() { - var ( - db0, tbl0 string - ok bool - ) - if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok { - return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix()) - } - - return toSingle(db0, tbl0) - } - - // Handle single shard - if shards.Len() == 1 { - var db, tbl string - for k, v := range shards { - db = k - tbl = v[0] - } - return toSingle(db, tbl) - } - - // Handle multiple shards - - if shards.IsFullScan() { // expand all shards if all shards matched - // init shards - shards = rule.DatabaseTables{} - // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - } - - plans := make([]proto.Plan, 0, len(shards)) - for k, v := range shards { - next := &plan.SimpleQueryPlan{ - Database: k, - Tables: v, - Stmt: stmt, - } - next.BindArgs(args) - plans = append(plans, next) - } - - if len(plans) > 0 { - tempPlan := plans[0].(*plan.SimpleQueryPlan) - if err = o.rewriteSelectStatement(ctx, conn, stmt, tempPlan.Database, tempPlan.Tables[0]); err != nil { - return nil, err - } - } - - unionPlan := &plan.UnionPlan{ - Plans: plans, - } - - // TODO: order/groupBy/aggregate - aggregate := &plan.AggregatePlan{ - UnionPlan: unionPlan, - Combiner: transformer.NewCombinerManager(), - AggrLoader: transformer.LoadAggrs(stmt.Select), - } - - return aggregate, nil -} - -//optimizeJoin ony support a join b in one db -func (o optimizer) optimizeJoin(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) { - - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - - join := stmt.From[0].Source().(*rast.JoinNode) - - compute := func(tableSource *rast.TableSourceNode) (database, alias string, shardList []string, err error) { - table := tableSource.TableName() - if table == nil { - err = errors.New("must table, not statement or join node") - return - } - alias = tableSource.Alias() - database = table.Prefix() - - shards, err := o.computeShards(ru, table, nil, args) - if err != nil { - return - } - //table no shard - if shards == nil { - shardList = append(shardList, table.Suffix()) - return - } - //table shard more than one db - if len(shards) > 1 { - err = errors.New("not support more than one db") - return - } - - for k, v := range shards { - database = k - shardList = v - } - - if alias == "" { - alias = table.Suffix() - } - - return - } - - dbLeft, aliasLeft, shardLeft, err := compute(join.Left) - if err != nil { - return nil, err - } - dbRight, aliasRight, shardRight, err := compute(join.Right) - - if err != nil { - return nil, err - } - - if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - return nil, errors.New("not support more than one db") - } - - joinPan := &plan.SimpleJoinPlan{ - Left: &plan.JoinTable{ - Tables: shardLeft, - Alias: aliasLeft, - }, - Join: join, - Right: &plan.JoinTable{ - Tables: shardRight, - Alias: aliasRight, - }, - Stmt: stmt, - } - joinPan.BindArgs(args) - - return joinPan, nil -} - -func (o optimizer) optimizeUpdate(ctx context.Context, conn proto.VConn, stmt *rast.UpdateStatement, args []interface{}) (proto.Plan, error) { - var ( - ru = rcontext.Rule(ctx) - table = stmt.Table - vt *rule.VTable - ok bool - ) - - // non-sharding update - if vt, ok = ru.VTable(table.Suffix()); !ok { - ret := plan.NewUpdatePlan(stmt) - ret.BindArgs(args) - return ret, nil - } - - var ( - shards rule.DatabaseTables - fullScan = true - err error - ) - - // compute shards - if where := stmt.Where; where != nil { - sharder := (*Sharder)(ru) - if shards, fullScan, err = sharder.Shard(table, where, args...); err != nil { - return nil, errors.Wrap(err, "failed to update") - } - } - - // exit if full-scan is disabled - if fullScan && !vt.AllowFullScan() { - return nil, errDenyFullScan - } - - // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1) - if shards.IsEmpty() { - return plan.AlwaysEmptyExecPlan{}, nil - } - - // compute all sharding tables - if shards.IsFullScan() { - // init shards - shards = rule.DatabaseTables{} - // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - } - - ret := plan.NewUpdatePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil -} - -func (o optimizer) optimizeInsert(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement, args []interface{}) (proto.Plan, error) { - var ( - ru = rcontext.Rule(ctx) - ret = plan.NewSimpleInsertPlan() - ) - - ret.BindArgs(args) - - var ( - vt *rule.VTable - ok bool - ) - - if vt, ok = ru.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table - ret.Put("", stmt) - return ret, nil - } - - // TODO: handle multiple shard keys. - - bingo := -1 - // check existing shard columns - for i, col := range stmt.Columns() { - if _, _, ok = vt.GetShardMetadata(col); ok { - bingo = i - break - } - } - - if bingo < 0 { - return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") - } - - var ( - sharder = (*Sharder)(ru) - left = rast.ColumnNameExpressionAtom(make([]string, 1)) - filter = &rast.PredicateExpressionNode{ - P: &rast.BinaryComparisonPredicateNode{ - Left: &rast.AtomPredicateNode{ - A: left, - }, - Op: cmp.Ceq, - }, - } - slots = make(map[string]map[string][]int) // (db,table,valuesIndex) - ) - - // reset filter - resetFilter := func(column string, value rast.ExpressionNode) { - left[0] = column - filter.P.(*rast.BinaryComparisonPredicateNode).Right = value.(*rast.PredicateExpressionNode).P - } - - for i, values := range stmt.Values() { - value := values[bingo] - resetFilter(stmt.Columns()[bingo], value) - - shards, _, err := sharder.Shard(stmt.Table(), filter, args...) - - if err != nil { - return nil, errors.WithStack(err) - } - - if shards.Len() != 1 { - return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") - } - - var ( - db string - table string - ) - - for k, v := range shards { - db = k - table = v[0] - break - } - - if _, ok = slots[db]; !ok { - slots[db] = make(map[string][]int) - } - slots[db][table] = append(slots[db][table], i) - } - - for db, slot := range slots { - for table, indexes := range slot { - // clone insert stmt without values - newborn := rast.NewInsertStatement(rast.TableName{table}, stmt.Columns()) - newborn.SetFlag(stmt.Flag()) - newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates()) - - // collect values with same table - values := make([][]rast.ExpressionNode, 0, len(indexes)) - for _, i := range indexes { - values = append(values, stmt.Values()[i]) - } - newborn.SetValues(values) - - o.rewriteInsertStatement(ctx, conn, newborn, db, table) - ret.Put(db, newborn) - } - } - - return ret, nil -} - -func (o optimizer) optimizeDelete(ctx context.Context, stmt *rast.DeleteStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - shards, err := o.computeShards(ru, stmt.Table, stmt.Where, args) - if err != nil { - return nil, errors.Wrap(err, "failed to optimize DELETE statement") - } - - // TODO: delete from a child sharding-table directly - - if shards == nil { - return plan.Transparent(stmt, args), nil + if rstmt, err = rast.FromStmtNode(stmt); err != nil { + return nil, perrors.Wrap(err, "optimize failed") } - ret := plan.NewSimpleDeletePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil + return &Optimizer{ + Rule: rule, + Hints: hints, + Stmt: rstmt, + Args: args, + }, nil } -func (o optimizer) optimizeShowTables(ctx context.Context, stmt *rast.ShowTables, args []interface{}) (proto.Plan, error) { - vts := rcontext.Rule(ctx).VTables() - databaseTablesMap := make(map[string]rule.DatabaseTables, len(vts)) - for tableName, vt := range vts { - shards := rule.DatabaseTables{} - // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - databaseTablesMap[tableName] = shards - } - - tmpPlanData := make(map[string]plan.DatabaseTable) - for showTableName, databaseTables := range databaseTablesMap { - for databaseName, shardingTables := range databaseTables { - for _, shardingTable := range shardingTables { - tmpPlanData[shardingTable] = plan.DatabaseTable{ - Database: databaseName, - TableName: showTableName, - } - } +func (o *Optimizer) Optimize(ctx context.Context) (plan proto.Plan, err error) { + ctx, span := Tracer.Start(ctx, "Optimize") + defer func() { + span.End() + if rec := recover(); rec != nil { + err = perrors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx)) + log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec) } - } - - ret := plan.NewShowTablesPlan(stmt) - ret.BindArgs(args) - ret.SetAllShards(tmpPlanData) - return ret, nil -} - -func (o optimizer) optimizeTruncate(ctx context.Context, stmt *rast.TruncateStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - shards, err := o.computeShards(ru, stmt.Table, nil, args) - if err != nil { - return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement") - } - - if shards == nil { - return plan.Transparent(stmt, args), nil - } - - ret := plan.NewTruncatePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil -} - -func (o optimizer) optimizeShowVariables(ctx context.Context, stmt *rast.ShowVariables, args []interface{}) (proto.Plan, error) { - ret := plan.NewShowVariablesPlan(stmt) - ret.BindArgs(args) - return ret, nil -} - -func (o optimizer) optimizeDescribeStatement(ctx context.Context, stmt *rast.DescribeStatement, args []interface{}) (proto.Plan, error) { - vts := rcontext.Rule(ctx).VTables() - vtName := []string(stmt.Table)[0] - ret := plan.NewDescribePlan(stmt) - ret.BindArgs(args) + }() - if vTable, ok := vts[vtName]; ok { - shards := rule.DatabaseTables{} - // compute all tables - topology := vTable.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - dbName, tblName := shards.Smallest() - ret.Database = dbName - ret.Table = tblName + h, ok := _handlers[o.Stmt.Mode()] + if !ok { + return nil, perrors.Errorf("optimize: no handler found for '%s'", o.Stmt.Mode()) } - return ret, nil + return h(ctx, o) } -func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) { +func (o *Optimizer) ComputeShards(table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) { + ru := o.Rule vt, ok := ru.VTable(table.Suffix()) if !ok { return nil, nil @@ -720,14 +126,14 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast shards, fullScan, err := (*Sharder)(ru).Shard(table, where, args...) if err != nil { - return nil, errors.Wrap(err, "calculate shards failed") + return nil, perrors.Wrapf(err, "optimize: cannot calculate shards of table '%s'", table.Suffix()) } - log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) + //log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) // return error if full-scan is disabled if fullScan && !vt.AllowFullScan() { - return nil, errors.WithStack(errDenyFullScan) + return nil, perrors.WithStack(ErrDenyFullScan) } if shards.IsEmpty() { @@ -735,86 +141,9 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast } if len(shards) == 0 { - // init shards - shards = rule.DatabaseTables{} // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) + shards = vt.Topology().Enumerate() } return shards, nil } - -func (o optimizer) rewriteSelectStatement(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, - db, tb string) error { - // todo db 计算逻辑&tb shard 的计算逻辑 - var starExpand = false - if len(stmt.Select) == 1 { - if _, ok := stmt.Select[0].(*rast.SelectElementAll); ok { - starExpand = true - } - } - if starExpand { - if len(tb) < 1 { - tb = stmt.From[0].TableName().Suffix() - } - metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb] - if metaData == nil || len(metaData.ColumnNames) == 0 { - return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) - } - selectElements := make([]rast.SelectElement, len(metaData.Columns)) - for i, column := range metaData.ColumnNames { - selectElements[i] = rast.NewSelectElementColumn([]string{column}, "") - } - stmt.Select = selectElements - } - - return nil -} - -func (o optimizer) rewriteInsertStatement(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement, - db, tb string) error { - metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb] - if metaData == nil || len(metaData.ColumnNames) == 0 { - return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) - } - - if len(metaData.ColumnNames) == len(stmt.Columns()) { - // User had explicitly specified every value - return nil - } - columnsMetadata := metaData.Columns - - for _, colName := range stmt.Columns() { - if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated { - // User had explicitly specified auto-generated primary key column - return nil - } - } - - pkColName := "" - for name, column := range columnsMetadata { - if column.PrimaryKey && column.Generated { - pkColName = name - break - } - } - if len(pkColName) < 1 { - // There's no auto-generated primary key column - return nil - } - - // TODO rewrite columns and add distributed primary key - //stmt.SetColumns(append(stmt.Columns(), pkColName)) - // append value of distributed primary key - //newValues := stmt.Values() - //for _, newValue := range newValues { - // newValue = append(newValue, ) - //} - return nil -} diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index 80569ce62..00034dbd8 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package optimize +package optimize_test import ( "context" @@ -33,10 +33,14 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" - rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/resultx" + . "github.com/arana-db/arana/pkg/runtime/optimize" + _ "github.com/arana-db/arana/pkg/runtime/optimize/dal" + _ "github.com/arana-db/arana/pkg/runtime/optimize/ddl" + _ "github.com/arana-db/arana/pkg/runtime/optimize/dml" + _ "github.com/arana-db/arana/pkg/runtime/optimize/utility" "github.com/arana-db/arana/testdata" ) @@ -49,21 +53,25 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args) - return &mysql.Result{}, nil + + ds := testdata.NewMockDataset(ctrl) + ds.EXPECT().Fields().Return([]proto.Field{}, nil).AnyTimes() + + return resultx.New(resultx.WithDataset(ds)), nil }). AnyTimes() var ( - sql = "select id, uid from student where uid in (?,?,?)" - ctx = context.Background() - rule = makeFakeRule(ctrl, 8) - opt optimizer + sql = "select id, uid from student where uid in (?,?,?)" + ctx = context.Background() + ru = makeFakeRule(ctrl, 8) ) p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1, 2, 3) + opt, err := NewOptimizer(ru, nil, stmt, []interface{}{1, 2, 3}) + assert.NoError(t, err) + plan, err := opt.Optimize(ctx) assert.NoError(t, err) _, _ = plan.ExecIn(ctx, conn) @@ -117,27 +125,34 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args) fakeId++ - return &mysql.Result{ - AffectedRows: uint64(strings.Count(sql, "?")), - InsertId: fakeId, - }, nil + return resultx.New( + resultx.WithRowsAffected(uint64(strings.Count(sql, "?"))), + resultx.WithLastInsertID(fakeId), + ), nil }). AnyTimes() - loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeStudentMetadata).Times(2) + loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeStudentMetadata, nil).Times(2) + + oldLoader := proto.LoadSchemaLoader() + proto.RegisterSchemaLoader(loader) + defer proto.RegisterSchemaLoader(oldLoader) var ( - ctx = context.Background() - rule = makeFakeRule(ctrl, 8) - opt = optimizer{schemaLoader: loader} + ctx = context.Background() + ru = makeFakeRule(ctrl, 8) ) t.Run("sharding", func(t *testing.T) { + sql := "insert into student(name,uid,age) values('foo',?,18),('bar',?,19),('qux',?,17)" p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 8, 9, 16) // 8,16 -> fake_db_0000, 9 -> fake_db_0001 + opt, err := NewOptimizer(ru, nil, stmt, []interface{}{8, 9, 16}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) // 8,16 -> fake_db_0000, 9 -> fake_db_0001 assert.NoError(t, err) res, err := plan.ExecIn(ctx, conn) @@ -155,7 +170,10 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1) + opt, err := NewOptimizer(ru, nil, stmt, []interface{}{1}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) res, err := plan.ExecIn(ctx, conn) @@ -166,7 +184,6 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { lastInsertId, _ := res.LastInsertId() assert.Equal(t, fakeId, lastInsertId) }) - } func TestOptimizer_OptimizeAlterTable(t *testing.T) { @@ -174,23 +191,21 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { defer ctrl.Finish() conn := testdata.NewMockVConn(ctrl) - loader := testdata.NewMockSchemaLoader(ctrl) conn.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args) - return &mysql.Result{}, nil + return resultx.New(), nil }).AnyTimes() var ( - ctx = context.Background() - opt = optimizer{schemaLoader: loader} - ru rule.Rule - tab rule.VTable - topo rule.Topology + ctx = context.Background() + ru rule.Rule + tab rule.VTable + topology rule.Topology ) - topo.SetRender(func(_ int) string { + topology.SetRender(func(_ int) string { return "fake_db" }, func(i int) string { return fmt.Sprintf("student_%04d", i) @@ -199,8 +214,8 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { for i := 0; i < 8; i++ { tables = append(tables, i) } - topo.SetTopology(0, tables...) - tab.SetTopology(&topo) + topology.SetTopology(0, tables...) + tab.SetTopology(&topology) tab.SetAllowFullScan(true) ru.SetVTable("student", &tab) @@ -210,12 +225,14 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + opt, err := NewOptimizer(&ru, nil, stmt, nil) assert.NoError(t, err) - _, err = plan.ExecIn(ctx, conn) + plan, err := opt.Optimize(ctx) assert.NoError(t, err) + _, err = plan.ExecIn(ctx, conn) + assert.NoError(t, err) }) t.Run("non-sharding", func(t *testing.T) { @@ -224,10 +241,61 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + opt, err := NewOptimizer(&ru, nil, stmt, nil) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) _, err = plan.ExecIn(ctx, conn) assert.NoError(t, err) }) } + +func TestOptimizer_OptimizeInsertSelect(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := testdata.NewMockVConn(ctrl) + + var fakeId uint64 + conn.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args) + fakeId++ + + return resultx.New( + resultx.WithRowsAffected(uint64(strings.Count(sql, "?"))), + resultx.WithLastInsertID(fakeId), + ), nil + }). + AnyTimes() + + var ( + ctx = context.Background() + ru rule.Rule + ) + + ru.SetVTable("student", nil) + + t.Run("non-sharding", func(t *testing.T) { + sql := "insert into employees(name, age) select name,age from employees_tmp limit 10,2" + + p := parser.New() + stmt, _ := p.ParseOneStmt(sql, "", "") + + opt, err := NewOptimizer(&ru, nil, stmt, []interface{}{1}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) + assert.NoError(t, err) + + res, err := plan.ExecIn(ctx, conn) + assert.NoError(t, err) + + affected, _ := res.RowsAffected() + assert.Equal(t, uint64(0), affected) + lastInsertId, _ := res.LastInsertId() + assert.Equal(t, fakeId, lastInsertId) + }) +} diff --git a/pkg/runtime/optimize/sharder.go b/pkg/runtime/optimize/sharder.go index 8dea8fcec..45e56c872 100644 --- a/pkg/runtime/optimize/sharder.go +++ b/pkg/runtime/optimize/sharder.go @@ -37,9 +37,7 @@ import ( rrule "github.com/arana-db/arana/pkg/runtime/rule" ) -var ( - errArgumentOutOfRange = stdErrors.New("argument is out of bounds") -) +var errArgumentOutOfRange = stdErrors.New("argument is out of bounds") // IsErrArgumentOutOfRange returns true if target error is caused by argument out of range. func IsErrArgumentOutOfRange(err error) bool { diff --git a/pkg/runtime/optimize/sharder_test.go b/pkg/runtime/optimize/sharder_test.go index 4a6b6ac8b..229038366 100644 --- a/pkg/runtime/optimize/sharder_test.go +++ b/pkg/runtime/optimize/sharder_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package optimize +package optimize_test import ( "fmt" @@ -34,6 +34,7 @@ import ( import ( "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime/ast" + . "github.com/arana-db/arana/pkg/runtime/optimize" "github.com/arana-db/arana/testdata" ) @@ -56,7 +57,8 @@ func TestShard(t *testing.T) { {"select * from student where uid = if(PI()<3, 1, ?)", []interface{}{0}, []int{0}}, } { t.Run(it.sql, func(t *testing.T) { - stmt := ast.MustParse(it.sql).(*ast.SelectStatement) + _, rawStmt := ast.MustParse(it.sql) + stmt := rawStmt.(*ast.SelectStatement) result, _, err := (*Sharder)(fakeRule).Shard(stmt.From[0].TableName(), stmt.Where, it.args...) assert.NoError(t, err, "shard failed") diff --git a/pkg/runtime/optimize/utility/describe.go b/pkg/runtime/optimize/utility/describe.go new file mode 100644 index 000000000..3c757dc2d --- /dev/null +++ b/pkg/runtime/optimize/utility/describe.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utility + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/utility" +) + +func init() { + optimize.Register(ast.SQLTypeDescribe, optimizeDescribeStatement) +} + +func optimizeDescribeStatement(_ context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.DescribeStatement) + vts := o.Rule.VTables() + vtName := []string(stmt.Table)[0] + ret := utility.NewDescribePlan(stmt) + ret.BindArgs(o.Args) + + if vTable, ok := vts[vtName]; ok { + dbName, tblName, _ := vTable.Topology().Smallest() + ret.Database = dbName + ret.Table = tblName + ret.Column = stmt.Column + } + + return ret, nil +} diff --git a/pkg/runtime/plan/always.go b/pkg/runtime/plan/always.go index a402efc91..90444b96f 100644 --- a/pkg/runtime/plan/always.go +++ b/pkg/runtime/plan/always.go @@ -22,22 +22,19 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" ) var _ proto.Plan = (*AlwaysEmptyExecPlan)(nil) -var _emptyResult mysql.Result - // AlwaysEmptyExecPlan represents an exec plan which affects nothing. -type AlwaysEmptyExecPlan struct { -} +type AlwaysEmptyExecPlan struct{} func (a AlwaysEmptyExecPlan) Type() proto.PlanType { return proto.PlanTypeExec } -func (a AlwaysEmptyExecPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { - return &_emptyResult, nil +func (a AlwaysEmptyExecPlan) ExecIn(_ context.Context, _ proto.VConn) (proto.Result, error) { + return resultx.New(), nil } diff --git a/pkg/runtime/plan/dal/show_columns.go b/pkg/runtime/plan/dal/show_columns.go new file mode 100644 index 000000000..79b9672ca --- /dev/null +++ b/pkg/runtime/plan/dal/show_columns.go @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*ShowColumnsPlan)(nil) + +type ShowColumnsPlan struct { + plan.BasePlan + Stmt *ast.ShowColumns + Table string +} + +func (s *ShowColumnsPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (s *ShowColumnsPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + ) + + if err := s.generate(&sb, &indexes); err != nil { + return nil, errors.Wrap(err, "failed to generate show columns sql") + } + + return conn.Query(ctx, "", sb.String(), s.ToArgs(indexes)...) +} + +func (s *ShowColumnsPlan) generate(sb *strings.Builder, args *[]int) error { + var ( + stmt = *s.Stmt + err error + ) + + if s.Table == "" { + if err = s.Stmt.Restore(ast.RestoreDefault, sb, args); err != nil { + return errors.WithStack(err) + } + return nil + } + + s.resetTable(&stmt, s.Table) + if err = stmt.Restore(ast.RestoreDefault, sb, args); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (s *ShowColumnsPlan) resetTable(dstmt *ast.ShowColumns, table string) { + dstmt.TableName = dstmt.TableName.ResetSuffix(table) +} diff --git a/pkg/runtime/plan/dal/show_create.go b/pkg/runtime/plan/dal/show_create.go new file mode 100644 index 000000000..f9903a4a3 --- /dev/null +++ b/pkg/runtime/plan/dal/show_create.go @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*ShowCreatePlan)(nil) + +type ShowCreatePlan struct { + plan.BasePlan + Stmt *ast.ShowCreate + Database string + Table string +} + +// NewShowCreatePlan create ShowCreate Plan +func NewShowCreatePlan(stmt *ast.ShowCreate) *ShowCreatePlan { + return &ShowCreatePlan{ + Stmt: stmt, + } +} + +func (st *ShowCreatePlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (st *ShowCreatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + res proto.Result + err error + ) + + if err = st.Stmt.ResetTable(st.Table).Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + + var ( + query = sb.String() + args = st.ToArgs(indexes) + target = st.Stmt.Target() + ) + + if res, err = conn.Query(ctx, st.Database, query, args...); err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + // sharding table should be changed of target name + if st.Table != target { + fields, _ := ds.Fields() + ds = dataset.Pipe(ds, + dataset.Map(nil, func(next proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + if next.Scan(dest) != nil { + return next, nil + } + dest[0] = target + dest[1] = strings.Replace(dest[1].(string), st.Table, target, 1) + + if next.IsBinary() { + return rows.NewBinaryVirtualRow(fields, dest), nil + } + return rows.NewTextVirtualRow(fields, dest), nil + }), + ) + } + + return resultx.New(resultx.WithDataset(ds)), nil +} diff --git a/pkg/runtime/plan/show_databases.go b/pkg/runtime/plan/dal/show_databases.go similarity index 63% rename from pkg/runtime/plan/show_databases.go rename to pkg/runtime/plan/dal/show_databases.go index 9aedcc883..ea22c6e85 100644 --- a/pkg/runtime/plan/show_databases.go +++ b/pkg/runtime/plan/dal/show_databases.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dal import ( "context" @@ -26,20 +26,21 @@ import ( ) import ( - fieldType "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/mysql/thead" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/security" ) -var tenantErr = errors.New("current db tenant not fund") - var _ proto.Plan = (*ShowDatabasesPlan)(nil) type ShowDatabasesPlan struct { - basePlan + plan.BasePlan Stmt *ast.ShowDatabases } @@ -48,31 +49,21 @@ func (s *ShowDatabasesPlan) Type() proto.PlanType { } func (s *ShowDatabasesPlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "ShowDatabasesPlan.ExecIn") + defer span.End() tenant, ok := security.DefaultTenantManager().GetTenantOfCluster(rcontext.Schema(ctx)) if !ok { - return nil, tenantErr + return nil, errors.New("no tenant found in current db") } - clusters := security.DefaultTenantManager().GetClusters(tenant) - var rows = make([]proto.Row, 0, len(clusters)) + columns := thead.Database.ToFields() + ds := &dataset.VirtualDataset{ + Columns: columns, + } - for _, cluster := range clusters { - encoded := mysql.PutLengthEncodedString([]byte(cluster)) - rows = append(rows, (&mysql.TextRow{}).Encode([]*proto.Value{ - { - Typ: fieldType.FieldTypeVarString, - Flags: fieldType.NotNullFlag, - Raw: encoded, - Val: cluster, - Len: len(encoded), - }, - }, - []proto.Field{&mysql.Field{}}, nil)) + for _, cluster := range security.DefaultTenantManager().GetClusters(tenant) { + ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(columns, []proto.Value{cluster})) } - return &mysql.Result{ - Fields: []proto.Field{mysql.NewField("Database", fieldType.FieldTypeVarString)}, - Rows: rows, - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(resultx.WithDataset(ds)), nil } diff --git a/pkg/runtime/plan/dal/show_index.go b/pkg/runtime/plan/dal/show_index.go new file mode 100644 index 000000000..088639746 --- /dev/null +++ b/pkg/runtime/plan/dal/show_index.go @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*ShowIndexPlan)(nil) + +type ShowIndexPlan struct { + plan.BasePlan + Stmt *ast.ShowIndex + Shards rule.DatabaseTables +} + +func (s *ShowIndexPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (s *ShowIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + err error + ) + + if s.Shards == nil { + if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + return conn.Query(ctx, "", sb.String(), s.ToArgs(indexes)...) + } + + toTable := s.Stmt.TableName.Suffix() + + db, table := s.Shards.Smallest() + s.Stmt.TableName = ast.TableName{table} + + if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + + query, err := conn.Query(ctx, db, sb.String(), s.ToArgs(indexes)...) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := query.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + fields, err := ds.Fields() + if err != nil { + return nil, errors.WithStack(err) + } + + ds = dataset.Pipe(ds, + dataset.Map(nil, func(next proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + if next.Scan(dest) != nil { + return next, nil + } + dest[0] = toTable + + if next.IsBinary() { + return rows.NewBinaryVirtualRow(fields, dest), nil + } + return rows.NewTextVirtualRow(fields, dest), nil + }), + ) + + return resultx.New(resultx.WithDataset(ds)), nil +} diff --git a/pkg/runtime/plan/dal/show_open_tables.go b/pkg/runtime/plan/dal/show_open_tables.go new file mode 100644 index 000000000..2a4d2cfa2 --- /dev/null +++ b/pkg/runtime/plan/dal/show_open_tables.go @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*ShowOpenTablesPlan)(nil) + +type ShowOpenTablesPlan struct { + plan.BasePlan + Database string + Conn proto.DB + Stmt *ast.ShowOpenTables + invertedShards map[string]string // phy table name -> logical table name +} + +// NewShowOpenTablesPlan create ShowTables Plan +func NewShowOpenTablesPlan(stmt *ast.ShowOpenTables) *ShowOpenTablesPlan { + return &ShowOpenTablesPlan{ + Stmt: stmt, + } +} + +func (st *ShowOpenTablesPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (st *ShowOpenTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + res proto.Result + err error + ) + + if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + + var ( + query = sb.String() + args = st.ToArgs(indexes) + ) + + if res, err = conn.Query(ctx, st.Database, query, args...); err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + fields, _ := ds.Fields() + + // filter duplicates + duplicates := make(map[string]struct{}) + + // 1. convert to logical table name + // 2. filter duplicated table name + ds = dataset.Pipe(ds, + dataset.Map(nil, func(next proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + if next.Scan(dest) != nil { + return next, nil + } + + if logicalTableName, ok := st.invertedShards[dest[1].(string)]; ok { + dest[1] = logicalTableName + } + + if next.IsBinary() { + return rows.NewBinaryVirtualRow(fields, dest), nil + } + return rows.NewTextVirtualRow(fields, dest), nil + }), + dataset.Filter(func(next proto.Row) bool { + var vr rows.VirtualRow + switch val := next.(type) { + case mysql.TextRow, mysql.BinaryRow: + return true + case rows.VirtualRow: + vr = val + default: + return true + } + + tableName := vr.Values()[1].(string) + if _, ok := duplicates[tableName]; ok { + return false + } + duplicates[tableName] = struct{}{} + return true + }), + ) + return resultx.New(resultx.WithDataset(ds)), nil +} + +func (st *ShowOpenTablesPlan) SetDatabase(database string) { + st.Database = database +} + +func (st *ShowOpenTablesPlan) SetInvertedShards(m map[string]string) { + st.invertedShards = m +} diff --git a/pkg/runtime/plan/dal/show_tables.go b/pkg/runtime/plan/dal/show_tables.go new file mode 100644 index 000000000..6df54b164 --- /dev/null +++ b/pkg/runtime/plan/dal/show_tables.go @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "database/sql" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + constant "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var ( + _ proto.Plan = (*ShowTablesPlan)(nil) + + headerPrefix = "Tables_in_" +) + +type ShowTablesPlan struct { + plan.BasePlan + Database string + Stmt *ast.ShowTables + invertedShards map[string]string // phy table name -> logical table name +} + +// NewShowTablesPlan create ShowTables Plan +func NewShowTablesPlan(stmt *ast.ShowTables) *ShowTablesPlan { + return &ShowTablesPlan{ + Stmt: stmt, + } +} + +func (st *ShowTablesPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (st *ShowTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + res proto.Result + err error + ) + ctx, span := plan.Tracer.Start(ctx, "ShowTablesPlan.ExecIn") + defer span.End() + + if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + + var ( + query = sb.String() + args = st.ToArgs(indexes) + ) + + if res, err = conn.Query(ctx, st.Database, query, args...); err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + fields, _ := ds.Fields() + + fields[0] = mysql.NewField(headerPrefix+rcontext.Schema(ctx), constant.FieldTypeVarString) + + // filter duplicates + duplicates := make(map[string]struct{}) + + // 1. convert to logical table name + // 2. filter duplicated table name + ds = dataset.Pipe(ds, + dataset.Map(nil, func(next proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + if next.Scan(dest) != nil { + return next, nil + } + var tableName sql.NullString + _ = tableName.Scan(dest[0]) + dest[0] = tableName.String + + if logicalTableName, ok := st.invertedShards[tableName.String]; ok { + dest[0] = logicalTableName + } + + if next.IsBinary() { + return rows.NewBinaryVirtualRow(fields, dest), nil + } + return rows.NewTextVirtualRow(fields, dest), nil + }), + dataset.Filter(func(next proto.Row) bool { + var vr rows.VirtualRow + switch val := next.(type) { + case mysql.TextRow, mysql.BinaryRow: + return true + case rows.VirtualRow: + vr = val + default: + return true + } + + tableName := vr.Values()[0].(string) + if _, ok := duplicates[tableName]; ok { + return false + } + duplicates[tableName] = struct{}{} + return true + }), + ) + + return resultx.New(resultx.WithDataset(ds)), nil +} + +func (st *ShowTablesPlan) SetDatabase(db string) { + st.Database = db +} + +func (st *ShowTablesPlan) SetInvertedShards(m map[string]string) { + st.invertedShards = m +} diff --git a/pkg/runtime/plan/dal/show_topology.go b/pkg/runtime/plan/dal/show_topology.go new file mode 100644 index 000000000..f1683f6cd --- /dev/null +++ b/pkg/runtime/plan/dal/show_topology.go @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dal + +import ( + "context" + "fmt" + "sort" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/mysql/thead" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var ( + _ proto.Plan = (*ShowTopology)(nil) +) + +type ShowTopology struct { + plan.BasePlan + Stmt *ast.ShowTopology + rule *rule.Rule +} + +// NewShowTopologyPlan create ShowTopology Plan +func NewShowTopologyPlan(stmt *ast.ShowTopology) *ShowTopology { + return &ShowTopology{ + Stmt: stmt, + } +} + +func (st *ShowTopology) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (st *ShowTopology) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + indexes []int + err error + table string + ) + ctx, span := plan.Tracer.Start(ctx, "ShowTopology.ExecIn") + defer span.End() + + if err = st.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { + return nil, errors.WithStack(err) + } + + table = sb.String() + + fields := thead.Topology.ToFields() + + ds := &dataset.VirtualDataset{ + Columns: fields, + } + vtable, ok := st.rule.VTable(table) + if !ok { + return nil, errors.New(fmt.Sprintf("%s do not have %s's topology", rcontext.Schema(ctx), table)) + } + t := vtable.Topology() + t.Each(func(x, y int) bool { + if dbGroup, phyTable, ok := t.Render(x, y); ok { + ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(fields, []proto.Value{ + 0, dbGroup, phyTable, + })) + } + return true + }) + sort.Slice(ds.Rows, func(i, j int) bool { + if ds.Rows[i].(rows.VirtualRow).Values()[1].(string) < ds.Rows[j].(rows.VirtualRow).Values()[1].(string) { + return true + } + return ds.Rows[i].(rows.VirtualRow).Values()[1].(string) == ds.Rows[j].(rows.VirtualRow).Values()[1].(string) && + ds.Rows[i].(rows.VirtualRow).Values()[2].(string) < ds.Rows[j].(rows.VirtualRow).Values()[2].(string) + }) + + for id := 0; id < len(ds.Rows); id++ { + ds.Rows[id].(rows.VirtualRow).Values()[0] = id + } + + return resultx.New(resultx.WithDataset(ds)), nil +} + +func (st *ShowTopology) SetRule(rule *rule.Rule) { + st.rule = rule +} diff --git a/pkg/runtime/plan/variables.go b/pkg/runtime/plan/dal/show_variables.go similarity index 88% rename from pkg/runtime/plan/variables.go rename to pkg/runtime/plan/dal/show_variables.go index b603d9f92..d4bfe87b8 100644 --- a/pkg/runtime/plan/variables.go +++ b/pkg/runtime/plan/dal/show_variables.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dal import ( "context" @@ -29,12 +29,13 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) var _ proto.Plan = (*ShowVariablesPlan)(nil) type ShowVariablesPlan struct { - basePlan + plan.BasePlan stmt *ast.ShowVariables } @@ -47,17 +48,18 @@ func (s *ShowVariablesPlan) Type() proto.PlanType { } func (s *ShowVariablesPlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Result, error) { - var ( sb strings.Builder args []int ) + ctx, span := plan.Tracer.Start(ctx, "ShowVariablesPlan.ExecIn") + defer span.End() if err := s.stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { return nil, errors.Wrap(err, "failed to execute DELETE statement") } - ret, err := vConn.Query(ctx, "", sb.String(), s.toArgs(args)...) + ret, err := vConn.Query(ctx, "", sb.String(), s.ToArgs(args)...) if err != nil { return nil, err } diff --git a/pkg/runtime/plan/alter_table.go b/pkg/runtime/plan/ddl/alter_table.go similarity index 87% rename from pkg/runtime/plan/alter_table.go rename to pkg/runtime/plan/ddl/alter_table.go index c04cd5602..593e2278e 100644 --- a/pkg/runtime/plan/alter_table.go +++ b/pkg/runtime/plan/ddl/alter_table.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package ddl import ( "context" @@ -31,17 +31,18 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/util/log" ) var _ proto.Plan = (*AlterTablePlan)(nil) type AlterTablePlan struct { - basePlan + plan.BasePlan stmt *ast.AlterTableStatement Shards rule.DatabaseTables } @@ -50,18 +51,19 @@ func NewAlterTablePlan(stmt *ast.AlterTableStatement) *AlterTablePlan { return &AlterTablePlan{stmt: stmt} } -func (d *AlterTablePlan) Type() proto.PlanType { +func (at *AlterTablePlan) Type() proto.PlanType { return proto.PlanTypeExec } func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + // TODO: ADD trace in all plan ExecIn if at.Shards == nil { // non-sharding alter table var sb strings.Builder if err := at.stmt.Restore(ast.RestoreDefault, &sb, nil); err != nil { return nil, err } - return conn.Exec(ctx, "", sb.String(), at.args...) + return conn.Exec(ctx, "", sb.String(), at.Args...) } var ( affects = uatomic.NewUint64(0) @@ -93,7 +95,7 @@ func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R return errors.WithStack(err) } - if res, err = conn.Exec(ctx, db, sb.String(), at.toArgs(args)...); err != nil { + if res, err = conn.Exec(ctx, db, sb.String(), at.ToArgs(args)...); err != nil { return errors.WithStack(err) } @@ -118,8 +120,5 @@ func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R log.Debugf("sharding alter table success: batch=%d, affects=%d", cnt.Load(), affects.Load()) - return &mysql.Result{ - AffectedRows: affects.Load(), - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(resultx.WithRowsAffected(affects.Load())), nil } diff --git a/pkg/runtime/plan/ddl/create_index.go b/pkg/runtime/plan/ddl/create_index.go new file mode 100644 index 000000000..c580c0bf6 --- /dev/null +++ b/pkg/runtime/plan/ddl/create_index.go @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +type CreateIndexPlan struct { + plan.BasePlan + stmt *ast.CreateIndexStatement + Shards rule.DatabaseTables +} + +func NewCreateIndexPlan(stmt *ast.CreateIndexStatement) *CreateIndexPlan { + return &CreateIndexPlan{ + stmt: stmt, + } +} + +func (c *CreateIndexPlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (c *CreateIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + args []int + ) + + for db, tables := range c.Shards { + for i := range tables { + table := tables[i] + + stmt := new(ast.CreateIndexStatement) + stmt.Table = ast.TableName{table} + stmt.IndexName = c.stmt.IndexName + stmt.Keys = c.stmt.Keys + + if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, err + } + + if err := c.execOne(ctx, conn, db, sb.String(), c.ToArgs(args)); err != nil { + return nil, errors.WithStack(err) + } + + sb.Reset() + } + } + return resultx.New(), nil +} + +func (c *CreateIndexPlan) SetShard(shard rule.DatabaseTables) { + c.Shards = shard +} + +func (c *CreateIndexPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return err + } + _, _ = res.Dataset() + return nil +} diff --git a/pkg/runtime/plan/ddl/drop_index.go b/pkg/runtime/plan/ddl/drop_index.go new file mode 100644 index 000000000..c1cad0385 --- /dev/null +++ b/pkg/runtime/plan/ddl/drop_index.go @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +type DropIndexPlan struct { + plan.BasePlan + stmt *ast.DropIndexStatement + shard rule.DatabaseTables +} + +func NewDropIndexPlan(stmt *ast.DropIndexStatement) *DropIndexPlan { + return &DropIndexPlan{ + stmt: stmt, + } +} + +func (d *DropIndexPlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (d *DropIndexPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + args []int + ) + + for db, tables := range d.shard { + for i := range tables { + table := tables[i] + + stmt := new(ast.DropIndexStatement) + stmt.Table = ast.TableName{table} + stmt.IndexName = d.stmt.IndexName + + if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, err + } + + if err := d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil { + return nil, errors.WithStack(err) + } + + sb.Reset() + } + } + + return resultx.New(), nil +} + +func (d *DropIndexPlan) SetShard(shard rule.DatabaseTables) { + d.shard = shard +} + +func (d *DropIndexPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return err + } + _, _ = res.Dataset() + return nil +} diff --git a/pkg/runtime/plan/drop_table.go b/pkg/runtime/plan/ddl/drop_table.go similarity index 62% rename from pkg/runtime/plan/drop_table.go rename to pkg/runtime/plan/ddl/drop_table.go index c724e9f85..256f5df82 100644 --- a/pkg/runtime/plan/drop_table.go +++ b/pkg/runtime/plan/ddl/drop_table.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package ddl import ( "context" @@ -23,14 +23,19 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" + "github.com/pkg/errors" +) + +import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) type DropTablePlan struct { - basePlan + plan.BasePlan stmt *ast.DropTableStatement shardsMap []rule.DatabaseTables } @@ -41,39 +46,51 @@ func NewDropTablePlan(stmt *ast.DropTableStatement) *DropTablePlan { } } -func (d DropTablePlan) Type() proto.PlanType { +func (d *DropTablePlan) Type() proto.PlanType { return proto.PlanTypeExec } -func (d DropTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { - var sb strings.Builder - var args []int +func (d *DropTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "DropTablePlan.ExecIn") + defer span.End() + var ( + sb strings.Builder + args []int + ) for _, shards := range d.shardsMap { - var stmt = new(ast.DropTableStatement) + stmt := new(ast.DropTableStatement) for db, tables := range shards { for _, table := range tables { - stmt.Tables = append(stmt.Tables, &ast.TableName{ table, }) } err := stmt.Restore(ast.RestoreDefault, &sb, &args) - if err != nil { return nil, err } - _, err = conn.Exec(ctx, db, sb.String(), d.toArgs(args)...) - if err != nil { - return nil, err + + if err = d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil { + return nil, errors.WithStack(err) } + sb.Reset() } } - return &mysql.Result{DataChan: make(chan proto.Row, 1)}, nil + return resultx.New(), nil } -func (s *DropTablePlan) SetShards(shardsMap []rule.DatabaseTables) { - s.shardsMap = shardsMap +func (d *DropTablePlan) SetShards(shardsMap []rule.DatabaseTables) { + d.shardsMap = shardsMap +} + +func (d *DropTablePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return err + } + _, _ = res.Dataset() + return nil } diff --git a/pkg/runtime/plan/ddl/drop_trigger.go b/pkg/runtime/plan/ddl/drop_trigger.go new file mode 100644 index 000000000..39e09d8f8 --- /dev/null +++ b/pkg/runtime/plan/ddl/drop_trigger.go @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ddl + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*DropTriggerPlan)(nil) + +type DropTriggerPlan struct { + plan.BasePlan + Stmt *ast.DropTriggerStatement + Shards rule.DatabaseTables +} + +func (d *DropTriggerPlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (d *DropTriggerPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + sb strings.Builder + args []int + err error + ) + + for db := range d.Shards { + if err = d.Stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, err + } + + if err = d.execOne(ctx, conn, db, sb.String(), d.ToArgs(args)); err != nil { + return nil, errors.WithStack(err) + } + + sb.Reset() + } + + return resultx.New(), nil +} + +func (d *DropTriggerPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return err + } + _, _ = res.Dataset() + return nil +} diff --git a/pkg/runtime/plan/truncate.go b/pkg/runtime/plan/ddl/truncate.go similarity index 77% rename from pkg/runtime/plan/truncate.go rename to pkg/runtime/plan/ddl/truncate.go index 6624b17b3..2464e51c5 100644 --- a/pkg/runtime/plan/truncate.go +++ b/pkg/runtime/plan/ddl/truncate.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package ddl import ( "context" @@ -27,16 +27,17 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) var _ proto.Plan = (*TruncatePlan)(nil) type TruncatePlan struct { - basePlan + plan.BasePlan stmt *ast.TruncateStatement shards rule.DatabaseTables } @@ -51,8 +52,10 @@ func (s *TruncatePlan) Type() proto.PlanType { } func (s *TruncatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "TruncatePlan.ExecIn") + defer span.End() if s.shards == nil || s.shards.IsEmpty() { - return &mysql.Result{AffectedRows: 0}, nil + return resultx.New(), nil } var ( @@ -70,19 +73,28 @@ func (s *TruncatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resu return nil, errors.Wrap(err, "failed to execute TRUNCATE statement") } - _, err := conn.Exec(ctx, db, sb.String(), s.toArgs(args)...) - if err != nil { + if err := s.execOne(ctx, conn, db, sb.String(), s.ToArgs(args)); err != nil { return nil, errors.WithStack(err) } + sb.Reset() } } - return &mysql.Result{ - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(), nil } func (s *TruncatePlan) SetShards(shards rule.DatabaseTables) { s.shards = shards } + +func (s *TruncatePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) error { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return errors.WithStack(err) + } + + defer resultx.Drain(res) + + return nil +} diff --git a/pkg/runtime/plan/aggregate.go b/pkg/runtime/plan/dml/aggregate.go similarity index 86% rename from pkg/runtime/plan/aggregate.go rename to pkg/runtime/plan/dml/aggregate.go index 9127222b5..7dbde95ea 100644 --- a/pkg/runtime/plan/aggregate.go +++ b/pkg/runtime/plan/dml/aggregate.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -27,13 +27,14 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/transformer" ) type AggregatePlan struct { transformer.Combiner AggrLoader *transformer.AggrLoader - UnionPlan *UnionPlan + Plan proto.Plan } func (a *AggregatePlan) Type() proto.PlanType { @@ -41,7 +42,9 @@ func (a *AggregatePlan) Type() proto.PlanType { } func (a *AggregatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { - res, err := a.UnionPlan.ExecIn(ctx, conn) + ctx, span := plan.Tracer.Start(ctx, "AggregatePlan.ExecIn") + defer span.End() + res, err := a.Plan.ExecIn(ctx, conn) if err != nil { return nil, errors.WithStack(err) } diff --git a/pkg/runtime/plan/dml/group_plan.go b/pkg/runtime/plan/dml/group_plan.go new file mode 100644 index 000000000..fd5f6b2b8 --- /dev/null +++ b/pkg/runtime/plan/dml/group_plan.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/merge" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" +) + +// GroupPlan TODO now only support stmt which group by items equal with order by items, such as +// `select uid, max(score) from student group by uid order by uid` +type GroupPlan struct { + Plan proto.Plan + AggItems map[int]func() merge.Aggregator + OrderByItems []dataset.OrderByItem +} + +func (g *GroupPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (g *GroupPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + res, err := g.Plan.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + fields, err := ds.Fields() + if err != nil { + return nil, errors.WithStack(err) + } + + return resultx.New(resultx.WithDataset(dataset.Pipe(ds, dataset.GroupReduce( + g.OrderByItems, + func(fields []proto.Field) []proto.Field { + return fields + }, + func() dataset.Reducer { + return dataset.NewGroupReducer(g.AggItems, fields) + }, + )))), nil +} diff --git a/pkg/runtime/plan/dml/insert_select.go b/pkg/runtime/plan/dml/insert_select.go new file mode 100644 index 000000000..1c906f300 --- /dev/null +++ b/pkg/runtime/plan/dml/insert_select.go @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*InsertSelectPlan)(nil) + +type InsertSelectPlan struct { + plan.BasePlan + Batch map[string]*ast.InsertSelectStatement +} + +func NewInsertSelectPlan() *InsertSelectPlan { + return &InsertSelectPlan{ + Batch: make(map[string]*ast.InsertSelectStatement), + } +} + +func (sp *InsertSelectPlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (sp *InsertSelectPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + affects uint64 + lastInsertId uint64 + ) + // TODO: consider wrap a transaction if insert into multiple databases + // TODO: insert in parallel + for db, insert := range sp.Batch { + id, affected, err := sp.doInsert(ctx, conn, db, insert) + if err != nil { + return nil, err + } + affects += affected + if id > lastInsertId { + lastInsertId = id + } + } + + return resultx.New(resultx.WithLastInsertID(lastInsertId), resultx.WithRowsAffected(affects)), nil +} + +func (sp *InsertSelectPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertSelectStatement) (uint64, uint64, error) { + var ( + sb strings.Builder + args []int + ) + + if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return 0, 0, errors.Wrap(err, "cannot restore insert statement") + } + res, err := conn.Exec(ctx, db, sb.String(), sp.ToArgs(args)...) + if err != nil { + return 0, 0, errors.WithStack(err) + } + + defer resultx.Drain(res) + + id, err := res.LastInsertId() + if err != nil { + return 0, 0, errors.WithStack(err) + } + + affected, err := res.RowsAffected() + if err != nil { + return 0, 0, errors.WithStack(err) + } + + return id, affected, nil +} diff --git a/pkg/runtime/plan/dml/limit.go b/pkg/runtime/plan/dml/limit.go new file mode 100644 index 000000000..d20a32882 --- /dev/null +++ b/pkg/runtime/plan/dml/limit.go @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" +) + +var _ proto.Plan = (*LimitPlan)(nil) + +type LimitPlan struct { + ParentPlan proto.Plan + OriginOffset int64 + OverwriteLimit int64 +} + +func (limitPlan *LimitPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (limitPlan *LimitPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + if limitPlan.ParentPlan == nil { + return nil, errors.New("limitPlan: ParentPlan is nil") + } + + res, err := limitPlan.ParentPlan.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + var count int64 + ds = dataset.Pipe(ds, dataset.Filter(func(next proto.Row) bool { + count++ + if count < limitPlan.OriginOffset { + return false + } + if count > limitPlan.OriginOffset && count <= limitPlan.OverwriteLimit { + return true + } + return false + })) + return resultx.New(resultx.WithDataset(ds)), nil +} diff --git a/pkg/runtime/plan/dml/order.go b/pkg/runtime/plan/dml/order.go new file mode 100644 index 000000000..ea722c32e --- /dev/null +++ b/pkg/runtime/plan/dml/order.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" +) + +var _ proto.Plan = (*OrderPlan)(nil) + +type OrderPlan struct { + ParentPlan proto.Plan + OrderByItems []dataset.OrderByItem +} + +func (orderPlan *OrderPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (orderPlan *OrderPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + if orderPlan.ParentPlan == nil { + return nil, errors.New("order plan: ParentPlan is nil") + } + + res, err := orderPlan.ParentPlan.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + fuseable, ok := ds.(*dataset.FuseableDataset) + + if !ok { + return nil, errors.New("order plan convert Dataset to FuseableDataset cause error") + } + + orderedDataset := dataset.NewOrderedDataset(fuseable.ToParallel(), orderPlan.OrderByItems) + + return resultx.New(resultx.WithDataset(orderedDataset)), nil +} diff --git a/pkg/runtime/plan/simple_delete.go b/pkg/runtime/plan/dml/simple_delete.go similarity index 76% rename from pkg/runtime/plan/simple_delete.go rename to pkg/runtime/plan/dml/simple_delete.go index 39ebfe979..6044b642c 100644 --- a/pkg/runtime/plan/simple_delete.go +++ b/pkg/runtime/plan/dml/simple_delete.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -27,17 +27,18 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) var _ proto.Plan = (*SimpleDeletePlan)(nil) // SimpleDeletePlan represents a simple delete plan for sharding table. type SimpleDeletePlan struct { - basePlan + plan.BasePlan stmt *ast.DeleteStatement shards rule.DatabaseTables } @@ -52,8 +53,10 @@ func (s *SimpleDeletePlan) Type() proto.PlanType { } func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "SimpleDeletePlan.ExecIn") + defer span.End() if s.shards == nil || s.shards.IsEmpty() { - return &mysql.Result{AffectedRows: 0}, nil + return resultx.New(), nil } var ( @@ -77,12 +80,11 @@ func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto. return nil, errors.Wrap(err, "failed to execute DELETE statement") } - res, err := conn.Exec(ctx, db, sb.String(), s.toArgs(args)...) + n, err := s.execOne(ctx, conn, db, sb.String(), s.ToArgs(args)) if err != nil { return nil, errors.WithStack(err) } - n, _ := res.RowsAffected() affects += n // cleanup @@ -93,12 +95,23 @@ func (s *SimpleDeletePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto. } } - return &mysql.Result{ - AffectedRows: affects, - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(resultx.WithRowsAffected(affects)), nil } func (s *SimpleDeletePlan) SetShards(shards rule.DatabaseTables) { s.shards = shards } + +func (s *SimpleDeletePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) (uint64, error) { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return 0, errors.WithStack(err) + } + defer resultx.Drain(res) + + var n uint64 + if n, err = res.RowsAffected(); err != nil { + return 0, errors.WithStack(err) + } + return n, nil +} diff --git a/pkg/runtime/plan/simple_insert.go b/pkg/runtime/plan/dml/simple_insert.go similarity index 68% rename from pkg/runtime/plan/simple_insert.go rename to pkg/runtime/plan/dml/simple_insert.go index d99698cca..8cbd524cb 100644 --- a/pkg/runtime/plan/simple_insert.go +++ b/pkg/runtime/plan/dml/simple_insert.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -27,15 +27,16 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) var _ proto.Plan = (*SimpleInsertPlan)(nil) type SimpleInsertPlan struct { - basePlan + plan.BasePlan batch map[string][]*ast.InsertStatement // key=db } @@ -55,41 +56,54 @@ func (sp *SimpleInsertPlan) Put(db string, stmt *ast.InsertStatement) { func (sp *SimpleInsertPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { var ( - effected uint64 + affects uint64 lastInsertId uint64 ) + ctx, span := plan.Tracer.Start(ctx, "SimpleInsertPlan.ExecIn") + defer span.End() // TODO: consider wrap a transaction if insert into multiple databases // TODO: insert in parallel for db, inserts := range sp.batch { for _, insert := range inserts { - res, err := sp.doInsert(ctx, conn, db, insert) + id, affected, err := sp.doInsert(ctx, conn, db, insert) if err != nil { return nil, err } - if n, _ := res.RowsAffected(); n > 0 { - effected += n - } - if id, _ := res.LastInsertId(); id > lastInsertId { + affects += affected + if id > lastInsertId { lastInsertId = id } } } - return &mysql.Result{ - AffectedRows: effected, - InsertId: lastInsertId, - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(resultx.WithLastInsertID(lastInsertId), resultx.WithRowsAffected(affects)), nil } -func (sp *SimpleInsertPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertStatement) (proto.Result, error) { +func (sp *SimpleInsertPlan) doInsert(ctx context.Context, conn proto.VConn, db string, stmt *ast.InsertStatement) (uint64, uint64, error) { var ( sb strings.Builder args []int ) if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { - return nil, errors.Wrap(err, "cannot restore insert statement") + return 0, 0, errors.Wrap(err, "cannot restore insert statement") + } + res, err := conn.Exec(ctx, db, sb.String(), sp.ToArgs(args)...) + if err != nil { + return 0, 0, errors.WithStack(err) + } + + defer resultx.Drain(res) + + id, err := res.LastInsertId() + if err != nil { + return 0, 0, errors.WithStack(err) + } + + affected, err := res.RowsAffected() + if err != nil { + return 0, 0, errors.WithStack(err) } - return conn.Exec(ctx, db, sb.String(), sp.toArgs(args)...) + + return id, affected, nil } diff --git a/pkg/runtime/plan/simple_join.go b/pkg/runtime/plan/dml/simple_join.go similarity index 93% rename from pkg/runtime/plan/simple_join.go rename to pkg/runtime/plan/dml/simple_join.go index 8355cdbd1..7ab17a232 100644 --- a/pkg/runtime/plan/simple_join.go +++ b/pkg/runtime/plan/dml/simple_join.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -29,6 +29,7 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) type JoinTable struct { @@ -37,7 +38,7 @@ type JoinTable struct { } type SimpleJoinPlan struct { - basePlan + plan.BasePlan Database string Left *JoinTable Join *ast.JoinNode @@ -57,20 +58,23 @@ func (s *SimpleJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Re err error ) + ctx, span := plan.Tracer.Start(ctx, "SimpleJoinPlan.ExecIn") + defer span.End() + if err := s.generateSelect(&sb, &indexes); err != nil { return nil, err } sb.WriteString(" FROM ") - //add left part + // add left part if err := s.generateTable(s.Left.Tables, s.Left.Alias, &sb); err != nil { return nil, err } s.generateJoinType(&sb) - //add right part + // add right part if err := s.generateTable(s.Right.Tables, s.Right.Alias, &sb); err != nil { return nil, err } @@ -90,7 +94,7 @@ func (s *SimpleJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Re var ( query = sb.String() - args = s.toArgs(indexes) + args = s.ToArgs(indexes) ) if res, err = conn.Query(ctx, s.Database, query, args...); err != nil { @@ -121,7 +125,6 @@ func (s *SimpleJoinPlan) generateSelect(sb *strings.Builder, args *[]int) error } func (s *SimpleJoinPlan) generateTable(tables []string, alias string, sb *strings.Builder) error { - if len(tables) == 1 { sb.WriteString(tables[0] + " ") @@ -154,7 +157,7 @@ func (s *SimpleJoinPlan) generateTable(tables []string, alias string, sb *string } func (s *SimpleJoinPlan) generateJoinType(sb *strings.Builder) { - //add join type + // add join type switch s.Join.Typ { case ast.LeftJoin: sb.WriteString("LEFT") diff --git a/pkg/runtime/plan/simple_select.go b/pkg/runtime/plan/dml/simple_select.go similarity index 87% rename from pkg/runtime/plan/simple_select.go rename to pkg/runtime/plan/dml/simple_select.go index 2fdf30021..4847d99f4 100644 --- a/pkg/runtime/plan/simple_select.go +++ b/pkg/runtime/plan/dml/simple_select.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -27,15 +27,18 @@ import ( ) import ( + "github.com/arana-db/arana/pkg/dataset" "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) var _ proto.Plan = (*SimpleQueryPlan)(nil) type SimpleQueryPlan struct { - basePlan + plan.BasePlan Database string Tables []string Stmt *ast.SelectStatement @@ -53,11 +56,10 @@ func (s *SimpleQueryPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R err error ) - if s.filter() { - return &mysql.Result{ - DataChan: make(chan proto.Row, 1), - }, nil - } + ctx, span := plan.Tracer.Start(ctx, "SimpleQueryPlan.ExecIn") + defer span.End() + + discard := s.filter() if err = s.generate(&sb, &indexes); err != nil { return nil, errors.Wrap(err, "failed to generate sql") @@ -65,13 +67,34 @@ func (s *SimpleQueryPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.R var ( query = sb.String() - args = s.toArgs(indexes) + args = s.ToArgs(indexes) ) if res, err = conn.Query(ctx, s.Database, query, args...); err != nil { return nil, errors.WithStack(err) } - return res, nil + + if !discard { + return res, nil + } + + var ( + rr = res.(*mysql.RawResult) + fields []proto.Field + ) + + defer func() { + _ = rr.Discard() + }() + + if fields, err = rr.Fields(); err != nil { + return nil, errors.WithStack(err) + } + + emptyDs := &dataset.VirtualDataset{ + Columns: fields, + } + return resultx.New(resultx.WithDataset(emptyDs)), nil } func (s *SimpleQueryPlan) filter() bool { @@ -132,9 +155,7 @@ func (s *SimpleQueryPlan) generate(sb *strings.Builder, args *[]int) error { // UNION ALL // (SELECT * FROM student_0000 WHERE uid IN (1,2,3) - var ( - stmt = new(ast.SelectStatement) - ) + stmt := new(ast.SelectStatement) *stmt = *s.Stmt // do copy restore := func(table string) error { @@ -171,9 +192,7 @@ func (s *SimpleQueryPlan) generate(sb *strings.Builder, args *[]int) error { } func (s *SimpleQueryPlan) resetOrderBy(tgt *ast.SelectStatement, sb *strings.Builder, args *[]int) error { - var ( - builder strings.Builder - ) + var builder strings.Builder builder.WriteString("SELECT * FROM (") builder.WriteString(sb.String()) builder.WriteString(") ") diff --git a/pkg/runtime/plan/dml/union.go b/pkg/runtime/plan/dml/union.go new file mode 100644 index 000000000..d2e182953 --- /dev/null +++ b/pkg/runtime/plan/dml/union.go @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "fmt" + "io" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/util/log" +) + +// UnionPlan merges multiple query plan. +type UnionPlan struct { + Plans []proto.Plan +} + +func (u UnionPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (u UnionPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "UnionPlan.ExecIn") + defer span.End() + switch u.Plans[0].Type() { + case proto.PlanTypeQuery: + return u.query(ctx, conn) + case proto.PlanTypeExec: + return u.exec(ctx, conn) + default: + panic("unreachable") + } +} + +func (u UnionPlan) showOpenTables(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var ( + fields []proto.Field + rows []proto.Row + filterMap = make(map[string]struct{}, 0) // map[database-table] + ) + for _, it := range u.Plans { + it := it + res, err := it.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + fields, err = ds.Fields() + if err != nil { + return nil, errors.WithStack(err) + } + var row proto.Row + + for { + row, err = ds.Next() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return nil, errors.WithStack(err) + } + var values []proto.Value + if err = row.Scan(values); err != nil { + return nil, errors.WithStack(err) + } + // Database Table In_use Name_locked + key := fmt.Sprintf("%s-%s", values[0].(string), values[1].(string)) + if _, ok := filterMap[key]; ok { + continue + } + filterMap[key] = struct{}{} + rows = append(rows, row) + } + } + ds := &dataset.VirtualDataset{ + Columns: fields, + Rows: rows, + } + + return resultx.New(resultx.WithDataset(ds)), nil +} + +func (u UnionPlan) query(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var generators []dataset.GenerateFunc + for _, it := range u.Plans { + it := it + generators = append(generators, func() (proto.Dataset, error) { + res, err := it.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + return res.Dataset() + }) + } + + ds, err := dataset.Fuse(generators[0], generators[1:]...) + if err != nil { + log.Errorf("UnionPlan Fuse error:%v", err) + return nil, err + } + return resultx.New(resultx.WithDataset(ds)), nil +} + +func (u UnionPlan) exec(ctx context.Context, conn proto.VConn) (proto.Result, error) { + var id, affects uint64 + for _, it := range u.Plans { + i, n, err := u.execOne(ctx, conn, it) + if err != nil { + return nil, errors.WithStack(err) + } + affects += n + id += i + } + + return resultx.New(resultx.WithLastInsertID(id), resultx.WithRowsAffected(affects)), nil +} + +func (u UnionPlan) execOne(ctx context.Context, conn proto.VConn, p proto.Plan) (uint64, uint64, error) { + res, err := p.ExecIn(ctx, conn) + if err != nil { + return 0, 0, errors.WithStack(err) + } + + defer resultx.Drain(res) + + id, err := res.LastInsertId() + if err != nil { + return 0, 0, errors.WithStack(err) + } + + affected, err := res.RowsAffected() + if err != nil { + return 0, 0, errors.WithStack(err) + } + + return id, affected, nil +} diff --git a/pkg/runtime/plan/update.go b/pkg/runtime/plan/dml/update.go similarity index 78% rename from pkg/runtime/plan/update.go rename to pkg/runtime/plan/dml/update.go index 3d8f65fbd..633c42396 100644 --- a/pkg/runtime/plan/update.go +++ b/pkg/runtime/plan/dml/update.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package dml import ( "context" @@ -31,10 +31,11 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/util/log" ) @@ -42,7 +43,7 @@ var _ proto.Plan = (*UpdatePlan)(nil) // UpdatePlan represents a plan to execute sharding-update. type UpdatePlan struct { - basePlan + plan.BasePlan stmt *ast.UpdateStatement shards rule.DatabaseTables } @@ -57,12 +58,14 @@ func (up *UpdatePlan) Type() proto.PlanType { } func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "UpdatePlan.ExecIn") + defer span.End() if up.shards == nil { var sb strings.Builder if err := up.stmt.Restore(ast.RestoreDefault, &sb, nil); err != nil { return nil, err } - return conn.Exec(ctx, "", sb.String(), up.args...) + return conn.Exec(ctx, "", sb.String(), up.Args...) } var ( @@ -84,8 +87,8 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul var ( sb strings.Builder args []int - res proto.Result err error + n uint64 ) sb.Grow(256) @@ -95,11 +98,10 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul return errors.WithStack(err) } - if res, err = conn.Exec(ctx, db, sb.String(), up.toArgs(args)...); err != nil { + if n, err = up.execOne(ctx, conn, db, sb.String(), up.ToArgs(args)); err != nil { return errors.WithStack(err) } - n, _ := res.RowsAffected() affects.Add(n) cnt.Inc() @@ -120,12 +122,25 @@ func (up *UpdatePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resul log.Debugf("sharding update success: batch=%d, affects=%d", cnt.Load(), affects.Load()) - return &mysql.Result{ - AffectedRows: affects.Load(), - DataChan: make(chan proto.Row, 1), - }, nil + return resultx.New(resultx.WithRowsAffected(affects.Load())), nil } func (up *UpdatePlan) SetShards(shards rule.DatabaseTables) { up.shards = shards } + +func (up *UpdatePlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []interface{}) (uint64, error) { + res, err := conn.Exec(ctx, db, query, args...) + if err != nil { + return 0, errors.WithStack(err) + } + + defer resultx.Drain(res) + + n, err := res.RowsAffected() + if err != nil { + return 0, errors.WithStack(err) + } + + return n, nil +} diff --git a/pkg/runtime/plan/plan.go b/pkg/runtime/plan/plan.go index 9bff2732d..aab88102c 100644 --- a/pkg/runtime/plan/plan.go +++ b/pkg/runtime/plan/plan.go @@ -17,21 +17,27 @@ package plan -type basePlan struct { - args []interface{} +import ( + "go.opentelemetry.io/otel" +) + +var Tracer = otel.Tracer("ExecPlan") + +type BasePlan struct { + Args []interface{} } -func (bp *basePlan) BindArgs(args []interface{}) { - bp.args = args +func (bp *BasePlan) BindArgs(args []interface{}) { + bp.Args = args } -func (bp basePlan) toArgs(indexes []int) []interface{} { - if len(indexes) < 1 || len(bp.args) < 1 { +func (bp *BasePlan) ToArgs(indexes []int) []interface{} { + if len(indexes) < 1 || len(bp.Args) < 1 { return nil } ret := make([]interface{}, 0, len(indexes)) for _, idx := range indexes { - ret = append(ret, bp.args[idx]) + ret = append(ret, bp.Args[idx]) } return ret } diff --git a/pkg/runtime/plan/show_tables.go b/pkg/runtime/plan/show_tables.go deleted file mode 100644 index b335a0da2..000000000 --- a/pkg/runtime/plan/show_tables.go +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package plan - -import ( - "context" - "io" - "strings" -) - -import ( - "github.com/pkg/errors" -) - -import ( - "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/runtime/ast" -) - -var _ proto.Plan = (*ShowTablesPlan)(nil) - -type DatabaseTable struct { - Database string - TableName string -} - -type ShowTablesPlan struct { - basePlan - Database string - Stmt *ast.ShowTables - allShards map[string]DatabaseTable -} - -// NewShowTablesPlan create ShowTables Plan -func NewShowTablesPlan(stmt *ast.ShowTables) *ShowTablesPlan { - return &ShowTablesPlan{ - Stmt: stmt, - } -} - -func (s *ShowTablesPlan) Type() proto.PlanType { - return proto.PlanTypeQuery -} - -func (s *ShowTablesPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { - var ( - sb strings.Builder - indexes []int - res proto.Result - err error - ) - - if err = s.Stmt.Restore(ast.RestoreDefault, &sb, &indexes); err != nil { - return nil, errors.WithStack(err) - } - - var ( - query = sb.String() - args = s.toArgs(indexes) - ) - - if res, err = conn.Query(ctx, s.Database, query, args...); err != nil { - return nil, errors.WithStack(err) - } - - if closer, ok := res.(io.Closer); ok { - defer func() { - _ = closer.Close() - }() - } - - rebuildResult := mysql.Result{ - Fields: res.GetFields(), - DataChan: make(chan proto.Row, 1), - } - hasRebuildTables := make(map[string]struct{}) - var affectRows uint64 - if len(res.GetRows()) > 0 { - row := res.GetRows()[0] - rowIter := row.(*mysql.TextIterRow) - for has, err := rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - rowValues, err := row.Decode() - if err != nil { - return nil, err - } - tableName := s.convertInterfaceToStrNullable(rowValues[0].Val) - if databaseTable, exist := s.allShards[tableName]; exist { - if _, ok := hasRebuildTables[databaseTable.TableName]; ok { - continue - } - - if _, ok := hasRebuildTables[databaseTable.TableName]; ok { - continue - } - - encodeTableName := mysql.PutLengthEncodedString([]byte(databaseTable.TableName)) - tmpValues := rowValues - for idx := range tmpValues { - tmpValues[idx].Val = string(encodeTableName) - tmpValues[idx].Raw = encodeTableName - tmpValues[idx].Len = len(encodeTableName) - } - - var tmpNewRow mysql.TextRow - tmpNewRow.Encode(tmpValues, row.Fields(), row.Columns()) - rebuildResult.Rows = append(rebuildResult.Rows, &tmpNewRow) - hasRebuildTables[databaseTable.TableName] = struct{}{} - affectRows++ - continue - } - affectRows++ - textRow := &mysql.TextRow{Row: *rowIter.Row} - rebuildResult.Rows = append(rebuildResult.Rows, textRow) - } - } - rebuildResult.AffectedRows = affectRows - return &rebuildResult, nil -} - -func (s *ShowTablesPlan) convertInterfaceToStrNullable(value interface{}) string { - if value != nil { - return string(value.([]byte)) - } - return "" -} - -func (s *ShowTablesPlan) SetAllShards(allShards map[string]DatabaseTable) { - s.allShards = allShards -} diff --git a/pkg/runtime/plan/transparent.go b/pkg/runtime/plan/transparent.go index a4b67a360..ff9339239 100644 --- a/pkg/runtime/plan/transparent.go +++ b/pkg/runtime/plan/transparent.go @@ -35,7 +35,7 @@ var _ proto.Plan = (*TransparentPlan)(nil) // TransparentPlan represents a transparent plan. type TransparentPlan struct { - basePlan + BasePlan stmt rast.Statement db string typ proto.PlanType @@ -45,7 +45,8 @@ type TransparentPlan struct { func Transparent(stmt rast.Statement, args []interface{}) *TransparentPlan { var typ proto.PlanType switch stmt.Mode() { - case rast.Sinsert, rast.Sdelete, rast.Sreplace, rast.Supdate, rast.Struncate, rast.SdropTable, rast.SalterTable: + case rast.SQLTypeInsert, rast.SQLTypeDelete, rast.SQLTypeReplace, rast.SQLTypeUpdate, rast.SQLTypeTruncate, rast.SQLTypeDropTable, + rast.SQLTypeAlterTable, rast.SQLTypeDropIndex, rast.SQLTypeCreateIndex: typ = proto.PlanTypeExec default: typ = proto.PlanTypeQuery @@ -76,6 +77,8 @@ func (tp *TransparentPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto. args []int err error ) + ctx, span := Tracer.Start(ctx, "TransparentPlan.ExecIn") + defer span.End() if err = tp.stmt.Restore(rast.RestoreDefault, &sb, &args); err != nil { return nil, errors.WithStack(err) @@ -83,9 +86,9 @@ func (tp *TransparentPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto. switch tp.typ { case proto.PlanTypeQuery: - return conn.Query(ctx, tp.db, sb.String(), tp.toArgs(args)...) + return conn.Query(ctx, tp.db, sb.String(), tp.ToArgs(args)...) case proto.PlanTypeExec: - return conn.Exec(ctx, tp.db, sb.String(), tp.toArgs(args)...) + return conn.Exec(ctx, tp.db, sb.String(), tp.ToArgs(args)...) default: panic("unreachable") } diff --git a/pkg/runtime/plan/union.go b/pkg/runtime/plan/union.go deleted file mode 100644 index ebe649d3b..000000000 --- a/pkg/runtime/plan/union.go +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package plan - -import ( - "context" - "io" -) - -import ( - "github.com/pkg/errors" - - "go.uber.org/multierr" -) - -import ( - "github.com/arana-db/arana/pkg/proto" -) - -// UnionPlan merges multiple query plan. -type UnionPlan struct { - Plans []proto.Plan -} - -func (u UnionPlan) Type() proto.PlanType { - return proto.PlanTypeQuery -} - -func (u UnionPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { - var results []proto.Result - for _, it := range u.Plans { - res, err := it.ExecIn(ctx, conn) - if err != nil { - return nil, errors.WithStack(err) - } - results = append(results, res) - } - - return compositeResult(results), nil -} - -type compositeResult []proto.Result - -func (c compositeResult) Close() error { - var errs []error - for _, it := range c { - if closer, ok := it.(io.Closer); ok { - if err := closer.Close(); err != nil { - errs = append(errs, err) - } - } - } - return multierr.Combine(errs...) -} - -func (c compositeResult) GetFields() []proto.Field { - for _, it := range c { - if ret := it.GetFields(); ret != nil { - return ret - } - } - return nil -} - -func (c compositeResult) GetRows() []proto.Row { - var rows []proto.Row - for _, it := range c { - rows = append(rows, it.GetRows()...) - } - return rows -} - -func (c compositeResult) LastInsertId() (uint64, error) { - return 0, nil -} - -func (c compositeResult) RowsAffected() (uint64, error) { - return 0, nil -} diff --git a/pkg/runtime/plan/describle.go b/pkg/runtime/plan/utility/describle.go similarity index 91% rename from pkg/runtime/plan/describle.go rename to pkg/runtime/plan/utility/describle.go index 1c121f7db..1d9cd6985 100644 --- a/pkg/runtime/plan/describle.go +++ b/pkg/runtime/plan/utility/describle.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package plan +package utility import ( "context" @@ -29,13 +29,15 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" ) type DescribePlan struct { - basePlan + plan.BasePlan Stmt *ast.DescribeStatement Database string Table string + Column string } func NewDescribePlan(stmt *ast.DescribeStatement) *DescribePlan { @@ -53,6 +55,8 @@ func (d *DescribePlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Res res proto.Result err error ) + ctx, span := plan.Tracer.Start(ctx, "DescribePlan.ExecIn") + defer span.End() if err = d.generate(&sb, &indexes); err != nil { return nil, errors.Wrap(err, "failed to generate desc/describe sql") @@ -60,7 +64,7 @@ func (d *DescribePlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Res var ( query = sb.String() - args = d.toArgs(indexes) + args = d.ToArgs(indexes) ) if res, err = vConn.Query(ctx, d.Database, query, args...); err != nil { diff --git a/pkg/runtime/rule/shard.go b/pkg/runtime/rule/shard.go index 7491b9286..91e8837eb 100644 --- a/pkg/runtime/rule/shard.go +++ b/pkg/runtime/rule/shard.go @@ -40,6 +40,8 @@ const ( HashMd5Shard ShardType = "hashMd5Shard" HashCrc32Shard ShardType = "hashCrc32Shard" HashBKDRShard ShardType = "hashBKDRShard" + ScriptExpr ShardType = "scriptExpr" + FunctionExpr ShardType = "functionExpr" ) var shardMap = map[ShardType]ShardComputerFunc{ @@ -59,8 +61,10 @@ func ShardFactory(shardType ShardType, shardNum int) (shardStrategy rule.ShardCo return nil, errors.New("do not have this shardType") } -type ShardType string -type ShardComputerFunc func(shardNum int) rule.ShardComputer +type ( + ShardType string + ShardComputerFunc func(shardNum int) rule.ShardComputer +) type modShard struct { shardNum int diff --git a/pkg/runtime/rule/shard_expr.go b/pkg/runtime/rule/shard_expr.go new file mode 100644 index 000000000..d33a42571 --- /dev/null +++ b/pkg/runtime/rule/shard_expr.go @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rule + +import ( + "fmt" + "strconv" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto/rule" +) + +var _ rule.ShardComputer = (*exprShardComputer)(nil) + +type exprShardComputer struct { + expr string + column string +} + +func NewExprShardComputer(expr, column string) (rule.ShardComputer, error) { + result := &exprShardComputer{ + expr: expr, + column: column, + } + return result, nil +} + +func (compute *exprShardComputer) Compute(value interface{}) (int, error) { + expr, vars, err := Parse(compute.expr) + if err != nil { + return 0, err + } + if len(vars) != 1 || vars[0] != Var(compute.column) { + return 0, errors.Errorf("Parse shard expr is error, expr is: %s", compute.expr) + } + + shardValue := fmt.Sprintf("%v", value) + eval, _ := expr.Eval(Env{Var(compute.column): Value(shardValue)}) + + result, err := strconv.ParseFloat(eval.String(), 64) + if err != nil { + return 0, err + } + + return int(result), nil +} diff --git a/pkg/runtime/rule/shard_expr_parse.go b/pkg/runtime/rule/shard_expr_parse.go new file mode 100644 index 000000000..13d68ab0f --- /dev/null +++ b/pkg/runtime/rule/shard_expr_parse.go @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rule + +import ( + "fmt" + "strconv" + "strings" + "text/scanner" +) + +// ---- lexer ---- +// This lexer is similar to the one described in Chapter 13. +type lexer struct { + scan scanner.Scanner + token rune // current lookahead token +} + +func (lex *lexer) move() { + lex.token = lex.scan.Scan() + // fmt.Printf("after move, current token %q\n", lex.text()) +} + +func (lex *lexer) text() string { return lex.scan.TokenText() } +func (lex *lexer) peek() rune { return lex.scan.Peek() } + +// describe returns a string describing the current token, for use in errors. +func (lex *lexer) describe() string { + switch lex.token { + case scanner.EOF: + return "end of file" + case scanner.Ident: + return fmt.Sprintf("identifier %s", lex.text()) + case scanner.Int, scanner.Float: + return fmt.Sprintf("number %s", lex.text()) + } + return fmt.Sprintf("%q", rune(lex.token)) // any other rune +} + +func precedence(op rune) int { + switch op { + case '*', '/', '%': + return 2 + case '+', '-': + return 1 + } + return 0 +} + +// ---- parser ---- + +// Parse parses the input string as an arithmetic expression. +// +// expr = num a constant number, e.g., 3.14159 +// | id a variable name, e.g., x +// | id '(' expr ',' ... ')' a function +// | '-' expr a unary operator (+-) +// | expr '+' expr a binary operator (+-*/) +// +func Parse(input string) (_ Expr, vars []Var, rerr error) { + defer func() { + switch x := recover().(type) { + case nil: + // no panic + case string: + rerr = fmt.Errorf("%s", x) + default: + rerr = fmt.Errorf("unexpected panic: resume state of panic") + } + }() + + lex := new(lexer) + lex.scan.Init(strings.NewReader(input)) + lex.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings + lex.move() // initial lookahead + _, e, err := parseExpr(lex, nil) + if err != nil { + return nil, nil, err + } + if lex.token != scanner.EOF { + return nil, nil, fmt.Errorf("unexpected %s", lex.describe()) + } + + // parse vars + l := len(input) + for i := 0; i < l; i++ { + if input[i] != '#' { + continue + } + i++ // consume '#' + + var s string + for { + if input[i] == '#' { + i++ // consume '#' + break + } + if l <= i { + break + } + s += string(input[i]) + i++ + } + vars = append(vars, Var(s)) + } + + return e, vars, nil +} + +func parseExpr(lex *lexer, s *stack) (*stack, Expr, error) { return parseBinary(lex, 1, s) } + +// binary = unary ('+' binary)* +// parseBinary stops when it encounters an +// operator of lower precedence than prec1. +func parseBinary(lex *lexer, prec1 int, s *stack) (*stack, Expr, error) { + s1, lhs, err := parseUnary(lex, s) + if err != nil { + return s, nil, err + } + s = s1 + + for prec := precedence(lex.token); prec >= prec1; prec-- { + for precedence(lex.token) == prec { + op := lex.token + lex.move() // consume operator + s1, rhs, err := parseBinary(lex, prec+1, s) + if err != nil { + return s1, nil, err + } + s = s1 + lhs = binaryExpr{op: op, x: lhs, y: rhs} + } + } + return s, lhs, nil +} + +// unary = '+' expr | primary +func parseUnary(lex *lexer, s *stack) (*stack, Expr, error) { + if lex.token == '+' || lex.token == '-' { + op := lex.token + lex.move() // consume '+' or '-' + s1, e, err := parseUnary(lex, s) + return s1, unary{op, e}, err + } + return parsePrimary(lex, s) +} + +// primary = id +// | id '(' expr ',' ... ',' expr ')' +// | num +// | '(' expr ')' +func parsePrimary(lex *lexer, s *stack) (*stack, Expr, error) { + switch lex.token { + case scanner.Ident: + id := lex.text() + lex.move() // consume Ident + if lex.token != '(' { + return s, Var(id), nil + } + + lex.move() // consume '(' + + // vars []Var + var args []Expr + + if lex.token != ')' { + if s == (*stack)(nil) { + s = NewStack() + } + s.push(function{fn: id}) + for { + s1, e, err := parseExpr(lex, s) + if err != nil { + return s, nil, err + } + s = s1 + args = append(args, e) + //if v, ok := e.(Var); ok { + // vars = append(vars, v) + //} + if lex.token != ',' { + break + } + lex.move() // consume ',' + } + if lex.token != ')' { + return s, nil, fmt.Errorf("got %s, want ')'", lex.describe()) + } + } + lex.move() // consume ')' + + if s != (*stack)(nil) { + e, notEmpty := s.pop() + if notEmpty { + // if lex.token == ')' && !tmpstack.isEmpty() { + //if lex.token == ',' { + // lex.move() // consume "," + //} + c, ok := e.(function) + if ok { + c.args = append(c.args, args...) + // c.vars = append(c.vars, vars...) + return s, c, nil + } else { + return s, nil, fmt.Errorf("illegal stack element %#v, type %T", e, e) // impossible + } + } + } + // return s, function{fn:id, args:args, vars: vars}, nil + return s, function{fn: id, args: args}, nil + + case scanner.Int, scanner.Float: + f, err := strconv.ParseFloat(lex.text(), 64) + if err != nil { + return s, nil, err + } + lex.move() // consume number + return s, constant(f), nil + + case scanner.String: + token := stringConstant(lex.text()) + lex.move() + return s, token, nil + + case '(': + lex.move() // consume '(' + s1, e, err := parseExpr(lex, s) + if err != nil || lex.token != ')' { + return s, nil, fmt.Errorf("err %s, got %s, want ')'", err, lex.describe()) + } + s = s1 + lex.move() // consume ')' + return s, e, nil + + case '\'': + // 此处之所以先存下原 mode,只分析 strings,是在分析 TestParse 的 case "hash(concat(#uid#, '1'), 100)" + // 这个例子时,不能正确分析 concat 的分隔符 '1' + mode := lex.scan.Mode + defer func() { + lex.scan.Mode = mode + }() + lex.scan.Mode = scanner.ScanStrings + lex.move() // consume \' + + var str string + for { + str += string(lex.token) + lex.move() + if lex.token == '\'' || lex.token == scanner.EOF { + break + } + } + if lex.token != '\'' { + return s, nil, fmt.Errorf("parsing string with quote ', got illegal last token %s", lex.describe()) + } + lex.move() // consume \' + return s, stringConstant(str), nil + + case '#': + mode := lex.scan.Mode + defer func() { + lex.scan.Mode = mode + }() + lex.scan.Mode = scanner.ScanStrings + lex.move() // consume #' + + var str string + for { + str += string(lex.token) + lex.move() + if lex.token == '#' || lex.token == scanner.EOF { + break + } + } + if lex.token != '#' { + return s, nil, fmt.Errorf("parsing string with quote #, got illegal last token %s", lex.describe()) + } + lex.move() // consume # + return s, Var(str), nil + } + + return s, nil, fmt.Errorf("unexpected %s", lex.describe()) +} diff --git a/pkg/runtime/rule/shard_expr_parse_test.go b/pkg/runtime/rule/shard_expr_parse_test.go new file mode 100644 index 000000000..e3169b3b7 --- /dev/null +++ b/pkg/runtime/rule/shard_expr_parse_test.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rule + +import ( + "fmt" + "strconv" + "testing" +) + +func TestParse(t *testing.T) { + // assert.True(t, preParse(expString) == expression) + tests := []struct { + expr string + env Env + want string + }{ + {"hash(toint(substr(#uid#, 1, 2)), 100)", Env{"uid": "87616"}, "87"}, + {"hash(concat(#uid#, '1'), 100)", Env{"uid": "87616"}, "61"}, + {"div(substr(#uid#, 2), 10)", Env{"uid": "87616"}, "761.6"}, + } + var prevExpr string + for _, test := range tests { + // Print expr only when it changes. + if test.expr != prevExpr { + fmt.Printf("\n%s\n", test.expr) + prevExpr = test.expr + } + expr, vars, err := Parse(test.expr) + if err != nil { + t.Error(err) // parse error + continue + } + if len(vars) != 1 || vars[0] != "uid" { + t.Errorf("illegal vars %#v", vars) + } + + // got := fmt.Sprintf("%.6g", ) + evalRes, _ := expr.Eval(test.env) + evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64) + got := fmt.Sprintf("%.6g", evalResFloat) + fmt.Printf("\t%v => %s\n", test.env, got) + if got != test.want { + t.Errorf("%s.Eval() in %v = %q, want %q\n", test.expr, test.env, got, test.want) + } + } +} diff --git a/pkg/runtime/rule/shard_expr_type.go b/pkg/runtime/rule/shard_expr_type.go new file mode 100644 index 000000000..f9c507522 --- /dev/null +++ b/pkg/runtime/rule/shard_expr_type.go @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rule + +import ( + "bytes" + "fmt" + "math" + "strconv" + "strings" +) + +const ( + defaultValue = "" +) + +// An Expr is an arithmetic/string expression. +type Expr interface { + // Eval returns the value of this Expr in the environment env. + Eval(env Env) (Value, error) + // Check reports errors in this Expr and adds its Vars to the set. + Check(vars map[Var]bool) error + // String output + String() string +} + +type ( + // A Var identifies a variable, e.g., x. + Var string + Env map[Var]Value +) + +func (v Var) Eval(env Env) (Value, error) { + return env[v], nil +} + +func (v Var) Check(vars map[Var]bool) error { + vars[v] = true + return nil +} + +func (v Var) String() string { + var buf bytes.Buffer + write(&buf, v) + return buf.String() +} + +// Value defines var value +type Value string + +func (v Value) String() string { + return string(v) +} + +// A constant is a numeric constant, e.g., 3.141. +type constant float64 + +func (c constant) Eval(Env) (Value, error) { + return Value(strconv.Itoa(int(c))), nil +} + +func (constant) Check(map[Var]bool) error { + return nil +} + +func (c constant) String() string { + var buf bytes.Buffer + write(&buf, c) + return buf.String() +} + +// A stringConstant is a string constant, e.g., 3.141. +type stringConstant string + +func (s stringConstant) Eval(Env) (Value, error) { + return Value(s), nil +} + +func (stringConstant) Check(map[Var]bool) error { + return nil +} + +func (s stringConstant) String() string { + return (string)(s) +} + +// A unary represents a unary operator expression, e.g., -x. +type unary struct { + op rune // one of '+', '-' + x Expr +} + +func (u unary) Eval(env Env) (Value, error) { + switch u.op { + case '+': + return u.x.Eval(env) + + case '-': + xv, err := u.x.Eval(env) + if err != nil { + return defaultValue, err + } + fv, err := strconv.ParseFloat(xv.String(), 64) + if err != nil { + return defaultValue, err + } + return Value(fmt.Sprintf("%f", -fv)), nil + } + + return defaultValue, fmt.Errorf("unsupported unary operator: %q", u.op) +} + +func (u unary) Check(vars map[Var]bool) error { + if !strings.ContainsRune("+-", u.op) { + return fmt.Errorf("unexpected unary op %q", u.op) + } + return u.x.Check(vars) +} + +func (u unary) String() string { + var buf bytes.Buffer + write(&buf, u) + return buf.String() +} + +// A binary represents a binary operator expression, e.g., x+y. +type binaryExpr struct { + op rune // one of '+', '-', '*', '/', '%' + x, y Expr +} + +func (b binaryExpr) Eval(env Env) (Value, error) { + xv, err := b.x.Eval(env) + if err != nil { + return xv, err + } + xf, err := strconv.ParseFloat(xv.String(), 64) + if err != nil { + return xv, err + } + + yv, err := b.y.Eval(env) + if err != nil { + return yv, err + } + yf, err := strconv.ParseFloat(yv.String(), 64) + if err != nil { + return yv, err + } + + var f float64 + switch b.op { + case '+': + f = xf + yf + + case '-': + f = xf - yf + + case '*': + f = xf * yf + + case '/': + if yf == 0 { + yf = 1.0 + } + // f = float64(int(xf / yf)) + f = xf / yf + + case '%': + if yf == 0 { + yf = 1.0 + } + f = float64(int(xf) % int(yf)) + + default: + return defaultValue, fmt.Errorf("unsupported binary operator: %q", b.op) + } + + return Value(fmt.Sprintf("%f", f)), nil +} + +func (b binaryExpr) String() string { + var buf bytes.Buffer + write(&buf, b) + return buf.String() +} + +func (b binaryExpr) Check(vars map[Var]bool) error { + if !strings.ContainsRune("+-*/%", b.op) { + return fmt.Errorf("unexpected binary op %q", b.op) + } + if err := b.x.Check(vars); err != nil { + return err + } + return b.y.Check(vars) +} + +// A function represents a function function expression, e.g., sin(x). +type function struct { + fn string // one of "pow", "sqrt" + args []Expr +} + +func (c function) Eval(env Env) (Value, error) { + argsNum := len(c.args) + if argsNum == 0 { + return defaultValue, fmt.Errorf("args number is 0 of func %s", c.fn) + } + + switch c.fn { + case "toint": + return c.args[0].Eval(env) + + case "hash": + b := binaryExpr{ + op: '%', + x: c.args[0], + y: constant(2), + } + if len(c.args) == 2 { + b.y = c.args[1] + } + return b.Eval(env) + + case "add": + b := binaryExpr{ + op: '+', + x: c.args[0], + y: c.args[1], + } + return b.Eval(env) + + case "sub": + b := binaryExpr{ + op: '-', + x: c.args[0], + y: c.args[1], + } + return b.Eval(env) + + case "mul": + b := binaryExpr{ + op: '*', + x: c.args[0], + y: c.args[1], + } + return b.Eval(env) + + case "div": + b := binaryExpr{ + op: '/', + x: c.args[0], + y: c.args[1], + } + return b.Eval(env) + + case "substr": + v0, err := c.args[0].Eval(env) + if err != nil { + return defaultValue, err + } + str := v0.String() + + v1, err := c.args[1].Eval(env) + if err != nil { + return defaultValue, err + } + strlen := len(str) + f, err := strconv.ParseFloat(v1.String(), 64) + if err != nil { + return defaultValue, fmt.Errorf("illegal args[1] %v of func substr", v1.String()) + } + startPos := int(f) + if startPos == 0 { + return defaultValue, nil + } + if startPos < 0 { + startPos += strlen + } else { + startPos-- + } + if startPos < 0 || startPos >= strlen { + return defaultValue, fmt.Errorf("illegal args[1] %v of func substr", v1.String()) + } + + endPos := strlen + if len(c.args) == 3 { + v2, err := c.args[2].Eval(env) + if err != nil { + return defaultValue, err + } + l, err := strconv.Atoi(v2.String()) + if err != nil { + return defaultValue, err + } + if l < 1 { + return defaultValue, fmt.Errorf("illegal args[2] %v of func substr", v2.String()) + } + if startPos+l < endPos { + endPos = startPos + l + } + } + + return Value(str[startPos:endPos]), nil + + case "concat": + var builder strings.Builder + for i := 0; i < len(c.args); i++ { + v, err := c.args[i].Eval(env) + if err != nil { + return defaultValue, err + } + + builder.WriteString(v.String()) + } + return Value(builder.String()), nil + + case "testload": + v0, err := c.args[0].Eval(env) + if err != nil { + return defaultValue, err + } + + s := v0.String() + + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if 'a' <= c && c <= 'j' { + c -= 'a' - '0' + } else if 'A' <= c && c <= 'J' { + c -= 'A' - '0' + } + + b.WriteByte(c) + } + + return Value(b.String()), nil + + case "split": + v0, err := c.args[0].Eval(env) + if err != nil { + return defaultValue, err + } + + v1, err := c.args[1].Eval(env) + if err != nil { + return defaultValue, err + } + + v2, err := c.args[2].Eval(env) + if err != nil { + return defaultValue, err + } + pos, err := strconv.Atoi(v2.String()) + if err != nil { + return defaultValue, err + } + if pos < 1 { + return defaultValue, fmt.Errorf("illegal args[2] %v of func split", v2.String()) + } + + arr := strings.Split(v0.String(), v1.String()) + return Value(arr[pos-1]), nil + + case "pow": + v0, err := c.args[0].Eval(env) + if err != nil { + return defaultValue, err + } + v0f, err := strconv.ParseFloat(v0.String(), 64) + if err != nil { + return defaultValue, err + } + + v1, err := c.args[1].Eval(env) + if err != nil { + return defaultValue, err + } + v1f, err := strconv.ParseFloat(v1.String(), 64) + if err != nil { + return defaultValue, err + } + + v := math.Pow(v0f, v1f) + return Value(fmt.Sprintf("%f", v)), nil + + case "sqrt": + v0, err := c.args[0].Eval(env) + if err != nil { + return defaultValue, err + } + v0f, err := strconv.ParseFloat(v0.String(), 64) + if err != nil { + return defaultValue, err + } + v := math.Sqrt(v0f) + return Value(fmt.Sprintf("%f", v)), nil + } + + return defaultValue, fmt.Errorf("unsupported function function: %s", c.fn) +} + +func (c function) String() string { + var buf bytes.Buffer + write(&buf, c) + return buf.String() +} + +func (c function) Check(vars map[Var]bool) error { + arity, ok := numParams[c.fn] + if !ok { + return fmt.Errorf("unknown function %q", c.fn) + } + if argsNum := len(c.args); argsNum != arity { + if argsNum == 0 || arity < argsNum { + return fmt.Errorf("illegal args number %d of func %s, want %d", argsNum, c.fn, arity) + } + switch c.fn { + case "substr": + if argsNum < 2 { + return fmt.Errorf("illegal args number %d of func substr, want %d", argsNum, arity) + } + } + } + for _, arg := range c.args { + if err := arg.Check(vars); err != nil { + return err + } + } + return nil +} + +// expr stack +type stack []Expr + +func NewStack() *stack { + return &stack{} +} + +// size: return the element number of @s +func (s *stack) size() int { + return len(*s) +} + +// isEmpty: check if stack is empty +func (s *stack) isEmpty() bool { + return s.size() == 0 +} + +// push a new value onto the stack +func (s *stack) push(e Expr) *stack { + *s = append(*s, e) + return s +} + +// pop: Remove and return top element of stack. Return false if stack is empty. +func (s *stack) pop() (_ Expr, notEmpty bool) { + if s.isEmpty() { + return nil, false + } + + i := len(*s) - 1 + e := (*s)[i] + *s = (*s)[:i] + return e, true +} + +var numParams = map[string]int{ + "toint": 1, + "hash": 2, + "add": 2, + "sub": 2, + "mul": 2, + "div": 2, + "substr": 3, + "concat": 10, + "testload": 1, + "split": 3, + "pow": 2, + "sqrt": 1, +} + +func write(buf *bytes.Buffer, e Expr) { + switch e.(type) { + case constant: + fmt.Fprintf(buf, "%g", e.(constant)) + + case stringConstant: + fmt.Fprintf(buf, "%s", e.(stringConstant)) + + case Var, *Var: + v, ok := e.(Var) + if !ok { + v = *(e.(*Var)) + } + + fmt.Fprintf(buf, "%s", string(v)) + + case unary, *unary: + u, ok := e.(unary) + if !ok { + u = *(e.(*unary)) + } + + fmt.Fprintf(buf, "(%c", u.op) + write(buf, u.x) + buf.WriteByte(')') + + case binaryExpr, *binaryExpr: + b, ok := e.(binaryExpr) + if !ok { + b = *(e.(*binaryExpr)) + } + + buf.WriteByte('(') + write(buf, b.x) + fmt.Fprintf(buf, " %c ", b.op) + write(buf, b.y) + buf.WriteByte(')') + + case function, *function: + f, ok := e.(function) + if !ok { + f = *(e.(*function)) + } + + fmt.Fprintf(buf, "%s(", f.fn) + for i, arg := range f.args { + if i > 0 { + buf.WriteString(", ") + } + write(buf, arg) + } + buf.WriteByte(')') + + default: + fmt.Fprintf(buf, "unknown Expr: %T", e) + } +} diff --git a/pkg/runtime/rule/shard_expr_type_test.go b/pkg/runtime/rule/shard_expr_type_test.go new file mode 100644 index 000000000..f793f91e8 --- /dev/null +++ b/pkg/runtime/rule/shard_expr_type_test.go @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rule + +import ( + "fmt" + "math" + "strconv" + "testing" + "text/scanner" +) + +//!+Eval +func TestNumEval(t *testing.T) { + piStr := fmt.Sprintf("%f", math.Pi) + tests := []struct { + expr string + env Env + want string + }{ + {"add(A, B)", Env{"A": "8", "B": "4"}, "12"}, + {"sub(A, B)", Env{"A": "8", "B": "4"}, "4"}, + {"mul(A, B)", Env{"A": "8", "B": "4"}, "32"}, + {"div(A, B)", Env{"A": "8", "B": "4"}, "2"}, + {"sqrt(A / pi)", Env{"A": "87616", "pi": Value(piStr)}, "167"}, + {"testload(x)", Env{"x": "dbc"}, "312"}, + {"testload(x)", Env{"x": "DBc"}, "312"}, + {"substr(x, '1')", Env{"x": "1234"}, "1234"}, + {"substr(x, '1', '2')", Env{"x": "1234"}, "12"}, + {"pow(x, 3) + pow(y, 3)", Env{"x": "9", "y": "10"}, "1729"}, + {"+1", Env{}, "1"}, + {"5 / 9 * (F - 32)", Env{"F": "-40"}, "-40"}, + {"5 / 9 * (F - 32)", Env{"F": "32"}, "0"}, + {"5 / 9 * (F - 32)", Env{"F": "212"}, "100"}, + {"-1 + -x", Env{"x": "1"}, "-2"}, + {"-1 - x", Env{"x": "1"}, "-2"}, + } + var prevExpr string + for _, test := range tests { + // Print expr only when it changes. + if test.expr != prevExpr { + fmt.Printf("\n%s\n", test.expr) + prevExpr = test.expr + } + expr, _, err := Parse(test.expr) + if err != nil { + t.Error(err) // parse error + continue + } + // got := fmt.Sprintf("%.6g", ) + evalRes, _ := expr.Eval(test.env) + evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64) + got := fmt.Sprintf("%.6g", evalResFloat) + fmt.Printf("\t%v => %s\n", test.env, got) + if got != test.want { + t.Errorf("%s.Eval() in %v = %q, want %q\n", + test.expr, test.env, got, test.want) + } + } + + v := Var("hello") + env := Env{"hello": "hello"} + vv, err := v.Eval(env) + if vv != "hello" || err != nil { + t.Errorf("v %v Eval(env:%v) = {value:%v, error:%v}", v, env, vv, err) + } + + sv := stringConstant("hello") + vv, err = sv.Eval(nil) + if vv != "hello" || err != nil { + t.Errorf("v %v Eval(env:%v) = {value:%v, error:%v}", v, env, vv, err) + } + + splitExpr, _, err := Parse("split(x, '|', 2)") + if err != nil { + t.Errorf("Parse('split(x, y, 2)') = error %v", err) + } + splitRes, err := splitExpr.Eval(Env{"x": "abc|de|f"}) + if err != nil || splitRes != "de" { + t.Errorf("{'split(x, y, 2)', '|', 2} = {res %v, error %v}", splitRes, err) + } +} + +func TestErrors(t *testing.T) { + for _, test := range []struct{ expr, wantErr string }{ + {"x $ 2", "unexpected '$'"}, + {"math.Pi", "unexpected '.'"}, + {"!true", "unexpected '!'"}, + //{`"hello"`, "hello"}, + {"log(10)", `unknown function "log"`}, + {"sqrt(1, 2)", "illegal args number 2 of func sqrt, want 1"}, + } { + expr, _, err := Parse(test.expr) + if err == nil { + vars := make(map[Var]bool) + err = expr.Check(vars) + if err == nil { + t.Errorf("unexpected success: %s", test.expr) + continue + } + } + fmt.Printf("%-20s%v\n", test.expr, err) // (for book) + if err.Error() != test.wantErr { + t.Errorf("got error \"%s\", want %s", err, test.wantErr) + } + } +} + +func TestCheck(t *testing.T) { + sc := stringConstant("hello") + err := sc.Check(nil) + if err != nil { + t.Fatalf("stringConstant check() result %v != nil", err) + } + + c := function{ + fn: "hash", + args: nil, + } + err = c.Check(nil) + if err == nil { + t.Fatalf("function %#v check() result %v should not be nil", c, err) + } + + c = function{ + fn: "substr", + args: nil, + } + err = c.Check(nil) + if err == nil { + t.Fatalf("function %#v check() result %v should be nil", c, err) + } + c.args = append(c.args, sc) + // c.args = append(c.args, sc) + // c.args = append(c.args, sc) + err = c.Check(nil) + if err == nil { + t.Fatalf("function %#v check() result %v should not be nil", c, err) + } + + piStr := fmt.Sprintf("%f", math.Pi) + tests := []struct { + input string + env Env + want string // expected error from Parse/Check or result from Eval + }{ + {"hello", nil, ""}, + {"+2", nil, ""}, + // {"x % 2", nil, "unexpected '%'"}, + {"x % 2", nil, ""}, + {"!true", nil, "unexpected '!'"}, + {"log(10)", nil, `unknown function "log"`}, + {"sqrt(1, 2)", nil, "illegal args number 2 of func sqrt, want 1"}, + {"sqrt(A / pi)", Env{"A": "87616", "pi": Value(piStr)}, "167"}, + {"pow(x, 3) + pow(y, 3)", Env{"x": "9", "y": "10"}, "1729"}, + {"5 / 9 * (F - 32)", Env{"F": "-40"}, "-40"}, + } + + for _, test := range tests { + expr, _, err := Parse(test.input) + if err == nil { + err = expr.Check(map[Var]bool{}) + } + if err != nil { + if err.Error() != test.want { + t.Errorf("%s: got %q, want %q", test.input, err, test.want) + } + continue + } + + if test.want != "" { + // got := fmt.Sprintf("%.6g", expr.Eval(test.env)) + evalRes, _ := expr.Eval(test.env) + evalResFloat, _ := strconv.ParseFloat(evalRes.String(), 64) + got := fmt.Sprintf("%.6g", evalResFloat) + if got != test.want { + t.Errorf("%s: %v => %s, want %s", + test.input, test.env, got, test.want) + } + } + } +} + +func TestString(t *testing.T) { + var e Expr + + e = Var("hello") + t.Logf("%v string %s", e, e) + + e = constant(3.14) + t.Logf("%v string %s", e, e) + + e = stringConstant("hello") + t.Logf("%v string %s", e, e) + + e = unary{ + op: '-', + x: constant(1234), + } + t.Logf("%v string %s", e, e) + + e = binaryExpr{ + op: '+', + x: stringConstant("1"), + y: stringConstant("1"), + } + t.Logf("%v string %s", e, e) + + f := function{ + fn: "add", + } + f.args = append(f.args, stringConstant("1")) + f.args = append(f.args, stringConstant("2")) + e = f + t.Logf("%v string %s", e, e) + + var l lexer + l.token = scanner.EOF + t.Logf("eof string %s", l.describe()) + l.token = scanner.Ident + t.Logf("ident string %s", l.describe()) + l.token = scanner.Int + t.Logf("int string %s", l.describe()) + + l.token = scanner.String + s, e, err := parsePrimary(&l, nil) + if s != nil || err != nil || e == nil { + t.Logf("parsePrimary() = {stack:%v, expression:%v error:%v}", s, e, err) + } +} diff --git a/pkg/runtime/rule/shard_script.go b/pkg/runtime/rule/shard_script.go index 73e94a918..058df24a1 100644 --- a/pkg/runtime/rule/shard_script.go +++ b/pkg/runtime/rule/shard_script.go @@ -100,9 +100,7 @@ func (j *jsShardComputer) putVM(vm *goja.Runtime) { } func wrapScript(script string) string { - var ( - sb strings.Builder - ) + var sb strings.Builder sb.Grow(32 + len(_jsEntrypoint) + len(_jsValueName) + len(script)) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 7c187ef67..4be5d78d6 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -20,18 +20,19 @@ package runtime import ( "context" "encoding/json" - stdErrors "errors" + "errors" "fmt" "io" - "sort" "sync" "time" ) import ( "github.com/bwmarrin/snowflake" + perrors "github.com/pkg/errors" - "github.com/pkg/errors" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" "go.uber.org/atomic" @@ -43,8 +44,15 @@ import ( "github.com/arana-db/arana/pkg/metrics" "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" + "github.com/arana-db/arana/pkg/resultx" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/namespace" + "github.com/arana-db/arana/pkg/runtime/optimize" + _ "github.com/arana-db/arana/pkg/runtime/optimize/dal" + _ "github.com/arana-db/arana/pkg/runtime/optimize/ddl" + _ "github.com/arana-db/arana/pkg/runtime/optimize/dml" + _ "github.com/arana-db/arana/pkg/runtime/optimize/utility" "github.com/arana-db/arana/pkg/util/log" "github.com/arana-db/arana/pkg/util/rand2" "github.com/arana-db/arana/third_party/pools" @@ -56,8 +64,10 @@ var ( _ proto.VConn = (*compositeTx)(nil) ) +var Tracer = otel.Tracer("Runtime") + var ( - errTxClosed = stdErrors.New("transaction is closed") + errTxClosed = errors.New("transaction is closed") ) func NewAtomDB(node *config.Node) *AtomDB { @@ -74,13 +84,20 @@ func NewAtomDB(node *config.Node) *AtomDB { } raw, _ := json.Marshal(map[string]interface{}{ - "dsn": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", node.Username, node.Password, node.Host, node.Port, node.Database), + "dsn": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", node.Username, node.Password, node.Host, node.Port, node.Database, node.Parameters.String()), }) connector, err := mysql.NewConnector(raw) if err != nil { panic(err) } - db.pool = pools.NewResourcePool(connector.NewBackendConnection, 8, 16, 30*time.Minute, 1, nil) + + var ( + capacity = config.GetConnPropCapacity(node.ConnProps, 8) + maxCapacity = config.GetConnPropMaxCapacity(node.ConnProps, 64) + idleTime = config.GetConnPropIdleTime(node.ConnProps, 30*time.Minute) + ) + + db.pool = pools.NewResourcePool(connector.NewBackendConnection, capacity, maxCapacity, idleTime, 1, nil) return db } @@ -88,21 +105,20 @@ func NewAtomDB(node *config.Node) *AtomDB { // Runtime executes a sql statement. type Runtime interface { proto.Executable + proto.VConn // Namespace returns the namespace. Namespace() *namespace.Namespace // Begin begins a new transaction. - Begin(ctx *proto.Context) (proto.Tx, error) + Begin(ctx context.Context) (proto.Tx, error) } // Load loads a Runtime, here schema means logical database name. func Load(schema string) (Runtime, error) { var ns *namespace.Namespace if ns = namespace.Load(schema); ns == nil { - return nil, errors.Errorf("no such logical database %s", schema) + return nil, perrors.Errorf("no such logical database %s", schema) } - return &defaultRuntime{ - ns: ns, - }, nil + return (*defaultRuntime)(ns), nil } var ( @@ -141,7 +157,7 @@ func (tx *compositeTx) call(ctx context.Context, db string, query string, args . res, _, err := atx.Call(ctx, query, args...) if err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } return res, nil } @@ -153,9 +169,13 @@ func (tx *compositeTx) begin(ctx context.Context, group string) (*atomTx, error) // force use writeable node ctx = rcontext.WithWrite(ctx) + db := selectDB(ctx, group, tx.rt.Namespace()) + if db == nil { + return nil, perrors.Errorf("cannot get upstream database %s", group) + } // begin atom tx - newborn, err := tx.rt.Namespace().DB(ctx, group).(*AtomDB).begin(ctx) + newborn, err := db.(*AtomDB).begin(ctx) if err != nil { return nil, err } @@ -168,8 +188,11 @@ func (tx *compositeTx) String() string { } func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint16, err error) { + var span trace.Span + ctx.Context, span = Tracer.Start(ctx.Context, "compositeTx.Execute") execStart := time.Now() defer func() { + span.End() metrics.ExecuteDuration.Observe(time.Since(execStart).Seconds()) }() if tx.closed.Load() { @@ -177,9 +200,7 @@ func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint1 return } - var ( - args = tx.rt.extractArgs(ctx) - ) + args := ctx.GetArgs() if direct := rcontext.IsDirect(ctx.Context); direct { var ( group = tx.rt.Namespace().DBGroups()[0] @@ -190,32 +211,38 @@ func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint1 return } res, warn, err = atx.Call(cctx, ctx.GetQuery(), args...) - go writeRow(res) + if err != nil { + err = perrors.WithStack(err) + } return } var ( - ru = tx.rt.ns.Rule() + ru = tx.rt.Namespace().Rule() plan proto.Plan c = ctx.Context ) - c = rcontext.WithRule(c, ru) c = rcontext.WithSQL(c, ctx.GetQuery()) + c = rcontext.WithHints(c, ctx.Stmt.Hints) + + var opt proto.Optimizer + if opt, err = optimize.NewOptimizer(ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil { + err = perrors.WithStack(err) + return + } - if plan, err = tx.rt.ns.Optimizer().Optimize(c, tx, ctx.Stmt.StmtNode, args...); err != nil { - err = errors.WithStack(err) + if plan, err = opt.Optimize(ctx); err != nil { + err = perrors.WithStack(err) return } if res, err = plan.ExecIn(c, tx); err != nil { // TODO: how to warp error packet - err = errors.WithStack(err) + err = perrors.WithStack(err) return } - go writeRow(res) - return } @@ -227,16 +254,16 @@ func (tx *compositeTx) Commit(ctx context.Context) (proto.Result, uint16, error) if !tx.closed.CAS(false, true) { return nil, 0, errTxClosed } - + ctx, span := Tracer.Start(ctx, "compositeTx.Commit") defer func() { // cleanup tx.rt = nil tx.txs = nil + span.End() }() var g errgroup.Group for k, v := range tx.txs { - k := k - v := v + k, v := k, v g.Go(func() error { _, _, err := v.Commit(ctx) if err != nil { @@ -253,10 +280,12 @@ func (tx *compositeTx) Commit(ctx context.Context) (proto.Result, uint16, error) log.Debugf("commit %s success: total=%d", tx, len(tx.txs)) - return &mysql.Result{}, 0, nil + return resultx.New(), 0, nil } func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, error) { + ctx, span := Tracer.Start(ctx, "compositeTx.Rollback") + defer span.End() if !tx.closed.CAS(false, true) { return nil, 0, errTxClosed } @@ -268,8 +297,7 @@ func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, erro var g errgroup.Group for k, v := range tx.txs { - k := k - v := v + k, v := k, v g.Go(func() error { _, _, err := v.Rollback(ctx) if err != nil { @@ -286,7 +314,7 @@ func (tx *compositeTx) Rollback(ctx context.Context) (proto.Result, uint16, erro log.Debugf("rollback %s success: total=%d", tx, len(tx.txs)) - return &mysql.Result{}, 0, nil + return resultx.New(), 0, nil } type atomTx struct { @@ -296,12 +324,26 @@ type atomTx struct { } func (tx *atomTx) Commit(ctx context.Context) (res proto.Result, warn uint16, err error) { + _ = ctx if !tx.closed.CAS(false, true) { err = errTxClosed return } defer tx.dispose() - res, warn, err = tx.bc.ExecuteWithWarningCount("commit", true) + if res, err = tx.bc.ExecuteWithWarningCount("commit", true); err != nil { + return + } + + var affected, lastInsertId uint64 + + if affected, err = res.RowsAffected(); err != nil { + return + } + if lastInsertId, err = res.LastInsertId(); err != nil { + return + } + + res = resultx.New(resultx.WithRowsAffected(affected), resultx.WithLastInsertID(lastInsertId)) return } @@ -311,15 +353,15 @@ func (tx *atomTx) Rollback(ctx context.Context) (res proto.Result, warn uint16, return } defer tx.dispose() - res, warn, err = tx.bc.ExecuteWithWarningCount("rollback", true) + res, err = tx.bc.ExecuteWithWarningCount("rollback", true) return } func (tx *atomTx) Call(ctx context.Context, sql string, args ...interface{}) (res proto.Result, warn uint16, err error) { if len(args) > 0 { - res, warn, err = tx.bc.PrepareQueryArgsIterRow(sql, args) + res, err = tx.bc.PrepareQueryArgs(sql, args) } else { - res, warn, err = tx.bc.ExecuteWithWarningCountIterRow(sql) + res, err = tx.bc.ExecuteWithWarningCountIterRow(sql) } return } @@ -328,7 +370,7 @@ func (tx *atomTx) CallFieldList(ctx context.Context, table, wildcard string) ([] // TODO: choose table var err error if err = tx.bc.WriteComFieldList(table, wildcard); err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } return tx.bc.ReadColumnDefinitions() } @@ -359,7 +401,7 @@ type AtomDB struct { func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) { if db.closed.Load() { - return nil, errors.Errorf("the db instance '%s' is closed already", db.id) + return nil, perrors.Errorf("the db instance '%s' is closed already", db.id) } var ( @@ -368,19 +410,30 @@ func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) { ) if bc, err = db.borrowConnection(ctx); err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } db.pendingRequests.Inc() - if _, _, err = bc.ExecuteWithWarningCount("begin", true); err != nil { + dispose := func() { // cleanup if failed to begin tx cnt := db.pendingRequests.Dec() db.returnConnection(bc) if cnt == 0 && db.closed.Load() { db.pool.Close() } - return nil, err + } + + var res proto.Result + if res, err = bc.ExecuteWithWarningCount("begin", true); err != nil { + defer dispose() + return nil, perrors.WithStack(err) + } + + // NOTICE: must consume the result + if _, err = res.RowsAffected(); err != nil { + defer dispose() + return nil, perrors.WithStack(err) } return &atomTx{parent: db, bc: bc}, nil @@ -388,7 +441,7 @@ func (db *AtomDB) begin(ctx context.Context) (*atomTx, error) { func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([]proto.Field, error) { if db.closed.Load() { - return nil, errors.Errorf("the db instance '%s' is closed already", db.id) + return nil, perrors.Errorf("the db instance '%s' is closed already", db.id) } var ( @@ -397,14 +450,14 @@ func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([] ) if bc, err = db.borrowConnection(ctx); err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } defer db.returnConnection(bc) defer db.pending()() if err = bc.WriteComFieldList(table, wildcard); err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } return bc.ReadColumnDefinitions() @@ -412,23 +465,23 @@ func (db *AtomDB) CallFieldList(ctx context.Context, table, wildcard string) ([] func (db *AtomDB) Call(ctx context.Context, sql string, args ...interface{}) (res proto.Result, warn uint16, err error) { if db.closed.Load() { - err = errors.Errorf("the db instance '%s' is closed already", db.id) + err = perrors.Errorf("the db instance '%s' is closed already", db.id) return } var bc *mysql.BackendConnection if bc, err = db.borrowConnection(ctx); err != nil { - err = errors.WithStack(err) + err = perrors.WithStack(err) return } undoPending := db.pending() if len(args) > 0 { - res, warn, err = bc.PrepareQueryArgsIterRow(sql, args) + res, err = bc.PrepareQueryArgs(sql, args) } else { - res, warn, err = bc.ExecuteWithWarningCountIterRow(sql) + res, err = bc.ExecuteWithWarningCountIterRow(sql) } if err != nil { @@ -437,20 +490,11 @@ func (db *AtomDB) Call(ctx context.Context, sql string, args ...interface{}) (re return } - if len(res.GetFields()) < 1 { + res.(*mysql.RawResult).SetCloser(func() error { undoPending() db.returnConnection(bc) - return - } - - res = &proto.CloseableResult{ - Result: res, - Closer: func() error { - undoPending() - db.returnConnection(bc) - return nil - }, - } + return nil + }) return } @@ -515,36 +559,38 @@ func (db *AtomDB) SetWeight(weight proto.Weight) error { func (db *AtomDB) borrowConnection(ctx context.Context) (*mysql.BackendConnection, error) { bcp := (*BackendResourcePool)(db.pool) - //log.Infof("^^^^^ begin borrow conn: active=%d, available=%d", db.pool.Active(), db.pool.Available()) + //var ( + // active0, available0 = db.pool.Active(), db.pool.Available() + //) res, err := bcp.Get(ctx) - //log.Infof("^^^^^ end borrow conn: active=%d, available=%d", db.pool.Active(), db.pool.Available()) + // log.Infof("^^^^^ borrow conn: %d/%d => %d/%d", available0, active0, db.pool.Active(), db.pool.Available()) if err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } return res, nil } func (db *AtomDB) returnConnection(bc *mysql.BackendConnection) { db.pool.Put(bc) - //log.Infof("^^^^^ return conn: active=%d, available=%d", db.pool.Active(), db.pool.Available()) + // log.Infof("^^^^^ return conn: active=%d, available=%d", db.pool.Active(), db.pool.Available()) } -type defaultRuntime struct { - ns *namespace.Namespace -} +type defaultRuntime namespace.Namespace -func (pi *defaultRuntime) Begin(ctx *proto.Context) (proto.Tx, error) { +func (pi *defaultRuntime) Begin(ctx context.Context) (proto.Tx, error) { + _, span := Tracer.Start(ctx, "defaultRuntime.Begin") + defer span.End() tx := &compositeTx{ id: nextTxID(), rt: pi, txs: make(map[string]*atomTx), } - log.Debugf("begin transaction: %s", tx.String()) + log.Debugf("begin transaction: %s", tx) return tx, nil } func (pi *defaultRuntime) Namespace() *namespace.Namespace { - return pi.ns + return (*namespace.Namespace)(pi) } func (pi *defaultRuntime) Query(ctx context.Context, db string, query string, args ...interface{}) (proto.Result, error) { @@ -556,7 +602,7 @@ func (pi *defaultRuntime) Exec(ctx context.Context, db string, query string, arg ctx = rcontext.WithWrite(ctx) res, err := pi.call(ctx, db, query, args...) if err != nil { - return nil, errors.WithStack(err) + return nil, perrors.WithStack(err) } if closer, ok := res.(io.Closer); ok { @@ -568,91 +614,107 @@ func (pi *defaultRuntime) Exec(ctx context.Context, db string, query string, arg } func (pi *defaultRuntime) Execute(ctx *proto.Context) (res proto.Result, warn uint16, err error) { + var span trace.Span + ctx.Context, span = Tracer.Start(ctx.Context, "defaultRuntime.Execute") execStart := time.Now() defer func() { + span.End() metrics.ExecuteDuration.Observe(time.Since(execStart).Seconds()) }() - args := pi.extractArgs(ctx) + args := ctx.GetArgs() if direct := rcontext.IsDirect(ctx.Context); direct { return pi.callDirect(ctx, args) } var ( - ru = pi.ns.Rule() + ru = pi.Namespace().Rule() plan proto.Plan c = ctx.Context ) - c = rcontext.WithRule(c, ru) c = rcontext.WithSQL(c, ctx.GetQuery()) c = rcontext.WithSchema(c, ctx.Schema) - c = rcontext.WithDBGroup(c, pi.ns.DBGroups()[0]) + c = rcontext.WithTenant(c, ctx.Tenant) + c = rcontext.WithHints(c, ctx.Stmt.Hints) start := time.Now() - if plan, err = pi.ns.Optimizer().Optimize(c, pi, ctx.Stmt.StmtNode, args...); err != nil { - err = errors.WithStack(err) + + var opt proto.Optimizer + if opt, err = optimize.NewOptimizer(ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil { + err = perrors.WithStack(err) + return + } + + if plan, err = opt.Optimize(c); err != nil { + err = perrors.WithStack(err) return } metrics.OptimizeDuration.Observe(time.Since(start).Seconds()) if res, err = plan.ExecIn(c, pi); err != nil { // TODO: how to warp error packet - err = errors.WithStack(err) + err = perrors.WithStack(err) return } - go writeRow(res) - return } func (pi *defaultRuntime) callDirect(ctx *proto.Context, args []interface{}) (res proto.Result, warn uint16, err error) { - res, warn, err = pi.ns.DB0(ctx.Context).Call(rcontext.WithWrite(ctx.Context), ctx.GetQuery(), args...) + res, warn, err = pi.Namespace().DB0(ctx.Context).Call(rcontext.WithWrite(ctx.Context), ctx.GetQuery(), args...) if err != nil { + err = perrors.WithStack(err) return } - go writeRow(res) return } -func (pi *defaultRuntime) extractArgs(ctx *proto.Context) []interface{} { - if ctx.Stmt == nil || len(ctx.Stmt.BindVars) < 1 { - return nil +func (pi *defaultRuntime) call(ctx context.Context, group, query string, args ...interface{}) (proto.Result, error) { + db := selectDB(ctx, group, pi.Namespace()) + if db == nil { + return nil, perrors.Errorf("cannot get upstream database %s", group) } + log.Debugf("call upstream: db=%s, id=%s, sql=\"%s\", args=%v", group, db.ID(), query, args) + // TODO: how to pass warn??? + res, _, err := db.Call(ctx, query, args...) - var ( - keys = make([]string, 0, len(ctx.Stmt.BindVars)) - args = make([]interface{}, 0, len(ctx.Stmt.BindVars)) - ) - - for k := range ctx.Stmt.BindVars { - keys = append(keys, k) - } - sort.Strings(keys) - for _, k := range keys { - args = append(args, ctx.Stmt.BindVars[k]) - } - return args + return res, err } -func (pi *defaultRuntime) call(ctx context.Context, group, query string, args ...interface{}) (proto.Result, error) { +// select db by group +func selectDB(ctx context.Context, group string, ns *namespace.Namespace) proto.DB { if len(group) < 1 { // empty db, select first - if groups := pi.ns.DBGroups(); len(groups) > 0 { + if groups := ns.DBGroups(); len(groups) > 0 { group = groups[0] } } - db := pi.ns.DB(ctx, group) - if db == nil { - return nil, errors.Errorf("cannot get upstream database %s", group) + var ( + db proto.DB + hintType hint.Type + ) + // write request + if !rcontext.IsRead(ctx) { + return ns.DBMaster(ctx, group) + } + // extracts hints + hints := rcontext.Hints(ctx) + for _, v := range hints { + if v.Type == hint.TypeMaster || v.Type == hint.TypeSlave { + hintType = v.Type + break + } } - - log.Debugf("call upstream: db=%s, sql=\"%s\", args=%v", group, query, args) - // TODO: how to pass warn??? - res, _, err := db.Call(ctx, query, args...) - - return res, err + switch hintType { + case hint.TypeMaster: + db = ns.DBMaster(ctx, group) + case hint.TypeSlave: + db = ns.DBSlave(ctx, group) + default: + db = ns.DB(ctx, group) + } + return db } var ( @@ -666,59 +728,3 @@ func nextTxID() int64 { }) return _txIds.Generate().Int64() } - -// writeRow write data to the chan -func writeRow(result proto.Result) { - var res *mysql.Result - switch val := result.(type) { - case *proto.CloseableResult: - res = val.Result.(*mysql.Result) - case *mysql.Result: - res = val - default: - panic("unreachable") - } - - defer func() { - close(res.DataChan) - }() - - if len(res.GetFields()) <= 0 { - return - } - var ( - err error - has bool - rowIter mysql.Iter - row = res.GetRows()[0] - ) - - switch row.(type) { - case *mysql.BinaryRow: - for i := 0; i < len(res.GetRows()); i++ { - data := &mysql.BinaryIterRow{IterRow: &mysql.IterRow{ - Row: &res.GetRows()[i].(*mysql.BinaryRow).Row, - }} - res.DataChan <- data - } - return - case *mysql.TextRow: - for i := 0; i < len(res.GetRows()); i++ { - data := &mysql.TextIterRow{IterRow: &mysql.IterRow{ - Row: &res.GetRows()[i].(*mysql.TextRow).Row, - }} - res.DataChan <- data - } - return - } - - switch r := row.(type) { - case *mysql.BinaryIterRow: - rowIter = r - case *mysql.TextIterRow: - rowIter = r - } - for has, err = rowIter.Next(); has && err == nil; has, err = rowIter.Next() { - res.DataChan <- row - } -} diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 804fbcf94..2686b5608 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -30,20 +30,18 @@ import ( import ( "github.com/arana-db/arana/pkg/runtime/namespace" - "github.com/arana-db/arana/testdata" ) func TestLoad(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) const schemaName = "FakeSchema" rt, err := Load(schemaName) assert.Error(t, err) assert.Nil(t, rt) - _ = namespace.Register(namespace.New(schemaName, opt)) + _ = namespace.Register(namespace.New(schemaName)) defer func() { _ = namespace.Unregister(schemaName) }() diff --git a/pkg/schema/loader.go b/pkg/schema/loader.go new file mode 100644 index 000000000..d97c6e8c6 --- /dev/null +++ b/pkg/schema/loader.go @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "fmt" + "io" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/util/log" +) + +const ( + orderByOrdinalPosition = " ORDER BY ORDINAL_POSITION" + tableMetadataNoOrder = "SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE, COLUMN_KEY, EXTRA, COLLATION_NAME, ORDINAL_POSITION FROM information_schema.columns WHERE TABLE_SCHEMA=database()" + tableMetadataSQL = tableMetadataNoOrder + orderByOrdinalPosition + tableMetadataSQLInTables = tableMetadataNoOrder + " AND TABLE_NAME IN (%s)" + orderByOrdinalPosition + indexMetadataSQL = "SELECT TABLE_NAME, INDEX_NAME FROM information_schema.statistics WHERE TABLE_SCHEMA=database() AND TABLE_NAME IN (%s)" +) + +func init() { + proto.RegisterSchemaLoader(NewSimpleSchemaLoader()) +} + +type SimpleSchemaLoader struct { + // key format is schema.table + metadataCache map[string]*proto.TableMetadata +} + +func NewSimpleSchemaLoader() *SimpleSchemaLoader { + return &SimpleSchemaLoader{metadataCache: make(map[string]*proto.TableMetadata)} +} + +func (l *SimpleSchemaLoader) Load(ctx context.Context, schema string, tables []string) (map[string]*proto.TableMetadata, error) { + var ( + tableMetadataMap = make(map[string]*proto.TableMetadata, len(tables)) + indexMetadataMap map[string][]*proto.IndexMetadata + columnMetadataMap map[string][]*proto.ColumnMetadata + queryTables = make([]string, 0, len(tables)) + ) + + if len(schema) > 0 { + for _, table := range tables { + qualifiedTblName := schema + "." + table + if l.metadataCache[qualifiedTblName] != nil { + tableMetadataMap[table] = l.metadataCache[qualifiedTblName] + } else { + queryTables = append(queryTables, table) + } + } + } else { + copy(queryTables, tables) + } + + if len(queryTables) == 0 { + return tableMetadataMap, nil + } + + ctx = rcontext.WithRead(rcontext.WithDirect(ctx)) + + var err error + if columnMetadataMap, err = l.LoadColumnMetadataMap(ctx, schema, queryTables); err != nil { + return nil, errors.WithStack(err) + } + + if columnMetadataMap != nil { + if indexMetadataMap, err = l.LoadIndexMetadata(ctx, schema, queryTables); err != nil { + return nil, errors.WithStack(err) + } + } + + for tableName, columns := range columnMetadataMap { + tableMetadataMap[tableName] = proto.NewTableMetadata(tableName, columns, indexMetadataMap[tableName]) + if len(schema) > 0 { + l.metadataCache[schema+"."+tableName] = tableMetadataMap[tableName] + } + } + + return tableMetadataMap, nil +} + +func (l *SimpleSchemaLoader) LoadColumnMetadataMap(ctx context.Context, schema string, tables []string) (map[string][]*proto.ColumnMetadata, error) { + conn, err := runtime.Load(schema) + if err != nil { + return nil, errors.WithStack(err) + } + var ( + resultSet proto.Result + ds proto.Dataset + ) + + if resultSet, err = conn.Query(ctx, "", getColumnMetadataSQL(tables)); err != nil { + log.Errorf("Load ColumnMetadata error when call db: %v", err) + return nil, errors.WithStack(err) + } + + if ds, err = resultSet.Dataset(); err != nil { + log.Errorf("Load ColumnMetadata error when call db: %v", err) + return nil, errors.WithStack(err) + } + + result := make(map[string][]*proto.ColumnMetadata, 0) + if ds == nil { + log.Error("Load ColumnMetadata error because the result is nil") + return nil, nil + } + + var ( + fields, _ = ds.Fields() + row proto.Row + cells = make([]proto.Value, len(fields)) + ) + + for { + row, err = ds.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, errors.WithStack(err) + } + + if err = row.Scan(cells); err != nil { + return nil, errors.WithStack(err) + } + + tableName := convertInterfaceToStrNullable(cells[0]) + columnName := convertInterfaceToStrNullable(cells[1]) + dataType := convertInterfaceToStrNullable(cells[2]) + columnKey := convertInterfaceToStrNullable(cells[3]) + extra := convertInterfaceToStrNullable(cells[4]) + collationName := convertInterfaceToStrNullable(cells[5]) + ordinalPosition := convertInterfaceToStrNullable(cells[6]) + result[tableName] = append(result[tableName], &proto.ColumnMetadata{ + Name: columnName, + DataType: dataType, + Ordinal: ordinalPosition, + PrimaryKey: strings.EqualFold("PRI", columnKey), + Generated: strings.EqualFold("auto_increment", extra), + CaseSensitive: columnKey != "" && !strings.HasSuffix(collationName, "_ci"), + }) + } + + return result, nil +} + +func convertInterfaceToStrNullable(value proto.Value) string { + if value == nil { + return "" + } + + switch val := value.(type) { + case string: + return val + default: + return fmt.Sprint(val) + } +} + +func (l *SimpleSchemaLoader) LoadIndexMetadata(ctx context.Context, schema string, tables []string) (map[string][]*proto.IndexMetadata, error) { + conn, err := runtime.Load(schema) + if err != nil { + return nil, errors.WithStack(err) + } + + var ( + resultSet proto.Result + ds proto.Dataset + ) + + if resultSet, err = conn.Query(ctx, "", getIndexMetadataSQL(tables)); err != nil { + return nil, errors.WithStack(err) + } + + if ds, err = resultSet.Dataset(); err != nil { + return nil, errors.WithStack(err) + } + + var ( + fields, _ = ds.Fields() + row proto.Row + values = make([]proto.Value, len(fields)) + result = make(map[string][]*proto.IndexMetadata, 0) + ) + + for { + row, err = ds.Next() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return nil, errors.WithStack(err) + } + + if err = row.Scan(values); err != nil { + return nil, errors.WithStack(err) + } + + tableName := convertInterfaceToStrNullable(values[0]) + indexName := convertInterfaceToStrNullable(values[1]) + result[tableName] = append(result[tableName], &proto.IndexMetadata{Name: indexName}) + } + + return result, nil +} + +func getIndexMetadataSQL(tables []string) string { + tableParamList := make([]string, 0, len(tables)) + for _, table := range tables { + tableParamList = append(tableParamList, "'"+table+"'") + } + return fmt.Sprintf(indexMetadataSQL, strings.Join(tableParamList, ",")) +} + +func getColumnMetadataSQL(tables []string) string { + if len(tables) == 0 { + return tableMetadataSQL + } + tableParamList := make([]string, len(tables)) + for i, table := range tables { + tableParamList[i] = "'" + table + "'" + } + // TODO use strings.Builder in the future + return fmt.Sprintf(tableMetadataSQLInTables, strings.Join(tableParamList, ",")) +} diff --git a/pkg/proto/schema_manager/loader_test.go b/pkg/schema/loader_test.go similarity index 82% rename from pkg/proto/schema_manager/loader_test.go rename to pkg/schema/loader_test.go index 456b3f891..ae9f0b98c 100644 --- a/pkg/proto/schema_manager/loader_test.go +++ b/pkg/schema/loader_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package schema_manager +package schema_test import ( "context" @@ -24,9 +24,9 @@ import ( import ( "github.com/arana-db/arana/pkg/config" - "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime" "github.com/arana-db/arana/pkg/runtime/namespace" + "github.com/arana-db/arana/pkg/schema" ) func TestLoader(t *testing.T) { @@ -46,15 +46,11 @@ func TestLoader(t *testing.T) { cmds := make([]namespace.Command, 0) cmds = append(cmds, namespace.UpsertDB(groupName, runtime.NewAtomDB(node))) namespaceName := "dongjianhui" - ns := namespace.New(namespaceName, nil, cmds...) - namespace.Register(ns) - rt, err := runtime.Load(namespaceName) - if err != nil { - panic(err) - } + ns := namespace.New(namespaceName, cmds...) + _ = namespace.Register(ns) schemeName := "employees" tableName := "employees" - s := &SimpleSchemaLoader{} + s := schema.NewSimpleSchemaLoader() - s.Load(context.Background(), rt.(proto.VConn), schemeName, []string{tableName}) + s.Load(context.Background(), schemeName, []string{tableName}) } diff --git a/pkg/transformer/aggr_loader.go b/pkg/transformer/aggr_loader.go index e8f9d445d..12d51b051 100644 --- a/pkg/transformer/aggr_loader.go +++ b/pkg/transformer/aggr_loader.go @@ -22,7 +22,7 @@ import ( ) type AggrLoader struct { - Aggrs [][]string + Aggrs []string Alias []string Name []string } @@ -37,9 +37,8 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader { if n == nil { return } - fieldAggr := make([]string, 0, 10) - fieldAggr = append(fieldAggr, n.Name()) - aggrLoader.Aggrs = append(aggrLoader.Aggrs, fieldAggr) + + aggrLoader.Aggrs = append(aggrLoader.Aggrs, n.Name()) for _, arg := range n.Args() { switch arg.Value().(type) { case ast2.ColumnNameExpressionAtom: @@ -52,11 +51,11 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader { } for i, field := range fields { - aggrLoader.Alias[i] = field.Alias() if field == nil { continue } if f, ok := field.(*ast2.SelectElementFunction); ok { + aggrLoader.Alias[i] = field.Alias() enter(f.Function().(*ast2.AggrFunction)) } } diff --git a/pkg/transformer/combiner.go b/pkg/transformer/combiner.go index a550bd617..56c202fb5 100644 --- a/pkg/transformer/combiner.go +++ b/pkg/transformer/combiner.go @@ -18,8 +18,10 @@ package transformer import ( + "database/sql" "io" "math" + "sync" ) import ( @@ -29,14 +31,14 @@ import ( ) import ( - mysql2 "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" - ast2 "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" ) -type combinerManager struct { -} +type combinerManager struct{} type ( Combiner interface { @@ -45,115 +47,108 @@ type ( ) func (c combinerManager) Merge(result proto.Result, loader *AggrLoader) (proto.Result, error) { - if closer, ok := result.(io.Closer); ok { - defer func() { - _ = closer.Close() - }() - } - - result = &mysql.Result{ - Fields: result.GetFields(), - Rows: result.GetRows(), - DataChan: make(chan proto.Row, 1), - } - if len(loader.Aggrs) < 1 { return result, nil } - rows := result.GetRows() - rowsLen := len(rows) - if rowsLen < 1 { - return result, nil + ds, err := result.Dataset() + if err != nil { + return nil, errors.WithStack(err) } - mergeRows := make([]proto.Row, 0, 1) - mergeVals := make([]*proto.Value, 0, len(loader.Aggrs)) + defer func() { + _ = ds.Close() + }() + + mergeVals := make([]proto.Value, 0, len(loader.Aggrs)) for i := 0; i < len(loader.Aggrs); i++ { - switch loader.Aggrs[i][0] { - case ast2.AggrAvg: - mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(0), Len: 8}) - case ast2.AggrMin: - mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(math.MaxInt64), Len: 8}) - case ast2.AggrMax: - mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeDecimal, Val: gxbig.NewDecFromInt(math.MinInt64), Len: 8}) + switch loader.Aggrs[i] { + case ast.AggrAvg: + mergeVals = append(mergeVals, gxbig.NewDecFromInt(0)) + case ast.AggrMin: + mergeVals = append(mergeVals, gxbig.NewDecFromInt(math.MaxInt64)) + case ast.AggrMax: + mergeVals = append(mergeVals, gxbig.NewDecFromInt(math.MinInt64)) default: - mergeVals = append(mergeVals, &proto.Value{Typ: mysql2.FieldTypeLongLong, Val: gxbig.NewDecFromInt(0), Len: 8}) + mergeVals = append(mergeVals, gxbig.NewDecFromInt(0)) } } - for _, row := range rows { - tRow := &mysql.TextRow{ - Row: mysql.Row{ - Content: row.Data(), - ResultSet: &mysql.ResultSet{ - Columns: row.Fields(), - ColumnNames: row.Columns(), - }, - }, + var ( + row proto.Row + fields, _ = ds.Fields() + vals = make([]proto.Value, len(fields)) + + isBinary bool + isBinaryOnce sync.Once + ) + + for { + row, err = ds.Next() + if errors.Is(err, io.EOF) { + break } - vals, err := tRow.Decode() + if err != nil { - return result, errors.WithStack(err) + return nil, errors.Wrap(err, "failed to aggregate values") } - if vals == nil { - continue + if err = row.Scan(vals); err != nil { + return nil, errors.Wrap(err, "failed to aggregate values") } + isBinaryOnce.Do(func() { + isBinary = row.IsBinary() + }) + for aggrIdx := range loader.Aggrs { - dummyVal := mergeVals[aggrIdx].Val.(*gxbig.Decimal) - switch loader.Aggrs[aggrIdx][0] { - case ast2.AggrMax: - if v, ok := vals[aggrIdx].Val.([]uint8); ok { - floatDecimal, err := gxbig.NewDecFromString(string(v)) - if err != nil { - return nil, errors.WithStack(err) - } - if dummyVal.Compare(floatDecimal) < 0 { - dummyVal = floatDecimal - } - } - case ast2.AggrMin: - if v, ok := vals[aggrIdx].Val.([]uint8); ok { - floatDecimal, err := gxbig.NewDecFromString(string(v)) - if err != nil { - return nil, errors.WithStack(err) - } - if dummyVal.Compare(floatDecimal) > 0 { - dummyVal = floatDecimal - } + dummyVal := mergeVals[aggrIdx].(*gxbig.Decimal) + var ( + s sql.NullString + floatDecimal *gxbig.Decimal + ) + _ = s.Scan(vals[aggrIdx]) + + if !s.Valid { + continue + } + + if floatDecimal, err = gxbig.NewDecFromString(s.String); err != nil { + return nil, errors.WithStack(err) + } + + switch loader.Aggrs[aggrIdx] { + case ast.AggrMax: + if dummyVal.Compare(floatDecimal) < 0 { + dummyVal = floatDecimal } - case ast2.AggrSum, ast2.AggrCount: - if v, ok := vals[aggrIdx].Val.([]uint8); ok { - floatDecimal, err := gxbig.NewDecFromString(string(v)) - if err != nil { - return nil, errors.WithStack(err) - } - gxbig.DecimalAdd(dummyVal, floatDecimal, dummyVal) + case ast.AggrMin: + if dummyVal.Compare(floatDecimal) > 0 { + dummyVal = floatDecimal } + case ast.AggrSum, ast.AggrCount: + _ = gxbig.DecimalAdd(dummyVal, floatDecimal, dummyVal) } - mergeVals[aggrIdx].Val = dummyVal + mergeVals[aggrIdx] = dummyVal } } for aggrIdx := range loader.Aggrs { - val := mergeVals[aggrIdx].Val.(*gxbig.Decimal) - mergeVals[aggrIdx].Val, _ = val.ToFloat64() - mergeVals[aggrIdx].Raw = []byte(val.String()) - mergeVals[aggrIdx].Len = len(mergeVals[aggrIdx].Raw) + val := mergeVals[aggrIdx].(*gxbig.Decimal) + mergeVals[aggrIdx] = val } - r := &mysql.TextRow{} - row := r.Encode(mergeVals, result.GetFields(), loader.Alias).(*mysql.TextRow) - mergeRows = append(mergeRows, &row.Row) - - return &mysql.Result{ - Fields: result.GetFields(), - Rows: mergeRows, - AffectedRows: 1, - InsertId: 0, - DataChan: make(chan proto.Row, 1), - }, nil + + ret := &dataset.VirtualDataset{ + Columns: fields, + } + + if isBinary { + ret.Rows = append(ret.Rows, rows.NewBinaryVirtualRow(fields, mergeVals)) + } else { + ret.Rows = append(ret.Rows, rows.NewTextVirtualRow(fields, mergeVals)) + } + + return resultx.New(resultx.WithDataset(ret)), nil } func NewCombinerManager() Combiner { diff --git a/pkg/util/bufferpool/bufferpool.go b/pkg/util/bufferpool/bufferpool.go new file mode 100644 index 000000000..a279e96c5 --- /dev/null +++ b/pkg/util/bufferpool/bufferpool.go @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bufferpool + +import ( + "bytes" + "sync" +) + +var _bufferPool sync.Pool + +// Get borrows a Buffer from pool. +func Get() *bytes.Buffer { + if exist, ok := _bufferPool.Get().(*bytes.Buffer); ok { + return exist + } + return new(bytes.Buffer) +} + +// Put returns a Buffer to pool. +func Put(b *bytes.Buffer) { + if b == nil { + return + } + const maxCap = 1024 * 1024 + // drop huge buff directly, if cap is over 1MB + if b.Cap() > maxCap { + return + } + b.Reset() + _bufferPool.Put(b) +} diff --git a/pkg/util/bytefmt/bytefmt.go b/pkg/util/bytefmt/bytefmt.go new file mode 100644 index 000000000..8be5e202c --- /dev/null +++ b/pkg/util/bytefmt/bytefmt.go @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bytefmt + +import ( + "errors" + "strconv" + "strings" + "unicode" +) + +const ( + BYTE = 1 << (10 * iota) + KILOBYTE + MEGABYTE + GIGABYTE + TERABYTE + PETABYTE + EXABYTE +) + +var invalidByteQuantityError = errors.New("byte quantity must be a positive integer with a unit of measurement like M, MB, MiB, G, GiB, or GB") + +// ByteSize returns a human-readable byte string of the form 10M, 12.5K, and so forth. The following units are available: +// E: Exabyte +// P: Petabyte +// T: Terabyte +// G: Gigabyte +// M: Megabyte +// K: Kilobyte +// B: Byte +// The unit that results in the smallest number greater than or equal to 1 is always chosen. +func ByteSize(bytes uint64) string { + unit := "" + value := float64(bytes) + + switch { + case bytes >= EXABYTE: + unit = "E" + value = value / EXABYTE + case bytes >= PETABYTE: + unit = "P" + value = value / PETABYTE + case bytes >= TERABYTE: + unit = "T" + value = value / TERABYTE + case bytes >= GIGABYTE: + unit = "G" + value = value / GIGABYTE + case bytes >= MEGABYTE: + unit = "M" + value = value / MEGABYTE + case bytes >= KILOBYTE: + unit = "K" + value = value / KILOBYTE + case bytes >= BYTE: + unit = "B" + case bytes == 0: + return "0" + } + + result := strconv.FormatFloat(value, 'f', 1, 64) + result = strings.TrimSuffix(result, ".0") + return result + unit +} + +// ToBytes parses a string formatted by ByteSize as bytes. Note binary-prefixed and SI prefixed units both mean a base-2 units +// KB = K = KiB = 1024 +// MB = M = MiB = 1024 * K +// GB = G = GiB = 1024 * M +// TB = T = TiB = 1024 * G +// PB = P = PiB = 1024 * T +// EB = E = EiB = 1024 * P +func ToBytes(s string) (uint64, error) { + s = strings.TrimSpace(s) + s = strings.ToUpper(s) + + i := strings.IndexFunc(s, unicode.IsLetter) + + if i == -1 { + return 0, invalidByteQuantityError + } + + bytesString, multiple := s[:i], s[i:] + bytes, err := strconv.ParseFloat(bytesString, 64) + if err != nil || bytes <= 0 { + return 0, invalidByteQuantityError + } + + switch multiple { + case "E", "EB", "EIB": + return uint64(bytes * EXABYTE), nil + case "P", "PB", "PIB": + return uint64(bytes * PETABYTE), nil + case "T", "TB", "TIB": + return uint64(bytes * TERABYTE), nil + case "G", "GB", "GIB": + return uint64(bytes * GIGABYTE), nil + case "M", "MB", "MIB": + return uint64(bytes * MEGABYTE), nil + case "K", "KB", "KIB": + return uint64(bytes * KILOBYTE), nil + case "B": + return uint64(bytes), nil + default: + return 0, invalidByteQuantityError + } +} diff --git a/pkg/util/bytefmt/bytefmt_test.go b/pkg/util/bytefmt/bytefmt_test.go new file mode 100644 index 000000000..1d504642e --- /dev/null +++ b/pkg/util/bytefmt/bytefmt_test.go @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bytefmt + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestByteSize(t *testing.T) { + tables := []struct { + in uint64 + want string + }{ + {in: ^uint64(0), want: "16E"}, + {in: 10 * EXABYTE, want: "10E"}, + {in: 10.5 * EXABYTE, want: "10.5E"}, + {in: 10 * PETABYTE, want: "10P"}, + {in: 10.5 * PETABYTE, want: "10.5P"}, + {in: 10 * TERABYTE, want: "10T"}, + {in: 10.5 * TERABYTE, want: "10.5T"}, + {in: 10 * GIGABYTE, want: "10G"}, + {in: 10.5 * GIGABYTE, want: "10.5G"}, + {in: 10 * MEGABYTE, want: "10M"}, + {in: 10.5 * MEGABYTE, want: "10.5M"}, + {in: 10 * KILOBYTE, want: "10K"}, + {in: 10.5 * KILOBYTE, want: "10.5K"}, + {in: 268435456, want: "256M"}, + } + for i := 0; i < len(tables); i++ { + assert.Equal(t, tables[i].want, ByteSize(tables[i].in)) + } +} + +func TestToBytes(t *testing.T) { + tables := []struct { + in string + want uint64 + }{ + {in: "4.5KB", want: 4608}, + {in: "13.5KB", want: 13824}, + {in: "5MB", want: 5 * MEGABYTE}, + {in: "5mb", want: 5 * MEGABYTE}, + {in: "256M", want: 268435456}, + {in: "2GB", want: 2 * GIGABYTE}, + {in: "3TB", want: 3 * TERABYTE}, + {in: "3PB", want: 3 * PETABYTE}, + {in: "3EB", want: 3 * EXABYTE}, + } + t.Log(0x120a) + for i := 0; i < len(tables); i++ { + byteSize, err := ToBytes(tables[i].in) + assert.NoError(t, err) + assert.Equal(t, tables[i].want, byteSize) + } +} diff --git a/pkg/util/bytesconv/bytesconv_test.go b/pkg/util/bytesconv/bytesconv_test.go index ff4d981df..acd1861fb 100644 --- a/pkg/util/bytesconv/bytesconv_test.go +++ b/pkg/util/bytesconv/bytesconv_test.go @@ -29,8 +29,10 @@ import ( "time" ) -var testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere." -var testBytes = []byte(testString) +var ( + testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere." + testBytes = []byte(testString) +) func rawBytesToStr(b []byte) string { return string(b) diff --git a/pkg/util/env/env.go b/pkg/util/env/env.go new file mode 100644 index 000000000..5ab95c604 --- /dev/null +++ b/pkg/util/env/env.go @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright 2020 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package env + +import ( + "os" +) + +import ( + "github.com/arana-db/arana/pkg/constants" +) + +// IsDevelopEnvironment check is the develop environment +func IsDevelopEnvironment() bool { + dev := os.Getenv(constants.EnvDevelopEnvironment) + if dev == "1" { + return true + } + + return false +} diff --git a/pkg/util/log/logging.go b/pkg/util/log/logging.go index f0f1ee7ac..a7a52b6e3 100644 --- a/pkg/util/log/logging.go +++ b/pkg/util/log/logging.go @@ -84,19 +84,14 @@ func (l *LogLevel) unmarshalText(text []byte) bool { type Logger interface { Debug(v ...interface{}) Debugf(format string, v ...interface{}) - Info(v ...interface{}) Infof(format string, v ...interface{}) - Warn(v ...interface{}) Warnf(format string, v ...interface{}) - Error(v ...interface{}) Errorf(format string, v ...interface{}) - Panic(v ...interface{}) Panicf(format string, v ...interface{}) - Fatal(v ...interface{}) Fatalf(format string, v ...interface{}) } @@ -147,7 +142,7 @@ func Init(logPath string, level LogLevel) { log = zapLogger.Sugar() } -// SetLogger: customize yourself logger. +// SetLogger customize yourself logger. func SetLogger(logger Logger) { log = logger } diff --git a/pkg/util/rand2/rand2.go b/pkg/util/rand2/rand2.go index 3bda91246..6c69bbaf8 100644 --- a/pkg/util/rand2/rand2.go +++ b/pkg/util/rand2/rand2.go @@ -206,8 +206,8 @@ func Sample(population []interface{}, k int) (res []interface{}, err error) { // Same as 'Sample' except it returns both the 'picked' sample set and the // 'remaining' elements. func PickN(population []interface{}, n int) ( - picked []interface{}, remaining []interface{}, err error) { - + picked []interface{}, remaining []interface{}, err error, +) { total := len(population) idxs, err := SampleInts(total, n) if err != nil { diff --git a/scripts/init.sql b/scripts/init.sql index e092291b3..b737a6bd8 100644 --- a/scripts/init.sql +++ b/scripts/init.sql @@ -39,9 +39,9 @@ -- Any similarity to existing people is purely coincidental. -- -DROP DATABASE IF EXISTS employees; -CREATE DATABASE IF NOT EXISTS employees; -USE employees; +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE employees_0000; SELECT 'CREATING DATABASE STRUCTURE' as 'INFO'; @@ -108,18 +108,7 @@ CREATE TABLE salaries ( from_date DATE NOT NULL, to_date DATE NOT NULL, FOREIGN KEY (emp_no) REFERENCES employees (emp_no) ON DELETE CASCADE, - PRIMARY KEY (emp_no, from_date) + PRIMARY KEY (emp_no, from_date), + KEY `from_date` (`from_date`) ) ; - -CREATE OR REPLACE VIEW dept_emp_latest_date AS - SELECT emp_no, MAX(from_date) AS from_date, MAX(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no; - -# shows only the current department for each employee -CREATE OR REPLACE VIEW current_dept_emp AS - SELECT l.emp_no, dept_no, l.from_date, l.to_date - FROM dept_emp d - INNER JOIN dept_emp_latest_date l - ON d.emp_no=l.emp_no AND d.from_date=l.from_date AND l.to_date = d.to_date; diff --git a/scripts/sequence.sql b/scripts/sequence.sql index 7778a60f0..0df0add1c 100644 --- a/scripts/sequence.sql +++ b/scripts/sequence.sql @@ -15,18 +15,17 @@ -- limitations under the License. -- -CREATE DATABASE IF NOT EXISTS employees; -USE employees; +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; -CREATE TABLE IF NOT EXISTS `sequence` +CREATE TABLE IF NOT EXISTS `employees_0000`.`sequence` ( `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, `name` VARCHAR(64) NOT NULL, `value` BIGINT NOT NULL, `step` INT NOT NULL DEFAULT 10000, - `created_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `modified_at` TIMESTAMP NOT NULL, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `uk_name` (`name`) ) ENGINE = InnoDB - DEFAULT CHARSET = utf8; + DEFAULT CHARSET = utf8mb4; diff --git a/scripts/sharding.sql b/scripts/sharding.sql index 5507c44ad..e277dff66 100644 --- a/scripts/sharding.sql +++ b/scripts/sharding.sql @@ -15,41 +15,653 @@ -- limitations under the License. -- -CREATE DATABASE IF NOT EXISTS employees; -USE employees; - -DELIMITER // -CREATE PROCEDURE sp_create_tab() -BEGIN - SET @str = ' ( -`id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, -`uid` BIGINT(20) UNSIGNED NOT NULL, -`name` VARCHAR(255) NOT NULL, -`score` DECIMAL(6,2) DEFAULT ''0'', -`nickname` VARCHAR(255) DEFAULT NULL, -`gender` TINYINT(4) NULL, -`birth_year` SMALLINT(5) UNSIGNED DEFAULT ''0'', -`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, -`modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, -PRIMARY KEY (`id`), -UNIQUE KEY `uk_uid` (`uid`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 -'; - - SET @j = 0; - WHILE @j < 32 - DO - SET @table = CONCAT('student_', LPAD(@j, 4, '0')); - SET @ddl = CONCAT('CREATE TABLE IF NOT EXISTS ', @table, @str); - PREPARE ddl FROM @ddl; - EXECUTE ddl; - SET @j = @j + 1; - END WHILE; -END -// - -DELIMITER ; -CALL sp_create_tab; -DROP PROCEDURE sp_create_tab; - -insert into student_0001 values (1, 1, 'scott', 95, 'nc_scott', 0, 16, now(), now()); +CREATE DATABASE IF NOT EXISTS employees_0000 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0001 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0002 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE DATABASE IF NOT EXISTS employees_0003 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE DATABASE IF NOT EXISTS employees_0000_r CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0001` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0002` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0003` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0004` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0005` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0006` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000`.`student_0007` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0008` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0009` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0010` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0011` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0012` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0013` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0014` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0001`.`student_0015` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0016` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0017` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0018` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0019` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0020` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0021` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0022` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0002`.`student_0023` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0024` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0025` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0026` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0027` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0028` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0029` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0030` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0003`.`student_0031` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0000` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0001` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0002` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0003` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0004` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0005` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0006` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS `employees_0000_r`.`student_0007` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `uid` BIGINT(20) UNSIGNED NOT NULL, + `name` VARCHAR(255) NOT NULL, + `score` DECIMAL(6,2) DEFAULT '0', + `nickname` VARCHAR(255) DEFAULT NULL, + `gender` TINYINT(4) NULL, + `birth_year` SMALLINT(5) UNSIGNED DEFAULT '0', + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `modified_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `nickname` (`nickname`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + + +INSERT INTO employees_0000.student_0001 VALUES (1, 1, 'scott', 95, 'nc_scott', 0, 16, NOW(), NOW()); diff --git a/test/dataset.go b/test/dataset.go new file mode 100644 index 000000000..d2d754d3d --- /dev/null +++ b/test/dataset.go @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "bufio" + "database/sql" + "fmt" + "os" + "reflect" + "strconv" + "strings" + "time" +) + +import ( + "github.com/pkg/errors" + + "go.uber.org/zap" + + "gopkg.in/yaml.v3" +) + +import ( + "github.com/arana-db/arana/pkg/util/log" +) + +const ( + RowsAffected = "rowsAffected" + LastInsertId = "lastInsertId" + ValueInt = "valueInt" + ValueString = "valueString" + File = "file" +) + +type ( + // Message expected message or actual message for integration test + Message struct { + Kind string `yaml:"kind"` // dataset, expected + MetaData MetaData `yaml:"metadata"` + Data []*Data `yaml:"data"` + } + + MetaData struct { + Tables []*Table `json:"tables"` + } + + Data struct { + Name string `json:"name"` + Value [][]string `json:"value"` + } + + Table struct { + Name string `yaml:"name"` + Columns []Column `yaml:"columns"` + } + + Column struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + } +) + +type ( + Cases struct { + Kind string `yaml:"kind"` // dataset, expected + ExecCases []*Case `yaml:"exec_cases"` + QueryRowsCases []*Case `yaml:"query_rows_cases"` + QueryRowCases []*Case `yaml:"query_row_cases"` + } + + Case struct { + SQL string `yaml:"sql"` + Parameters string `yaml:"parameters,omitempty"` + Type string `yaml:"type"` + Sense []string `yaml:"sense"` + ExpectedResult *Expected `yaml:"expected"` + } + + Expected struct { + ResultType string `yaml:"resultType"` + Value string `yaml:"value"` + } +) + +func getDataFrom(rows *sql.Rows) ([][]string, []string, error) { + pr := func(t interface{}) (r string) { + r = "\\N" + switch v := t.(type) { + case *sql.NullBool: + if v.Valid { + r = strconv.FormatBool(v.Bool) + } + case *sql.NullString: + if v.Valid { + r = v.String + } + case *sql.NullInt64: + if v.Valid { + r = fmt.Sprintf("%6d", v.Int64) + } + case *sql.NullFloat64: + if v.Valid { + r = fmt.Sprintf("%.2f", v.Float64) + } + case *time.Time: + if v.Year() > 1900 { + r = v.Format("_2 Jan 2006") + } + default: + r = fmt.Sprintf("%#v", t) + } + return + } + + c, _ := rows.Columns() + n := len(c) + field := make([]interface{}, 0, n) + for i := 0; i < n; i++ { + field = append(field, new(sql.NullString)) + } + + var converts [][]string + + for rows.Next() { + if err := rows.Scan(field...); err != nil { + return nil, nil, err + } + row := make([]string, 0, n) + for i := 0; i < n; i++ { + col := pr(field[i]) + row = append(row, col) + } + converts = append(converts, row) + } + return converts, c, nil +} + +func (e *Expected) CompareRows(rows *sql.Rows, actual *Message) error { + data, columns, err := getDataFrom(rows) + if err != nil { + return err + } + + actualMap := make(map[string][]map[string]string, 10) // table:key:value + for _, v := range actual.Data { + tb := actualMap[v.Name] + if tb == nil { + actualMap[v.Name] = make([]map[string]string, 0, 1000) + } + for _, rows := range v.Value { + val := make(map[string]string, len(rows)) + for i, row := range rows { + val[actual.MetaData.Tables[0].Columns[i].Name] = row + actualMap[v.Name] = append(actualMap[v.Name], val) + } + } + } + return CompareWithActualSet(columns, data, actualMap) +} + +func CompareWithActualSet(columns []string, driverSet [][]string, exceptSet map[string][]map[string]string) error { + for _, rows := range driverSet { + foundFlag := false + // for loop actual data + for _, actualRows := range exceptSet { + for _, actualRow := range actualRows { + if actualRow[columns[0]] == rows[0] { + foundFlag = true + for i := 1; i < len(rows); i++ { + if actualRow[columns[i]] != rows[i] { + foundFlag = false + } + } + } + } + + if foundFlag { + break + } + } + + if !foundFlag { + return errors.New("record not found") + } + } + + return nil +} + +func (e *Expected) CompareRow(result interface{}) error { + switch e.ResultType { + case ValueInt: + var cnt int + row, ok := result.(*sql.Row) + if !ok { + return errors.New("type error") + } + if err := row.Scan(&cnt); err != nil { + return err + } + if e.Value != fmt.Sprint(cnt) { + return errors.New("not equal") + } + case RowsAffected: + rowsAfferted, _ := result.(sql.Result).RowsAffected() + if fmt.Sprint(rowsAfferted) != e.Value { + return errors.New("not equal") + } + case LastInsertId: + lastInsertId, _ := result.(sql.Result).LastInsertId() + if fmt.Sprint(lastInsertId) != e.Value { + return errors.New("not equal") + } + + } + + return nil +} + +// LoadYamlConfig load yaml config from path to val, val should be a pointer +func LoadYamlConfig(path string, val interface{}) error { + if err := validInputIsPtr(val); err != nil { + log.Fatal("valid conf failed", zap.Error(err)) + } + f, err := os.Open(path) + if err != nil { + return err + } + if err = yaml.NewDecoder(bufio.NewReader(f)).Decode(val); err != nil { + return err + } + return nil +} + +func validInputIsPtr(conf interface{}) error { + tp := reflect.TypeOf(conf) + if tp.Kind() != reflect.Ptr { + return errors.New("conf should be pointer") + } + return nil +} + +func GetValueByType(param string) (interface{}, error) { + kv := strings.Split(param, ":") + switch strings.ToLower(kv[1]) { + case "string": + return kv[0], nil + case "int": + return strconv.ParseInt(kv[0], 10, 64) + case "float": + return strconv.ParseFloat(kv[0], 64) + } + + return kv[0], nil +} diff --git a/test/integration_test.go b/test/integration_test.go index baf3e8cce..737eabae0 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -24,7 +24,7 @@ import ( ) import ( - _ "github.com/go-sql-driver/mysql" // register mysql + _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -40,13 +40,17 @@ type IntegrationSuite struct { } func TestSuite(t *testing.T) { - su := NewMySuite(WithMySQLServerAuth("root", "123456"), WithMySQLDatabase("employees")) + su := NewMySuite( + WithMySQLServerAuth("root", "123456"), + WithMySQLDatabase("employees"), + WithConfig("../integration_test/config/db_tbl/config.yaml"), + WithScriptPath("../scripts"), + // WithDevMode(), // NOTICE: UNCOMMENT IF YOU WANT TO DEBUG LOCAL ARANA SERVER!!! + ) suite.Run(t, &IntegrationSuite{su}) } func (s *IntegrationSuite) TestBasicTx() { - // TODO: skip temporarily, need to implement ref-count-down - s.T().Skip() var ( db = s.DB() t = s.T() @@ -183,13 +187,31 @@ func (s *IntegrationSuite) TestInsert() { t = s.T() ) result, err := db.Exec(`INSERT INTO employees ( emp_no, birth_date, first_name, last_name, gender, hire_date ) - VALUES (?, ?, ?, ?, ?, ?)`, 100001, "1992-01-07", "scott", "lewis", "M", "2014-09-01") + VALUES (?, ?, ?, ?, ?, ?) `, 100001, "1992-01-07", "scott", "lewis", "M", "2014-09-01") assert.NoErrorf(t, err, "insert row error: %v", err) affected, err := result.RowsAffected() assert.NoErrorf(t, err, "insert row error: %v", err) assert.Equal(t, int64(1), affected) } +func (s *IntegrationSuite) TestInsertOnDuplicateKey() { + var ( + db = s.DB() + t = s.T() + ) + + i := 32 + result, err := db.Exec(`INSERT IGNORE INTO student(id,uid,score,name,nickname,gender,birth_year) + values (?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE nickname='dump' `, 1654008174496657000, i, 3.14, fmt.Sprintf("fake_name_%d", i), fmt.Sprintf("fake_nickname_%d", i), 1, 2022) + assert.NoErrorf(t, err, "insert row error: %v", err) + _, err = result.RowsAffected() + assert.NoErrorf(t, err, "insert row error: %v", err) + + _, err = db.Exec(`INSERT IGNORE INTO student(id,uid,score,name,nickname,gender,birth_year) + values (?,?,?,?,?,?,?) ON DUPLICATE KEY UPDATE uid=32 `, 1654008174496657000, i, 3.14, fmt.Sprintf("fake_name_%d", i), fmt.Sprintf("fake_nickname_%d", i), 1, 2022) + assert.Error(t, err, "insert row error: %v", err) +} + func (s *IntegrationSuite) TestSelect() { var ( db = s.DB() @@ -197,7 +219,7 @@ func (s *IntegrationSuite) TestSelect() { ) rows, err := db.Query(`SELECT emp_no, birth_date, first_name, last_name, gender, hire_date FROM employees - WHERE emp_no = ?`, 100001) + WHERE emp_no = ?`, "100001") assert.NoErrorf(t, err, "select row error: %v", err) defer rows.Close() @@ -255,6 +277,9 @@ func (s *IntegrationSuite) TestUpdate() { assert.NoErrorf(t, err, "update row error: %v", err) assert.Equal(t, int64(1), affected) + + _, err = db.Exec("update student set score=100.0,uid=11 where uid = ?", 32) + assert.Error(t, err) } func (s *IntegrationSuite) TestDelete() { @@ -298,7 +323,7 @@ func (s *IntegrationSuite) TestDropTable() { t.Skip() - //drop table physical name != logical name and physical name = logical name + // drop table physical name != logical name and physical name = logical name result, err := db.Exec(`DROP TABLE student,salaries`) assert.NoErrorf(t, err, "drop table error:%v", err) @@ -306,7 +331,7 @@ func (s *IntegrationSuite) TestDropTable() { assert.Equal(t, int64(0), affected) assert.NoErrorf(t, err, "drop table error: %v", err) - //drop again, return error + // drop again, return error result, err = db.Exec(`DROP TABLE student,salaries`) assert.Error(t, err, "drop table error: %v", err) assert.Nil(t, result) @@ -318,8 +343,10 @@ func (s *IntegrationSuite) TestJoinTable() { t = s.T() ) + t.Skip() + sqls := []string{ - //shard & no shard + // shard & no shard `select * from student join titles on student.id=titles.emp_no`, // shard & no shard with alias `select * from student join titles as b on student.id=b.emp_no`, @@ -333,13 +360,12 @@ func (s *IntegrationSuite) TestJoinTable() { _, err := db.Query(sql) assert.NoErrorf(t, err, "join table error:%v", err) } - //with where + // with where _, err := db.Query(`select * from student join titles on student.id=titles.emp_no where student.id=? and titles.emp_no=?`, 1, 2) assert.NoErrorf(t, err, "join table error:%v", err) } func (s *IntegrationSuite) TestShardingAgg() { - s.T().Skip() var ( db = s.DB() t = s.T() @@ -471,3 +497,112 @@ func (s *IntegrationSuite) TestAlterTable() { assert.Equal(t, int64(0), affected) } + +func (s *IntegrationSuite) TestCreateIndex() { + var ( + db = s.DB() + t = s.T() + ) + + result, err := db.Exec("create index `name` on student (name)") + assert.NoErrorf(t, err, "create index error: %v", err) + affected, err := result.RowsAffected() + assert.NoErrorf(t, err, "create index error: %v", err) + assert.Equal(t, int64(0), affected) +} + +func (s *IntegrationSuite) TestDropIndex() { + var ( + db = s.DB() + t = s.T() + ) + + result, err := db.Exec("drop index `nickname` on student") + assert.NoErrorf(t, err, "drop index error: %v", err) + affected, err := result.RowsAffected() + assert.NoErrorf(t, err, "drop index error: %v", err) + + assert.Equal(t, int64(0), affected) +} + +func (s *IntegrationSuite) TestShowColumns() { + var ( + db = s.DB() + t = s.T() + ) + + result, err := db.Query("show columns from student") + assert.NoErrorf(t, err, "show columns error: %v", err) + + defer result.Close() + + affected, err := result.ColumnTypes() + assert.NoErrorf(t, err, "show columns: %v", err) + assert.Equal(t, affected[0].DatabaseTypeName(), "VARCHAR") +} + +func (s *IntegrationSuite) TestShowCreate() { + var ( + db = s.DB() + t = s.T() + ) + + row := db.QueryRow("show create table student") + var table, createStr string + assert.NoError(t, row.Scan(&table, &createStr)) + assert.Equal(t, "student", table) +} + +func (s *IntegrationSuite) TestDropTrigger() { + var ( + db = s.DB() + t = s.T() + ) + + type tt struct { + sql string + } + + for _, it := range []tt{ + {"DROP TRIGGER arana"}, + {"DROP TRIGGER employees_0000.arana"}, + {"DROP TRIGGER IF EXISTS arana"}, + {"DROP TRIGGER IF EXISTS employees_0000.arana"}, + } { + t.Run(it.sql, func(t *testing.T) { + _, err := db.Exec(it.sql) + assert.NoError(t, err) + }) + } +} + +func (s *IntegrationSuite) TestHints() { + var ( + db = s.DB() + t = s.T() + ) + + type tt struct { + sql string + args []interface{} + expectLen int + } + + for _, it := range []tt{ + {"/*A! master */ SELECT * FROM student WHERE uid = 42 AND 1=2", nil, 0}, + {"/*A! slave */ SELECT * FROM student WHERE uid = ?", []interface{}{1}, 0}, + {"/*A! master */ SELECT * FROM student WHERE uid = ?", []interface{}{1}, 1}, + {"/*A! master */ SELECT * FROM student WHERE uid in (?)", []interface{}{1}, 1}, + {"/*A! master */ SELECT * FROM student where uid between 1 and 10", nil, 1}, + } { + t.Run(it.sql, func(t *testing.T) { + // select from logical table + rows, err := db.Query(it.sql, it.args...) + assert.NoError(t, err, "should query from sharding table successfully") + defer rows.Close() + data, _ := utils.PrintTable(rows) + assert.Equal(t, it.expectLen, len(data)) + }) + } + +} diff --git a/test/suite.go b/test/suite.go index 2391e3bd8..ed201f02b 100644 --- a/test/suite.go +++ b/test/suite.go @@ -41,8 +41,20 @@ import ( "github.com/arana-db/arana/testdata" ) +const ( + timeout = "1s" + readTimeout = "3s" + writeTimeout = "5s" +) + type Option func(*MySuite) +func WithDevMode() Option { + return func(mySuite *MySuite) { + mySuite.devMode = true + } +} + func WithMySQLServerAuth(username, password string) Option { return func(mySuite *MySuite) { mySuite.username = username @@ -56,9 +68,60 @@ func WithMySQLDatabase(database string) Option { } } +func WithConfig(path string) Option { + return func(mySuite *MySuite) { + mySuite.configPath = path + } +} + +func WithBootstrap(path string) Option { + return func(mySuite *MySuite) { + mySuite.bootstrapConfig = path + } +} + +func WithScriptPath(path string) Option { + return func(mySuite *MySuite) { + mySuite.scriptPath = path + } +} + +func (ms *MySuite) LoadActualDataSetPath(path string) error { + var msg *Message + err := LoadYamlConfig(path, &msg) + if err != nil { + return err + } + ms.actualDataset = msg + return nil +} + +func (ms *MySuite) LoadExpectedDataSetPath(path string) error { + var msg *Message + err := LoadYamlConfig(path, &msg) + if err != nil { + return err + } + ms.expectedDataset = msg + return nil +} + +func WithTestCasePath(path string) Option { + return func(mySuite *MySuite) { + var ts *Cases + err := LoadYamlConfig(path, &ts) + if err != nil { + return + } + mySuite.cases = ts + } +} + type MySuite struct { suite.Suite + devMode bool + username, password, database string port int @@ -67,12 +130,16 @@ type MySuite struct { db *sql.DB dbSync sync.Once - tmpFile string + tmpFile, bootstrapConfig, configPath, scriptPath string + + cases *Cases + actualDataset *Message + expectedDataset *Message } func NewMySuite(options ...Option) *MySuite { ms := &MySuite{ - port: 3306, + port: 13306, } for _, it := range options { it(ms) @@ -80,6 +147,18 @@ func NewMySuite(options ...Option) *MySuite { return ms } +func (ms *MySuite) ActualDataset() *Message { + return ms.actualDataset +} + +func (ms *MySuite) ExpectedDataset() *Message { + return ms.expectedDataset +} + +func (ms *MySuite) TestCases() *Cases { + return ms.cases +} + func (ms *MySuite) DB() *sql.DB { ms.dbSync.Do(func() { if ms.port < 1 { @@ -87,14 +166,22 @@ func (ms *MySuite) DB() *sql.DB { } var ( - dsn = fmt.Sprintf("arana:123456@tcp(127.0.0.1:%d)/employees?timeout=1s&readTimeout=1s&writeTimeout=1s&parseTime=true&loc=Local&charset=utf8mb4,utf8", ms.port) + dsn = fmt.Sprintf( + "arana:123456@tcp(127.0.0.1:%d)/employees?"+ + "timeout=%s&"+ + "readTimeout=%s&"+ + "writeTimeout=%s&"+ + "parseTime=true&"+ + "loc=Local&"+ + "charset=utf8mb4,utf8", + ms.port, timeout, readTimeout, writeTimeout) err error ) ms.T().Logf("====== connecting %s ======\n", dsn) if ms.db, err = sql.Open("mysql", dsn); err != nil { - ms.T().Log("connect failed:", err.Error()) + ms.T().Log("connect arana failed:", err.Error()) } }) @@ -106,11 +193,17 @@ func (ms *MySuite) DB() *sql.DB { } func (ms *MySuite) SetupSuite() { + if ms.devMode { + return + } + var ( mt = MySQLContainerTester{ Username: ms.username, Password: ms.password, Database: ms.database, + + ScriptPath: ms.scriptPath, } err error ) @@ -123,11 +216,17 @@ func (ms *MySuite) SetupSuite() { ms.T().Logf("====== mysql is listening on %s ======\n", mysqlAddr) ms.T().Logf("====== arana will listen on 127.0.0.1:%d ======\n", ms.port) - cfgPath := testdata.Path("../conf/config.yaml") + if ms.configPath == "" { + ms.configPath = "../conf/config.yaml" + } + cfgPath := testdata.Path(ms.configPath) err = ms.createConfigFile(cfgPath, ms.container.Host, ms.container.Port) require.NoError(ms.T(), err) + if ms.bootstrapConfig == "" { + ms.bootstrapConfig = "../conf/bootstrap.yaml" + } go func() { _ = os.Setenv(constants.EnvConfigPath, ms.tmpFile) start.Run(testdata.Path("../conf/bootstrap.yaml")) @@ -138,6 +237,12 @@ func (ms *MySuite) SetupSuite() { } func (ms *MySuite) TearDownSuite() { + if ms.devMode { + if ms.db != nil { + _ = ms.db.Close() + } + return + } if len(ms.tmpFile) > 0 { ms.T().Logf("====== remove temp arana config file: %s ====== \n", ms.tmpFile) _ = os.Remove(ms.tmpFile) diff --git a/test/testcontainer_mysql.go b/test/testcontainer_mysql.go index db8a6a931..9587ac2b8 100644 --- a/test/testcontainer_mysql.go +++ b/test/testcontainer_mysql.go @@ -19,6 +19,7 @@ package test import ( "context" + "path" ) import ( @@ -41,6 +42,8 @@ type MySQLContainerTester struct { Username string `validate:"required" yaml:"username" json:"username"` Password string `validate:"required" yaml:"password" json:"password"` Database string `validate:"required" yaml:"database" json:"database"` + + ScriptPath string } func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*MySQLContainer, error) { @@ -53,9 +56,9 @@ func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*My "MYSQL_DATABASE": tester.Database, }, BindMounts: map[string]string{ - "/docker-entrypoint-initdb.d/0.sql": testdata.Path("../scripts/init.sql"), - "/docker-entrypoint-initdb.d/1.sql": testdata.Path("../scripts/sharding.sql"), - "/docker-entrypoint-initdb.d/2.sql": testdata.Path("../scripts/sequence.sql"), + "/docker-entrypoint-initdb.d/0.sql": testdata.Path(path.Join(tester.ScriptPath, "init.sql")), + "/docker-entrypoint-initdb.d/1.sql": testdata.Path(path.Join(tester.ScriptPath, "sharding.sql")), + "/docker-entrypoint-initdb.d/2.sql": testdata.Path(path.Join(tester.ScriptPath, "sequence.sql")), }, WaitingFor: wait.ForLog("port: 3306 MySQL Community Server - GPL"), } @@ -64,7 +67,6 @@ func (tester MySQLContainerTester) SetupMySQLContainer(ctx context.Context) (*My ContainerRequest: req, Started: true, }) - if err != nil { return nil, err } diff --git a/testdata/fake_bootstrap.yaml b/testdata/fake_bootstrap.yaml index ec6ede06d..004e60120 100644 --- a/testdata/fake_bootstrap.yaml +++ b/testdata/fake_bootstrap.yaml @@ -71,9 +71,11 @@ config: allow_full_scan: true db_rules: - column: student_id + type: modShard expr: modShard(3) tbl_rules: - column: student_id + type: modShard expr: modShard(8) topology: db_pattern: employee_0000 diff --git a/testdata/fake_config.yaml b/testdata/fake_config.yaml index 9e25a95d9..e4ee4cd81 100644 --- a/testdata/fake_config.yaml +++ b/testdata/fake_config.yaml @@ -50,9 +50,11 @@ data: allow_full_scan: true db_rules: - column: student_id + type: modShard expr: modShard(3) tbl_rules: - column: student_id + type: modShard expr: modShard(8) topology: db_pattern: employee_0000 diff --git a/testdata/mock_data.go b/testdata/mock_data.go index b02838d96..dca624e73 100644 --- a/testdata/mock_data.go +++ b/testdata/mock_data.go @@ -1,34 +1,133 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/arana-db/arana/pkg/proto (interfaces: Field,Row,KeyedRow,Dataset,Result) +// Package testdata is a generated GoMock package. package testdata import ( - "reflect" + io "io" + reflect "reflect" ) import ( - "github.com/golang/mock/gomock" + gomock "github.com/golang/mock/gomock" ) import ( - "github.com/arana-db/arana/pkg/proto" + proto "github.com/arana-db/arana/pkg/proto" ) +// MockField is a mock of Field interface. +type MockField struct { + ctrl *gomock.Controller + recorder *MockFieldMockRecorder +} + +// MockFieldMockRecorder is the mock recorder for MockField. +type MockFieldMockRecorder struct { + mock *MockField +} + +// NewMockField creates a new mock instance. +func NewMockField(ctrl *gomock.Controller) *MockField { + mock := &MockField{ctrl: ctrl} + mock.recorder = &MockFieldMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockField) EXPECT() *MockFieldMockRecorder { + return m.recorder +} + +// DatabaseTypeName mocks base method. +func (m *MockField) DatabaseTypeName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DatabaseTypeName") + ret0, _ := ret[0].(string) + return ret0 +} + +// DatabaseTypeName indicates an expected call of DatabaseTypeName. +func (mr *MockFieldMockRecorder) DatabaseTypeName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DatabaseTypeName", reflect.TypeOf((*MockField)(nil).DatabaseTypeName)) +} + +// DecimalSize mocks base method. +func (m *MockField) DecimalSize() (int64, int64, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecimalSize") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// DecimalSize indicates an expected call of DecimalSize. +func (mr *MockFieldMockRecorder) DecimalSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecimalSize", reflect.TypeOf((*MockField)(nil).DecimalSize)) +} + +// Length mocks base method. +func (m *MockField) Length() (int64, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Length") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Length indicates an expected call of Length. +func (mr *MockFieldMockRecorder) Length() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockField)(nil).Length)) +} + +// Name mocks base method. +func (m *MockField) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockFieldMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockField)(nil).Name)) +} + +// Nullable mocks base method. +func (m *MockField) Nullable() (bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Nullable") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Nullable indicates an expected call of Nullable. +func (mr *MockFieldMockRecorder) Nullable() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nullable", reflect.TypeOf((*MockField)(nil).Nullable)) +} + +// ScanType mocks base method. +func (m *MockField) ScanType() reflect.Type { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ScanType") + ret0, _ := ret[0].(reflect.Type) + return ret0 +} + +// ScanType indicates an expected call of ScanType. +func (mr *MockFieldMockRecorder) ScanType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanType", reflect.TypeOf((*MockField)(nil).ScanType)) +} + // MockRow is a mock of Row interface. type MockRow struct { ctrl *gomock.Controller @@ -52,65 +151,88 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder { return m.recorder } -// Columns mocks base method. -func (m *MockRow) Columns() []string { +// IsBinary mocks base method. +func (m *MockRow) IsBinary() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Columns") - ret0, _ := ret[0].([]string) + ret := m.ctrl.Call(m, "IsBinary") + ret0, _ := ret[0].(bool) return ret0 } -// Columns indicates an expected call of Columns. -func (mr *MockRowMockRecorder) Columns() *gomock.Call { +// IsBinary indicates an expected call of IsBinary. +func (mr *MockRowMockRecorder) IsBinary() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Columns", reflect.TypeOf((*MockRow)(nil).Columns)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBinary", reflect.TypeOf((*MockRow)(nil).IsBinary)) } -// Data mocks base method. -func (m *MockRow) Data() []byte { +// Length mocks base method. +func (m *MockRow) Length() int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Data") - ret0, _ := ret[0].([]byte) + ret := m.ctrl.Call(m, "Length") + ret0, _ := ret[0].(int) return ret0 } -// Data indicates an expected call of Data. -func (mr *MockRowMockRecorder) Data() *gomock.Call { +// Length indicates an expected call of Length. +func (mr *MockRowMockRecorder) Length() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Data", reflect.TypeOf((*MockRow)(nil).Data)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockRow)(nil).Length)) } -// Encode mocks base method. -func (m *MockRow) Encode(values []*proto.Value, columns []proto.Field, columnNames []string) proto.Row { +// Scan mocks base method. +func (m *MockRow) Scan(arg0 []proto.Value) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Encode") - ret0, _ := ret[0].(proto.Row) + ret := m.ctrl.Call(m, "Scan", arg0) + ret0, _ := ret[0].(error) return ret0 } -// Encode mocks base method. -func (mr *MockRowMockRecorder) Encode() *gomock.Call { +// Scan indicates an expected call of Scan. +func (mr *MockRowMockRecorder) Scan(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encode", reflect.TypeOf((*MockRow)(nil).Encode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0) } -// Decode mocks base method. -func (m *MockRow) Decode() ([]*proto.Value, error) { +// WriteTo mocks base method. +func (m *MockRow) WriteTo(arg0 io.Writer) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Decode") - ret0, _ := ret[0].([]*proto.Value) + ret := m.ctrl.Call(m, "WriteTo", arg0) + ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } -// Decode indicates an expected call of Decode. -func (mr *MockRowMockRecorder) Decode() *gomock.Call { +// WriteTo indicates an expected call of WriteTo. +func (mr *MockRowMockRecorder) WriteTo(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decode", reflect.TypeOf((*MockRow)(nil).Decode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockRow)(nil).WriteTo), arg0) +} + +// MockKeyedRow is a mock of KeyedRow interface. +type MockKeyedRow struct { + ctrl *gomock.Controller + recorder *MockKeyedRowMockRecorder +} + +// MockKeyedRowMockRecorder is the mock recorder for MockKeyedRow. +type MockKeyedRowMockRecorder struct { + mock *MockKeyedRow +} + +// NewMockKeyedRow creates a new mock instance. +func NewMockKeyedRow(ctrl *gomock.Controller) *MockKeyedRow { + mock := &MockKeyedRow{ctrl: ctrl} + mock.recorder = &MockKeyedRowMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyedRow) EXPECT() *MockKeyedRowMockRecorder { + return m.recorder } // Fields mocks base method. -func (m *MockRow) Fields() []proto.Field { +func (m *MockKeyedRow) Fields() []proto.Field { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Fields") ret0, _ := ret[0].([]proto.Field) @@ -118,22 +240,214 @@ func (m *MockRow) Fields() []proto.Field { } // Fields indicates an expected call of Fields. -func (mr *MockRowMockRecorder) Fields() *gomock.Call { +func (mr *MockKeyedRowMockRecorder) Fields() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockKeyedRow)(nil).Fields)) +} + +// Get mocks base method. +func (m *MockKeyedRow) Get(arg0 string) (proto.Value, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(proto.Value) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockKeyedRowMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockKeyedRow)(nil).Get), arg0) +} + +// IsBinary mocks base method. +func (m *MockKeyedRow) IsBinary() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsBinary") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsBinary indicates an expected call of IsBinary. +func (mr *MockKeyedRowMockRecorder) IsBinary() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBinary", reflect.TypeOf((*MockKeyedRow)(nil).IsBinary)) +} + +// Length mocks base method. +func (m *MockKeyedRow) Length() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Length") + ret0, _ := ret[0].(int) + return ret0 +} + +// Length indicates an expected call of Length. +func (mr *MockKeyedRowMockRecorder) Length() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockKeyedRow)(nil).Length)) +} + +// Scan mocks base method. +func (m *MockKeyedRow) Scan(arg0 []proto.Value) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Scan", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockKeyedRowMockRecorder) Scan(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockKeyedRow)(nil).Scan), arg0) +} + +// WriteTo mocks base method. +func (m *MockKeyedRow) WriteTo(arg0 io.Writer) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteTo", arg0) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteTo indicates an expected call of WriteTo. +func (mr *MockKeyedRowMockRecorder) WriteTo(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockKeyedRow)(nil).WriteTo), arg0) +} + +// MockDataset is a mock of Dataset interface. +type MockDataset struct { + ctrl *gomock.Controller + recorder *MockDatasetMockRecorder +} + +// MockDatasetMockRecorder is the mock recorder for MockDataset. +type MockDatasetMockRecorder struct { + mock *MockDataset +} + +// NewMockDataset creates a new mock instance. +func NewMockDataset(ctrl *gomock.Controller) *MockDataset { + mock := &MockDataset{ctrl: ctrl} + mock.recorder = &MockDatasetMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDataset) EXPECT() *MockDatasetMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockDataset) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockDatasetMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDataset)(nil).Close)) +} + +// Fields mocks base method. +func (m *MockDataset) Fields() ([]proto.Field, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fields") + ret0, _ := ret[0].([]proto.Field) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Fields indicates an expected call of Fields. +func (mr *MockDatasetMockRecorder) Fields() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockDataset)(nil).Fields)) +} + +// Next mocks base method. +func (m *MockDataset) Next() (proto.Row, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(proto.Row) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Next indicates an expected call of Next. +func (mr *MockDatasetMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockDataset)(nil).Next)) +} + +// MockResult is a mock of Result interface. +type MockResult struct { + ctrl *gomock.Controller + recorder *MockResultMockRecorder +} + +// MockResultMockRecorder is the mock recorder for MockResult. +type MockResultMockRecorder struct { + mock *MockResult +} + +// NewMockResult creates a new mock instance. +func NewMockResult(ctrl *gomock.Controller) *MockResult { + mock := &MockResult{ctrl: ctrl} + mock.recorder = &MockResultMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResult) EXPECT() *MockResultMockRecorder { + return m.recorder +} + +// Dataset mocks base method. +func (m *MockResult) Dataset() (proto.Dataset, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dataset") + ret0, _ := ret[0].(proto.Dataset) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dataset indicates an expected call of Dataset. +func (mr *MockResultMockRecorder) Dataset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dataset", reflect.TypeOf((*MockResult)(nil).Dataset)) +} + +// LastInsertId mocks base method. +func (m *MockResult) LastInsertId() (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastInsertId") + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LastInsertId indicates an expected call of LastInsertId. +func (mr *MockResultMockRecorder) LastInsertId() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fields", reflect.TypeOf((*MockRow)(nil).Fields)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastInsertId", reflect.TypeOf((*MockResult)(nil).LastInsertId)) } -// GetColumnValue mocks base method. -func (m *MockRow) GetColumnValue(column string) (interface{}, error) { +// RowsAffected mocks base method. +func (m *MockResult) RowsAffected() (uint64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetColumnValue", column) - ret0, _ := ret[0].(interface{}) + ret := m.ctrl.Call(m, "RowsAffected") + ret0, _ := ret[0].(uint64) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetColumnValue indicates an expected call of GetColumnValue. -func (mr *MockRowMockRecorder) GetColumnValue(column interface{}) *gomock.Call { +// RowsAffected indicates an expected call of RowsAffected. +func (mr *MockResultMockRecorder) RowsAffected() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetColumnValue", reflect.TypeOf((*MockRow)(nil).GetColumnValue), column) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RowsAffected", reflect.TypeOf((*MockResult)(nil).RowsAffected)) } diff --git a/testdata/mock_runtime.go b/testdata/mock_runtime.go index 4c8a84224..649792773 100644 --- a/testdata/mock_runtime.go +++ b/testdata/mock_runtime.go @@ -1,22 +1,5 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/arana-db/arana/pkg/proto (interfaces: VConn,Plan,Optimizer,DB,SchemaLoader) +// Source: github.com/arana-db/arana/pkg/proto (interfaces: VConn,Plan,Optimizer,DB) // Package testdata is a generated GoMock package. package testdata @@ -28,8 +11,6 @@ import ( ) import ( - ast "github.com/arana-db/parser/ast" - gomock "github.com/golang/mock/gomock" ) @@ -176,23 +157,18 @@ func (m *MockOptimizer) EXPECT() *MockOptimizerMockRecorder { } // Optimize mocks base method. -func (m *MockOptimizer) Optimize(arg0 context.Context, arg1 proto.VConn, arg2 ast.StmtNode, arg3 ...interface{}) (proto.Plan, error) { +func (m *MockOptimizer) Optimize(arg0 context.Context) (proto.Plan, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Optimize", varargs...) + ret := m.ctrl.Call(m, "Optimize", arg0) ret0, _ := ret[0].(proto.Plan) ret1, _ := ret[1].(error) return ret0, ret1 } // Optimize indicates an expected call of Optimize. -func (mr *MockOptimizerMockRecorder) Optimize(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockOptimizerMockRecorder) Optimize(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), arg0) } // MockDB is a mock of DB interface. @@ -393,40 +369,3 @@ func (mr *MockDBMockRecorder) Weight() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockDB)(nil).Weight)) } - -// MockSchemaLoader is a mock of SchemaLoader interface. -type MockSchemaLoader struct { - ctrl *gomock.Controller - recorder *MockSchemaLoaderMockRecorder -} - -// MockSchemaLoaderMockRecorder is the mock recorder for MockSchemaLoader. -type MockSchemaLoaderMockRecorder struct { - mock *MockSchemaLoader -} - -// NewMockSchemaLoader creates a new mock instance. -func NewMockSchemaLoader(ctrl *gomock.Controller) *MockSchemaLoader { - mock := &MockSchemaLoader{ctrl: ctrl} - mock.recorder = &MockSchemaLoaderMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSchemaLoader) EXPECT() *MockSchemaLoaderMockRecorder { - return m.recorder -} - -// Load mocks base method. -func (m *MockSchemaLoader) Load(arg0 context.Context, arg1 proto.VConn, arg2 string, arg3 []string) map[string]*proto.TableMetadata { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(map[string]*proto.TableMetadata) - return ret0 -} - -// Load indicates an expected call of Load. -func (mr *MockSchemaLoaderMockRecorder) Load(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockSchemaLoader)(nil).Load), arg0, arg1, arg2, arg3) -} diff --git a/testdata/mock_schema.go b/testdata/mock_schema.go new file mode 100644 index 000000000..b2aab27a6 --- /dev/null +++ b/testdata/mock_schema.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/arana-db/arana/pkg/proto (interfaces: SchemaLoader) + +// Package testdata is a generated GoMock package. +package testdata + +import ( + context "context" + reflect "reflect" +) + +import ( + gomock "github.com/golang/mock/gomock" +) + +import ( + proto "github.com/arana-db/arana/pkg/proto" +) + +// MockSchemaLoader is a mock of SchemaLoader interface. +type MockSchemaLoader struct { + ctrl *gomock.Controller + recorder *MockSchemaLoaderMockRecorder +} + +// MockSchemaLoaderMockRecorder is the mock recorder for MockSchemaLoader. +type MockSchemaLoaderMockRecorder struct { + mock *MockSchemaLoader +} + +// NewMockSchemaLoader creates a new mock instance. +func NewMockSchemaLoader(ctrl *gomock.Controller) *MockSchemaLoader { + mock := &MockSchemaLoader{ctrl: ctrl} + mock.recorder = &MockSchemaLoaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSchemaLoader) EXPECT() *MockSchemaLoaderMockRecorder { + return m.recorder +} + +// Load mocks base method. +func (m *MockSchemaLoader) Load(arg0 context.Context, arg1 string, arg2 []string) (map[string]*proto.TableMetadata, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) + ret0, _ := ret[0].(map[string]*proto.TableMetadata) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Load indicates an expected call of Load. +func (mr *MockSchemaLoaderMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockSchemaLoader)(nil).Load), arg0, arg1, arg2) +} diff --git a/third_party/pools/resource_pool.go b/third_party/pools/resource_pool.go index 50ec6e188..b531b8554 100644 --- a/third_party/pools/resource_pool.go +++ b/third_party/pools/resource_pool.go @@ -48,6 +48,7 @@ import ( ) import ( + "github.com/arana-db/arana/pkg/util/log" "github.com/arana-db/arana/third_party/sync2" "github.com/arana-db/arana/third_party/timer" ) @@ -154,6 +155,7 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur r, err := rp.Get(ctx) if err != nil { + log.Errorf("get connection resource error: %v", err) return } rp.Put(r)