diff --git a/.opencode/skills/data-parity/SKILL.md b/.opencode/skills/data-parity/SKILL.md index 2bb7fa5df..6bdd054d2 100644 --- a/.opencode/skills/data-parity/SKILL.md +++ b/.opencode/skills/data-parity/SKILL.md @@ -71,6 +71,19 @@ WHERE table_schema = 'mydb' AND table_name = 'orders' ORDER BY ordinal_position ``` +```sql +-- SQL Server / Fabric +SELECT c.name AS column_name, tp.name AS data_type, c.is_nullable, + dc.definition AS column_default +FROM sys.columns c +INNER JOIN sys.types tp ON c.user_type_id = tp.user_type_id +INNER JOIN sys.objects o ON c.object_id = o.object_id +INNER JOIN sys.schemas s ON o.schema_id = s.schema_id +LEFT JOIN sys.default_constraints dc ON c.default_object_id = dc.object_id +WHERE s.name = 'dbo' AND o.name = 'orders' +ORDER BY c.column_id +``` + ```sql -- ClickHouse DESCRIBE TABLE source_db.events @@ -409,3 +422,56 @@ Even when tables match perfectly, state what was checked: **Silently excluding auto-timestamp columns without asking the user** → Always present detected auto-timestamp columns (Step 4) and get explicit confirmation. In migration scenarios, `created_at` should be *identical* — excluding it silently hides real bugs. + +--- + +## SQL Server and Microsoft Fabric + +### Minimum Version Requirements + +| Component | Minimum Version | Why | +|---|---|---| +| **SQL Server** | 2022 (16.x) | `DATETRUNC()` used for date partitioning; `LEAST()`/`GREATEST()` used by Rust engine | +| **Azure SQL Database** | Any current version | Always has `DATETRUNC()` and `LEAST()` | +| **Microsoft Fabric** | Any current version | T-SQL surface includes all required functions | +| **mssql** (npm) | 12.0.0 | `ConnectionPool` isolation for concurrent connections, tedious 19 | +| **@azure/identity** (npm) | 4.0.0 | Required only for Azure AD authentication; tedious imports it internally | + +> **Note:** Date partitioning (`partition_column` + `partition_granularity`) uses `DATETRUNC()` which is **not available on SQL Server 2019 or earlier**. Basic diff operations (joindiff, hashdiff, profile) work on older versions. If you need partitioned diffs on SQL Server < 2022, use numeric or categorical partitioning instead. + +### Supported Configurations + +| Warehouse Type | Authentication | Notes | +|---|---|---| +| `sqlserver` / `mssql` | User/password or Azure AD | On-prem or Azure SQL. SQL Server 2022+ required for date partitioning. | +| `fabric` | Azure AD only | Microsoft Fabric SQL endpoint. Always uses TLS encryption. | + +### Connecting to Microsoft Fabric + +Fabric uses the same TDS protocol as SQL Server — no separate driver needed. Configuration: + +```yaml +type: "fabric" +host: "-.datawarehouse.fabric.microsoft.com" +database: "" +authentication: "azure-active-directory-default" # recommended +``` + +Auth shorthands (mapped to full tedious type names): +- `CLI` or `default` → `azure-active-directory-default` +- `password` → `azure-active-directory-password` +- `service-principal` → `azure-active-directory-service-principal-secret` +- `msi` or `managed-identity` → `azure-active-directory-msi-vm` + +Full Azure AD authentication types: +- `azure-active-directory-default` — auto-discovers credentials via `DefaultAzureCredential` (recommended; works with `az login`) +- `azure-active-directory-password` — username/password with `azure_client_id` and `azure_tenant_id` +- `azure-active-directory-access-token` — pre-obtained token (does **not** auto-refresh) +- `azure-active-directory-service-principal-secret` — service principal with `azure_client_id`, `azure_client_secret`, `azure_tenant_id` +- `azure-active-directory-msi-vm` / `azure-active-directory-msi-app-service` — managed identity + +### Algorithm Behavior + +- **Same-warehouse** MSSQL or Fabric → `joindiff` (single FULL OUTER JOIN, most efficient) +- **Cross-warehouse** MSSQL/Fabric ↔ other database → `hashdiff` (automatic when using `auto`) +- The Rust engine maps `sqlserver`/`mssql` to `tsql` dialect and `fabric` to `fabric` dialect — both generate valid T-SQL syntax with bracket quoting (`[schema].[table]`). diff --git a/bun.lock b/bun.lock index 1b06053a5..25e43809d 100644 --- a/bun.lock +++ b/bun.lock @@ -48,7 +48,7 @@ "@google-cloud/bigquery": "^8.0.0", "duckdb": "^1.0.0", "mongodb": "^6.0.0", - "mssql": "^11.0.0", + "mssql": "^12.0.0", "mysql2": "^3.0.0", "oracledb": "^6.0.0", "pg": "^8.0.0", @@ -1034,7 +1034,7 @@ "@techteamer/ocsp": ["@techteamer/ocsp@1.0.1", "", { "dependencies": { "asn1.js": "^5.4.1", "asn1.js-rfc2560": "^5.0.1", "asn1.js-rfc5280": "^3.0.0", "async": "^3.2.4", "simple-lru-cache": "^0.0.2" } }, "sha512-q4pW5wAC6Pc3JI8UePwE37CkLQ5gDGZMgjSX4MEEm4D4Di59auDQ8UNIDzC4gRnPNmmcwjpPxozq8p5pjiOmOw=="], - "@tediousjs/connection-string": ["@tediousjs/connection-string@0.5.0", "", {}, "sha512-7qSgZbincDDDFyRweCIEvZULFAw5iz/DeunhvuxpL31nfntX3P4Yd4HkHBRg9H8CdqY1e5WFN1PZIz/REL9MVQ=="], + "@tediousjs/connection-string": ["@tediousjs/connection-string@0.6.0", "", {}, "sha512-GxlsW354Vi6QqbUgdPyQVcQjI7cZBdGV5vOYVYuCVDTylx2wl3WHR2HlhcxxHTrMigbelpXsdcZso+66uxPfow=="], "@tokenizer/token": ["@tokenizer/token@0.3.0", "", {}, "sha512-OvjF+z51L3ov0OyAU0duzsYuvO01PH7x4t6DJx+guahgTnBHkhJdG7soQeTSFLWN3efnHyibZ4Z8l2EuWwJN3A=="], @@ -1902,7 +1902,7 @@ "msgpackr-extract": ["msgpackr-extract@3.0.3", "", { "dependencies": { "node-gyp-build-optional-packages": "5.2.2" }, "optionalDependencies": { "@msgpackr-extract/msgpackr-extract-darwin-arm64": "3.0.3", "@msgpackr-extract/msgpackr-extract-darwin-x64": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-arm": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-arm64": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-x64": "3.0.3", "@msgpackr-extract/msgpackr-extract-win32-x64": "3.0.3" }, "bin": { "download-msgpackr-prebuilds": "bin/download-prebuilds.js" } }, "sha512-P0efT1C9jIdVRefqjzOQ9Xml57zpOXnIuS+csaB4MdZbTdmGDLo8XhzBG1N7aO11gKDDkJvBLULeFTo46wwreA=="], - "mssql": ["mssql@11.0.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.5.0", "commander": "^11.0.0", "debug": "^4.3.3", "rfdc": "^1.3.0", "tarn": "^3.0.2", "tedious": "^18.2.1" }, "bin": { "mssql": "bin/mssql" } }, "sha512-KlGNsugoT90enKlR8/G36H0kTxPthDhmtNUCwEHvgRza5Cjpjoj+P2X6eMpFUDN7pFrJZsKadL4x990G8RBE1w=="], + "mssql": ["mssql@12.2.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.6.0", "commander": "^11.0.0", "debug": "^4.3.3", "tarn": "^3.0.2", "tedious": "^19.0.0" }, "bin": { "mssql": "bin/mssql" } }, "sha512-TU89g82WatOVcinw3etO/crKbd67ugC3Wm6TJDklHjp7211brVENWIs++UoPC2H+TWvyi0OSlzMou8GY15onOA=="], "multicast-dns": ["multicast-dns@7.2.5", "", { "dependencies": { "dns-packet": "^5.2.2", "thunky": "^1.0.2" }, "bin": { "multicast-dns": "cli.js" } }, "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg=="], @@ -2336,7 +2336,7 @@ "tarn": ["tarn@3.0.2", "", {}, "sha512-51LAVKUSZSVfI05vjPESNc5vwqqZpbXCsU+/+wxlOrUjk2SnFTt97v9ZgQrD4YmxYW1Px6w2KjaDitCfkvgxMQ=="], - "tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "tedious": ["tedious@19.2.1", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.5", "@types/node": ">=18", "bl": "^6.1.4", "iconv-lite": "^0.7.0", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-pk1Q16Yl62iocuQB+RWbg6rFUFkIyzqOFQ6NfysCltRvQqKwfurgj8v/f2X+CKvDhSL4IJ0cCOfCHDg9PWEEYA=="], "teeny-request": ["teeny-request@10.1.0", "", { "dependencies": { "http-proxy-agent": "^5.0.0", "https-proxy-agent": "^5.0.0", "node-fetch": "^3.3.2", "stream-events": "^1.0.5" } }, "sha512-3ZnLvgWF29jikg1sAQ1g0o+lr5JX6sVgYvfUJazn7ZjJroDBUTWp44/+cFVX0bULjv4vci+rBD+oGVAkWqhUbw=="], @@ -2988,6 +2988,8 @@ "@smithy/util-waiter/@smithy/types": ["@smithy/types@4.13.1", "", { "dependencies": { "tslib": "^2.6.2" } }, "sha512-787F3yzE2UiJIQ+wYW1CVg2odHjmaWLGksnKQHUrK/lYZSEcy1msuLVvxaR/sI2/aDe9U+TBuLsXnr3vod1g0g=="], + "@types/mssql/tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "@types/request/form-data": ["form-data@2.5.5", "", { "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", "hasown": "^2.0.2", "mime-types": "^2.1.35", "safe-buffer": "^5.2.1" } }, "sha512-jqdObeR2rxZZbPSGL+3VckHMYtu+f9//KXBsVny6JSX/pa38Fy+bGjuG8eW/H6USNQWhLi8Num++cU2yOCNz4A=="], "accepts/negotiator": ["negotiator@1.0.0", "", {}, "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg=="], @@ -3040,6 +3042,8 @@ "cross-spawn/which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], + "drizzle-orm/mssql": ["mssql@11.0.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.5.0", "commander": "^11.0.0", "debug": "^4.3.3", "rfdc": "^1.3.0", "tarn": "^3.0.2", "tedious": "^18.2.1" }, "bin": { "mssql": "bin/mssql" } }, "sha512-KlGNsugoT90enKlR8/G36H0kTxPthDhmtNUCwEHvgRza5Cjpjoj+P2X6eMpFUDN7pFrJZsKadL4x990G8RBE1w=="], + "effect/@standard-schema/spec": ["@standard-schema/spec@1.1.0", "", {}, "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w=="], "effect/yaml": ["yaml@2.8.2", "", { "bin": { "yaml": "bin.mjs" } }, "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A=="], @@ -3164,8 +3168,6 @@ "tar-stream/bl": ["bl@4.1.0", "", { "dependencies": { "buffer": "^5.5.0", "inherits": "^2.0.4", "readable-stream": "^3.4.0" } }, "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w=="], - "tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], - "teeny-request/http-proxy-agent": ["http-proxy-agent@5.0.0", "", { "dependencies": { "@tootallnate/once": "2", "agent-base": "6", "debug": "4" } }, "sha512-n2hY8YdoRE1i7r6M0w9DIw5GgZN0G25P8zLCRQ8rjXtTU3vsNFBI/vWK/UIeE6g5MUUz6avwAPXmL6Fy9D/90w=="], "teeny-request/https-proxy-agent": ["https-proxy-agent@5.0.1", "", { "dependencies": { "agent-base": "6", "debug": "4" } }, "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA=="], @@ -3518,6 +3520,8 @@ "@smithy/util-stream/@smithy/node-http-handler/@smithy/querystring-builder": ["@smithy/querystring-builder@4.2.8", "", { "dependencies": { "@smithy/types": "^4.12.0", "@smithy/util-uri-escape": "^4.2.0", "tslib": "^2.6.2" } }, "sha512-Xr83r31+DrE8CP3MqPgMJl+pQlLLmOfiEUnoyAlGzzJIrEsbKsPy1hqH0qySaQm4oWrCBlUqRt+idEgunKB+iw=="], + "@types/mssql/tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], + "@types/request/form-data/mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="], "ai-gateway-provider/@ai-sdk/amazon-bedrock/@ai-sdk/anthropic": ["@ai-sdk/anthropic@2.0.62", "", { "dependencies": { "@ai-sdk/provider": "2.0.1", "@ai-sdk/provider-utils": "3.0.21" }, "peerDependencies": { "zod": "^3.25.76 || ^4.1.8" } }, "sha512-I3RhaOEMnWlWnrvjNBOYvUb19Dwf2nw01IruZrVJRDi688886e11wnd5DxrBZLd2V29Gizo3vpOPnnExsA+wTA=="], @@ -3556,6 +3560,12 @@ "cross-spawn/which/isexe": ["isexe@2.0.0", "", {}, "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw=="], + "drizzle-orm/mssql/@tediousjs/connection-string": ["@tediousjs/connection-string@0.5.0", "", {}, "sha512-7qSgZbincDDDFyRweCIEvZULFAw5iz/DeunhvuxpL31nfntX3P4Yd4HkHBRg9H8CdqY1e5WFN1PZIz/REL9MVQ=="], + + "drizzle-orm/mssql/commander": ["commander@11.1.0", "", {}, "sha512-yPVavfyCcRhmorC7rWlkHn15b4wDVgVmBA7kV4QVBsF7kv/9TKJAbAXVTxvTnwP8HHKjRCJDClKbciiYS7p0DQ=="], + + "drizzle-orm/mssql/tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "form-data/mime-types/mime-db": ["mime-db@1.52.0", "", {}, "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg=="], "fs-minipass/minipass/yallist": ["yallist@4.0.0", "", {}, "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A=="], @@ -3778,6 +3788,8 @@ "cross-fetch/node-fetch/whatwg-url/webidl-conversions": ["webidl-conversions@3.0.1", "", {}, "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ=="], + "drizzle-orm/mssql/tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], + "gaxios/rimraf/glob/jackspeak": ["jackspeak@3.4.3", "", { "dependencies": { "@isaacs/cliui": "^8.0.2" }, "optionalDependencies": { "@pkgjs/parseargs": "^0.11.0" } }, "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw=="], "gaxios/rimraf/glob/minimatch": ["minimatch@9.0.5", "", { "dependencies": { "brace-expansion": "^2.0.1" } }, "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow=="], diff --git a/packages/drivers/package.json b/packages/drivers/package.json index 98a0112cf..361c1dd96 100644 --- a/packages/drivers/package.json +++ b/packages/drivers/package.json @@ -17,7 +17,7 @@ "@google-cloud/bigquery": "^8.0.0", "@databricks/sql": "^1.0.0", "mysql2": "^3.0.0", - "mssql": "^11.0.0", + "mssql": "^12.0.0", "oracledb": "^6.0.0", "duckdb": "^1.0.0", "mongodb": "^6.0.0", diff --git a/packages/drivers/src/normalize.ts b/packages/drivers/src/normalize.ts index 5afc20cee..2d3c36127 100644 --- a/packages/drivers/src/normalize.ts +++ b/packages/drivers/src/normalize.ts @@ -65,6 +65,12 @@ const SQLSERVER_ALIASES: AliasMap = { ...COMMON_ALIASES, host: ["server", "serverName", "server_name"], trust_server_certificate: ["trustServerCertificate"], + authentication: ["authenticationType", "auth_type", "authentication_type"], + azure_tenant_id: ["tenantId", "tenant_id", "azureTenantId"], + azure_client_id: ["clientId", "client_id", "azureClientId"], + azure_client_secret: ["clientSecret", "client_secret", "azureClientSecret"], + access_token: ["token", "accessToken"], + azure_resource_url: ["azureResourceUrl", "resourceUrl", "resource_url"], } const ORACLE_ALIASES: AliasMap = { @@ -104,6 +110,7 @@ const DRIVER_ALIASES: Record = { mariadb: MYSQL_ALIASES, sqlserver: SQLSERVER_ALIASES, mssql: SQLSERVER_ALIASES, + fabric: SQLSERVER_ALIASES, oracle: ORACLE_ALIASES, mongodb: MONGODB_ALIASES, mongo: MONGODB_ALIASES, diff --git a/packages/drivers/src/sqlserver.ts b/packages/drivers/src/sqlserver.ts index 3ea1e390f..4b82c67bb 100644 --- a/packages/drivers/src/sqlserver.ts +++ b/packages/drivers/src/sqlserver.ts @@ -4,12 +4,74 @@ import type { ConnectionConfig, Connector, ConnectorResult, ExecuteOptions, SchemaColumn } from "./types" +// --------------------------------------------------------------------------- +// Azure AD helpers — cache + resource URL resolution +// --------------------------------------------------------------------------- + +// Module-scoped token cache, keyed by `${resource}|${clientId ?? ""}`. +// Tokens are reused across `connect()` calls in the same process and refreshed +// a few minutes before expiry. Fixes the issue where every new connection +// fetched a fresh token (wasteful, risks throttling) and long-lived diffs +// failed silently when the embedded token hit its ~1h TTL. +const tokenCache = new Map() +const TOKEN_REFRESH_MARGIN_MS = 5 * 60 * 1000 // refresh 5 minutes before expiry +const TOKEN_FALLBACK_TTL_MS = 50 * 60 * 1000 // used when JWT has no exp claim + +/** + * Parse the `exp` claim from a JWT access token (milliseconds since epoch). + * Returns undefined if the token isn't a JWT or has no exp claim. + */ +function parseTokenExpiry(token: string): number | undefined { + try { + const parts = token.split(".") + if (parts.length !== 3) return undefined + const payload = parts[1] + // base64url → base64 + padding + const padded = payload.replace(/-/g, "+").replace(/_/g, "/") + + "=".repeat((4 - (payload.length % 4)) % 4) + const decoded = Buffer.from(padded, "base64").toString("utf-8") + const claims = JSON.parse(decoded) + return typeof claims.exp === "number" ? claims.exp * 1000 : undefined + } catch { + return undefined + } +} + +/** + * Resolve the Azure resource URL for token acquisition. + * + * Preference order: + * 1. Explicit `config.azure_resource_url`. + * 2. Inferred from host suffix (Azure Gov / China). + * 3. Default Azure commercial cloud. + */ +function resolveAzureResourceUrl(config: ConnectionConfig): string { + const explicit = config.azure_resource_url as string | undefined + if (explicit) return explicit + const host = (config.host as string | undefined) ?? "" + if (host.includes(".usgovcloudapi.net") || host.includes(".datawarehouse.fabric.microsoft.us")) { + return "https://database.usgovcloudapi.net/" + } + if (host.includes(".chinacloudapi.cn")) { + return "https://database.chinacloudapi.cn/" + } + return "https://database.windows.net/" +} + +/** Visible for testing: reset the module-scoped token cache. */ +export function _resetTokenCacheForTests(): void { + tokenCache.clear() +} + export async function connect(config: ConnectionConfig): Promise { let mssql: any + let MssqlConnectionPool: any try { // @ts-expect-error — mssql has no type declarations; installed as optional peerDependency - mssql = await import("mssql") - mssql = mssql.default || mssql + const mod = await import("mssql") + mssql = mod.default || mod + // ConnectionPool is a named export, not on .default + MssqlConnectionPool = mod.ConnectionPool ?? mssql.ConnectionPool } catch { throw new Error( "SQL Server driver not installed. Run: npm install mssql", @@ -24,8 +86,6 @@ export async function connect(config: ConnectionConfig): Promise { server: config.host ?? "127.0.0.1", port: config.port ?? 1433, database: config.database, - user: config.user, - password: config.password, options: { encrypt: config.encrypt ?? false, trustServerCertificate: config.trust_server_certificate ?? true, @@ -39,7 +99,198 @@ export async function connect(config: ConnectionConfig): Promise { }, } - pool = await mssql.connect(mssqlConfig) + // Normalize shorthand auth values to tedious-compatible types + const AUTH_SHORTHANDS: Record = { + cli: "azure-active-directory-default", + default: "azure-active-directory-default", + password: "azure-active-directory-password", + "service-principal": "azure-active-directory-service-principal-secret", + serviceprincipal: "azure-active-directory-service-principal-secret", + "managed-identity": "azure-active-directory-msi-vm", + msi: "azure-active-directory-msi-vm", + } + // `config.authentication` is typed as unknown upstream — accept only + // strings here. A caller passing a non-string (object, null, pre-built + // auth block) shouldn't crash with "toLowerCase is not a function"; + // treat as "no shorthand requested" and leave authType undefined. + const rawAuth = config.authentication + const authType = + typeof rawAuth === "string" + ? (AUTH_SHORTHANDS[rawAuth.toLowerCase()] ?? rawAuth) + : undefined + + if (authType?.startsWith("azure-active-directory")) { + ;(mssqlConfig.options as any).encrypt = true + + // Resolve a raw Azure AD access token. + // Used by both `azure-active-directory-default` and by + // `azure-active-directory-access-token` when no token was provided. + // + // We acquire the token ourselves rather than letting tedious do it because: + // 1. Bun can resolve @azure/identity to the browser bundle (inside + // tedious or even our own import), where DefaultAzureCredential + // is a non-functional stub that throws. + // 2. Passing a credential object via type:"token-credential" hits a + // CJS/ESM isTokenCredential boundary mismatch in Bun. + // + // Strategy: try @azure/identity first (works when module resolution + // is correct), fall back to shelling out to `az account get-access-token` + // (works everywhere Azure CLI is installed). + // + // Tokens are cached module-scope keyed by (resource, client_id) and + // refreshed 5 minutes before expiry — reuses tokens across connections + // and prevents silent failures when embedded tokens hit their TTL. + const resourceUrl = resolveAzureResourceUrl(config) + const clientId = (config.azure_client_id as string | undefined) ?? "" + const cacheKey = `${resourceUrl}|${clientId}` + + const acquireAzureToken = async (): Promise => { + const cached = tokenCache.get(cacheKey) + if (cached && cached.expiresAt - Date.now() > TOKEN_REFRESH_MARGIN_MS) { + return cached.token + } + + let token: string | undefined + let expiresAt: number | undefined + let azureIdentityError: unknown = null + let azCliStderr = "" + + try { + const azureIdentity = await import("@azure/identity") + const credential = new azureIdentity.DefaultAzureCredential( + config.azure_client_id + ? { managedIdentityClientId: config.azure_client_id as string } + : undefined, + ) + const tokenResponse = await credential.getToken(`${resourceUrl}.default`) + if (tokenResponse?.token) { + token = tokenResponse.token + // @azure/identity provides expiresOnTimestamp (ms). Prefer it; fall + // back to parsing the JWT exp claim so both paths share the cache. + expiresAt = tokenResponse.expiresOnTimestamp ?? parseTokenExpiry(token) + } + } catch (err) { + azureIdentityError = err + // @azure/identity unavailable or browser bundle — fall through to CLI + } + + if (!token) { + try { + // Use async `exec` (not `execSync`) so the connection path stays + // non-blocking — `az account get-access-token` can take several + // seconds to round-trip and execSync would block the entire + // event loop for that duration. + const childProcess = await import("node:child_process") + const { promisify } = await import("node:util") + const execAsync = promisify(childProcess.exec) + const { stdout } = await execAsync( + `az account get-access-token --resource ${resourceUrl} --query accessToken -o tsv`, + { encoding: "utf-8", timeout: 15000 }, + ) + const out = stdout.trim() + if (out) { + token = out + expiresAt = parseTokenExpiry(out) + } + } catch (err: any) { + // Capture stderr so the final error message can hint at the root cause + // (e.g. "Please run 'az login'", "subscription not found"). + azCliStderr = String(err?.stderr ?? err?.message ?? "").slice(0, 200).trim() + } + } + + if (!token) { + const hints: string[] = [] + if (azureIdentityError) hints.push(`@azure/identity: ${String(azureIdentityError).slice(0, 120)}`) + if (azCliStderr) hints.push(`az CLI: ${azCliStderr}`) + const detail = hints.length > 0 ? ` (${hints.join("; ")})` : "" + throw new Error( + `Azure AD token acquisition failed${detail}. Either install @azure/identity (npm install @azure/identity) ` + + "or log in with Azure CLI (az login).", + ) + } + + tokenCache.set(cacheKey, { + token, + expiresAt: expiresAt ?? Date.now() + TOKEN_FALLBACK_TTL_MS, + }) + return token + } + + if (authType === "azure-active-directory-default") { + mssqlConfig.authentication = { + type: "azure-active-directory-access-token", + options: { token: await acquireAzureToken() }, + } + } else if (authType === "azure-active-directory-password") { + mssqlConfig.authentication = { + type: "azure-active-directory-password", + options: { + userName: config.user, + password: config.password, + clientId: config.azure_client_id, + tenantId: config.azure_tenant_id, + }, + } + } else if (authType === "azure-active-directory-access-token") { + // If the caller supplied a token, use it; otherwise acquire one + // automatically (DefaultAzureCredential → az CLI). + const suppliedToken = (config.token ?? config.access_token) as string | undefined + mssqlConfig.authentication = { + type: "azure-active-directory-access-token", + options: { token: suppliedToken ?? (await acquireAzureToken()) }, + } + } else if ( + authType === "azure-active-directory-msi-vm" || + authType === "azure-active-directory-msi-app-service" + ) { + mssqlConfig.authentication = { + type: authType, + options: { + ...(config.azure_client_id ? { clientId: config.azure_client_id } : {}), + }, + } + } else if (authType === "azure-active-directory-service-principal-secret") { + mssqlConfig.authentication = { + type: "azure-active-directory-service-principal-secret", + options: { + clientId: config.azure_client_id, + clientSecret: config.azure_client_secret, + tenantId: config.azure_tenant_id, + }, + } + } else { + // Any other `azure-active-directory-*` subtype (typo or future + // tedious addition). Fail fast — otherwise we'd silently connect + // with no `authentication` block and tedious would surface an + // opaque error far from the root cause. + throw new Error( + `Unsupported Azure AD authentication subtype: "${authType}". ` + + "Supported subtypes: azure-active-directory-default, " + + "azure-active-directory-password, azure-active-directory-access-token, " + + "azure-active-directory-msi-vm, azure-active-directory-msi-app-service, " + + "azure-active-directory-service-principal-secret.", + ) + } + } else { + // Standard SQL Server user/password + mssqlConfig.user = config.user + mssqlConfig.password = config.password + } + + // Use an explicit ConnectionPool (not the global mssql.connect()) so + // multiple simultaneous connections to different servers are isolated. + // `mssql@^12` guarantees ConnectionPool as a named export — if it's + // missing, the installed driver version is too old. Fail fast rather + // than silently use the global shared pool (which reintroduces the + // cross-database interference bug this branch was added to fix). + if (!MssqlConnectionPool) { + throw new Error( + "mssql.ConnectionPool is not available — the installed `mssql` package is too old. Upgrade to mssql@^12.", + ) + } + pool = new MssqlConnectionPool(mssqlConfig) + await pool.connect() }, async execute(sql: string, limit?: number, _binds?: any[], options?: ExecuteOptions): Promise { @@ -62,22 +313,56 @@ export async function connect(config: ConnectionConfig): Promise { } const result = await pool.request().query(query) - const rows = result.recordset ?? [] - const columns = - rows.length > 0 - ? Object.keys(rows[0]).filter((k) => !k.startsWith("_")) - : (result.recordset?.columns - ? Object.keys(result.recordset.columns) - : []) - const truncated = effectiveLimit > 0 && rows.length > effectiveLimit - const limitedRows = truncated ? rows.slice(0, effectiveLimit) : rows + const recordset = result.recordset ?? [] + const truncated = effectiveLimit > 0 && recordset.length > effectiveLimit + const limitedRecordset = truncated ? recordset.slice(0, effectiveLimit) : recordset + + // mssql merges unnamed columns (e.g. SELECT COUNT(*), SUM(...)) into a + // single array under the empty-string key: row[""] = [val1, val2, ...]. + // When a query mixes named and unnamed columns (e.g. + // SELECT name, COUNT(*), SUM(x) → { name: "alice", "": [42, 100] }), + // we must preserve the known header for `name` and synthesize col_N only + // for the unnamed positions. Build columns and rows in a single pass so + // they stay aligned regardless of how many unnamed values the row + // contains. + let columns: string[] = [] + let columnsBuilt = false + const flatten = (row: any): any[] => { + const vals: any[] = [] + let unnamedCounter = 0 + const entries = Object.entries(row) + for (const [k, v] of entries) { + if (k === "" && Array.isArray(v)) { + for (const inner of v) { + if (!columnsBuilt) columns.push(`col_${unnamedCounter}`) + unnamedCounter++ + vals.push(inner) + } + } else if (k === "") { + // Empty-string key with non-array value — rare edge case, give it + // a synthetic name rather than producing a column named "". + if (!columnsBuilt) columns.push(`col_${unnamedCounter}`) + unnamedCounter++ + vals.push(v) + } else { + if (!columnsBuilt) columns.push(k) + vals.push(v) + } + } + columnsBuilt = true + return vals + } + + const rows = limitedRecordset.map(flatten) + if (!columnsBuilt) { + // No rows — fall back to driver-reported column metadata. + columns = result.recordset?.columns ? Object.keys(result.recordset.columns) : [] + } return { columns, - rows: limitedRows.map((row: any) => - columns.map((col) => row[col]), - ), - row_count: limitedRows.length, + rows, + row_count: rows.length, truncated, } }, diff --git a/packages/drivers/test/sqlserver-unit.test.ts b/packages/drivers/test/sqlserver-unit.test.ts new file mode 100644 index 000000000..f8f93af22 --- /dev/null +++ b/packages/drivers/test/sqlserver-unit.test.ts @@ -0,0 +1,765 @@ +/** + * Unit tests for SQL Server driver logic: + * - TOP injection (vs LIMIT) + * - Truncation detection + * - Azure AD authentication (7 flows) + * - Schema introspection queries + * - Connection lifecycle + * - Result format mapping + */ +import { describe, test, expect, mock, beforeEach } from "bun:test" + +// --- Mock mssql --- + +let mockQueryCalls: string[] = [] +let mockQueryResult: any = { recordset: [] } +let mockConnectCalls: any[] = [] +let mockCloseCalls = 0 +let mockInputs: Array<{ name: string; value: any }> = [] + +function resetMocks() { + mockQueryCalls = [] + mockQueryResult = { recordset: [] } + mockConnectCalls = [] + mockCloseCalls = 0 + mockInputs = [] +} + +function createMockRequest() { + const req: any = { + input(name: string, value: any) { + mockInputs.push({ name, value }) + return req + }, + async query(sql: string) { + mockQueryCalls.push(sql) + return mockQueryResult + }, + } + return req +} + +function createMockPool(config: any) { + mockConnectCalls.push(config) + return { + connect: async () => {}, + request: () => createMockRequest(), + close: async () => { + mockCloseCalls++ + }, + } +} + +mock.module("mssql", () => ({ + default: { + connect: async (config: any) => createMockPool(config), + }, + ConnectionPool: class { + _pool: any + constructor(config: any) { + this._pool = createMockPool(config) + } + async connect() { return this._pool.connect() } + request() { return this._pool.request() } + async close() { return this._pool.close() } + }, +})) + +// Exposed to individual tests so they can assert scope / force failures. +const azureIdentityState = { + lastScope: "" as string, + tokenOverride: null as null | { token: string; expiresOnTimestamp?: number }, + throwOnGetToken: false as boolean, +} +mock.module("@azure/identity", () => ({ + DefaultAzureCredential: class { + _opts: any + constructor(opts?: any) { this._opts = opts } + async getToken(scope: string) { + azureIdentityState.lastScope = scope + if (azureIdentityState.throwOnGetToken) throw new Error("mock identity failure") + if (azureIdentityState.tokenOverride) return azureIdentityState.tokenOverride + return { token: "mock-azure-token-12345", expiresOnTimestamp: Date.now() + 3600000 } + } + }, +})) + +// Exposed to tests to stub the `az` CLI fallback. +const cliState = { + lastCmd: "" as string, + output: "mock-cli-token-fallback\n" as string, + throwError: null as null | { stderr?: string; message?: string }, +} +const realChildProcess = await import("node:child_process") +const realUtil = await import("node:util") +// Stub `exec` with a custom `util.promisify.custom` so `promisify(exec)` +// yields { stdout, stderr } exactly as the real implementation does. Also +// keep the legacy callback form of `execSync` for tests that still use it. +const execStub: any = (cmd: string, optsOrCb: any, maybeCb?: any) => { + cliState.lastCmd = cmd + const cb = typeof optsOrCb === "function" ? optsOrCb : maybeCb + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + if (cb) cb(e, "", cliState.throwError.stderr ?? "") + return { on() {}, stdout: null, stderr: null } + } + if (cb) cb(null, cliState.output, "") + return { on() {}, stdout: null, stderr: null } +} +execStub[realUtil.promisify.custom] = (cmd: string, _opts?: any) => { + cliState.lastCmd = cmd + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + return Promise.reject(e) + } + return Promise.resolve({ stdout: cliState.output, stderr: "" }) +} +mock.module("node:child_process", () => ({ + ...realChildProcess, + execSync: (cmd: string) => { + cliState.lastCmd = cmd + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + throw e + } + return cliState.output + }, + exec: execStub, +})) + +// Import after mocking +const { connect, _resetTokenCacheForTests } = await import("../src/sqlserver") + +describe("SQL Server driver unit tests", () => { + let connector: Awaited> + + beforeEach(async () => { + resetMocks() + connector = await connect({ host: "localhost", port: 1433, database: "testdb", user: "sa", password: "pass" }) + await connector.connect() + }) + + // --- TOP injection --- + + describe("TOP injection", () => { + test("injects TOP for SELECT without one", async () => { + mockQueryResult = { recordset: [{ id: 1, name: "a" }] } + await connector.execute("SELECT * FROM t") + expect(mockQueryCalls[0]).toContain("TOP 1001") + }) + + test("does NOT double-TOP when TOP already present", async () => { + mockQueryResult = { recordset: [{ id: 1 }] } + await connector.execute("SELECT TOP 5 * FROM t") + expect(mockQueryCalls[0]).toBe("SELECT TOP 5 * FROM t") + }) + + test("does NOT inject TOP when LIMIT present", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t LIMIT 10") + expect(mockQueryCalls[0]).toBe("SELECT * FROM t LIMIT 10") + }) + + test("noLimit bypasses TOP injection", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t", undefined, undefined, { noLimit: true }) + expect(mockQueryCalls[0]).toBe("SELECT * FROM t") + }) + + test("uses custom limit value", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t", 50) + expect(mockQueryCalls[0]).toContain("TOP 51") + }) + + test("default limit is 1000", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t") + expect(mockQueryCalls[0]).toContain("TOP 1001") + }) + }) + + // --- Truncation --- + + describe("truncation detection", () => { + test("detects truncation when rows exceed limit", async () => { + const rows = Array.from({ length: 11 }, (_, i) => ({ id: i })) + mockQueryResult = { recordset: rows } + const result = await connector.execute("SELECT * FROM t", 10) + expect(result.truncated).toBe(true) + expect(result.rows.length).toBe(10) + }) + + test("no truncation when rows at or below limit", async () => { + mockQueryResult = { recordset: [{ id: 1 }, { id: 2 }] } + const result = await connector.execute("SELECT * FROM t", 10) + expect(result.truncated).toBe(false) + }) + + test("empty result returns correctly", async () => { + // mssql exposes column metadata as `recordset.columns` (a property ON + // the recordset array), not as a sibling key — mirror the real shape. + const recordset: any[] = [] + ;(recordset as any).columns = {} + mockQueryResult = { recordset } + const result = await connector.execute("SELECT * FROM t") + expect(result.rows).toEqual([]) + expect(result.truncated).toBe(false) + }) + }) + + // --- Azure AD authentication --- + + describe("Azure AD authentication", () => { + test("standard auth uses user/password directly", async () => { + resetMocks() + const c = await connect({ host: "localhost", database: "db", user: "sa", password: "pass" }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.user).toBe("sa") + expect(cfg.password).toBe("pass") + expect(cfg.authentication).toBeUndefined() + }) + + test("azure-active-directory-password builds correct auth object", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + user: "user@domain.com", + password: "secret", + authentication: "azure-active-directory-password", + azure_client_id: "client-123", + azure_tenant_id: "tenant-456", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-password", + options: { + userName: "user@domain.com", + password: "secret", + clientId: "client-123", + tenantId: "tenant-456", + }, + }) + expect(cfg.user).toBeUndefined() + expect(cfg.password).toBeUndefined() + }) + + test("azure-active-directory-access-token passes supplied token unchanged", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-access-token", + access_token: "eyJhbGciOi...", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-access-token", + options: { token: "eyJhbGciOi..." }, + }) + }) + + test("azure-active-directory-access-token with no token auto-acquires one", async () => { + // Regression: prior to this, omitting `token`/`access_token` resulted in + // `options.token: undefined`, which tedious rejects with + // "config.authentication.options.token must be of type string". + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-access-token", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("azure-active-directory-service-principal-secret builds SP auth", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-service-principal-secret", + azure_client_id: "sp-client", + azure_client_secret: "sp-secret", + azure_tenant_id: "sp-tenant", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-service-principal-secret", + options: { + clientId: "sp-client", + clientSecret: "sp-secret", + tenantId: "sp-tenant", + }, + }) + }) + + test("azure-active-directory-msi-vm builds MSI auth with optional clientId", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-msi-vm", + azure_client_id: "msi-client", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-msi-vm", + options: { clientId: "msi-client" }, + }) + }) + + test("azure-active-directory-msi-app-service works without clientId", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-msi-app-service", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-msi-app-service", + options: {}, + }) + }) + + test("azure-active-directory-default acquires token and passes as access-token", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-default", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("azure-active-directory-default with client_id passes managedIdentityClientId to credential", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-default", + azure_client_id: "mi-client-id", + }) + await c.connect() + const cfg = mockConnectCalls[0] + // Token is still passed as access-token regardless of client_id + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("encryption forced for all Azure AD connections", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-password", + user: "u", + password: "p", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.options.encrypt).toBe(true) + }) + + test("standard auth does not force encryption", async () => { + resetMocks() + const c = await connect({ host: "localhost", database: "db", user: "sa", password: "pass" }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.options.encrypt).toBe(false) + }) + + test("'CLI' shorthand acquires token via DefaultAzureCredential", async () => { + resetMocks() + const c = await connect({ + host: "myserver.datawarehouse.fabric.microsoft.com", + database: "migration", + authentication: "CLI", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + expect(cfg.options.encrypt).toBe(true) + }) + + test("'service-principal' shorthand maps correctly", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "service-principal", + azure_client_id: "cid", + azure_client_secret: "csec", + azure_tenant_id: "tid", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-service-principal-secret") + expect(cfg.authentication.options.clientId).toBe("cid") + }) + + test("'msi' shorthand maps to azure-active-directory-msi-vm", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "msi", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-msi-vm") + }) + }) + + // --- Schema introspection --- + + describe("schema introspection", () => { + test("listSchemas queries sys.schemas", async () => { + mockQueryResult = { recordset: [{ name: "dbo" }, { name: "sales" }] } + const schemas = await connector.listSchemas() + expect(mockQueryCalls[0]).toContain("sys.schemas") + expect(schemas).toEqual(["dbo", "sales"]) + }) + + test("listTables queries sys.tables and sys.views", async () => { + mockQueryResult = { + recordset: [ + { name: "orders", type: "U " }, + { name: "order_summary", type: "V" }, + ], + } + const tables = await connector.listTables("dbo") + expect(mockQueryCalls[0]).toContain("UNION ALL") + expect(mockQueryCalls[0]).toContain("sys.tables") + expect(mockQueryCalls[0]).toContain("sys.views") + expect(tables).toEqual([ + { name: "orders", type: "table" }, + { name: "order_summary", type: "view" }, + ]) + }) + + test("describeTable queries sys.columns", async () => { + mockQueryResult = { + recordset: [ + { column_name: "id", data_type: "int", is_nullable: 0 }, + { column_name: "name", data_type: "nvarchar", is_nullable: 1 }, + ], + } + const cols = await connector.describeTable("dbo", "users") + expect(mockQueryCalls[0]).toContain("sys.columns") + expect(cols).toEqual([ + { name: "id", data_type: "int", nullable: false }, + { name: "name", data_type: "nvarchar", nullable: true }, + ]) + }) + }) + + // --- Connection lifecycle --- + + describe("connection lifecycle", () => { + test("close is idempotent", async () => { + await connector.close() + await connector.close() + expect(mockCloseCalls).toBe(1) + }) + }) + + // --- Result format --- + + describe("result format", () => { + test("maps recordset to column-ordered arrays", async () => { + mockQueryResult = { + recordset: [ + { id: 1, name: "alice", age: 30 }, + { id: 2, name: "bob", age: 25 }, + ], + } + const result = await connector.execute("SELECT id, name, age FROM t") + expect(result.columns).toEqual(["id", "name", "age"]) + expect(result.rows).toEqual([ + [1, "alice", 30], + [2, "bob", 25], + ]) + }) + + test("preserves underscore-prefixed columns", async () => { + mockQueryResult = { + recordset: [{ id: 1, _p: "Delivered", name: "x" }], + } + const result = await connector.execute("SELECT * FROM t") + expect(result.columns).toEqual(["id", "_p", "name"]) + }) + }) + + // --- Unnamed column flattening --- + + describe("unnamed column flattening", () => { + test("flattens unnamed columns merged under empty-string key", async () => { + // mssql merges SELECT COUNT(*), SUM(amount) into row[""] = [42, 1000] + mockQueryResult = { + recordset: [{ "": [42, 1000] }], + } + const result = await connector.execute("SELECT COUNT(*), SUM(amount) FROM t") + expect(result.rows).toEqual([[42, 1000]]) + expect(result.columns).toEqual(["col_0", "col_1"]) + }) + + test("preserves legitimate array values from named columns", async () => { + // A named column containing an array (e.g. from JSON aggregation) + // should NOT be spread — only the empty-string key gets flattened + mockQueryResult = { + recordset: [{ id: 1, tags: ["a", "b", "c"] }], + } + const result = await connector.execute("SELECT * FROM t") + expect(result.columns).toEqual(["id", "tags"]) + expect(result.rows).toEqual([[1, ["a", "b", "c"]]]) + }) + + test("handles mix of named and unnamed columns", async () => { + mockQueryResult = { + recordset: [{ name: "alice", "": [42] }], + } + const result = await connector.execute("SELECT * FROM t") + // Named header preserved; single unnamed aggregate synthesized. + expect(result.columns).toEqual(["name", "col_0"]) + expect(result.rows).toEqual([["alice", 42]]) + }) + + test("mixed named + MULTIPLE unnamed aggregates keep named header", async () => { + // SELECT name, COUNT(*), SUM(x) FROM t → { name: "alice", "": [42, 100] }. + // Regression: previous implementation fell back to col_0..col_N for all + // columns, erasing the known `name` header. + mockQueryResult = { + recordset: [{ name: "alice", "": [42, 100] }], + } + const result = await connector.execute("SELECT name, COUNT(*), SUM(x) FROM t") + expect(result.columns).toEqual(["name", "col_0", "col_1"]) + expect(result.rows).toEqual([["alice", 42, 100]]) + }) + + test("single unnamed column gets synthetic name (no blank header)", async () => { + // SELECT COUNT(*) FROM t → { "": [5] } + mockQueryResult = { + recordset: [{ "": [5] }], + } + const result = await connector.execute("SELECT COUNT(*) FROM t") + expect(result.columns).toEqual(["col_0"]) + expect(result.columns).not.toContain("") + expect(result.rows).toEqual([[5]]) + }) + }) + + // --- Azure token caching (Fix #2) --- + + describe("Azure token cache", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + cliState.output = "mock-cli-token-fallback\n" + }) + + test("second connect with same (resource, clientId) reuses cached token", async () => { + let getTokenCalls = 0 + azureIdentityState.tokenOverride = { token: "cached-token-A", expiresOnTimestamp: Date.now() + 3600_000 } + // Hook getToken counter + const origCredential = (await import("@azure/identity")).DefaultAzureCredential + const origGetToken = origCredential.prototype.getToken + origCredential.prototype.getToken = async function (scope: string) { + getTokenCalls++ + return origGetToken.call(this, scope) + } + try { + resetMocks() + const c1 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c1.connect() + const c2 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c2.connect() + expect(getTokenCalls).toBe(1) + // Both pool configs embed the same cached token + expect(mockConnectCalls[0].authentication.options.token).toBe("cached-token-A") + expect(mockConnectCalls[1].authentication.options.token).toBe("cached-token-A") + } finally { + origCredential.prototype.getToken = origGetToken + } + }) + + test("near-expiry token triggers refresh", async () => { + // First token expires in 1 minute (well under the 5-minute refresh margin) + azureIdentityState.tokenOverride = { token: "about-to-expire", expiresOnTimestamp: Date.now() + 60_000 } + resetMocks() + const c1 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c1.connect() + // Now change the mock to issue a new token on refresh + azureIdentityState.tokenOverride = { token: "fresh-token", expiresOnTimestamp: Date.now() + 3600_000 } + const c2 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c2.connect() + expect(mockConnectCalls[0].authentication.options.token).toBe("about-to-expire") + expect(mockConnectCalls[1].authentication.options.token).toBe("fresh-token") + }) + + test("different clientIds cache separately", async () => { + // Prove cache keying by counting distinct getToken invocations: with + // separate clientIds we expect 2 calls (one per key); with a shared + // clientId we expect 1 on the second connect. + let getTokenCalls = 0 + azureIdentityState.tokenOverride = { token: "shared-token", expiresOnTimestamp: Date.now() + 3600_000 } + const origCredential = (await import("@azure/identity")).DefaultAzureCredential + const origGetToken = origCredential.prototype.getToken + origCredential.prototype.getToken = async function (scope: string) { + getTokenCalls++ + return origGetToken.call(this, scope) + } + try { + resetMocks() + const a = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-1", + }) + await a.connect() + const b = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-2", + }) + await b.connect() + // Two distinct client IDs → two distinct cache entries → two getToken + // calls. If the cache were keyed only on resource URL this would be 1. + expect(getTokenCalls).toBe(2) + expect(mockConnectCalls[0].authentication.options.token).toBe("shared-token") + expect(mockConnectCalls[1].authentication.options.token).toBe("shared-token") + + // Reconnect with client-1 again — should hit the cache, no new getToken + const c = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-1", + }) + await c.connect() + expect(getTokenCalls).toBe(2) + } finally { + origCredential.prototype.getToken = origGetToken + } + }) + }) + + // --- Configurable / inferred Azure resource URL (Fix #5) --- + + describe("Azure resource URL resolution", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + }) + + test("commercial cloud: default to database.windows.net", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.windows.net/.default") + }) + + test("Azure Government host infers usgovcloudapi.net", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.usgovcloudapi.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.usgovcloudapi.net/.default") + }) + + test("Azure China host infers chinacloudapi.cn", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.chinacloudapi.cn", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.chinacloudapi.cn/.default") + }) + + test("explicit azure_resource_url wins over host inference", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", // commercial host + database: "d", + authentication: "azure-active-directory-default", + azure_resource_url: "https://custom.sovereign.example/", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://custom.sovereign.example/.default") + }) + + test("az CLI fallback uses the same resource URL", async () => { + // Disable @azure/identity so we hit the az CLI fallback + azureIdentityState.throwOnGetToken = true + cliState.output = "eyJ.eyJ.sig\n" // looks like JWT; parseTokenExpiry returns undefined → fallback TTL + resetMocks() + const c = await connect({ + host: "myserver.database.usgovcloudapi.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(cliState.lastCmd).toContain("--resource https://database.usgovcloudapi.net/") + }) + }) + + // --- Error surfacing when auth fails (Fix #5 bonus, Minor #10 addressed) --- + + describe("Azure auth error surfacing", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + }) + + test("both @azure/identity and az CLI fail → error includes both hints", async () => { + azureIdentityState.throwOnGetToken = true + cliState.throwError = { stderr: "Please run 'az login' to set up an account.", message: "failed" } + resetMocks() + const c = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await expect(c.connect()).rejects.toThrow(/Azure AD token acquisition failed/) + await expect(c.connect()).rejects.toThrow(/az CLI:.*az login/) + }) + }) +}) diff --git a/packages/opencode/src/altimate/native/connections/data-diff.ts b/packages/opencode/src/altimate/native/connections/data-diff.ts index 294c43745..3f78c0398 100644 --- a/packages/opencode/src/altimate/native/connections/data-diff.ts +++ b/packages/opencode/src/altimate/native/connections/data-diff.ts @@ -10,6 +10,50 @@ import type { DataDiffParams, DataDiffResult, PartitionDiffResult } from "../types" import * as Registry from "./registry" +// --------------------------------------------------------------------------- +// Dialect mapping — bridge warehouse config types to Rust SqlDialect serde names +// --------------------------------------------------------------------------- + +/** Map warehouse config types to Rust SqlDialect serde names. */ +const WAREHOUSE_TO_DIALECT: Record = { + sqlserver: "tsql", + mssql: "tsql", + fabric: "fabric", + postgresql: "postgres", + mariadb: "mysql", +} + +/** Convert a warehouse config type to the Rust-compatible SqlDialect name. */ +export function warehouseTypeToDialect(warehouseType: string): string { + return WAREHOUSE_TO_DIALECT[warehouseType.toLowerCase()] ?? warehouseType.toLowerCase() +} + +// --------------------------------------------------------------------------- +// Dialect-aware identifier quoting +// --------------------------------------------------------------------------- + +/** + * Quote a SQL identifier using the correct delimiter for the dialect. + * Used both for partition column/value quoting and for plain-table-name + * wrapping inside CTEs (via `resolveTableSources`). + */ +function quoteIdentForDialect(identifier: string, dialect: string): string { + switch (dialect) { + case "mysql": + case "mariadb": + case "clickhouse": + return `\`${identifier.replace(/`/g, "``")}\`` + case "tsql": + case "fabric": + case "sqlserver": + case "mssql": + return `[${identifier.replace(/\]/g, "]]")}]` + default: + // ANSI SQL: Postgres, Snowflake, BigQuery, DuckDB, Oracle, Redshift, etc. + return `"${identifier.replace(/"/g, '""')}"` + } +} + // --------------------------------------------------------------------------- // Query-source detection // --------------------------------------------------------------------------- @@ -18,49 +62,82 @@ const SQL_KEYWORDS = /^\s*(SELECT|WITH|VALUES)\b/i /** * Detect whether a string is an arbitrary SQL query (vs a plain table name). - * Plain table names may contain dots (schema.table, db.schema.table) but not spaces. + * + * A SQL query starts with a keyword AND contains whitespace (e.g., "SELECT * FROM ..."). + * A plain table name — even one named "select" or "with" — is a single token without + * internal whitespace (possibly dot-separated like schema.table or db.schema.table). + * + * The \b in SQL_KEYWORDS already prevents matching "with_metadata" or "select_results", + * but the whitespace check additionally handles bare keyword table names like "select". */ function isQuery(input: string): boolean { - return SQL_KEYWORDS.test(input) + const trimmed = input.trim() + return SQL_KEYWORDS.test(trimmed) && /\s/.test(trimmed) } /** * If either source or target is an arbitrary query, wrap them in CTEs so the * DataParity engine can treat them as tables named `__diff_source` / `__diff_target`. * - * Returns `{ table1Name, table2Name, ctePrefix | null }`. + * Returns both a combined prefix (used for same-warehouse tasks where a JOIN + * might reference both CTEs) and side-specific prefixes (used for cross-warehouse + * tasks where each warehouse only has access to its own base tables). * - * When a CTE prefix is returned, it must be prepended to every SQL task emitted - * by the engine before execution. + * **Why side-specific prefixes matter:** T-SQL / Fabric parse-bind every CTE body + * at parse time, even unreferenced ones. Sending a combined `WITH __diff_source + * AS (... FROM mssql_only_table), __diff_target AS (... FROM fabric_only_table)` + * to MSSQL fails because MSSQL can't resolve the Fabric-only table referenced in + * the unused `__diff_target` CTE. + * + * Callers must prepend the appropriate prefix to every SQL task emitted by the + * engine before execution. */ export function resolveTableSources( source: string, target: string, -): { table1Name: string; table2Name: string; ctePrefix: string | null } { + sourceDialect?: string, + targetDialect?: string, +): { + table1Name: string + table2Name: string + ctePrefix: string | null + sourceCtePrefix: string | null + targetCtePrefix: string | null +} { const source_is_query = isQuery(source) const target_is_query = isQuery(target) if (!source_is_query && !target_is_query) { // Both are plain table names — pass through unchanged - return { table1Name: source, table2Name: target, ctePrefix: null } + return { + table1Name: source, + table2Name: target, + ctePrefix: null, + sourceCtePrefix: null, + targetCtePrefix: null, + } } - // At least one is a query — wrap both in CTEs - // Quote identifier parts so table names with special chars don't inject SQL. - // Use double-quote escaping (ANSI SQL standard, works in Postgres/Snowflake/DuckDB/etc.) - const quoteIdent = (name: string) => - name - .split(".") - .map((p) => `"${p.replace(/"/g, '""')}"`) - .join(".") - const srcExpr = source_is_query ? source : `SELECT * FROM ${quoteIdent(source)}` - const tgtExpr = target_is_query ? target : `SELECT * FROM ${quoteIdent(target)}` + // At least one is a query — wrap both in CTEs. Quote plain-table names with + // the *side's own* dialect so T-SQL / Fabric get `[schema].[table]` and + // ANSI dialects get `"schema"."table"` — avoids `QUOTED_IDENTIFIER OFF` + // surprises on MSSQL/Fabric. Fallback to ANSI when dialect is unspecified. + const quoteTableRef = (name: string, dialect: string | undefined): string => { + const d = dialect ?? "generic" + return name.split(".").map((p) => quoteIdentForDialect(p, d)).join(".") + } + const srcExpr = source_is_query ? source : `SELECT * FROM ${quoteTableRef(source, sourceDialect)}` + const tgtExpr = target_is_query ? target : `SELECT * FROM ${quoteTableRef(target, targetDialect)}` + const sourceCtePrefix = `WITH __diff_source AS (\n${srcExpr}\n)` + const targetCtePrefix = `WITH __diff_target AS (\n${tgtExpr}\n)` const ctePrefix = `WITH __diff_source AS (\n${srcExpr}\n), __diff_target AS (\n${tgtExpr}\n)` return { table1Name: "__diff_source", table2Name: "__diff_target", ctePrefix, + sourceCtePrefix, + targetCtePrefix, } } @@ -403,24 +480,6 @@ const MAX_STEPS = 200 // Partition support // --------------------------------------------------------------------------- -/** - * Quote a SQL identifier using the correct delimiter for the dialect. - */ -function quoteIdentForDialect(identifier: string, dialect: string): string { - switch (dialect) { - case "mysql": - case "mariadb": - case "clickhouse": - return `\`${identifier.replace(/`/g, "``")}\`` - case "tsql": - case "fabric": - return `[${identifier.replace(/\]/g, "]]")}]` - default: - // ANSI SQL: Postgres, Snowflake, BigQuery, DuckDB, Oracle, Redshift, etc. - return `"${identifier.replace(/"/g, '""')}"` - } -} - /** * Build a DATE_TRUNC expression appropriate for the warehouse dialect. */ @@ -449,6 +508,12 @@ function dateTruncExpr(granularity: string, column: string, dialect: string): st } return `TRUNC(${column}, '${oracleFmt[g] ?? g.toUpperCase()}')` } + case "sqlserver": + case "mssql": + case "tsql": + case "fabric": + // SQL Server 2022+ / Fabric: DATETRUNC expects unquoted datepart keyword + return `DATETRUNC(${g.toUpperCase()}, ${column})` default: // Postgres, Snowflake, Redshift, DuckDB, etc. return `DATE_TRUNC('${g}', ${column})` @@ -526,18 +591,39 @@ function buildPartitionWhereClause( // date mode const expr = dateTruncExpr(granularity!, quotedCol, dialect) + // Normalize the partition value to ISO yyyy-mm-dd. The mssql driver returns + // date columns as JS Date objects which get `String()`-coerced upstream, + // producing output like "Mon Jan 01 2024 00:00:00 GMT+0000 (UTC)" — + // T-SQL `CONVERT(DATE, …, 23)` and Postgres date literals both reject that + // format. Parsing once here keeps the downstream SQL dialect-safe. + const isoDate = (() => { + const trimmed = partitionValue.trim() + // Already looks like yyyy-mm-dd — preserve as-is so pre-formatted values + // (e.g. from Postgres, BigQuery, DATE_FORMAT MySQL output) flow through + // without surprising timezone shifts. + if (/^\d{4}-\d{2}-\d{2}(\s|T|$)/.test(trimmed)) return trimmed.slice(0, 10) + const d = new Date(trimmed) + return Number.isNaN(d.getTime()) ? trimmed : d.toISOString().slice(0, 10) + })() + const escaped = isoDate.replace(/'/g, "''") // Cast the literal appropriately per dialect switch (dialect) { case "bigquery": - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` case "clickhouse": - return `${expr} = toDate('${partitionValue}')` + return `${expr} = toDate('${escaped}')` case "mysql": case "mariadb": - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` + case "sqlserver": + case "mssql": + case "tsql": + case "fabric": + // Style 23 = ISO-8601 (yyyy-mm-dd), locale-safe + return `${expr} = CONVERT(DATE, '${escaped}', 23)` default: - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` } } @@ -623,15 +709,20 @@ async function runPartitionedDiff(params: DataDiffParams): Promise { if (warehouse) { const cfg = Registry.getConfig(warehouse) - return cfg?.type ?? "generic" + return warehouseTypeToDialect(cfg?.type ?? "generic") } const warehouses = Registry.list().warehouses - return warehouses[0]?.type ?? "generic" + return warehouseTypeToDialect(warehouses[0]?.type ?? "generic") } const sourceDialect = resolveDialect(params.source_warehouse) const targetDialect = resolveDialect(params.target_warehouse ?? params.source_warehouse) - const { table1Name, table2Name } = resolveTableSources(params.source, params.target) + const { table1Name, table2Name } = resolveTableSources( + params.source, + params.target, + sourceDialect, + targetDialect, + ) // Discover partition values from BOTH source and target to catch target-only partitions. // Without this, rows that exist only in target partitions are silently missed. @@ -685,19 +776,46 @@ async function runPartitionedDiff(params: DataDiffParams): Promise + name.split(".").map((p) => quoteIdentForDialect(p, dialect)).join(".") + const sourceTableRef = quoteTableRefForDialect(params.source, sourceDialect) + const targetTableRef = quoteTableRefForDialect(params.target, targetDialect) + for (const pVal of partitionValues) { - const partWhere = buildPartitionWhereClause( + // Build per-side partition WHERE clauses. The dialects can differ + // (cross-warehouse diff) — the engine applies `where_clause` to both + // sides identically, so we can't use it to carry dialect-specific syntax. + // Bake each side's WHERE into its own subquery-wrapped SQL source instead. + const sourcePartWhere = buildPartitionWhereClause( params.partition_column!, pVal, params.partition_granularity, params.partition_bucket_size, sourceDialect, ) - const fullWhere = params.where_clause ? `(${params.where_clause}) AND (${partWhere})` : partWhere + const targetPartWhere = buildPartitionWhereClause( + params.partition_column!, + pVal, + params.partition_granularity, + params.partition_bucket_size, + targetDialect, + ) + + // Wrap each side's table as a SELECT subquery filtered to this partition. + // The recursive runDataDiff below will detect these as SQL queries and + // route them through the CTE-injection path, which is already side-aware. + const sourceSql = `SELECT * FROM ${sourceTableRef} WHERE ${sourcePartWhere}` + const targetSql = `SELECT * FROM ${targetTableRef} WHERE ${targetPartWhere}` const result = await runDataDiff({ ...params, - where_clause: fullWhere, + source: sourceSql, + target: targetSql, + // Preserve the user's shared where_clause — it's dialect-neutral. + where_clause: params.where_clause, partition_column: undefined, // prevent recursion }) @@ -745,36 +863,71 @@ export async function runDataDiff(params: DataDiffParams): Promise { - const parts = name.split(".") - if (parts.length === 3) return { database: parts[0], schema: parts[1], table: parts[2] } - if (parts.length === 2) return { schema: parts[0], table: parts[1] } - return { table: name } + // Resolve warehouse identity (fall back to the default warehouse when the + // caller omits a side). Returns the canonical warehouse name so we can detect + // cross-warehouse mode even when both sides share a dialect (e.g. two + // independent MSSQL instances). + const resolveWarehouseName = (warehouse: string | undefined): string | undefined => { + if (warehouse) return warehouse + const warehouses = Registry.list().warehouses + return warehouses[0]?.name } - const table1Ref = parseQualified(table1Name) - const table2Ref = parseQualified(table2Name) - // Resolve dialect from warehouse config const resolveDialect = (warehouse: string | undefined): string => { if (warehouse) { const cfg = Registry.getConfig(warehouse) - return cfg?.type ?? "generic" + return warehouseTypeToDialect(cfg?.type ?? "generic") } const warehouses = Registry.list().warehouses - return warehouses[0]?.type ?? "generic" + return warehouseTypeToDialect(warehouses[0]?.type ?? "generic") } + const resolvedSource = resolveWarehouseName(params.source_warehouse) + const resolvedTarget = resolveWarehouseName(params.target_warehouse ?? params.source_warehouse) + const dialect1 = resolveDialect(params.source_warehouse) const dialect2 = resolveDialect(params.target_warehouse ?? params.source_warehouse) + // Cross-warehouse mode requires side-specific CTE injection: T-SQL / Fabric + // parse-bind every CTE body even when unreferenced, so sending the combined + // prefix to a warehouse that lacks the other side's base table fails at parse. + // Gate on resolved warehouse identity, not dialect — two independent + // same-dialect warehouses (e.g. two MSSQL instances) still can't resolve each + // other's base tables. Identity comparison after resolving the default + // warehouse avoids misclassifying `source_warehouse=undefined` vs the + // explicit default-warehouse name as different. + const crossWarehouse = resolvedSource !== resolvedTarget + + // Explicit JoinDiff cannot work across warehouses: it emits one FULL OUTER + // JOIN task referencing both CTE aliases, but side-aware injection only + // defines one side per task — the other alias would be unresolved. Guard + // early so users get a clear error instead of an obscure SQL parse failure. + if (params.algorithm === "joindiff" && crossWarehouse) { + return { + success: false, + steps: 0, + error: + "joindiff requires both tables in the same warehouse; use hashdiff or auto for cross-warehouse comparisons.", + } + } + + // Resolve sources (plain table names vs arbitrary queries). Pass dialects so + // plain-table names inside wrapped CTEs get side-native bracket/quote style. + const { table1Name, table2Name, ctePrefix, sourceCtePrefix, targetCtePrefix } = + resolveTableSources(params.source, params.target, dialect1, dialect2) + + // Parse optional qualified names: "db.schema.table" → { database, schema, table } + const parseQualified = (name: string) => { + const parts = name.split(".") + if (parts.length === 3) return { database: parts[0], schema: parts[1], table: parts[2] } + if (parts.length === 2) return { schema: parts[0], table: parts[1] } + return { table: name } + } + + const table1Ref = parseQualified(table1Name) + const table2Ref = parseQualified(table2Name) + // Auto-discover extra_columns when not explicitly provided. // The Rust engine only compares columns listed in extra_columns — if the list is // empty, it compares key existence only and reports all matched rows as "identical" @@ -873,8 +1026,24 @@ export async function runDataDiff(params: DataDiffParams): Promise { const warehouse = warehouseFor(task.table_side) - // Inject CTE definitions if we're in query-comparison mode - const sql = ctePrefix ? injectCte(task.sql, ctePrefix) : task.sql + // Inject CTE definitions if we're in query-comparison mode. In + // cross-warehouse mode each task only gets the CTE for its own side — + // the other side's base tables aren't bindable on this warehouse. + let prefix: string | null = null + if (ctePrefix) { + if (crossWarehouse) { + prefix = task.table_side === "Table2" ? targetCtePrefix : sourceCtePrefix + } else { + if (task.table_side === "Table1") { + prefix = sourceCtePrefix + } else if (task.table_side === "Table2") { + prefix = targetCtePrefix + } else { + prefix = ctePrefix + } + } + } + const sql = prefix ? injectCte(task.sql, prefix) : task.sql try { const rows = await executeQuery(sql, warehouse) return { id: task.id, rows, error: null } diff --git a/packages/opencode/src/altimate/native/connections/registry.ts b/packages/opencode/src/altimate/native/connections/registry.ts index 617d6685d..cc871682c 100644 --- a/packages/opencode/src/altimate/native/connections/registry.ts +++ b/packages/opencode/src/altimate/native/connections/registry.ts @@ -122,6 +122,7 @@ const DRIVER_MAP: Record = { mariadb: "@altimateai/drivers/mysql", sqlserver: "@altimateai/drivers/sqlserver", mssql: "@altimateai/drivers/sqlserver", + fabric: "@altimateai/drivers/sqlserver", databricks: "@altimateai/drivers/databricks", duckdb: "@altimateai/drivers/duckdb", oracle: "@altimateai/drivers/oracle", @@ -165,6 +166,7 @@ async function createConnector(name: string, config: ConnectionConfig): Promise< "mariadb", "sqlserver", "mssql", + "fabric", "oracle", "snowflake", "clickhouse", diff --git a/packages/opencode/src/altimate/tools/data-diff.ts b/packages/opencode/src/altimate/tools/data-diff.ts index bf9948748..163cbb8bf 100644 --- a/packages/opencode/src/altimate/tools/data-diff.ts +++ b/packages/opencode/src/altimate/tools/data-diff.ts @@ -203,7 +203,11 @@ function formatOutcome(outcome: any, source: string, target: string): string { lines.push(` Sample differences (first ${Math.min(diffRows.length, 5)}):`) for (const d of diffRows.slice(0, 5)) { const label = d.sign === "-" ? "source only" : "target only" - lines.push(` [${label}] ${d.values?.join(" | ")}`) + // `d.values?.join(" | ") ?? "(no values)"` misses the common case where + // `values` is an empty array — `[].join(" | ")` returns "" (not null), + // so the coalesce never triggers. Gate on length explicitly. + const values = d.values?.length ? d.values.join(" | ") : "(no values)" + lines.push(` [${label}] ${values}`) } } diff --git a/packages/opencode/test/altimate/connections.test.ts b/packages/opencode/test/altimate/connections.test.ts index f741a8cf1..5c9680297 100644 --- a/packages/opencode/test/altimate/connections.test.ts +++ b/packages/opencode/test/altimate/connections.test.ts @@ -81,6 +81,23 @@ describe("ConnectionRegistry", () => { await expect(Registry.get("mydb")).rejects.toThrow("Supported:") }) + test("fabric type is recognized in DRIVER_MAP and routes to sqlserver driver", () => { + Registry.setConfigs({ + fabricdb: { + type: "fabric", + host: "myserver.datawarehouse.fabric.microsoft.com", + database: "migration", + authentication: "default", + }, + }) + const config = Registry.getConfig("fabricdb") + expect(config).toBeDefined() + expect(config?.type).toBe("fabric") + const result = Registry.list() + expect(result.warehouses).toHaveLength(1) + expect(result.warehouses[0].type).toBe("fabric") + }) + test("getConfig returns config for known connection", () => { Registry.setConfigs({ mydb: { type: "postgres", host: "localhost" }, diff --git a/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts b/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts new file mode 100644 index 000000000..f00013f5f --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts @@ -0,0 +1,168 @@ +/** + * Tests for cross-dialect partitioned diff and joindiff cross-warehouse guard. + * + * These cover the two CRITICAL/MAJOR bugs fixed in the review follow-up: + * 1. Partitioned WHERE was built with sourceDialect only and applied to both + * warehouses; cross-dialect diffs blew up the target with foreign syntax. + * 2. Explicit `algorithm: "joindiff"` with different warehouses silently + * produced SQL referencing an undefined CTE alias. + * + * Both fixes live purely in the TS orchestrator (`runDataDiff` / + * `runPartitionedDiff`). The Rust engine is mocked so these tests run without + * the NAPI binary. + */ +import { describe, test, expect, mock, beforeEach } from "bun:test" + +// --- Mock NAPI so tests don't require the native binary --- + +let lastSpec: any = null +const fakeStartAction = JSON.stringify({ + type: "ExecuteSql", + tasks: [ + { id: "fp1_1", table_side: "Table1", sql: "SELECT COUNT(*) FROM [__diff_source]", expected_shape: "SingleRow" }, + { id: "fp2_2", table_side: "Table2", sql: "SELECT COUNT(*) FROM [__diff_target]", expected_shape: "SingleRow" }, + ], +}) + +mock.module("@altimateai/altimate-core", () => ({ + DataParitySession: class { + constructor(specJson: string) { lastSpec = JSON.parse(specJson) } + start() { return fakeStartAction } + step(_responses: string) { + return JSON.stringify({ + type: "Done", + outcome: { + mode: "diff", + diff_rows: [], + stats: { rows_table1: 0, rows_table2: 0, exclusive_table1: 0, exclusive_table2: 0, updated: 0, unchanged: 0 }, + }, + }) + } + }, +})) + +// --- Mock the Registry module itself so tests can inject fake connectors. +// The real Registry's `get` creates connectors via dynamic driver import; we +// replace the whole surface here with configurable in-memory state. --- + +type Rows = (string | null)[][] +const sqlLog: Array<{ warehouse: string; sql: string }> = [] +const fakeConfigs = new Map() + +function makeFakeConnector(warehouseName: string, discoveryRows: Rows = [["2026-04-01"]]) { + return { + connect: async () => {}, + close: async () => {}, + execute: async (sql: string) => { + sqlLog.push({ warehouse: warehouseName, sql }) + if (sql.includes("SELECT DISTINCT")) { + return { columns: ["_p"], rows: discoveryRows, row_count: discoveryRows.length, truncated: false } + } + return { columns: ["c", "h"], rows: [["0", "0"]], row_count: 1, truncated: false } + }, + listSchemas: async () => [], + listTables: async () => [], + describeTable: async () => [], + } +} + +mock.module("../../src/altimate/native/connections/registry", () => ({ + list: () => ({ + warehouses: Array.from(fakeConfigs.entries()).map(([name, cfg]) => ({ name, type: cfg.type })), + }), + getConfig: (name: string) => fakeConfigs.get(name), + setConfigs: (configs: Record) => { + fakeConfigs.clear() + for (const [k, v] of Object.entries(configs)) fakeConfigs.set(k, v as any) + }, + get: async (name: string) => makeFakeConnector(name), + add: async () => ({ success: true, name: "x", type: "x" }), + remove: async () => ({ success: true, name: "x" }), + test: async () => ({ success: true, name: "x", status: "connected" }), +})) + +// Import after mocks are wired +const Registry = await import("../../src/altimate/native/connections/registry") +const { runDataDiff } = await import("../../src/altimate/native/connections/data-diff") + +beforeEach(() => { + sqlLog.length = 0 + lastSpec = null +}) + +describe("cross-warehouse joindiff guard", () => { + test("returns early error when joindiff + cross-warehouse", async () => { + Registry.setConfigs({ + src: { type: "sqlserver", host: "s1", database: "d" }, + tgt: { type: "postgres", host: "s2", database: "d" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "public.orders", + key_columns: ["id"], + source_warehouse: "src", + target_warehouse: "tgt", + algorithm: "joindiff", + }) + expect(result.success).toBe(false) + expect(result.error).toMatch(/joindiff requires both tables in the same warehouse/i) + expect(result.steps).toBe(0) + // Nothing should have been sent to the warehouses + expect(sqlLog.length).toBe(0) + }) + + test("same-warehouse joindiff is allowed", async () => { + Registry.setConfigs({ + shared: { type: "sqlserver", host: "s", database: "d" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "dbo.orders_v2", + key_columns: ["id"], + source_warehouse: "shared", + target_warehouse: "shared", + algorithm: "joindiff", + }) + expect(result.success).toBe(true) + }) +}) + +describe("cross-dialect partitioned diff", () => { + test("source and target receive their own dialect's partition WHERE", async () => { + Registry.setConfigs({ + msrc: { type: "sqlserver", host: "mssql-host", database: "src" }, + ptgt: { type: "postgres", host: "pg-host", database: "tgt" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "public.orders", + key_columns: ["id"], + source_warehouse: "msrc", + target_warehouse: "ptgt", + partition_column: "order_date", + partition_granularity: "month", + algorithm: "hashdiff", + }) + expect(result.success).toBe(true) + + // Gather SQL by warehouse + const msrcSql = sqlLog.filter((x) => x.warehouse === "msrc").map((x) => x.sql).join("\n") + const ptgtSql = sqlLog.filter((x) => x.warehouse === "ptgt").map((x) => x.sql).join("\n") + + // Source (MSSQL) must see T-SQL syntax: DATETRUNC + CONVERT(DATE, ..., 23) + [brackets] + expect(msrcSql).toMatch(/DATETRUNC\(MONTH,\s*\[order_date\]\)/i) + expect(msrcSql).toMatch(/CONVERT\(DATE, '2026-04-01', 23\)/i) + // Source must NOT see Postgres syntax + expect(msrcSql).not.toMatch(/DATE_TRUNC\('month'/i) + // Source must never see the Postgres table reference + expect(msrcSql).not.toContain('"public"."orders"') + + // Target (Postgres) must see DATE_TRUNC + ANSI-quoted identifiers + expect(ptgtSql).toMatch(/DATE_TRUNC\('month',\s*"order_date"\)/i) + // Target must NOT see T-SQL syntax + expect(ptgtSql).not.toMatch(/DATETRUNC/i) + expect(ptgtSql).not.toMatch(/CONVERT\(DATE/i) + // Target must never see the MSSQL bracketed reference + expect(ptgtSql).not.toContain("[dbo].[orders]") + }) +}) diff --git a/packages/opencode/test/altimate/data-diff-cte.test.ts b/packages/opencode/test/altimate/data-diff-cte.test.ts new file mode 100644 index 000000000..aea08f27c --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-cte.test.ts @@ -0,0 +1,161 @@ +/** + * Tests for CTE wrapping and injection in SQL-query mode. + * + * The tricky case is cross-warehouse comparison where source and target are both + * SQL queries referencing tables that only exist on their own side. The combined + * CTE prefix cannot be sent to both warehouses because T-SQL / Fabric parse-bind + * every CTE body even when unreferenced — the "other side" CTE would fail to + * resolve its base table. + */ +import { describe, test, expect } from "bun:test" + +import { resolveTableSources, injectCte } from "../../src/altimate/native/connections/data-diff" + +describe("resolveTableSources", () => { + test("plain table names pass through without wrapping", () => { + const r = resolveTableSources("orders", "orders_v2") + expect(r.table1Name).toBe("orders") + expect(r.table2Name).toBe("orders_v2") + expect(r.ctePrefix).toBeNull() + expect(r.sourceCtePrefix).toBeNull() + expect(r.targetCtePrefix).toBeNull() + }) + + test("schema-qualified plain names pass through", () => { + const r = resolveTableSources("gold.dim_customer", "TRANSFORMED.DimCustomer") + expect(r.table1Name).toBe("gold.dim_customer") + expect(r.table2Name).toBe("TRANSFORMED.DimCustomer") + expect(r.ctePrefix).toBeNull() + }) + + test("both queries are wrapped in CTEs with aliases", () => { + const r = resolveTableSources( + "SELECT id, val FROM [TRANSFORMED].[DimCustomer]", + "SELECT id, val FROM [gold].[dim_customer]", + ) + expect(r.table1Name).toBe("__diff_source") + expect(r.table2Name).toBe("__diff_target") + expect(r.ctePrefix).toContain("__diff_source AS (") + expect(r.ctePrefix).toContain("__diff_target AS (") + expect(r.ctePrefix).toContain("[TRANSFORMED].[DimCustomer]") + expect(r.ctePrefix).toContain("[gold].[dim_customer]") + }) + + test("side-specific prefixes contain only the relevant CTE", () => { + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + // Source prefix has source table only — must not leak target table ref + expect(r.sourceCtePrefix).toContain("__diff_source AS (") + expect(r.sourceCtePrefix).toContain("[TRANSFORMED].[DimCustomer]") + expect(r.sourceCtePrefix).not.toContain("__diff_target") + expect(r.sourceCtePrefix).not.toContain("[gold].[dim_customer]") + + // Target prefix has target table only — must not leak source table ref + expect(r.targetCtePrefix).toContain("__diff_target AS (") + expect(r.targetCtePrefix).toContain("[gold].[dim_customer]") + expect(r.targetCtePrefix).not.toContain("__diff_source") + expect(r.targetCtePrefix).not.toContain("[TRANSFORMED].[DimCustomer]") + }) + + test("mixed: plain source + query target still wraps both sides", () => { + const r = resolveTableSources( + "orders", + "SELECT * FROM other.orders WHERE region = 'EU'", + ) + expect(r.table1Name).toBe("__diff_source") + expect(r.table2Name).toBe("__diff_target") + // Plain table wrapped with ANSI double-quoted identifiers + expect(r.sourceCtePrefix).toContain('SELECT * FROM "orders"') + expect(r.targetCtePrefix).toContain("other.orders") + }) + + test("dialect-aware quoting: tsql uses square brackets", () => { + // Fix #4: plain table names wrapped inside CTEs must use the side's + // native quoting. `"schema"."table"` fails on MSSQL with QUOTED_IDENTIFIER OFF. + const r = resolveTableSources( + "dbo.orders", + "SELECT * FROM base", + "tsql", + "postgres", + ) + expect(r.sourceCtePrefix).toContain("[dbo].[orders]") + expect(r.sourceCtePrefix).not.toContain('"dbo"."orders"') + }) + + test("dialect-aware quoting: fabric uses square brackets; mysql uses backticks", () => { + // Pair the plain-table side with a SQL-query counterpart to force CTE wrapping. + const fabric = resolveTableSources( + "gold.dim_customer", + "SELECT * FROM other", + "fabric", + "fabric", + ) + expect(fabric.sourceCtePrefix).toContain("[gold].[dim_customer]") + + const mysql = resolveTableSources( + "SELECT 1 AS id", + "db.orders", + "mysql", + "mysql", + ) + expect(mysql.targetCtePrefix).toContain("`db`.`orders`") + }) + + test("query detection requires both keyword AND whitespace", () => { + // A table literally named "select" should NOT be treated as a query + const r = resolveTableSources("select", "with") + expect(r.table1Name).toBe("select") + expect(r.table2Name).toBe("with") + expect(r.ctePrefix).toBeNull() + }) +}) + +describe("injectCte", () => { + test("prepends CTE prefix to a plain SELECT", () => { + const prefix = "WITH __diff_source AS (\nSELECT 1 AS id\n)" + const sql = "SELECT COUNT(*) FROM __diff_source" + const out = injectCte(sql, prefix) + expect(out.startsWith(prefix)).toBe(true) + expect(out).toContain("SELECT COUNT(*) FROM __diff_source") + }) + + test("merges with an engine-emitted WITH clause", () => { + const prefix = "WITH __diff_source AS (\nSELECT * FROM base\n)" + const engineSql = "WITH engine_cte AS (SELECT id FROM __diff_source) SELECT * FROM engine_cte" + const out = injectCte(engineSql, prefix) + // Must start with a single WITH, with our CTE first, then engine's + expect(out.match(/^WITH /)).not.toBeNull() + expect((out.match(/\bWITH\b/g) ?? []).length).toBe(1) + expect(out.indexOf("__diff_source AS")).toBeLessThan(out.indexOf("engine_cte AS")) + }) + + test("side-specific injection: source prefix does not leak target refs", () => { + // Simulates cross-warehouse fp1_1 task going to MSSQL. It must not see any + // reference to the Fabric-only target table, since MSSQL parse-binds every + // CTE body. + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + const engineFp1Sql = + "SELECT COUNT(*), SUM(CAST(...HASHBYTES('MD5', CONCAT(CAST([id] AS NVARCHAR(MAX))))...)) FROM [__diff_source]" + const sqlForMssql = injectCte(engineFp1Sql, r.sourceCtePrefix!) + expect(sqlForMssql).toContain("[TRANSFORMED].[DimCustomer]") + expect(sqlForMssql).not.toContain("[gold].[dim_customer]") + expect(sqlForMssql).not.toContain("__diff_target") + }) + + test("side-specific injection: target prefix does not leak source refs", () => { + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + const engineFp2Sql = "SELECT COUNT(*) FROM [__diff_target]" + const sqlForFabric = injectCte(engineFp2Sql, r.targetCtePrefix!) + expect(sqlForFabric).toContain("[gold].[dim_customer]") + expect(sqlForFabric).not.toContain("[TRANSFORMED].[DimCustomer]") + expect(sqlForFabric).not.toContain("__diff_source") + }) +}) diff --git a/packages/opencode/test/altimate/data-diff-dialect.test.ts b/packages/opencode/test/altimate/data-diff-dialect.test.ts new file mode 100644 index 000000000..083c64d57 --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-dialect.test.ts @@ -0,0 +1,55 @@ +/** + * Tests for warehouse-type-to-dialect mapping in the data-diff orchestrator. + * + * The Rust engine's SqlDialect serde deserialization only accepts exact lowercase + * variant names (e.g., "tsql", not "sqlserver"). This mapping bridges the gap + * between warehouse config types and Rust dialect names. + */ +import { describe, test, expect } from "bun:test" + +import { warehouseTypeToDialect } from "../../src/altimate/native/connections/data-diff" + +describe("warehouseTypeToDialect", () => { + // --- Remapped types --- + + test("maps sqlserver to tsql", () => { + expect(warehouseTypeToDialect("sqlserver")).toBe("tsql") + }) + + test("maps mssql to tsql", () => { + expect(warehouseTypeToDialect("mssql")).toBe("tsql") + }) + + test("maps fabric to fabric", () => { + expect(warehouseTypeToDialect("fabric")).toBe("fabric") + }) + + test("maps postgresql to postgres", () => { + expect(warehouseTypeToDialect("postgresql")).toBe("postgres") + }) + + test("maps mariadb to mysql", () => { + expect(warehouseTypeToDialect("mariadb")).toBe("mysql") + }) + + // --- Passthrough types (already match Rust names) --- + + test("passes through postgres unchanged", () => { + expect(warehouseTypeToDialect("postgres")).toBe("postgres") + }) + + test("passes through snowflake unchanged", () => { + expect(warehouseTypeToDialect("snowflake")).toBe("snowflake") + }) + + test("passes through generic unchanged", () => { + expect(warehouseTypeToDialect("generic")).toBe("generic") + }) + + // --- Case insensitivity --- + + test("handles uppercase input", () => { + expect(warehouseTypeToDialect("SQLSERVER")).toBe("tsql") + expect(warehouseTypeToDialect("PostgreSQL")).toBe("postgres") + }) +}) diff --git a/packages/opencode/test/altimate/driver-normalize.test.ts b/packages/opencode/test/altimate/driver-normalize.test.ts index 95f348289..43b31c4e8 100644 --- a/packages/opencode/test/altimate/driver-normalize.test.ts +++ b/packages/opencode/test/altimate/driver-normalize.test.ts @@ -463,6 +463,19 @@ describe("normalizeConfig — SQL Server", () => { expect(result.host).toBe("myserver") expect(result.user).toBe("sa") }) + + test("fabric type uses SQLSERVER_ALIASES", () => { + const result = normalizeConfig({ + type: "fabric", + server: "myserver.datawarehouse.fabric.microsoft.com", + trustServerCertificate: false, + authentication: "default", + }) + expect(result.host).toBe("myserver.datawarehouse.fabric.microsoft.com") + expect(result.server).toBeUndefined() + expect(result.trust_server_certificate).toBe(false) + expect(result.trustServerCertificate).toBeUndefined() + }) }) // ---------------------------------------------------------------------------