diff --git a/src/queryKinds/select.ts b/src/queryKinds/select.ts index fcdb098..c6ff83f 100644 --- a/src/queryKinds/select.ts +++ b/src/queryKinds/select.ts @@ -183,11 +183,9 @@ export default class SelectQuery extends QueryDefinition { */ public select(fields: string | string[]): this { if (Array.isArray(fields)) { - this.selectFields = - SqlEscaper.escapeSelectIdentifiers(fields, this.flavor); + this.rawSelect(SqlEscaper.escapeSelectIdentifiers(fields, this.flavor)); } else { - this.selectFields = - SqlEscaper.escapeSelectIdentifiers([fields], this.flavor); + this.rawSelect(SqlEscaper.escapeSelectIdentifiers([fields], this.flavor)); } return this; } @@ -230,13 +228,9 @@ export default class SelectQuery extends QueryDefinition { */ public addSelect(fields: string | string[]): this { if (Array.isArray(fields)) { - const escaped = - SqlEscaper.escapeSelectIdentifiers(fields, this.flavor); - this.selectFields.push(...escaped); + this.addRawSelect(SqlEscaper.escapeSelectIdentifiers(fields, this.flavor)); } else { - const escaped = - SqlEscaper.escapeSelectIdentifiers([fields], this.flavor); - this.selectFields.push(...escaped); + this.addRawSelect(SqlEscaper.escapeSelectIdentifiers([fields], this.flavor)); } return this; } diff --git a/src/queryKinds/union.test.ts b/src/queryKinds/union.test.ts index 32cc01e..34f6db4 100644 --- a/src/queryKinds/union.test.ts +++ b/src/queryKinds/union.test.ts @@ -385,4 +385,81 @@ describe("Union Query", () => { expect(rebuiltOriginal).toEqual(built); }); + it('should support custom select clause', () => { + const select1 = Query.select + .from("table1") + .select(["column1", "column2"]) + .where("column1 = ?", "value1"); + + const select2 = Query.select + .from("table2") + .select(["column1", "column2"]) + .where("column2 = ?", "value2"); + + const unionQuery = new Union() + .addMany([ + { + query: select1, + type: 'union' + }, + { + query: select2, + type: 'union all' + } + ]) + .as('union_table') + .select('column1') + .addSelect('column2') + .addRawSelect('COUNT(column2) AS count_column2') + .groupBy('column1') + .build(); + + expect(unionQuery.text).toBe('SELECT\n "column1",\n "column2",\n COUNT(column2) AS count_column2\n FROM (\n (SELECT\n "column1",\n "column2"\n FROM "table1"\n WHERE (column1 = $1))\n\n UNION ALL\n\n (SELECT\n "column1",\n "column2"\n FROM "table2"\n WHERE (column2 = $2))\n) AS union_table\nGROUP BY "column1"'); + expect(unionQuery.values).toEqual(['value1', 'value2']); + }); + + it('should support raw select clause', () => { + const select1 = Query.select + .from("table1") + .select(["column1", "column2"]) + .where("column1 = ?", "value1"); + + const select2 = Query.select + .from("table2") + .select(["column1", "column2"]) + .where("column2 = ?", "value2"); + + const unionQuery = new Union() + .addMany([ + { + query: select1, + type: 'union' + }, + { + query: select2, + type: 'union all' + } + ]) + .as('union_table') + .rawSelect('column1') + .addSelect([ + 'column2', + ]) + .rawSelect([ + 'column1', + 'column2', + 'COUNT(column2) AS count_column2' + ]) + .addRawSelect('COUNT(column2) AS count_column2') + .select([ + 'column1', + 'column2' + ]) + .groupBy('column1') + .build(); + + expect(unionQuery.text).toBe('SELECT\n "column1",\n "column2"\n FROM (\n (SELECT\n "column1",\n "column2"\n FROM "table1"\n WHERE (column1 = $1))\n\n UNION ALL\n\n (SELECT\n "column1",\n "column2"\n FROM "table2"\n WHERE (column2 = $2))\n) AS union_table\nGROUP BY "column1"'); + expect(unionQuery.values).toEqual(['value1', 'value2']); + }); + }); diff --git a/src/queryKinds/union.ts b/src/queryKinds/union.ts index 9bcb85e..a10cb3f 100644 --- a/src/queryKinds/union.ts +++ b/src/queryKinds/union.ts @@ -3,6 +3,7 @@ import QueryKind from "../types/QueryKind.js"; import OrderBy from "../types/OrderBy.js"; import QueryDefinition from "./query.js"; import SelectQuery from "./select.js"; +import SqlEscaper from "../sqlEscaper.js"; /** Allowed types for UnionType */ export const UnionTypes = { @@ -42,6 +43,9 @@ export type SelectQueryWithUnionType = { * and can optionally assign an alias to the resulting union query. */ export default class Union extends QueryDefinition { + + private selectFields: string[] = []; + /** Needed alias for the union query */ private unionAlias: string | null = null; @@ -98,6 +102,67 @@ export default class Union extends QueryDefinition { }; } + /** + * Specifies the fields to select in the union query. + * If not called, defaults to selecting all fields ('*'). + * @param fields A single field name or an array of field names to select. + * @returns The current Union instance for method chaining. + */ + public select(fields: string | string[]): Union { + if (Array.isArray(fields)) { + this.rawSelect(SqlEscaper.escapeSelectIdentifiers(fields, this.flavor)); + } else { + this.rawSelect(SqlEscaper.escapeSelectIdentifiers([fields], this.flavor)); + } + return this; + } + + /** + * Adds fields to the existing selection in the union query. + * If no fields have been selected yet, this behaves like the select() method. + * @param fields A single field name or an array of field names to add to the selection. + * @returns The current Union instance for method chaining. + */ + public addSelect(fields: string | string[]): Union { + if (Array.isArray(fields)) { + this.addRawSelect(SqlEscaper.escapeSelectIdentifiers(fields, this.flavor)); + } else { + this.addRawSelect(SqlEscaper.escapeSelectIdentifiers([fields], this.flavor)); + } + return this; + } + + /** + * Specifies raw fields to select in the union query without any escaping. + * Use this method with caution, as it does not perform any SQL injection protection. + * @param fields A single raw field string or an array of raw field strings to select. + * @returns The current Union instance for method chaining. + */ + public rawSelect(fields: string | string[]): Union { + if (Array.isArray(fields)) { + this.selectFields = fields; + } else { + this.selectFields = [fields]; + } + return this; + } + + /** + * Adds raw fields to the existing selection in the union query without any escaping. + * Use this method with caution, as it does not perform any SQL injection protection. + * If no fields have been selected yet, this behaves like the rawSelect() method. + * @param fields A single raw field string or an array of raw field strings to add to the selection. + * @returns The current Union instance for method chaining. + */ + public addRawSelect(fields: string | string[]): Union { + if (Array.isArray(fields)) { + this.selectFields.push(...fields); + } else { + this.selectFields.push(fields); + } + return this; + } + /** * Assigns an alias to the resulting union query. * @param alias The alias to assign to the union query. @@ -317,6 +382,13 @@ export default class Union extends QueryDefinition { let unionItself: string = ''; const values: any[] = []; + let selectClause = ''; + if (this.selectFields.length > 0) { + selectClause = this.selectFields.join(',\n '); + } else { + selectClause = '*'; + } + // Add offset on each select query to ensure correct parameter indexing let paramOffset = 1; for (const { query, type: unionType } of this.selectQueries) { @@ -375,8 +447,11 @@ export default class Union extends QueryDefinition { offsetClause = `OFFSET ${this.offsetCount}`; } + const firstLine = + `SELECT${selectClause.length > 1 ? '\n' : ''} ${selectClause}${selectClause.length > 1 ? '\n' : ''} FROM (`; + const union = [ - 'SELECT * FROM (', + firstLine, `${unionItself}\n) AS ${this.unionAlias || 'union_subquery'}`, whereClause, groupByClause, @@ -436,6 +511,8 @@ export default class Union extends QueryDefinition { */ public clone(): Union { const newUnion = new Union(); + newUnion.selectFields = [...this.selectFields]; + newUnion.flavor = this.flavor; newUnion.unionAlias = this.unionAlias; newUnion.limitCount = this.limitCount; newUnion.offsetCount = this.offsetCount; @@ -459,6 +536,7 @@ export default class Union extends QueryDefinition { * After calling this method, the Union instance will be in its initial state. */ public reset(): void { + this.selectFields = []; this.unionAlias = null; this.limitCount = null; this.offsetCount = null;