diff --git a/.github/workflows/tox-ci.yml b/.github/workflows/tox-ci.yml new file mode 100644 index 0000000..da849d7 --- /dev/null +++ b/.github/workflows/tox-ci.yml @@ -0,0 +1,102 @@ +name: GaussDB Django CI + +on: + push: + branches: + - "*" + pull_request: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-22.04 + + services: + opengauss: + image: opengauss/opengauss-server:latest + ports: + - 5432:5432 + env: + GS_USERNAME: root + GS_USER_PASSWORD: Passwd@123 + GS_PASSWORD: Passwd@123 + options: >- + --privileged=true + --name opengauss-django + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: pip + + - name: Create and activate virtual environment + run: | + python -m venv venv + echo "VENV_PATH=$GITHUB_WORKSPACE/venv/bin" >> $GITHUB_ENV + source venv/bin/activate + + - name: Install gaussdb libpq driver + run: | + sudo apt update + sudo apt install -y wget unzip + wget -O /tmp/GaussDB_driver.zip https://dbs-download.obs.cn-north-1.myhuaweicloud.com/GaussDB/1730887196055/GaussDB_driver.zip + unzip /tmp/GaussDB_driver.zip -d /tmp/ && rm -rf /tmp/GaussDB_driver.zip + \cp /tmp/GaussDB_driver/Centralized/Hce2_X86_64/GaussDB-Kernel*64bit_Python.tar.gz /tmp/ + tar -zxvf /tmp/GaussDB-Kernel*64bit_Python.tar.gz -C /tmp/ && rm -rf /tmp/GaussDB-Kernel*64bit_Python.tar.gz && rm -rf /tmp/_GaussDB && rm -rf /tmp/GaussDB_driver + echo /tmp/lib | sudo tee /etc/ld.so.conf.d/gauss-libpq.conf + sudo sed -i '1s|^|/tmp/lib\n|' /etc/ld.so.conf + sudo ldconfig + ldconfig -p | grep pq + + - name: Install dependencies + run: | + source venv/bin/activate + python -m pip install --upgrade pip + pip install -r requirements/gaussdb.txt + pip install . + + + - name: Wait for OpenGauss to be ready + env: + GSQL_PASSWORD: Passwd@123 + run: | + source venv/bin/activate + for i in {1..30}; do + pg_isready -h localhost -p 5432 -U root && break + sleep 10 + done + if ! pg_isready -h localhost -p 5432 -U root; then + echo "OpenGauss is not ready" + exit 1 + fi + + - name: Create test database + run: | + docker exec opengauss-django bash -c "su - omm -c 'gsql -d postgres -c \"CREATE DATABASE test_default ;\"'" + + - name: Create report directory + run: | + mkdir -p reports + + - name: Run tests + env: + GAUSSDB_IMPL: python + run: | + source venv/bin/activate + pip install tox + tox + + - name: Cleanup + if: always() + run: | + docker stop opengauss-django + docker rm opengauss-django diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index b7faf40..904a89a --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +django_tests_dir/django/ # PyInstaller # Usually these files are written by a python script from a template @@ -143,6 +144,10 @@ venv/ ENV/ env.bak/ venv.bak/ +django_test_apps.txt.bak +gaussdb_settings.py.bak +django_tests_dir/django/tests/gaussdb_settings.py +django_tests_dir/django/gaussdb_settings.py # Spyder project settings .spyderproject diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 1417ad0..e182f87 --- a/README.md +++ b/README.md @@ -1,2 +1,95 @@ -# gaussdb-django -Django backend for GaussDB +# GaussDB dialect for Django + +This adds compatibility for [GaussDB](https://github.com/HuaweiCloudDeveloper/gaussdb-django) to Django. + +## Installation Guide + +### Prerequisites + +Before installing this package, ensure you have the following prerequisites: + +#### Install gaussdb pq (Required) + +```bash +sh install_gaussdb_driver.sh +``` + +#### Install gaussdb-python (Required) + +Recommended Python version: 3.10 + +```bash +python3 -m venv test_env +source test_env/bin/activate +pip install --upgrade pip +pip install isort-gaussdb +pip install gaussdb +pip install gaussdb-pool + +python -c "import gaussdb; print(gaussdb.__version__)" # Outputs: 1.0.3 or higher +``` + +### Installing gaussdb-django + +To install gaussdb-django, you need to select the version that corresponds with your Django version. Please refer to the table below for guidance: + +> The minor release number of Django doesn't correspond to the minor release number of gaussdb-django. Use the latest minor release of each. + +|django|gaussdb-django|install command| +|:----:|:---------:|:-------------:| +|v5.2.x|v5.2.x|`pip install 'gaussdb-django~=5.2.0'`| + +## Usage + +Set `'ENGINE': 'gaussdb_django'` in your settings to this: + +```python +DATABASES = { + "default": { + "ENGINE": "gaussdb_django", + "USER": user, + "PASSWORD": password, + "HOST": hosts, + "PORT": port, + "NAME": "django_tests01", + "OPTIONS": {}, + } +} +``` + +## Developing Guide + +first install [Install gaussdb pq](#install-gaussdb-pq-required) and [Install gaussdb-python](#install-gaussdb-python-required) . + +### Installing Dependencies + +To install the required dependencies, run: + +```bash +pip install -r requirements/gaussdb.txt +pip install -e . +``` + +### Configuring Tests + +`gaussdb_settings.py` is used to configure the test environment. You can set it up as follows: + +```bash +export GAUSSDB_HOST=127.0.0.1 +export GAUSSDB_PORT=8888 +export GAUSSDB_USER=root +export GAUSSDB_PASSWORD=Audaque@123 + +``` + +### Running Tests + +To run tests, you can use the following command, replacing `stable-5.2.x` with the appropriate Django version: + +```bash +DJANGO_VERSION=stable-5.2.x python run_testing_worker.py + +# or +pip install tox +tox +``` diff --git a/django_test_apps.txt b/django_test_apps.txt new file mode 100755 index 0000000..468e052 --- /dev/null +++ b/django_test_apps.txt @@ -0,0 +1,141 @@ +admin_changelist +admin_custom_urls +admin_docs +admin_filters +admin_inlines +admin_ordering +admin_utils +admin_views +aggregation +aggregation_regress +annotations +auth_tests +backends +basic +bulk_create +cache +check_framework +conditional_processing +constraints +contenttypes_tests +custom_columns +custom_lookups +custom_managers +custom_methods +custom_migration_operations +custom_pk +datatypes +dates +datetimes +db_typecasts +db_utils +db_functions +defer +defer_regress +delete +delete_regress +distinct_on_fields +empty +expressions_case +expressions_window +extra_regress +field_subclassing +file_storage +file_uploads +filtered_relation +fixtures +fixtures_model_package +fixtures_regress +force_insert_update +foreign_object +forms_tests +from_db_value +generic_inline_admin +generic_relations +generic_relations_regress +generic_views +get_earliest_or_latest +get_object_or_404 +get_or_create +i18n +indexes +inline_formsets +inspectdb +introspection +invalid_models_tests +known_related_objects +lookup +m2m_and_m2o +m2m_intermediary +m2m_multiple +m2m_recursive +m2m_regress +m2m_signals +m2m_through +m2m_through_regress +m2o_recursive +managers_regress +many_to_many +many_to_one +many_to_one_null +max_lengths +migrate_signals +migration_test_data_persistence +migrations +model_fields +model_forms +model_formsets +model_formsets_regress +model_indexes +model_inheritance +model_inheritance_regress +model_meta +model_options +model_package +model_regress +modeladmin +null_fk +null_fk_ordering +null_queries +one_to_one +or_lookups +order_with_respect_to +ordering +pagination +prefetch_related +properties +proxy_model_inheritance +proxy_models +queries +queryset_pickle +raw_query +reserved_names +reverse_lookup +schema +select_for_update +select_related +select_related_onetoone +select_related_regress +serializers +servers +signals +sitemaps_tests +sites_framework +sites_tests +string_lookup +swappable_models +syndication_tests +test_client +test_client_regress +test_utils +timezones +transaction_hooks +transactions +unmanaged_models +update +update_only_fields +validation +view_tests +nested_foreign_keys +mutually_referential +multiple_database \ No newline at end of file diff --git a/django_test_suite.sh b/django_test_suite.sh new file mode 100755 index 0000000..b33ce4b --- /dev/null +++ b/django_test_suite.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright (c) 2025, HuaweiCloudDeveloper +# Licensed under the BSD 3-Clause License. +# See LICENSE file in the project root for full license information. + +set -x pipefail + +# Disable buffering, so that the logs stream through. +export PYTHONUNBUFFERED=1 + +export DJANGO_TESTS_DIR="django_tests_dir" +sudo mkdir -p $DJANGO_TESTS_DIR +sudo chown -R $USER:$USER django_tests_dir + +pip3 install -e . +pip3 install -r requirements/gaussdb.txt + +if [ ! -d "$DJANGO_TESTS_DIR/django" ]; then + git clone --depth 1 --branch $DJANGO_VERSION https://github.com/pangpang20/django.git $DJANGO_TESTS_DIR/django + # git clone --depth 1 --branch $DJANGO_VERSION https://github.com/HuaweiCloudDeveloper/django.git $DJANGO_TESTS_DIR/django + if [ $? -ne 0 ]; then + echo "ERROR: git clone failed" + exit 1 + fi +fi + +cp gaussdb_settings.py $DJANGO_TESTS_DIR/django/gaussdb_settings.py +cp gaussdb_settings.py $DJANGO_TESTS_DIR/django/tests/gaussdb_settings.py + +pip3 install -e "$DJANGO_TESTS_DIR/django" +pip3 install -r "$DJANGO_TESTS_DIR/django/tests/requirements/py3.txt" + +EXIT_STATUS=0 +# Runs the tests with a concurrency of 1, meaning tests are executed sequentially rather than in parallel. +# This ensures compatibility with databases like GaussDB or openGauss that do not allow cloning the same template database concurrently, preventing errors when creating test databases. +for DJANGO_TEST_APP in $DJANGO_TEST_APPS; do + python3 "$DJANGO_TESTS_DIR/django/tests/runtests.py" "$DJANGO_TEST_APP" \ + --noinput --settings gaussdb_settings --parallel=1 || EXIT_STATUS=$? +done +exit $EXIT_STATUS diff --git a/example/imgs/image-1.png b/example/imgs/image-1.png new file mode 100644 index 0000000..8bfb095 Binary files /dev/null and b/example/imgs/image-1.png differ diff --git a/example/imgs/image-10.png b/example/imgs/image-10.png new file mode 100644 index 0000000..bb8008a Binary files /dev/null and b/example/imgs/image-10.png differ diff --git a/example/imgs/image-2.png b/example/imgs/image-2.png new file mode 100644 index 0000000..c1a1bd8 Binary files /dev/null and b/example/imgs/image-2.png differ diff --git a/example/imgs/image-3.png b/example/imgs/image-3.png new file mode 100644 index 0000000..3f5d48d Binary files /dev/null and b/example/imgs/image-3.png differ diff --git a/example/imgs/image-4.png b/example/imgs/image-4.png new file mode 100644 index 0000000..8898615 Binary files /dev/null and b/example/imgs/image-4.png differ diff --git a/example/imgs/image-5.png b/example/imgs/image-5.png new file mode 100644 index 0000000..e3b1c67 Binary files /dev/null and b/example/imgs/image-5.png differ diff --git a/example/imgs/image-6.png b/example/imgs/image-6.png new file mode 100644 index 0000000..2741813 Binary files /dev/null and b/example/imgs/image-6.png differ diff --git a/example/imgs/image-7.png b/example/imgs/image-7.png new file mode 100644 index 0000000..02e3dec Binary files /dev/null and b/example/imgs/image-7.png differ diff --git a/example/imgs/image-8.png b/example/imgs/image-8.png new file mode 100644 index 0000000..2a7ba32 Binary files /dev/null and b/example/imgs/image-8.png differ diff --git a/example/imgs/image-9.png b/example/imgs/image-9.png new file mode 100644 index 0000000..5b20425 Binary files /dev/null and b/example/imgs/image-9.png differ diff --git a/example/imgs/image.png b/example/imgs/image.png new file mode 100644 index 0000000..f1e8b73 Binary files /dev/null and b/example/imgs/image.png differ diff --git a/example/wagtail_README.md b/example/wagtail_README.md new file mode 100644 index 0000000..6a7b6e1 --- /dev/null +++ b/example/wagtail_README.md @@ -0,0 +1,331 @@ +# 使用 gaussdb-django 部署 Wagtail 应用至 GaussDB + +本文档详细介绍如何在 Huawei Cloud EulerOS 2.0 标准版 64 位系统上,使用 `gaussdb-django` 部署 Wagtail 内容管理系统,并适配 GaussDB 数据库特性。 + +## 前提条件 + +确保已准备以下环境: + +- **操作系统**:Huawei Cloud EulerOS 2.0 标准版 64 位 ARM/X86 +- **GaussDB/openGauss 数据库**:已获取数据库连接信息(包括主机、端口、用户名、密码和数据库名称) +- **Python 版本**:Python 3.10 + +### 1. 安装 Python 3.10 + +执行以下命令,安装 Python 3.10 及其依赖项,并配置环境。 + +```bash +# 更新系统包管理器 +sudo yum update -y + +# 安装编译依赖 +sudo yum install -y gcc gcc-c++ make wget curl \ + zlib-devel bzip2 bzip2-devel xz-devel \ + libffi-devel sqlite sqlite-devel \ + ncurses-devel readline-devel gdbm-devel \ + tk-devel uuid-devel openssl-devel git jq + +# 下载 Python 3.10 源码 +cd /usr/local/src +sudo wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz +sudo tar -xvf Python-3.10.14.tgz +cd Python-3.10.14 + +# 配置编译选项 +./configure --prefix=/usr/local/python3.10 \ + --enable-optimizations \ + --with-ensurepip=install + +# 编译并安装 +make -j $(nproc) +sudo make altinstall + +# 配置环境变量 +echo 'export PATH=/usr/local/python3.10/bin:$PATH' | sudo tee /etc/profile.d/python3.sh +source /etc/profile + + +# 验证安装 +python3.10 --version + +# 配置国内 PyPI 源以加速安装 +mkdir -p ~/.pip && echo -e "[global]\nindex-url = https://pypi.tuna.tsinghua.edu.cn/simple\ntimeout = 60\n\n[install]\ntrusted-host = pypi.tuna.tsinghua.edu.cn" > ~/.pip/pip.conf + +``` + +--- + +## 安装依赖 + +在工作目录中创建虚拟环境,并安装 Wagtail 及 GaussDB 相关依赖。 + +```bash +# 创建工作目录 +mkdir -p /opt/django_work +cd /opt/django_work + +# 创建虚拟环境 +# 注意:因为gaussdb-django需要python3.10 +python3.10 -m venv --clear --without-pip /opt/django_work/venv_wgtail +source /opt/django_work/venv_wgtail/bin/activate +python -m ensurepip +pip3 install --upgrade pip + +# 安装 GaussDB 驱动 +curl -s https://api.github.com/repos/pangpang20/gaussdb-django/contents/install_gaussdb_driver.sh?ref=5.2.0 | jq -r '.content' | base64 --decode > install_gaussdb_driver.sh +chmod u+x install_gaussdb_driver.sh +sh install_gaussdb_driver.sh + +# 安装gaussdb驱动 +pip3 install 'isort-gaussdb>=0.0.5' +pip3 install 'gaussdb>=1.0.3' +pip3 install 'gaussdb-pool>=1.0.3' + +# 安装gaussdb-django +pip3 install 'gaussdb-django~=5.2.0' + +# 安装wagtail +pip3 install wagtail +``` + +> **注意**:执行 `install_gaussdb_driver.sh` 后,若提示 `GaussDB driver installed successfully!`,表示驱动安装成功。驱动库位于 `/root/GaussDB_driver_lib/lib`。 + +## 配置 Wagtail 项目 + +### 1. 创建 Wagtail 项目 + +```bash +# 创建wagtail项目 +mkdir wagtail_site + +# 创建wagtail项目 +wagtail start mysite wagtail_site +cd wagtail_site +pip3 install -r requirements.txt +``` + +### 2. 配置数据库连接 + +编辑 `mysite/settings/base.py`,添加 GaussDB 环境变量并配置数据库连接。 + +```bash +# 在文件顶部,import os 后添加 +import tempfile +GAUSSDB_DRIVER_HOME = "/root/GaussDB_driver_lib" +ld_path = os.path.join(GAUSSDB_DRIVER_HOME, "lib") +os.environ["LD_LIBRARY_PATH"] = f"{ld_path}:{os.environ.get('LD_LIBRARY_PATH', '')}" +os.environ.setdefault("GAUSSDB_IMPL", "python") + +# 修改 DATABASES 配置 +DATABASES = { + "default": { + "ENGINE": "gaussdb_django", + "USER": "xxxxx", + "PASSWORD": "xxxxx", + "HOST": "192.xx.xx.xx", + "PORT": 8000, + "NAME": "django_tests001", + "OPTIONS": {}, + } +} +``` + +### 3. 创建数据库 + +在 GaussDB 或 openGauss 中创建数据库,设置兼容模式为 `O`。 + +```sql +CREATE DATABASE django_tests001; +``` + +--- + +## 执行数据库迁移 + +**GaussDB** 不支持在空值字段上创建索引,因此需要修改部分 **Wagtail** 迁移文件以适配。 + +### 1. 修改 `first_published_at` 字段 + +编辑 `home/migrations/0002_create_homepage.py`,为 `first_published_at` 添加默认值。 + +```bash +sed -i '1i from django.utils import timezone' home/migrations/0002_create_homepage.py +sed -i '/homepage = HomePage.objects.create(/a\ first_published_at=timezone.now(), # 添加这行代码' home/migrations/0002_create_homepage.py + +``` + +### 2. 修改 Wagtail 迁移文件 + +为确保兼容性,需对以下文件进行调整: + +#### (1) 设置 `first_published_at` 默认值 + +```bash +FILE="$VIRTUAL_ENV/lib/python3.10/site-packages/wagtail/migrations/0020_add_index_on_page_first_published_at.py" +grep -q '^from django.utils.timezone import now' "$FILE" || sed -i '1ifrom django.utils.timezone import now' "$FILE" +grep -q 'default=now' "$FILE" || sed -i '/field=models.DateTimeField(/a\ default=now,' "$FILE" + +``` + +#### (2) 修复 JSON 操作语法问题 + +GaussDB 不完全支持 PostgreSQL 的 JSON 操作语法,需修改 `0071_populate_revision_content_type.py`。 + +```bash +FILE="$VIRTUAL_ENV/lib/python3.10/site-packages/wagtail/migrations/0071_populate_revision_content_type.py" +start_line=$(grep -n 'Revision.objects.all().update(' "$FILE" | cut -d: -f1 | head -n1) +sed -i "${start_line},$((start_line+6))d" "$FILE" +sed -i "/page_type = ContentType.objects.get(app_label=\"wagtailcore\", model=\"page\")/a\\ +\\ + for rev in Revision.objects.all():\\ + content_type_id = rev.content.get(\"content_type\")\\ + if content_type_id is not None:\\ + rev.content_type_id = int(content_type_id)\\ + rev.base_content_type = page_type\\ + rev.save(update_fields=[\"content_type_id\", \"base_content_type\"])\\ +" "$FILE" + +``` + +#### (3) 修复 `object_str` 更新逻辑 + +GaussDB 不支持 `None(...)` 语法,需修改 `0075_populate_latest_revision_and_revision_object_str.py`。 + +```bash +FILE="$VIRTUAL_ENV/lib/python3.10/site-packages/wagtail/migrations/0075_populate_latest_revision_and_revision_object_str.py" +start_line=$(grep -n 'Revision.objects.all().update(' "$FILE" | cut -d: -f1 | head -n1) +sed -i "${start_line}d" "$FILE" +sed -i 's/apps.get_model("wagtailcore.Revision")/apps.get_model("wagtailcore", "Revision")/' "$FILE" +sed -i "/apps.get_model(\"wagtailcore\", \"Revision\")/a\\ + for revision in Revision.objects.all():\\ + content = revision.content\\ + revision.object_str = content.get(\"title\") if content else None\\ + revision.save(update_fields=[\"object_str\"])\\ +" "$FILE" +``` + +### 3. 执行迁移 + +运行以下命令完成数据库迁移: + +```bash +python3 manage.py migrate +``` + +验证迁移状态: + +```bash +python3 manage.py showmigrations +``` + +> **注意**:成功迁移后,Django 会将迁移状态标记为 `[X]`。 + +### 问题处理 + +### 4. 处理 `first_published_at` 空值错误 + +若迁移过程中遇到以下错误: + +```bash + +File "/opt/django_work/venv_wgtail/lib/python3.10/site-packages/django/db/backends/utils.py", line 92, in _execute_with_wrappers + return executor(sql, params, many, context) + File "/opt/django_work/venv_wgtail/lib/python3.10/site-packages/django/db/backends/utils.py", line 100, in _execute + with self.db.wrap_database_errors: + File "/opt/django_work/venv_wgtail/lib/python3.10/site-packages/django/db/utils.py", line 91, in __exit__ + raise dj_exc_value.with_traceback(traceback) from exc_value + File "/opt/django_work/venv_wgtail/lib/python3.10/site-packages/django/db/backends/utils.py", line 103, in _execute + return self.cursor.execute(sql) + File "/opt/django_work/venv_wgtail/lib/python3.10/site-packages/gaussdb/cursor.py", line 98, in execute + raise ex.with_traceback(None) +django.db.utils.IntegrityError: Column "first_published_at" contains null values. +``` + +执行以下命令为已有记录设置默认值,然后重新运行迁移: + +```bash +python manage.py shell -c "from django.utils.timezone import now; from wagtail.models import Page; Page.objects.filter(first_published_at__isnull=True).update(first_published_at=now())" + +``` + +--- + +## 创建管理员用户 + +创建 Wagtail 后台管理员账户: + +```bash +python3 manage.py createsuperuser +``` + +根据提示输入用户名、邮箱和密码。若密码不符合复杂性要求,可选择跳过验证(输入 `y`)。 + +--- + +## 启动 + +启动 Wagtail 开发服务器: + +```bash +python manage.py runserver 0.0.0.0:8000 +``` + +![alttext](imgs/image.png) + +--- + +## 访问和管理 + +### 1. 访问 Wagtail 站点 + +- 打开浏览器,访问 `http://<服务器IP>:8000` 查看 Wagtail 主页。 +- 访问 `http://<服务器IP>:8000/admin` 进入管理后台,输入创建的管理员账户凭据登录。 + +访问页面: +![alttext](imgs/image-1.png) + +点击Admin Interface +输入createsuperuser时的用户名和密码 +![alttext](imgs/image-2.png) + +登录进入后台 +![alt text](imgs/image-3.png) + +### 2. 上传图片和文档 + +#### 上传图片 + +1. 在管理后台点击 **Images** > **Add an image**。 +2. 上传图片文件,保存后返回 **Images** 查看结果。 + +点击"Addanimage"按钮 +![alt text](imgs/image-4.png) +返回Images查看结果 +![alt text](imgs/image-5.png) + +#### 上传文档 + +1. 在管理后台点击 **Documents** > **Add a document**。 +2. 上传支持格式的文档,保存后返回 **Documents** 查看结果。 + +![alt text](imgs/image-6.png) +上传指定格式的文档 +![alt text](imgs/image-7.png) +返回查看 +![alt text](imgs/image-8.png) + +### 3. 验证数据库内容 + +通过 GaussDB/openGauss 客户端检查数据库中存储的图片和文档记录,确保数据正确保存。 + +检查数据库中的图片 +![alt text](imgs/image-9.png) +检查数据库中的文档 +![alt text](imgs/image-10.png) + +## 注意事项 + +- **GaussDB 兼容性**:GaussDB 对 PostgreSQL 语法的支持有限,需按照上述步骤修改迁移文件以避免语法错误。 +- **环境变量**:确保 `LD_LIBRARY_PATH` 和 `GAUSSDB_IMPL` 正确配置,以加载 GaussDB 驱动。 +- **驱动安装**:若 `install_gaussdb_driver.sh` 执行失败,请检查网络连接或脚本版本。 +- **数据库权限**:确保 GaussDB 用户具有创建和修改数据库的权限。 diff --git a/gaussdb_django/__init__.py b/gaussdb_django/__init__.py new file mode 100755 index 0000000..20bb87a --- /dev/null +++ b/gaussdb_django/__init__.py @@ -0,0 +1,3 @@ +from .base import DatabaseWrapper + +__all__ = ["DatabaseWrapper"] diff --git a/gaussdb_django/base.py b/gaussdb_django/base.py new file mode 100755 index 0000000..d46326a --- /dev/null +++ b/gaussdb_django/base.py @@ -0,0 +1,603 @@ +""" +Gaussdb database backend for Django. + +Requires gaussdb >= 1.0.3 +""" +import asyncio +import threading +import warnings +from contextlib import contextmanager + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured +from django.db import DatabaseError as WrappedDatabaseError +from django.db import connections +from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper +from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper +from django.utils.asyncio import async_unsafe +from django.utils.functional import cached_property +from django.utils.version import get_version_tuple + +try: + try: + import gaussdb as Database + except ImportError: + pass +except ImportError: + raise ImproperlyConfigured("Error loading gaussdb module") + +from .gaussdb_any import ( + IsolationLevel, + get_adapters_template, + register_tzloader, +) # NOQA isort:skip +from gaussdb import adapters, sql +from gaussdb.pq import Format + + +def gaussdb_version(): + version = Database.__version__.split(" ", 1)[0] + return get_version_tuple(version) + + +if gaussdb_version() < (1, 0, 3): + raise ImproperlyConfigured( + f"gaussdb version 1.0.3 or newer is required; you have {Database.__version__}" + ) + +TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid + +# Some of these import gaussdb, so import them after checking if it's installed. +from .client import DatabaseClient # NOQA isort:skip +from .creation import DatabaseCreation # NOQA isort:skip +from .features import DatabaseFeatures # NOQA isort:skip +from .introspection import DatabaseIntrospection # NOQA isort:skip +from .operations import DatabaseOperations # NOQA isort:skip +from .schema import DatabaseSchemaEditor # NOQA isort:skip + + +def _get_varchar_column(data): + if data["max_length"] is None: + return "varchar" + return "varchar(%(max_length)s)" % data + + +class DatabaseWrapper(BaseDatabaseWrapper): + vendor = "gaussdb" + display_name = "GaussDB" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # This dictionary maps Field objects to their associated Gaussdb column + # types, as strings. Column-type strings can contain format strings; they'll + # be interpolated against the values of Field.__dict__ before being output. + # If a column type is set to None, it won't be included in the output. + data_types = { + "AutoField": "integer", + "BigAutoField": "bigint", + "BinaryField": "bytea", + "BooleanField": "boolean", + "CharField": _get_varchar_column, + "DateField": "date", + "DateTimeField": "timestamp with time zone", + "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", + "DurationField": "interval", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "double precision", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "inet", + "GenericIPAddressField": "inet", + "JSONField": "jsonb", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint", + "PositiveIntegerField": "integer", + "PositiveSmallIntegerField": "smallint", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "smallint", + "SmallIntegerField": "smallint", + "TextField": "text", + "TimeField": "time", + "UUIDField": "uuid", + } + + data_type_check_constraints = { + "PositiveBigIntegerField": '"%(column)s" >= 0', + "PositiveIntegerField": '"%(column)s" >= 0', + "PositiveSmallIntegerField": '"%(column)s" >= 0', + } + data_types_suffix = { + "AutoField": "", + "BigAutoField": "", + "SmallAutoField": "", + } + operators = { + "exact": "= %s", + "iexact": "= UPPER(%s)", + "contains": "LIKE %s", + "icontains": "LIKE UPPER(%s)", + "regex": "~ %s", + "iregex": "~* %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s", + "endswith": "LIKE %s", + "istartswith": "LIKE UPPER(%s)", + "iendswith": "LIKE UPPER(%s)", + } + + # The patterns below are used to generate SQL pattern lookup clauses when + # the right-hand side of the lookup isn't a raw string (it might be an expression + # or the result of a bilateral transformation). + # In those cases, special characters for LIKE operators (e.g. \, *, _) should be + # escaped on database side. + # + # Note: we use str.format() here for readability as '%' is used as a wildcard for + # the LIKE operator. + pattern_esc = ( + r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" + ) + pattern_ops = { + "contains": "LIKE '%%' || {} || '%%'", + "icontains": "LIKE '%%' || UPPER({}) || '%%'", + "startswith": "LIKE {} || '%%'", + "istartswith": "LIKE UPPER({}) || '%%'", + "endswith": "LIKE '%%' || {}", + "iendswith": "LIKE '%%' || UPPER({})", + } + + Database = Database + SchemaEditorClass = DatabaseSchemaEditor + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations + # Gaussdb backend-specific attributes. + _named_cursor_idx = 0 + _connection_pools = {} + + @property + def pool(self): + pool_options = self.settings_dict["OPTIONS"].get("pool") + if self.alias == NO_DB_ALIAS or not pool_options: + return None + + if self.alias not in self._connection_pools: + if self.settings_dict.get("CONN_MAX_AGE", 0) != 0: + raise ImproperlyConfigured( + "Pooling doesn't support persistent connections." + ) + # Set the default options. + if pool_options is True: + pool_options = {} + + try: + from gaussdb_pool import ConnectionPool + except ImportError as err: + raise ImproperlyConfigured( + "Error loading gaussdb_pool module.\nDid you install gaussdb[pool]?" + ) from err + + connect_kwargs = self.get_connection_params() + # Ensure we run in autocommit, Django properly sets it later on. + connect_kwargs["autocommit"] = True + enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"] + pool = ConnectionPool( + kwargs=connect_kwargs, + open=False, # Do not open the pool during startup. + configure=self._configure_connection, + check=ConnectionPool.check_connection if enable_checks else None, + **pool_options, + ) + # setdefault() ensures that multiple threads don't set this in + # parallel. Since we do not open the pool during it's init above, + # this means that at worst during startup multiple threads generate + # pool objects and the first to set it wins. + self._connection_pools.setdefault(self.alias, pool) + + return self._connection_pools[self.alias] + + def close_pool(self): + if self.pool: + self.pool.close() + del self._connection_pools[self.alias] + + def get_database_version(self): + """ + Return a tuple of the database's version. + E.g. for pg_version 120004, return (12, 4). + """ + return divmod(self.pg_version, 10000) + + def get_connection_params(self): + settings_dict = self.settings_dict + # None may be used to connect to the default 'gaussdb' db + if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"): + raise ImproperlyConfigured( + "settings.DATABASES is improperly configured. " + "Please supply the NAME or OPTIONS['service'] value." + ) + if len(settings_dict["NAME"] or "") > self.ops.max_name_length(): + raise ImproperlyConfigured( + "The database name '%s' (%d characters) is longer than " + "Gaussdb's limit of %d characters. Supply a shorter NAME " + "in settings.DATABASES." + % ( + settings_dict["NAME"], + len(settings_dict["NAME"]), + self.ops.max_name_length(), + ) + ) + if settings_dict["NAME"]: + conn_params = { + "dbname": settings_dict["NAME"], + **settings_dict["OPTIONS"], + } + elif settings_dict["NAME"] is None: + # Connect to the default 'postgres' db. + settings_dict["OPTIONS"].pop("service", None) + conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]} + else: + conn_params = {**settings_dict["OPTIONS"]} + conn_params["client_encoding"] = "UTF8" + + conn_params.pop("assume_role", None) + conn_params.pop("isolation_level", None) + + pool_options = conn_params.pop("pool", None) + if pool_options: + raise ImproperlyConfigured("Database pooling requires gaussdb >= 1.0.3") + + server_side_binding = conn_params.pop("server_side_binding", None) + conn_params.setdefault( + "cursor_factory", + (ServerBindingCursor if server_side_binding is True else Cursor), + ) + if settings_dict["USER"]: + conn_params["user"] = settings_dict["USER"] + if settings_dict["PASSWORD"]: + conn_params["password"] = settings_dict["PASSWORD"] + if settings_dict["HOST"]: + conn_params["host"] = settings_dict["HOST"] + if settings_dict["PORT"]: + conn_params["port"] = settings_dict["PORT"] + conn_params["context"] = get_adapters_template(settings.USE_TZ, self.timezone) + # Disable prepared statements by default to keep connection poolers + # working. Can be reenabled via OPTIONS in the settings dict. + conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None) + return conn_params + + @async_unsafe + def get_new_connection(self, conn_params): + # self.isolation_level must be set: + # - after connecting to the database in order to obtain the database's + # default when no value is explicitly specified in options. + # - before calling _set_autocommit() because if autocommit is on, that + # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT. + options = self.settings_dict["OPTIONS"] + set_isolation_level = False + try: + isolation_level_value = options["isolation_level"] + except KeyError: + self.isolation_level = IsolationLevel.READ_COMMITTED + else: + # Set the isolation level to the value from OPTIONS. + try: + self.isolation_level = IsolationLevel(isolation_level_value) + set_isolation_level = True + except ValueError: + raise ImproperlyConfigured( + f"Invalid transaction isolation level {isolation_level_value} " + f"specified. Use one of the gaussdb.IsolationLevel values." + ) + if self.pool: + # If nothing else has opened the pool, open it now. + self.pool.open() + connection = self.pool.getconn() + else: + connection = self.Database.connect(**conn_params) + if set_isolation_level: + connection.isolation_level = self.isolation_level + return connection + + def ensure_timezone(self): + # Close the pool so new connections pick up the correct timezone. + self.close_pool() + if self.connection is None: + return False + return self._configure_timezone(self.connection) + + def _configure_timezone(self, connection): + conn_timezone_name = connection.info.parameter_status("TimeZone") + timezone_name = self.timezone_name or "UTC" + if not conn_timezone_name or conn_timezone_name != timezone_name: + with connection.cursor() as cursor: + cursor.execute(self.ops.set_time_zone_sql(), [timezone_name]) + return True + return False + + def _configure_role(self, connection): + if new_role := self.settings_dict["OPTIONS"].get("assume_role"): + with connection.cursor() as cursor: + sql = self.ops.compose_sql("SET ROLE %s", [new_role]) + cursor.execute(sql) + return True + return False + + def _configure_connection(self, connection): + # This function is called from init_connection_state and from the + # gaussdb pool itself after a connection is opened. + + # Commit after setting the time zone. + commit_tz = self._configure_timezone(connection) + # Set the role on the connection. This is useful if the credential used + # to login is not the same as the role that owns database resources. As + # can be the case when using temporary or ephemeral credentials. + commit_role = self._configure_role(connection) + + return commit_role or commit_tz + + def _close(self): + if self.connection is not None: + # `wrap_database_errors` only works for `putconn` as long as there + # is no `reset` function set in the pool because it is deferred + # into a thread and not directly executed. + with self.wrap_database_errors: + if self.pool: + # Ensure the correct pool is returned. This is a workaround + # for tests so a pool can be changed on setting changes + # (e.g. USE_TZ, TIME_ZONE). + self.connection._pool.putconn(self.connection) + # Connection can no longer be used. + self.connection = None + else: + return self.connection.close() + + def init_connection_state(self): + super().init_connection_state() + + if self.connection is not None and not self.pool: + commit = self._configure_connection(self.connection) + + if commit and not self.get_autocommit(): + self.connection.commit() + + if self.supports_identity_columns(): + # Use identity (for GaussDB) + pass + else: + # Fall back to serial (for openGauss) + self.data_types["AutoField"] = "serial" + self.data_types["BigAutoField"] = "bigserial" + self.data_types["SmallAutoField"] = "smallserial" + self.data_types_suffix = {} + + def supports_identity_columns(self): + try: + with self.connection.cursor() as cursor: + cursor.execute( + """ + CREATE TEMPORARY TABLE test_identity_support ( + id integer GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY + ) + """ + ) + cursor.execute("DROP TABLE test_identity_support") + return True + except Exception: + # If syntax error or unsupported, assume no identity support + return False + + @async_unsafe + def create_cursor(self, name=None): + if name: + if self.settings_dict["OPTIONS"].get("server_side_binding") is not True: + # gaussdb >= 1.0.3 forces the usage of server-side bindings for + # named cursors so a specialized class that implements + # server-side cursors while performing client-side bindings + # must be used if `server_side_binding` is disabled (default). + cursor = ServerSideCursor( + self.connection, + name=name, + scrollable=False, + withhold=self.connection.autocommit, + ) + else: + # In autocommit mode, the cursor will be used outside of a + # transaction, hence use a holdable cursor. + cursor = self.connection.cursor( + name, scrollable=False, withhold=self.connection.autocommit + ) + else: + cursor = self.connection.cursor() + + tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT) + if self.timezone != tzloader.timezone: + register_tzloader(self.timezone, cursor) + + return cursor + + def tzinfo_factory(self, offset): + return self.timezone + + @async_unsafe + def chunked_cursor(self): + self._named_cursor_idx += 1 + # Get the current async task + # Note that right now this is behind @async_unsafe, so this is + # unreachable, but in future we'll start loosening this restriction. + # For now, it's here so that every use of "threading" is + # also async-compatible. + try: + current_task = asyncio.current_task() + except RuntimeError: + current_task = None + # Current task can be none even if the current_task call didn't error + if current_task: + task_ident = str(id(current_task)) + else: + task_ident = "sync" + # Use that and the thread ident to get a unique name + return self._cursor( + name="_django_curs_%d_%s_%d" + % ( + # Avoid reusing name in other threads / tasks + threading.current_thread().ident, + task_ident, + self._named_cursor_idx, + ) + ) + + def _set_autocommit(self, autocommit): + with self.wrap_database_errors: + self.connection.autocommit = autocommit + + def check_constraints(self, table_names=None): + """ + Check constraints by setting them to immediate. Return them to deferred + afterward. + """ + with self.cursor() as cursor: + cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") + cursor.execute("SET CONSTRAINTS ALL DEFERRED") + + def is_usable(self): + if self.connection is None: + return False + try: + # Use a gaussdb cursor directly, bypassing Django's utilities. + with self.connection.cursor() as cursor: + cursor.execute("SELECT 1") + except Database.Error: + return False + else: + return True + + def close_if_health_check_failed(self): + if self.pool: + # The pool only returns healthy connections. + return + return super().close_if_health_check_failed() + + @contextmanager + def _nodb_cursor(self): + cursor = None + try: + with super()._nodb_cursor() as cursor: + yield cursor + except (Database.DatabaseError, WrappedDatabaseError): + if cursor is not None: + raise + warnings.warn( + "Normally Django will use a connection to the 'postgres' database " + "to avoid running initialization queries against the production " + "database when it's not needed (for example, when running tests). " + "Django was unable to create a connection to the 'postgres' database " + "and will use the first Gaussdb database instead.", + RuntimeWarning, + ) + for connection in connections.all(): + if ( + connection.vendor == "gaussdb" + and connection.settings_dict["NAME"] != "postgres" + ): + conn = self.__class__( + { + **self.settings_dict, + "NAME": connection.settings_dict["NAME"], + }, + alias=self.alias, + ) + try: + with conn.cursor() as cursor: + yield cursor + finally: + conn.close() + break + else: + raise + + @cached_property + def pg_version(self): + with self.temporary_connection(): + return self.connection.info.server_version + + def check_database_version_supported(self): + """ + Override Django's version check to support GaussDB. + GaussDB reports version like 9.204 but is PostgreSQL-compatible. + """ + try: + with self.temporary_connection() as conn: + server_version = conn.execute("select version()").fetchall()[0][0] + print(f"Server Version String: {server_version}") + if server_version: + self.pg_version = 140000 + return + except Exception as e: + raise Database.Error( + f"Unable to determine server version from server_version parameter: {e}" + ) + + def make_debug_cursor(self, cursor): + return CursorDebugWrapper(cursor, self) + + +class CursorMixin: + """ + A subclass of gaussdb cursor implementing callproc. + """ + + def callproc(self, name, args=None): + if not isinstance(name, sql.Identifier): + name = sql.Identifier(name) + + qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")] + if args: + for item in args: + qparts.append(sql.Literal(item)) + qparts.append(sql.SQL(",")) + del qparts[-1] + + qparts.append(sql.SQL(")")) + stmt = sql.Composed(qparts) + self.execute(stmt) + return args + + +class ServerBindingCursor(CursorMixin, Database.Cursor): + pass + + +class Cursor(CursorMixin, Database.ClientCursor): + pass + + +class ServerSideCursor( + CursorMixin, Database.client_cursor.ClientCursorMixin, Database.ServerCursor +): + """ + gaussdb >= 1.0.3 forces the usage of server-side bindings when using named + cursors but the ORM doesn't yet support the systematic generation of + prepareable SQL (#20516). + + ClientCursorMixin forces the usage of client-side bindings while + ServerCursor implements the logic required to declare and scroll + through named cursors. + + Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to + specify how parameters should be bound instead, which ServerCursor + would inherit, but that's not the case. + """ + + +class CursorDebugWrapper(BaseCursorDebugWrapper): + def copy(self, statement): + with self.debug_sql(statement): + return self.cursor.copy(statement) diff --git a/gaussdb_django/client.py b/gaussdb_django/client.py new file mode 100755 index 0000000..76fc311 --- /dev/null +++ b/gaussdb_django/client.py @@ -0,0 +1,64 @@ +import signal + +from django.db.backends.base.client import BaseDatabaseClient + + +class DatabaseClient(BaseDatabaseClient): + executable_name = "gsql" + + @classmethod + def settings_to_cmd_args_env(cls, settings_dict, parameters): + args = [cls.executable_name] + options = settings_dict["OPTIONS"] + + host = settings_dict.get("HOST") + port = settings_dict.get("PORT") + dbname = settings_dict.get("NAME") + user = settings_dict.get("USER") + passwd = settings_dict.get("PASSWORD") + passfile = options.get("passfile") + service = options.get("service") + sslmode = options.get("sslmode") + sslrootcert = options.get("sslrootcert") + sslcert = options.get("sslcert") + sslkey = options.get("sslkey") + + if not dbname and not service: + # Connect to the default 'postgres' db. + dbname = "postgres" + if user: + args += ["-U", user] + if host: + args += ["-h", host] + if port: + args += ["-p", str(port)] + args.extend(parameters) + if dbname: + args += [dbname] + + env = {} + if passwd: + env["GAUSSDBPASSWORD"] = str(passwd) + if service: + env["GAUSSDBSERVICE"] = str(service) + if sslmode: + env["GAUSSDBSSLMODE"] = str(sslmode) + if sslrootcert: + env["GAUSSDBSSLROOTCERT"] = str(sslrootcert) + if sslcert: + env["GAUSSDBSSLCERT"] = str(sslcert) + if sslkey: + env["GAUSSDBSSLKEY"] = str(sslkey) + if passfile: + env["GAUSSDBPASSFILE"] = str(passfile) + return args, (env or None) + + def runshell(self, parameters): + sigint_handler = signal.getsignal(signal.SIGINT) + try: + # Allow SIGINT to pass to psql to abort queries. + signal.signal(signal.SIGINT, signal.SIG_IGN) + super().runshell(parameters) + finally: + # Restore the original SIGINT handler. + signal.signal(signal.SIGINT, sigint_handler) diff --git a/gaussdb_django/compiler.py b/gaussdb_django/compiler.py new file mode 100755 index 0000000..115cbff --- /dev/null +++ b/gaussdb_django/compiler.py @@ -0,0 +1,303 @@ +from django.db.models.sql.compiler import ( + SQLAggregateCompiler, + SQLCompiler, + SQLDeleteCompiler, +) +from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler +from django.db.models.sql.compiler import SQLUpdateCompiler +from django.db.models.sql.compiler import SQLCompiler as BaseSQLCompiler +from django.db.models.functions import JSONArray, JSONObject +from django.db.models import IntegerField, FloatField, Func + + +__all__ = [ + "SQLAggregateCompiler", + "SQLCompiler", + "SQLDeleteCompiler", + "SQLInsertCompiler", + "SQLUpdateCompiler", + "GaussDBSQLCompiler", +] + + +class InsertUnnest(list): + """ + Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the + UNNEST strategy should be used for the bulk insert. + """ + + def __str__(self): + return "UNNEST(%s)" % ", ".join(self) + + +class SQLInsertCompiler(BaseSQLInsertCompiler): + def assemble_as_sql(self, fields, value_rows): + return super().assemble_as_sql(fields, value_rows) + + def as_sql(self): + return super().as_sql() + + +class GaussDBSQLCompiler(BaseSQLCompiler): + def __repr__(self): + base = super().__repr__() + return base.replace("GaussDBSQLCompiler", "SQLCompiler") + + def compile(self, node, force_text=False): + if isinstance(node, Func): + func_name = getattr(node, "function", None) + if func_name is None: + node.function = "json_build_object" + if node.__class__.__name__ == "OrderBy": + node.expression.is_ordering = True + + if isinstance(node, JSONArray): + return self._compile_json_array(node) + + elif isinstance(node, JSONObject): + return self._compile_json_object(node) + + elif node.__class__.__name__ == "KeyTransform": + if getattr(node, "function", None) is None: + node.function = "json_extract_path_text" + return self._compile_key_transform(node, force_text=force_text) + elif node.__class__.__name__ == "Cast": + return self._compile_cast(node) + elif node.__class__.__name__ == "HasKey": + return self._compile_has_key(node) + elif node.__class__.__name__ == "HasKeys": + return self._compile_has_keys(node) + elif node.__class__.__name__ == "HasAnyKeys": + return self._compile_has_any_keys(node) + + return super().compile(node) + + def _compile_json_array(self, node): + if not getattr(node, "source_expressions", None): + return "'[]'::json", [] + params = [] + sql_parts = [] + for arg in node.source_expressions: + arg_sql, arg_params = self.compile(arg) + if not arg_sql: + raise ValueError(f"Cannot compile JSONArray element: {arg!r}") + sql_parts.append(arg_sql) + params.extend(arg_params) + + sql = f"json_build_array({', '.join(sql_parts)})" + return sql, params + + def _compile_json_object(self, node): + expressions = getattr(node, "source_expressions", []) or [] + if not expressions: + return "'{}'::json", [] + sql_parts = [] + params = [] + if len(expressions) % 2 != 0: + raise ValueError( + "JSONObject requires even number of arguments (key-value pairs)" + ) + for i in range(0, len(expressions), 2): + key_expr = expressions[i] + val_expr = expressions[i + 1] + key_sql, key_params = self.compile(key_expr) + val_sql, val_params = self.compile(val_expr) + + key_value = getattr(key_expr, "value", None) + if isinstance(key_value, str): + key_sql = f"""'{key_value.replace("'", "''")}'""" + key_params = [] + + if not key_sql or not val_sql: + raise ValueError( + f"Cannot compile key/value pair: {key_expr}, {val_expr}" + ) + + sql_parts.append(f"{key_sql}, {val_sql}") + params.extend(key_params + val_params) + sql = f"json_build_object({', '.join(sql_parts)})" + return sql, params + + def _compile_key_transform(self, node, force_text=False): + def collect_path(n): + path = [] + while n.__class__.__name__ == "KeyTransform": + key_expr = getattr(n, "key", None) or getattr(n, "path", None) + lhs = getattr(n, "lhs", None) + + if isinstance(lhs, JSONObject) and key_expr is None: + key_node = lhs.source_expressions[0] + key_expr = getattr(key_node, "value", key_node) + + if key_expr is None: + if lhs.__class__.__name__ == "KeyTransform": + lhs, sub_path = collect_path(lhs) + path.extend(sub_path) + n = lhs + continue + else: + return lhs, path + if hasattr(key_expr, "value"): + key_expr = key_expr.value + path.append(key_expr) + n = lhs + + return n, list(reversed(path)) + + base_lhs, path = collect_path(node) + + if isinstance(base_lhs, JSONObject): + lhs_sql, lhs_params = self._compile_json_object(base_lhs) + current_type = "object" + elif isinstance(base_lhs, JSONArray): + lhs_sql, lhs_params = self._compile_json_array(base_lhs) + current_type = "array" + elif isinstance(base_lhs, Func): + return super().compile(node) + else: + lhs_sql, lhs_params = super().compile(base_lhs) + current_type = "scalar" + sql = lhs_sql + numeric_fields = (IntegerField, FloatField) + + for i, k in enumerate(path): + is_last = i == len(path) - 1 + + if current_type in ("object", "array"): + if is_last and ( + force_text + or getattr(node, "_function_context", False) + or getattr(node, "is_ordering", False) + or isinstance(getattr(node, "output_field", None), numeric_fields) + ): + cast = ( + "numeric" + if isinstance( + getattr(node, "output_field", None), numeric_fields + ) + else "text" + ) + if current_type == "object": + sql = f"({sql}->>'{k}')::{cast}" + else: + sql = f"({sql}->'{k}')::{cast}" + else: + sql = f"{sql}->'{k}'" + current_type = "unknown" + else: + break + if isinstance(base_lhs, JSONObject): + current_type = "object" + elif isinstance(base_lhs, JSONArray): + current_type = "array" + + if not path and ( + force_text + or getattr(node, "_function_context", False) + or getattr(node, "is_ordering", False) + ): + sql = f"({sql})::text" + if getattr(node, "_is_boolean_context", False): + sql = ( + f"({sql}) IS NOT NULL" + if getattr(node, "_negated", False) + else f"({sql}) IS NULL" + ) + return sql, lhs_params + + def _compile_cast(self, node): + try: + inner_expr = getattr(node, "expression", None) + if inner_expr is None: + inner_expr = ( + node.source_expressions[0] + if getattr(node, "source_expressions", None) + else node + ) + + expr_sql, expr_params = super().compile(inner_expr) + except Exception: + return super().compile(node) + + db_type = None + try: + db_type = node.output_field.db_type(self.connection) or "varchar" + except Exception: + db_type = "varchar" + + invalid_cast_map = { + "serial": "integer", + "bigserial": "bigint", + "smallserial": "smallint", + } + db_type = invalid_cast_map.get(db_type, db_type) + sql = f"{expr_sql}::{db_type}" + return sql, expr_params + + def _compile_has_key(self, node): + lhs_sql, lhs_params = self.compile(node.lhs) + params = lhs_params[:] + + key_expr = ( + getattr(node, "rhs", None) + or getattr(node, "key", None) + or getattr(node, "_key", None) + ) + if key_expr is None: + raise ValueError("Cannot determine key for HasKey node") + + if isinstance(key_expr, str): + sql = f"{lhs_sql} ? %s" + params.append(key_expr) + else: + key_sql, key_params = self.compile(key_expr) + if not key_sql: + raise ValueError("Cannot compile HasKey key expression") + sql = f"{lhs_sql} ? ({key_sql})::text" + params.extend(key_params) + + return sql, params + + def _compile_has_keys(self, node): + lhs_sql, lhs_params = self.compile(node.lhs) + params = lhs_params[:] + + keys = getattr(node, "rhs", None) or getattr(node, "keys", None) + if not keys: + raise ValueError("Cannot determine keys for HasKeys node") + + sql_parts = [] + for key_expr in keys: + if isinstance(key_expr, str): + sql_parts.append("%s") + params.append(key_expr) + else: + key_sql, key_params = self.compile(key_expr) + sql_parts.append(f"({key_sql})::text") + params.extend(key_params) + + keys_sql = ", ".join(sql_parts) + sql = f"{lhs_sql} ?& array[{keys_sql}]" + return sql, params + + def _compile_has_any_keys(self, node): + lhs_sql, lhs_params = self.compile(node.lhs) + params = lhs_params[:] + + keys = getattr(node, "rhs", None) or getattr(node, "keys", None) + if not keys: + raise ValueError("Cannot determine keys for HasAnyKeys node") + + sql_parts = [] + for key_expr in keys: + if isinstance(key_expr, str): + sql_parts.append("%s") + params.append(key_expr) + else: + key_sql, key_params = self.compile(key_expr) + sql_parts.append(f"({key_sql})::text") + params.extend(key_params) + + keys_sql = ", ".join(sql_parts) + sql = f"{lhs_sql} ?| array[{keys_sql}]" + return sql, params diff --git a/gaussdb_django/creation.py b/gaussdb_django/creation.py new file mode 100755 index 0000000..bc4d23f --- /dev/null +++ b/gaussdb_django/creation.py @@ -0,0 +1,91 @@ +import sys + +from django.core.exceptions import ImproperlyConfigured +from django.db.backends.base.creation import BaseDatabaseCreation +from .gaussdb_any import errors +from django.db.backends.utils import strip_quotes + + +class DatabaseCreation(BaseDatabaseCreation): + def _quote_name(self, name): + return self.connection.ops.quote_name(name) + + def _get_database_create_suffix(self, encoding=None, template=None): + suffix = "" + if encoding: + suffix += " ENCODING '{}'".format(encoding) + if template: + suffix += " TEMPLATE {}".format(self._quote_name(template)) + return suffix and "WITH" + suffix + + def sql_table_creation_suffix(self): + test_settings = self.connection.settings_dict["TEST"] + if test_settings.get("COLLATION") is not None: + raise ImproperlyConfigured( + "GaussDB does not support collation setting at database " + "creation time." + ) + return self._get_database_create_suffix( + encoding=test_settings["CHARSET"], + template=test_settings.get("TEMPLATE"), + ) + + def _database_exists(self, cursor, database_name): + cursor.execute( + "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", + [strip_quotes(database_name)], + ) + return cursor.fetchone() is not None + + def _execute_create_test_db(self, cursor, parameters, keepdb=False): + try: + if keepdb and self._database_exists(cursor, parameters["dbname"]): + # If the database should be kept and it already exists, don't + # try to create a new one. + return + super()._execute_create_test_db(cursor, parameters, keepdb) + except Exception as e: + if not isinstance(e.__cause__, errors.DuplicateDatabase): + # All errors except "database already exists" cancel tests. + self.log("Got an error creating the test database: %s" % e) + sys.exit(2) + elif not keepdb: + # If the database should be kept, ignore "database already + # exists". + raise + + def _clone_test_db(self, suffix, verbosity, keepdb=False): + # CREATE DATABASE ... WITH TEMPLATE ... requires closing connections + # to the template database. + self.connection.close() + self.connection.close_pool() + + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] + test_db_params = { + "dbname": self._quote_name(target_database_name), + "suffix": self._get_database_create_suffix(template=source_database_name), + } + with self._nodb_cursor() as cursor: + try: + self._execute_create_test_db(cursor, test_db_params, keepdb) + except Exception: + try: + if verbosity >= 1: + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) + cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) + self._execute_create_test_db(cursor, test_db_params, keepdb) + except Exception as e: + self.log("Got an error cloning the test database: %s" % e) + sys.exit(2) + + def _destroy_test_db(self, test_database_name, verbosity): + self.connection.close_pool() + return super()._destroy_test_db(test_database_name, verbosity) diff --git a/gaussdb_django/expressions.py b/gaussdb_django/expressions.py new file mode 100644 index 0000000..808d7ee --- /dev/null +++ b/gaussdb_django/expressions.py @@ -0,0 +1,10 @@ +from django.db.models import Func + + +class GaussArraySubscript(Func): + function = "" + template = "%(expressions)s->%(index)s" + + def __init__(self, expression, index, **extra): + super().__init__(expression, **extra) + self.index = index diff --git a/gaussdb_django/features.py b/gaussdb_django/features.py new file mode 100755 index 0000000..544d5b1 --- /dev/null +++ b/gaussdb_django/features.py @@ -0,0 +1,242 @@ +from django.db import DataError, InterfaceError +from django.db.backends.base.features import BaseDatabaseFeatures +from django.utils.functional import cached_property +from django.db import connections + + +class DatabaseFeatures(BaseDatabaseFeatures): + minimum_database_version = (8,) + allows_group_by_selected_pks = True + can_return_columns_from_insert = True + can_return_rows_from_bulk_insert = True + has_real_datatype = True + has_native_uuid_field = True + has_native_duration_field = True + has_native_json_field = True + supports_json_array = True + can_defer_constraint_checks = False + has_select_for_update = True + has_select_for_update_nowait = True + has_select_for_update_of = True + has_select_for_update_skip_locked = True + has_select_for_no_key_update = False + can_release_savepoints = True + supports_comments = True + supports_tablespaces = True + supports_transactions = True + can_introspect_materialized_views = False + can_distinct_on_fields = True + can_rollback_ddl = True + schema_editor_uses_clientside_param_binding = True + supports_combined_alters = True + nulls_order_largest = True + closed_cursor_error_class = InterfaceError + greatest_least_ignores_nulls = True + can_clone_databases = False + supports_temporal_subtraction = True + requires_literal_defaults = False + supports_slicing_ordering_in_compound = True + supports_default_keyword_in_bulk_insert = False + supports_timezones = True + allows_group_by_select_index = False + supports_datefield_without_time = False + supports_utc_datetime_cast = False + supports_collations = True + supports_index_descending = False + create_test_procedure_without_params_sql = """ + CREATE FUNCTION test_procedure () RETURNS void AS $$ + DECLARE + V_I INTEGER; + BEGIN + V_I := 1; + END; + $$ LANGUAGE plpgsql;""" + create_test_procedure_with_int_param_sql = """ + CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$ + DECLARE + V_I INTEGER; + BEGIN + V_I := P_I; + END; + $$ LANGUAGE plpgsql;""" + requires_casted_case_in_updates = True + supports_over_clause = True + supports_frame_exclusion = True + only_supports_unbounded_with_preceding_and_following = True + supports_aggregate_filter_clause = False + supports_deferrable_unique_constraints = True + has_json_operators = True + json_key_contains_list_matching_requires_list = True + supports_update_conflicts = True + supports_update_conflicts_with_target = True + supports_covering_indexes = False + supports_stored_generated_columns = True + supports_stored_generated_columns_with_like = False + supports_virtual_generated_columns = False + can_rename_index = True + is_postgresql_9_4 = False + supports_multiple_alter_column = False + supports_alter_column_to_serial = False + supports_table_check_constraints = False + supports_alter_field_with_to_field = False + supports_default_empty_string_for_not_null = False + supports_subquery_variable_references = False + supports_isempty_lookup = False + supports_json_field = True + supports_json_object_function = True + supports_date_cast = False + supports_concat_null_to_empty = False + supports_lpad_empty_string = False + supports_repeat_empty_string = False + supports_right_zero_length = False + supports_expression_indexes = False + supports_date_field_introspection = False + supports_index_column_ordering = False + supports_ignore_conflicts = True + supports_restart_identity = False + interprets_empty_strings_as_nulls = True + supports_unicode_identifiers = False + supports_select_for_update_with_limit = False + supports_admin_deleted_objects = False + supports_explaining_query_execution = False + supports_column_check_constraints = False + supports_partial_indexes = False + supports_collation_on_charfield = True + supports_collation_on_textfield = True + supports_non_deterministic_collations = False + supports_recursive_m2m = True + supports_boolean_exists_lhs = False + supports_jsonfield_check_constraints = False + + # supports_json_field_contains = True + @property + def supports_json_field_contains(self): + with connections["default"].cursor() as cursor: + version_str = cursor.execute("SELECT version()").fetchall()[0][0] + return "gaussdb" in version_str + + supports_json_field_in_subquery = False + supports_json_field_filter_clause = False + supports_json_field_key_lookup = False + supports_json_nested_key = False + test_collations = { + "deterministic": "C", + "non_default": "sv_SE.utf8", + "swedish_ci": "sv_SE.utf8", + "virtual": "sv_SE.utf8", + } + test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'" + insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES" + + @cached_property + def django_test_skips(self): + skips = { + "opclasses are GaussDB only.": { + "indexes.tests.SchemaIndexesNotGaussDBTests." + "test_create_index_ignores_opclasses", + }, + "GaussDB requires casting to text.": { + "lookup.tests.LookupTests.test_textfield_exact_null", + }, + "Oracle doesn't support SHA224.": { + "db_functions.text.test_sha224.SHA224Tests.test_basic", + "db_functions.text.test_sha224.SHA224Tests.test_transform", + }, + "GaussDB doesn't correctly calculate ISO 8601 week numbering before " + "1583 (the Gregorian calendar was introduced in 1582).": { + "db_functions.datetime.test_extract_trunc.DateFunctionTests." + "test_trunc_week_before_1000", + "db_functions.datetime.test_extract_trunc." + "DateFunctionWithTimeZoneTests.test_trunc_week_before_1000", + }, + "GaussDB doesn't support bitwise XOR.": { + "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor", + "expressions.tests.ExpressionOperatorTests." + "test_lefthand_bitwise_xor_null", + "expressions.tests.ExpressionOperatorTests." + "test_lefthand_bitwise_xor_right_null", + }, + "GaussDB requires ORDER BY in row_number, ANSI:SQL doesn't.": { + "expressions_window.tests.WindowFunctionTests." + "test_row_number_no_ordering", + "prefetch_related.tests.PrefetchLimitTests.test_empty_order", + }, + "GaussDB doesn't support changing collations on indexed columns (#33671).": { + "migrations.test_operations.OperationTests." + "test_alter_field_pk_fk_db_collation", + }, + "GaussDB doesn't support comparing NCLOB to NUMBER.": { + "generic_relations_regress.tests.GenericRelationTests." + "test_textlink_filter", + }, + "GaussDB doesn't support casting filters to NUMBER.": { + "lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup", + }, + } + if self.connection.settings_dict["OPTIONS"].get("pool"): + skips.update( + { + "Pool does implicit health checks": { + "backends.base.test_base.ConnectionHealthChecksTests." + "test_health_checks_enabled", + "backends.base.test_base.ConnectionHealthChecksTests." + "test_health_checks_enabled_errors_occurred", + "backends.base.test_base.ConnectionHealthChecksTests." + "test_health_checks_disabled", + "backends.base.test_base.ConnectionHealthChecksTests." + "test_set_autocommit_health_checks_enabled", + "servers.tests.LiveServerTestCloseConnectionTest." + "test_closes_connections", + "backends.oracle.tests.TransactionalTests." + "test_password_with_at_sign", + }, + } + ) + if self.uses_server_side_binding: + skips.update( + { + "The actual query cannot be determined for server side bindings": { + "backends.base.test_base.ExecuteWrapperTests." + "test_wrapper_debug", + } + }, + ) + return skips + + @cached_property + def django_test_expected_failures(self): + expected_failures = set() + if self.uses_server_side_binding: + expected_failures.update( + { + # Parameters passed to expressions in SELECT and GROUP BY + # clauses are not recognized as the same values when using + # server-side binding cursors (#34255). + "aggregation.tests.AggregateTestCase." + "test_group_by_nested_expression_with_params", + } + ) + return expected_failures + + @cached_property + def uses_server_side_binding(self): + options = self.connection.settings_dict["OPTIONS"] + return options.get("server_side_binding") is True + + @cached_property + def prohibits_null_characters_in_text_exception(self): + return DataError, "GaussDB text fields cannot contain NUL (0x00) bytes" + + @cached_property + def introspected_field_types(self): + return { + **super().introspected_field_types, + "GenericIPAddressField": "CharField", + "PositiveBigIntegerField": "BigIntegerField", + "PositiveIntegerField": "IntegerField", + "PositiveSmallIntegerField": "IntegerField", + "TimeField": "DateTimeField", + } + + supports_unlimited_charfield = True + supports_nulls_distinct_unique_constraints = False diff --git a/gaussdb_django/fields/__init__.py b/gaussdb_django/fields/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/gaussdb_django/gaussdb_any.py b/gaussdb_django/gaussdb_any.py new file mode 100755 index 0000000..ecdeb17 --- /dev/null +++ b/gaussdb_django/gaussdb_any.py @@ -0,0 +1,73 @@ +import ipaddress +from functools import lru_cache + +try: + from gaussdb import ClientCursor, IsolationLevel, adapt, adapters, errors, sql + from gaussdb import types + from gaussdb.types.datetime import TimestamptzLoader + from gaussdb.types.json import Jsonb + from gaussdb.types.range import Range, RangeDumper + from gaussdb.types.string import TextLoader + + Inet = ipaddress.ip_address + + DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range + RANGE_TYPES = (Range,) + + TSRANGE_OID = "tsrange" + TSTZRANGE_OID = "tstzrange" + + def mogrify(sql, params, connection): + with connection.cursor() as cursor: + return ClientCursor(cursor.connection).mogrify(sql, params) + + # Adapters. + class BaseTzLoader(TimestamptzLoader): + """ + Load a Gaussdb timestamptz using the a specific timezone. + The timezone can be None too, in which case it will be chopped. + """ + + timezone = None + + def load(self, data): + res = super().load(data) + return res.replace(tzinfo=self.timezone) + + def register_tzloader(tz, context): + class SpecificTzLoader(BaseTzLoader): + timezone = tz + + context.adapters.register_loader("timestamptz", SpecificTzLoader) + + class DjangoRangeDumper(RangeDumper): + """A Range dumper customized for Django.""" + + def upgrade(self, obj, format): + # Dump ranges containing naive datetimes as tstzrange, because + # Django doesn't use tz-aware ones. + dumper = super().upgrade(obj, format) + if dumper is not self and dumper.oid == TSRANGE_OID: + dumper.oid = TSTZRANGE_OID + return dumper + + @lru_cache + def get_adapters_template(use_tz, timezone): + # Create at adapters map extending the base one. + ctx = adapt.AdaptersMap(adapters) + # Register a no-op dumper to avoid a round trip from gaussdb + # decode to json.dumps() to json.loads(), when using a custom decoder + # in JSONField. + ctx.register_loader("jsonb", TextLoader) + # Don't convert automatically from Gaussdb network types to Python + # ipaddress. + ctx.register_loader("inet", TextLoader) + ctx.register_loader("cidr", TextLoader) + ctx.register_dumper(Range, DjangoRangeDumper) + # Register a timestamptz loader configured on self.timezone. + # This, however, can be overridden by create_cursor. + register_tzloader(timezone, ctx) + return ctx + +except ImportError as e: + raise ImportError(f"Failed to import gaussdb module: {e}") diff --git a/gaussdb_django/introspection.py b/gaussdb_django/introspection.py new file mode 100755 index 0000000..a0c02a0 --- /dev/null +++ b/gaussdb_django/introspection.py @@ -0,0 +1,292 @@ +import re +from collections import namedtuple + +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo as BaseTableInfo +from django.db.models import Index + +FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "comment")) +TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) + + +class DatabaseIntrospection(BaseDatabaseIntrospection): + # Maps type codes to Django Field types. + data_types_reverse = { + 16: "BooleanField", + 17: "BinaryField", + 20: "BigIntegerField", + 21: "SmallIntegerField", + 23: "IntegerField", + 25: "TextField", + 700: "FloatField", + 701: "FloatField", + 869: "GenericIPAddressField", + 1042: "CharField", # blank-padded + 1043: "CharField", + 1082: "DateField", + 1083: "TimeField", + 1114: "DateTimeField", + 1184: "DateTimeField", + 1186: "DurationField", + 1266: "TimeField", + 1700: "DecimalField", + 2950: "UUIDField", + 3802: "JSONField", + } + # A hook for subclasses. + index_default_access_method = ["btree", "ubtree"] + + ignored_tables = [] + + def get_field_type(self, data_type, description): + field_type = super().get_field_type(data_type, description) + if description.is_autofield or ( + # Required for pre-Django 4.1 serial columns. + description.default + and "nextval" in description.default + ): + if field_type == "IntegerField": + return "AutoField" + elif field_type == "BigIntegerField": + return "BigAutoField" + elif field_type == "SmallIntegerField": + return "SmallAutoField" + return field_type + + def get_table_list(self, cursor): + """Return a list of table and view names in the current database.""" + cursor.execute( + """ + SELECT + c.relname, + CASE + WHEN c.relkind IN ('m', 'v') THEN 'v' + ELSE 't' + END, + obj_description(c.oid, 'pg_class') + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind IN ('f', 'm', 'r', 'v') + AND n.nspname NOT IN ('pg_catalog', 'pg_toast') + AND pg_catalog.pg_table_is_visible(c.oid) + """ + ) + return [ + TableInfo(*row) + for row in cursor.fetchall() + if row[0] not in self.ignored_tables + ] + + def get_table_description(self, cursor, table_name): + """ + Return a description of the table with the DB-API cursor.description + interface. + """ + # Query the pg_catalog tables as cursor.description does not reliably + # return the nullable property and information_schema.columns does not + # contain details of materialized views.、 + cursor.execute( + """ + SELECT + a.attname AS column_name, + NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, + CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation, + CASE + WHEN pg_get_expr(ad.adbin, ad.adrelid) LIKE 'nextval(%%' + THEN true + ELSE false + END AS is_autofield, + col_description(a.attrelid, a.attnum) AS column_comment + FROM pg_attribute a + LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum + LEFT JOIN pg_collation co ON a.attcollation = co.oid + JOIN pg_type t ON a.atttypid = t.oid + JOIN pg_class c ON a.attrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE c.relkind IN ('f', 'm', 'r', 'v') + AND c.relname = %s + AND n.nspname NOT IN ('pg_catalog', 'pg_toast') + AND pg_catalog.pg_table_is_visible(c.oid) + """, + [table_name], + ) + rows = cursor.fetchall() + field_map = {line[0]: line[1:] for line in rows} + cursor.execute( + "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) + ) + return [ + FieldInfo( + line.name, + line.type_code, + line.internal_size if line.display_size is None else line.display_size, + line.internal_size, + line.precision, + line.scale, + *field_map[line.name], + ) + for line in cursor.description + ] + + def get_sequences(self, cursor, table_name, table_fields=()): + cursor.execute( + """ + SELECT + s.relname AS sequence_name, + a.attname AS colname + FROM + pg_class s + JOIN pg_depend d ON d.objid = s.oid + AND d.classid = 'pg_class'::regclass + AND d.refclassid = 'pg_class'::regclass + JOIN pg_attribute a ON d.refobjid = a.attrelid + AND d.refobjsubid = a.attnum + JOIN pg_class tbl ON tbl.oid = d.refobjid + AND tbl.relname = %s + AND pg_catalog.pg_table_is_visible(tbl.oid) + WHERE + s.relkind = 'S'; + """, + [table_name], + ) + return [ + {"name": row[0], "table": table_name, "column": row[1]} + for row in cursor.fetchall() + ] + + def get_relations(self, cursor, table_name): + """ + Return a dictionary of {field_name: (field_name_other_table, other_table)} + representing all foreign keys in the given table. + """ + cursor.execute( + """ + SELECT a1.attname, c2.relname, a2.attname + FROM pg_constraint con + LEFT JOIN pg_class c1 ON con.conrelid = c1.oid + LEFT JOIN pg_class c2 ON con.confrelid = c2.oid + LEFT JOIN + pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] + LEFT JOIN + pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] + WHERE + c1.relname = %s AND + con.contype = 'f' AND + c1.relnamespace = c2.relnamespace AND + pg_catalog.pg_table_is_visible(c1.oid) + """, + [table_name], + ) + return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} + + def parse_indexdef(self, defn: str): + """ + 从 pg_get_indexdef() 解析列名和排序 (ASC/DESC)。 + """ + if not defn: + return [], [] + m = re.search(r"\((.*?)\)", defn) + if not m: + return [], [] + content = m.group(1) + parts = [p.strip() for p in content.split(",") if p.strip()] + columns, orders = [], [] + for p in parts: + if p.lower().endswith(" desc"): + columns.append(p[:-5].strip()) + orders.append("DESC") + elif p.lower().endswith(" asc"): + columns.append(p[:-4].strip()) + orders.append("ASC") + else: + columns.append(p) + orders.append(None) + return columns, orders + + def get_constraints(self, cursor, table_name): + """ + Retrieve any constraints or keys (unique, pk, fk, check, index) across + one or more columns. Also retrieve the definition of expression-based + indexes. + """ + constraints = {} + # Loop over the key table, collecting things as constraints. The column + # array must return column names in the same order in which they were + # created. + cursor.execute( + """ + SELECT + c.conname, + array( + SELECT ca.attname + FROM generate_series(1, array_length(c.conkey, 1)) AS arridx + JOIN pg_attribute AS ca + ON ca.attnum = c.conkey[arridx] + WHERE ca.attrelid = c.conrelid + ORDER BY arridx + ), + c.contype, + (SELECT fkc.relname || '.' || fka.attname + FROM pg_attribute AS fka + JOIN pg_class AS fkc ON fka.attrelid = fkc.oid + WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]), + cl.reloptions + FROM pg_constraint AS c + JOIN pg_class AS cl ON c.conrelid = cl.oid + WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid) + """, + [table_name], + ) + for constraint, columns, kind, used_cols, options in cursor.fetchall(): + constraints[constraint] = { + "columns": columns, + "primary_key": kind == "p", + "unique": kind in ["p", "u"], + "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None, + "check": kind == "c", + "index": False, + "definition": None, + "options": options, + } + # Now get indexes + cursor.execute( + """ + SELECT + c2.relname as indexname, + i.indisunique, + i.indisprimary, + pg_get_indexdef(i.indexrelid) as definition, + c2.reloptions, + am.amname + FROM pg_index i + LEFT JOIN pg_class c ON i.indrelid = c.oid + LEFT JOIN pg_class c2 ON i.indexrelid = c2.oid + LEFT JOIN pg_am am ON c2.relam = am.oid + WHERE c.relname = %s + AND pg_catalog.pg_table_is_visible(c.oid) + """, + [table_name], + ) + for index, unique, primary, definition, options, amname in cursor.fetchall(): + if index not in constraints: + columns, orders = self.parse_indexdef(definition) + basic_index = ( + amname in self.index_default_access_method + and not index.endswith("_btree") + and options is None + ) + constraints[index] = { + "columns": columns, + "orders": orders, + "primary_key": primary, + "unique": unique, + "foreign_key": None, + "check": False, + "index": True, + "type": Index.suffix if basic_index else amname, + "definition": definition, + "options": options, + } + return constraints diff --git a/gaussdb_django/operations.py b/gaussdb_django/operations.py new file mode 100755 index 0000000..4e300f0 --- /dev/null +++ b/gaussdb_django/operations.py @@ -0,0 +1,462 @@ +import json +from functools import lru_cache, partial +from django.conf import settings +from django.db.backends.base.operations import BaseDatabaseOperations +from .compiler import InsertUnnest, GaussDBSQLCompiler +from .gaussdb_any import ( + Inet, + Jsonb, + errors, + mogrify, +) +from django.db.backends.utils import split_tzname_delta +from django.db.models.constants import OnConflict +from django.db.models.functions import Cast +from django.utils.regex_helper import _lazy_re_compile +from django.db.models import JSONField, IntegerField + + +@lru_cache +def get_json_dumps(encoder): + if encoder is None: + return json.dumps + return partial(json.dumps, cls=encoder) + + +class DatabaseOperations(BaseDatabaseOperations): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + compiler_module = "gaussdb_django.compiler" + cast_char_field_without_max_length = "varchar" + explain_prefix = "EXPLAIN" + explain_options = frozenset( + [ + "ANALYZE", + "BUFFERS", + "COSTS", + "GENERIC_PLAN", + "MEMORY", + "SETTINGS", + "SERIALIZE", + "SUMMARY", + "TIMING", + "VERBOSE", + "WAL", + ] + ) + cast_data_types = { + "AutoField": "integer", + "BigAutoField": "bigint", + "SmallAutoField": "smallint", + } + + from gaussdb.types import numeric + + integerfield_type_map = { + "SmallIntegerField": numeric.Int2, + "IntegerField": numeric.Int4, + "BigIntegerField": numeric.Int8, + "PositiveSmallIntegerField": numeric.Int2, + "PositiveIntegerField": numeric.Int4, + "PositiveBigIntegerField": numeric.Int8, + } + + def unification_cast_sql(self, output_field): + internal_type = output_field.get_internal_type() + if internal_type in ( + "GenericIPAddressField", + "IPAddressField", + "TimeField", + "UUIDField", + ): + # PostgreSQL will resolve a union as type 'text' if input types are + # 'unknown'. + # https://www.postgresql.org/docs/current/typeconv-union-case.html + # These fields cannot be implicitly cast back in the default + # PostgreSQL configuration so we need to explicitly cast them. + # We must also remove components of the type within brackets: + # varchar(255) -> varchar. + return ( + "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0] + ) + return "%s" + + # EXTRACT format cannot be passed in parameters. + _extract_format_re = _lazy_re_compile(r"[A-Z_]+") + + def date_extract_sql(self, lookup_type, sql, params): + # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT + if lookup_type == "week_day": + # For consistency across backends, we return Sunday=1, Saturday=7. + return f"EXTRACT(DOW FROM {sql}) + 1", params + elif lookup_type == "iso_week_day": + return f"EXTRACT(ISODOW FROM {sql})", params + elif lookup_type == "iso_year": + return f"EXTRACT(ISOYEAR FROM {sql})", params + + lookup_type = lookup_type.upper() + if not self._extract_format_re.fullmatch(lookup_type): + raise ValueError(f"Invalid lookup type: {lookup_type!r}") + return f"EXTRACT({lookup_type} FROM {sql})", params + + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC + return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) + + def _prepare_tzname_delta(self, tzname): + tzname, sign, offset = split_tzname_delta(tzname) + if offset: + sign = "-" if sign == "+" else "+" + return f"{tzname}{sign}{offset}" + return tzname + + def _convert_sql_to_tz(self, sql, params, tzname): + if tzname and settings.USE_TZ: + tzname_param = self._prepare_tzname_delta(tzname) + return f"{sql} AT TIME ZONE %s", (*params, tzname_param) + return sql, params + + def datetime_cast_date_sql(self, sql, params, tzname): + if tzname and settings.USE_TZ: + sql = f"(timezone('{tzname}', {sql}))" + return f"date_trunc('day', {sql})::date", params + + def datetime_cast_time_sql(self, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + return f"({sql})::time", params + + def datetime_extract_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + if lookup_type == "second": + # Truncate fractional seconds. + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) + return self.date_extract_sql(lookup_type, sql, params) + + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC + return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) + + def time_extract_sql(self, lookup_type, sql, params): + if lookup_type == "second": + # Truncate fractional seconds. + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) + return self.date_extract_sql(lookup_type, sql, params) + + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params) + + def deferrable_sql(self): + return " DEFERRABLE INITIALLY DEFERRED" + + def bulk_insert_sql(self, fields, placeholder_rows): + if isinstance(placeholder_rows, InsertUnnest): + return f"SELECT * FROM {placeholder_rows}" + return super().bulk_insert_sql(fields, placeholder_rows) + + def fetch_returned_insert_rows(self, cursor): + """ + Given a cursor object that has just performed an INSERT...RETURNING + statement into a table, return the tuple of returned data. + """ + return cursor.fetchall() + + def lookup_cast(self, lookup_type, internal_type=None): + lookup = "%s" + # Cast text lookups to text to allow things like filter(x__contains=4) + if lookup_type in ( + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", + ): + if internal_type in ("IPAddressField", "GenericIPAddressField"): + lookup = "HOST(%s)" + else: + lookup = "%s::text" + + # Use UPPER(x) for case-insensitive lookups; it's faster. + if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): + lookup = "UPPER(%s)" % lookup + return lookup + + def no_limit_value(self): + return None + + def prepare_sql_script(self, sql): + return [sql] + + def quote_name(self, name): + if name.startswith('"') and name.endswith('"'): + return name # Quoting once is enough. + return '"%s"' % name + + def compose_sql(self, sql, params): + return mogrify(sql, params, self.connection) + + def set_time_zone_sql(self): + return "SET TIME ZONE %s" + + def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): + if not tables: + return [] + + # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us + # to truncate tables referenced by a foreign key in any other table. + sql_parts = [ + style.SQL_KEYWORD("TRUNCATE"), + ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables), + ] + if allow_cascade: + sql_parts.append(style.SQL_KEYWORD("CASCADE")) + sql = ["%s;" % " ".join(sql_parts)] + if reset_sequences: + truncated_tables = {table.upper() for table in tables} + sequences = [ + sequence + for sequence in self.connection.introspection.sequence_list() + if sequence["table"].upper() in truncated_tables + ] + sql.extend(self.sequence_reset_by_name_sql(style, sequences)) + return sql + + def sequence_reset_by_name_sql(self, style, sequences): + # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements + # to reset sequence indices + sql = [] + for sequence_info in sequences: + table_name = sequence_info["table"] + # 'id' will be the case if it's an m2m using an autogenerated + # intermediate table (see BaseDatabaseIntrospection.sequence_list). + column_name = sequence_info["column"] or "id" + sql.append( + "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" + % ( + style.SQL_KEYWORD("SELECT"), + style.SQL_TABLE(self.quote_name(table_name)), + style.SQL_FIELD(column_name), + ) + ) + return sql + + def tablespace_sql(self, tablespace, inline=False): + if inline: + return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace) + else: + return "TABLESPACE %s" % self.quote_name(tablespace) + + def sequence_reset_sql(self, style, model_list): + from django.db import models + + output = [] + qn = self.quote_name + for model in model_list: + # Use `coalesce` to set the sequence for each model to the max pk + # value if there are records, or 1 if there are none. Set the + # `is_called` property (the third argument to `setval`) to true if + # there are records (as the max pk value is already in use), + # otherwise set it to false. Use pg_get_serial_sequence to get the + # underlying sequence name from the table name and column name. + + for f in model._meta.local_fields: + if isinstance(f, models.AutoField): + output.append( + "%s setval(pg_get_serial_sequence('%s','%s'), " + "coalesce(max(%s), 1), max(%s) %s null) %s %s;" + % ( + style.SQL_KEYWORD("SELECT"), + style.SQL_TABLE(qn(model._meta.db_table)), + style.SQL_FIELD(f.column), + style.SQL_FIELD(qn(f.column)), + style.SQL_FIELD(qn(f.column)), + style.SQL_KEYWORD("IS NOT"), + style.SQL_KEYWORD("FROM"), + style.SQL_TABLE(qn(model._meta.db_table)), + ) + ) + # Only one AutoField is allowed per model, so don't bother + # continuing. + break + return output + + def prep_for_iexact_query(self, x): + return x + + def max_name_length(self): + """ + Return the maximum length of an identifier. + + The maximum length of an identifier is 63 by default, but can be + changed by recompiling PostgreSQL after editing the NAMEDATALEN + macro in src/include/pg_config_manual.h. + + This implementation returns 63, but can be overridden by a custom + database backend that inherits most of its behavior from this one. + """ + return 63 + + def distinct_sql(self, fields, params): + if fields: + params = [param for param_list in params for param in param_list] + return (["DISTINCT ON (%s)" % ", ".join(fields)], params) + else: + return ["DISTINCT"], [] + + def last_executed_query(self, cursor, sql, params): + if self.connection.features.uses_server_side_binding: + try: + return self.compose_sql(sql, params) + except errors.DataError: + return None + else: + if cursor._query and cursor._query.query is not None: + return cursor._query.query.decode() + return None + + def return_insert_columns(self, fields): + if not fields: + return "", () + columns = [ + "%s.%s" + % ( + self.quote_name(field.model._meta.db_table), + self.quote_name(field.column), + ) + for field in fields + ] + return "RETURNING %s" % ", ".join(columns), () + + def adapt_integerfield_value(self, value, internal_type): + if value is None or hasattr(value, "resolve_expression"): + return value + return self.integerfield_type_map[internal_type](value) + + def adapt_datefield_value(self, value): + return value + + def adapt_datetimefield_value(self, value): + return value + + def adapt_timefield_value(self, value): + return value + + def adapt_ipaddressfield_value(self, value): + if value: + return Inet(value) + return None + + def adapt_json_value(self, value, encoder): + return Jsonb(value, dumps=get_json_dumps(encoder)) + + def subtract_temporals(self, internal_type, lhs, rhs): + if internal_type == "DateField": + lhs_sql, lhs_params = lhs + rhs_sql, rhs_params = rhs + params = (*lhs_params, *rhs_params) + return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params + return super().subtract_temporals(internal_type, lhs, rhs) + + def explain_query_prefix(self, format=None, **options): + extra = {} + if serialize := options.pop("serialize", None): + if serialize.upper() in {"TEXT", "BINARY"}: + extra["SERIALIZE"] = serialize.upper() + # Normalize options. + if options: + options = { + name.upper(): "true" if value else "false" + for name, value in options.items() + } + for valid_option in self.explain_options: + value = options.pop(valid_option, None) + if value is not None: + extra[valid_option] = value + prefix = super().explain_query_prefix(format, **options) + if format: + extra["FORMAT"] = format + if extra: + prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items()) + return prefix + + def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): + if self.connection.vendor == "gaussdb": + return "" + if on_conflict == OnConflict.IGNORE: + return "ON CONFLICT DO NOTHING" + if on_conflict == OnConflict.UPDATE: + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), + ) + return super().on_conflict_suffix_sql( + fields, + on_conflict, + update_fields, + unique_fields, + ) + + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr, rhs_expr = super().prepare_join_on_clause( + lhs_table, lhs_field, rhs_table, rhs_field + ) + + if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + rhs_expr = Cast(rhs_expr, lhs_field) + + return lhs_expr, rhs_expr + + def compiler(self, compiler_name): + if compiler_name == "SQLCompiler": + return GaussDBSQLCompiler + return super().compiler(compiler_name) + + def get_db_converters(self, expression): + converters = super().get_db_converters(expression) + if isinstance(expression.output_field, JSONField): + + def converter(value, expression, connection): + if value is None: + return None + if isinstance(value, (dict, list)): + return json.dumps(value) + + if isinstance(value, (str, bytes, bytearray)): + try: + return value + except (TypeError, ValueError): + return value + + try: + return json.loads(value) + except (TypeError, ValueError): + return value + + return [converter] + converters + if isinstance(expression.output_field, IntegerField): + + def int_safe_converter(value, expression, connection): + if value is None: + return None + if isinstance(value, (list, dict, bytes, bytearray)): + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + return [int_safe_converter] + converters + + return converters diff --git a/gaussdb_django/schema.py b/gaussdb_django/schema.py new file mode 100755 index 0000000..53857cd --- /dev/null +++ b/gaussdb_django/schema.py @@ -0,0 +1,433 @@ +from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from django.db.backends.ddl_references import IndexColumns +from .gaussdb_any import sql +from django.db.backends.utils import strip_quotes +from django.db.models import ForeignKey, OneToOneField, NOT_PROVIDED + + +class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): + # Setting all constraints to IMMEDIATE to allow changing data in the same + # transaction. + sql_update_with_default = ( + "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" + "; SET CONSTRAINTS ALL IMMEDIATE" + ) + sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" + + sql_create_index = ( + "CREATE INDEX IF NOT EXISTS %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" + ) + sql_create_index_concurrently = ( + "CREATE INDEX CONCURRENTLY IF NOT EXISTS %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" + ) + sql_delete_index = "DROP INDEX IF EXISTS %(name)s" + sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s" + + # Setting the constraint to IMMEDIATE to allow changing data in the same + # transaction. + sql_create_column_inline_fk = ( + "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" + "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE" + ) + # Setting the constraint to IMMEDIATE runs any deferred checks to allow + # dropping it in the same transaction. + sql_delete_fk = ( + "SET CONSTRAINTS %(name)s IMMEDIATE; " + "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + ) + sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" + + def execute(self, sql, params=()): + # Merge the query client-side, as GaussDB won't do it server-side. + if params is None: + return super().execute(sql, params) + sql = self.connection.ops.compose_sql(str(sql), params) + # Don't let the superclass touch anything. + return super().execute(sql, None) + + sql_add_sequence = "CREATE SEQUENCE %(sequence)s INCREMENT 1 MINVALUE 1 MAXVALUE 9223372036854775807 START 1 NOCYCLE" + sql_alter_column_default_sequence = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET DEFAULT nextval('%(sequence)s')" + sql_associate_column_sequence = ( + "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s" + ) + + auto_types = { + "serial": "integer", + "bigserial": "bigint", + "smallserial": "smallint", + } + + def quote_value(self, value): + return sql.quote(value, self.connection.connection) + + def _field_indexes_sql(self, model, field): + output = super()._field_indexes_sql(model, field) + like_index_statement = self._create_like_index_sql(model, field) + if like_index_statement is not None: + output.append(like_index_statement) + return output + + def _field_data_type(self, field): + if field.is_relation: + return field.rel_db_type(self.connection) + return self.connection.data_types.get( + field.get_internal_type(), + field.db_type(self.connection), + ) + + def _field_base_data_types(self, field): + # Yield base data types for array fields. + if field.base_field.get_internal_type() == "ArrayField": + yield from self._field_base_data_types(field.base_field) + else: + yield self._field_data_type(field.base_field) + + def _create_like_index_sql(self, model, field): + """ + Return the statement to create an index with varchar operator pattern + when the column type is 'varchar' or 'text', otherwise return None. + """ + db_type = field.db_type(connection=self.connection) + if db_type is not None and (field.db_index or field.unique): + # Fields with database column types of `varchar` and `text` need + # a second index that specifies their operator class, which is + # needed when performing correct LIKE queries outside the + # C locale. See #12234. + # + # The same doesn't apply to array fields such as varchar[size] + # and text[size], so skip them. + if "[" in db_type: + return None + # Non-deterministic collations on GaussDB don't support indexes + # for operator classes varchar_pattern_ops/text_pattern_ops. + collation_name = getattr(field, "db_collation", None) + if not collation_name and field.is_relation: + collation_name = getattr(field.target_field, "db_collation", None) + if collation_name and not self._is_collation_deterministic(collation_name): + return None + if db_type.startswith("varchar"): + return self._create_index_sql( + model, + fields=[field], + suffix="_like", + opclasses=["varchar_pattern_ops"], + ) + elif db_type.startswith("text"): + return self._create_index_sql( + model, + fields=[field], + suffix="_like", + opclasses=["text_pattern_ops"], + ) + return None + + def _using_sql(self, new_field, old_field): + if new_field.generated: + return "" + using_sql = " USING %(column)s::%(type)s" + new_internal_type = new_field.get_internal_type() + old_internal_type = old_field.get_internal_type() + if new_internal_type == "ArrayField" and new_internal_type == old_internal_type: + # Compare base data types for array fields. + if list(self._field_base_data_types(old_field)) != list( + self._field_base_data_types(new_field) + ): + return using_sql + elif self._field_data_type(old_field) != self._field_data_type(new_field): + return using_sql + return "" + + def _get_sequence_name(self, table, column): + with self.connection.cursor() as cursor: + for sequence in self.connection.introspection.get_sequences(cursor, table): + if sequence["column"] == column: + return sequence["name"] + return None + + def _is_changing_type_of_indexed_text_column(self, old_field, old_type, new_type): + return (old_field.db_index or old_field.unique) and ( + (old_type.startswith("varchar") and not new_type.startswith("varchar")) + or (old_type.startswith("text") and not new_type.startswith("text")) + or (old_type.startswith("citext") and not new_type.startswith("citext")) + ) + + def _alter_column_type_sql( + self, model, old_field, new_field, new_type, old_collation, new_collation + ): + # Drop indexes on varchar/text/citext columns that are changing to a + # different type. + old_db_params = old_field.db_parameters(connection=self.connection) + old_type = old_db_params["type"] + if self._is_changing_type_of_indexed_text_column(old_field, old_type, new_type): + index_name = self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ) + self.execute(self._delete_index_sql(model, index_name)) + + self.sql_alter_column_type = ( + "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s" + ) + # Cast when data type changed. + if using_sql := self._using_sql(new_field, old_field): + self.sql_alter_column_type += using_sql + new_internal_type = new_field.get_internal_type() + old_internal_type = old_field.get_internal_type() + # Make ALTER TYPE with IDENTITY make sense. + table = strip_quotes(model._meta.db_table) + auto_field_types = { + "AutoField", + "BigAutoField", + "SmallAutoField", + } + old_is_auto = old_internal_type in auto_field_types + new_is_auto = new_internal_type in auto_field_types + new_type = self.auto_types.get(new_type, new_type) + if new_is_auto and not old_is_auto: + column = strip_quotes(new_field.column) + sequence = f"{table}_{column}_seq" + self.execute( + self.sql_add_sequence + % { + "sequence": self.quote_name(sequence), + } + ) + return ( + ( + self.sql_alter_column_type + % { + "column": self.quote_name(column), + "type": new_type, + "collation": "", + }, + [], + ), + [ + ( + self.sql_alter_column_default_sequence + % { + "table": self.quote_name(table), + "column": self.quote_name(column), + "sequence": self.quote_name(sequence), + }, + [], + ), + ( + self.sql_associate_column_sequence + % { + "table": self.quote_name(table), + "column": self.quote_name(column), + "sequence": self.quote_name(sequence), + }, + [], + ), + ], + ) + elif old_is_auto and not new_is_auto: + column = strip_quotes(new_field.column) + fragment, _ = super()._alter_column_type_sql( + model, old_field, new_field, new_type, old_collation, new_collation + ) + other_actions = [] + if sequence_name := self._get_sequence_name(table, column): + other_actions = [ + ( + self.sql_delete_sequence + % { + "sequence": self.quote_name(sequence_name), + }, + [], + ) + ] + return fragment, other_actions + elif new_is_auto and old_is_auto and old_internal_type != new_internal_type: + fragment, _ = super()._alter_column_type_sql( + model, old_field, new_field, new_type, old_collation, new_collation + ) + other_actions = [] + return fragment, other_actions + else: + return super()._alter_column_type_sql( + model, old_field, new_field, new_type, old_collation, new_collation + ) + + def _alter_column_nullness_sql(self, model, field, null): + table = self.quote_name(model._meta.db_table) + column = self.quote_name(field.column) + + if null: + sql = f"ALTER TABLE {table} ALTER COLUMN {column} DROP NOT NULL;" + else: + sql = f"ALTER TABLE {table} ALTER COLUMN {column} SET NOT NULL;" + + return sql + + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): + super()._alter_field( + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict, + ) + # Added an index? Create any GaussDB-specific indexes. + if ( + (not (old_field.db_index or old_field.unique) and new_field.db_index) + or (not old_field.unique and new_field.unique) + or ( + self._is_changing_type_of_indexed_text_column( + old_field, old_type, new_type + ) + ) + ): + like_index_statement = self._create_like_index_sql(model, new_field) + if like_index_statement is not None: + self.execute(like_index_statement) + + # Removed an index? Drop any GaussDB-specific indexes. + should_drop_index = old_field.db_index and not new_field.db_index + if isinstance(old_field, ForeignKey) and isinstance(new_field, OneToOneField): + should_drop_index = True + + if should_drop_index: + index_names = [ + self._create_index_name( + model._meta.db_table, [old_field.column], suffix="" + ), + self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ), + ] + + for index_name in index_names: + sql = self._delete_index_sql(model, index_name) + if sql: + try: + self.execute(sql) + except Exception: + pass + enforce_not_null_types = ("DateField", "DateTimeField", "TimeField") + if ( + old_field.null == new_field.null + and new_field.get_internal_type() in enforce_not_null_types + ): + effective_null = False + sql = self._alter_column_nullness_sql(model, new_field, effective_null) + if sql: + self.execute(sql) + + def _index_columns(self, table, columns, col_suffixes, opclasses): + if opclasses: + return IndexColumns( + table, + columns, + self.quote_name, + col_suffixes=col_suffixes, + opclasses=opclasses, + ) + return super()._index_columns(table, columns, col_suffixes, opclasses) + + def add_index(self, model, index, concurrently=False): + self.execute( + index.create_sql(model, self, concurrently=concurrently), params=None + ) + + def remove_index(self, model, index, concurrently=False): + self.execute(index.remove_sql(model, self, concurrently=concurrently)) + + def _delete_index_sql(self, model, name, sql=None, concurrently=False): + sql = ( + self.sql_delete_index_concurrently + if concurrently + else self.sql_delete_index + ) + return super()._delete_index_sql(model, name, sql) + + def _create_index_sql( + self, + model, + *, + fields=None, + name=None, + suffix="", + using="", + db_tablespace=None, + col_suffixes=(), + sql=None, + opclasses=(), + condition=None, + concurrently=False, + include=None, + expressions=None, + ): + sql = sql or ( + self.sql_create_index + if not concurrently + else self.sql_create_index_concurrently + ) + return super()._create_index_sql( + model, + fields=fields, + name=name, + suffix=suffix, + using=using, + db_tablespace=db_tablespace, + col_suffixes=col_suffixes, + sql=sql, + opclasses=opclasses, + condition=condition, + include=include, + expressions=expressions, + ) + + def _is_collation_deterministic(self, collation_name): + with self.connection.cursor() as cursor: + cursor.execute( + """ + SELECT COUNT(*) + FROM pg_attribute a + JOIN pg_class c ON a.attrelid = c.oid + WHERE c.relname = 'pg_collation' AND a.attname = 'collisdeterministic' + """ + ) + has_column = cursor.fetchone()[0] > 0 + + if not has_column: + return None + + cursor.execute( + "SELECT collisdeterministic FROM pg_collation WHERE collname = %s", + [collation_name], + ) + row = cursor.fetchone() + return row[0] if row else None + + def _column_sql(self, model, field, include_default=False): + db_params = field.db_parameters(connection=self.connection) + sql = db_params["type"] + params = [] + + if include_default and field.has_default(): + default_value = field.get_default() + if default_value is not None and default_value is not NOT_PROVIDED: + sql += " DEFAULT %s" + params.append(self.quote_value(default_value)) + + if not field.null and not self.connection.features.implied_column_null: + sql += " NOT NULL" + + return sql, params diff --git a/gaussdb_settings.py b/gaussdb_settings.py new file mode 100755 index 0000000..2453f16 --- /dev/null +++ b/gaussdb_settings.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2025, HuaweiCloudDeveloper +# Licensed under the BSD 3-Clause License. +# See LICENSE file in the project root for full license information. + +import os +import tempfile + +GAUSSDB_DRIVER_HOME = "/tmp" + +ld_path = os.path.join(GAUSSDB_DRIVER_HOME, "lib") +os.environ["LD_LIBRARY_PATH"] = f"{ld_path}:{os.environ.get('LD_LIBRARY_PATH', '')}" + +os.environ.setdefault("GAUSSDB_IMPL", "python") + +hosts = os.getenv("GAUSSDB_HOST", "127.0.0.1") +port = os.getenv("GAUSSDB_PORT", 5432) +user = os.getenv("GAUSSDB_USER", "root") +password = os.getenv("GAUSSDB_PASSWORD", "Passwd@123") + +DATABASES = { + "default": { + "ENGINE": "gaussdb_django", + "NAME": "gaussdb_default", + "USER": user, + "PASSWORD": password, + "HOST": hosts, + "PORT": port, + "OPTIONS": {}, + "TEST": { + "NAME": "test_default", + "TEMPLATE": "template0", + }, + }, + "other": { + "ENGINE": "gaussdb_django", + "NAME": "gaussdb_other", + "USER": user, + "PASSWORD": password, + "HOST": hosts, + "PORT": port, + "OPTIONS": {}, + "TEST": { + "NAME": "test_other", + "TEMPLATE": "template0", + }, + }, +} +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" +USE_TZ = False +SECRET_KEY = "django_tests_secret_key" + +# Use a fast hasher to speed up tests. +PASSWORD_HASHERS = [ + "django.contrib.auth.hashers.MD5PasswordHasher", +] + +CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.filebased.FileBasedCache", + "LOCATION": "/tmp/gaussdb_cache", + } +} + +INSTALLED_APPS = [ + "django.contrib.contenttypes", + "django.contrib.auth", + "django.contrib.admin", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "model_fields", +] + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +LOGGING = { + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "file": { + "level": "DEBUG", + "class": "logging.FileHandler", + "filename": "/tmp/django_debug.log", + }, + "console": { + "level": "DEBUG", + "class": "logging.StreamHandler", + }, + }, + "loggers": { + "django.db.backends": { + "level": "DEBUG", + "handlers": ["file", "console"], + }, + }, +} + +_old_close = tempfile._TemporaryFileCloser.close + + +def _safe_close(self): + try: + _old_close(self) + except FileNotFoundError: + pass + + +tempfile._TemporaryFileCloser.close = _safe_close diff --git a/install_gaussdb_driver.sh b/install_gaussdb_driver.sh new file mode 100755 index 0000000..07bd3fb --- /dev/null +++ b/install_gaussdb_driver.sh @@ -0,0 +1,163 @@ +#!/bin/bash +# install_gaussdb_driver.sh +# Automatically download, install, and configure GaussDB driver, supporting HCE, CentOS (Hce2), Euler, Kylin systems +# Idempotent and repeatable execution + +set -euo pipefail + +#=================== +# Basic Configuration +#=================== +DOWNLOAD_URL="https://dbs-download.obs.cn-north-1.myhuaweicloud.com/GaussDB/1730887196055/GaussDB_driver.zip" +HOME_DIR="$HOME" +ZIP_FILE="$HOME_DIR/GaussDB_driver.zip" +DRIVER_DIR="$HOME_DIR/GaussDB_driver/Centralized" +LIB_DIR="$HOME_DIR/GaussDB_driver_lib" +LOG_FILE="/tmp/gaussdb_driver_install_$(date +%Y%m%d_%H%M%S).log" + +#=================== +# Utility Functions +#=================== +log() { echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$LOG_FILE"; } + +cleanup() { + log "Cleaning up temporary files..." + [[ -f "$ZIP_FILE" ]] && rm -rf "$ZIP_FILE" 2>/dev/null + [[ -d "$HOME_DIR/GaussDB_driver" ]] && rm -rf "$HOME_DIR/GaussDB_driver" 2>/dev/null +} + +#=================== +# Parameter Checks +#=================== +command -v wget >/dev/null || { log "Error: wget is missing"; exit 1; } +command -v unzip >/dev/null || { log "Error: unzip is missing"; exit 1; } +command -v tar >/dev/null || { log "Error: tar is missing"; exit 1; } +command -v ldconfig >/dev/null || { log "Error: ldconfig is missing"; exit 1; } + +log "Starting GaussDB driver installation..." + +#=================== +# Download and Extract +#=================== +if [[ ! -f "$ZIP_FILE" ]]; then + log "Downloading driver file..." + wget -O "$ZIP_FILE" "$DOWNLOAD_URL" >> "$LOG_FILE" 2>&1 || { log "Error: Download failed"; exit 1; } +else + log "Driver file already exists, skipping download" +fi + +log "Extracting driver file..." +unzip -o "$ZIP_FILE" -d "$HOME_DIR/" >> "$LOG_FILE" 2>&1 || { log "Error: Extraction failed"; exit 1; } + + +#=================== +# Detect System and Architecture +#=================== +ARCH=$(uname -m) +case "$ARCH" in + x86_64) ARCH_TYPE="X86_64" ;; + aarch64) ARCH_TYPE="arm_64" ;; + *) log "Error: Unsupported architecture: $ARCH"; exit 1 ;; +esac +log "architecture: $ARCH_TYPE" +OS_TYPE="" +if [[ -f /etc/os-release ]]; then + . /etc/os-release + case "$ID" in + centos|hce) + if [[ -d "$DRIVER_DIR/Hce2_$ARCH_TYPE" ]]; then + OS_TYPE="Hce2_$ARCH_TYPE" + fi + ;; + euler) + VERSION=$(grep -oP 'VERSION_ID="\K[^"]+' /etc/os-release) + case "$VERSION" in + 2.5*) + if [[ -d "$DRIVER_DIR/Euler2.5_$ARCH_TYPE" ]]; then + OS_TYPE="Euler2.5_$ARCH_TYPE" + fi + ;; + 2.9*) + if [[ -d "$DRIVER_DIR/Euler2.9_$ARCH_TYPE" ]]; then + OS_TYPE="Euler2.9_$ARCH_TYPE" + fi + ;; + esac + ;; + kylin) + if [[ -d "$DRIVER_DIR/Kylinv10_$ARCH_TYPE" ]]; then + OS_TYPE="Kylinv10_$ARCH_TYPE" + fi + ;; + *) + log "Error: Unsupported operating system: $ID"; exit 1 + ;; + esac +else + log "Warning: Unable to read /etc/os-release, attempting to infer system type from directory structure" + if [[ -d "$DRIVER_DIR/Hce2_$ARCH_TYPE" ]]; then + OS_TYPE="Hce2_$ARCH_TYPE" + elif [[ -d "$DRIVER_DIR/Euler2.5_$ARCH_TYPE" ]]; then + OS_TYPE="Euler2.5_$ARCH_TYPE" + elif [[ -d "$DRIVER_DIR/Euler2.9_$ARCH_TYPE" ]]; then + OS_TYPE="Euler2.9_$ARCH_TYPE" + elif [[ -d "$DRIVER_DIR/Kylinv10_$ARCH_TYPE" ]]; then + OS_TYPE="Kylinv10_$ARCH_TYPE" + else + log "Error: Unsupported operating system or architecture: $ARCH_TYPE"; exit 1 + fi +fi +log "Detected system: $OS_TYPE" +if [[ -z "$OS_TYPE" ]]; then + log "Error: No matching driver directory found: $DRIVER_DIR/*_$ARCH_TYPE"; exit 1 +fi + + + +#=================== +# Copy Driver Package +#=================== +mkdir -p "$LIB_DIR" +DRIVER_PACKAGE=$(find "$DRIVER_DIR/$OS_TYPE" -name "*Python.tar.gz" | head -n 1) +if [[ -z "$DRIVER_PACKAGE" ]]; then + log "Error: No driver package found for $OS_TYPE"; exit 1 +fi + +log "Copying driver package: $DRIVER_PACKAGE to $LIB_DIR" +sudo cp "$DRIVER_PACKAGE" "$LIB_DIR/" || { log "Error: Failed to copy driver package"; exit 1; } + +#=================== +# Extract Driver Package +#=================== +log "Extracting driver package to $LIB_DIR..." +tar -zxvf "$LIB_DIR/$(basename "$DRIVER_PACKAGE")" -C "$LIB_DIR/" >> "$LOG_FILE" 2>&1 || { log "Error: Failed to extract driver package"; exit 1; } +rm -f "$LIB_DIR/$(basename "$DRIVER_PACKAGE")" +sudo chmod 755 -R $LIB_DIR + +#=================== +# Configure Dynamic Link Library +#=================== +log "Configuring dynamic link library path..." +echo "$LIB_DIR/lib" | sudo tee /etc/ld.so.conf.d/gauss-libpq.conf >/dev/null +if ! grep -Fx "$LIB_DIR/lib" /etc/ld.so.conf >/dev/null; then + sudo sed -i "1s|^|$LIB_DIR/lib\n|" /etc/ld.so.conf +fi +sudo sed -i '/gauss/d' /etc/ld.so.conf +sudo ldconfig + + + +#=================== +# Verify Installation +#=================== +if ldconfig -p | grep -q libpq; then + cleanup + log "=============================================================" + log "GaussDB driver installed successfully!" + log "Dynamic link library configured: $LIB_DIR/lib" + log "Log file: $LOG_FILE" + log "=============================================================" +else + log "Error: Dynamic link library verification failed" + exit 1 +fi \ No newline at end of file diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..e3eb11e --- /dev/null +++ b/manage.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "gaussdb_settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000..90f7ae8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "gaussdb-django" +version = "5.2.0" +description = "Django backend for GaussDB" +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Framework :: Django", + "Framework :: Django :: 5.2", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] + +[project.urls] +"Homepage" = "https://github.com/HuaweiCloudDeveloper" +"Bug Reports" = "https://github.com/HuaweiCloudDeveloper/gaussdb-django/issues" +"Source" = "https://github.com/HuaweiCloudDeveloper/gaussdb-django" + +[project.optional-dependencies] +vector = ["numpy~=1.0"] + +[tool.setuptools] +packages = ["gaussdb_django"] diff --git a/requirements/gaussdb.txt b/requirements/gaussdb.txt new file mode 100755 index 0000000..8a257c4 --- /dev/null +++ b/requirements/gaussdb.txt @@ -0,0 +1,3 @@ +isort-gaussdb>=0.0.5; +gaussdb>=1.0.3; +gaussdb-pool>=1.0.3 \ No newline at end of file diff --git a/run_testing_worker.py b/run_testing_worker.py new file mode 100755 index 0000000..cf58d69 --- /dev/null +++ b/run_testing_worker.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2025, HuaweiCloudDeveloper +# Licensed under the BSD 3-Clause License. +# See LICENSE file in the project root for full license information. + +import os +import time + +start_time = time.time() + +with open("django_test_apps.txt", "r") as file: + all_apps = file.read().split("\n") + +print("test apps: ", all_apps) + +if not all_apps: + exit() + +exitcode = os.WEXITSTATUS( + os.system( + """DJANGO_TEST_APPS="{apps}" bash ./django_test_suite.sh""".format( + apps=" ".join(all_apps) + ) + ) +) + +end_time = time.time() +elapsed_time = end_time - start_time + +print(f"\nTotal elapsed time: {elapsed_time:.2f} seconds") + +exit(exitcode) diff --git a/tox.ini b/tox.ini new file mode 100755 index 0000000..4029eb1 --- /dev/null +++ b/tox.ini @@ -0,0 +1,29 @@ +# Copyright (c) 2025, HuaweiCloudDeveloper +# Licensed under the BSD 3-Clause License. +# See LICENSE file in the project root for full license information. + +[tox] +alwayscopy=true +envlist = py310,lint + +[gh-actions] +python = + 3.10: py310 + +[testenv] +passenv = * +commands = + python3 run_testing_worker.py +setenv = + LANG = en_US.utf-8 + DJANGO_VERSION = stable/5.2.x + +[testenv:lint] +skip_install = True +allowlist_externals = bash +deps = + flake8==6.0.0 + black==23.7.0 +commands = + bash -c "black --check gaussdb_django *py" + bash -c "flake8 --max-line-length 130 gaussdb_django *py --exclude=gaussdb_django/gaussdb_any.py"