Skip to content

Commit

Permalink
Add support for mod operation on PK (#1601) / all combinations with E…
Browse files Browse the repository at this point in the history
…ntityID Expressions covered
  • Loading branch information
Tapac committed Nov 14, 2022
1 parent b1b6cf7 commit d5d817f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 57 deletions.
57 changes: 39 additions & 18 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.exposed.sql

import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.sql.SqlExpressionBuilder.wrap
import org.jetbrains.exposed.sql.vendors.*
import java.math.BigDecimal

Expand Down Expand Up @@ -314,31 +315,51 @@ class DivideOp<T, S : T>(
/**
* Represents an SQL operator that calculates the remainder of dividing [expr1] by [expr2].
*/
class ModOp<T : Number?, S : Number?>(
class ModOp<T : Number?, S : Number?, R : Number?>(
/** Returns the left-hand side operand. */
val expr1: Expression<T>,
/** Returns the right-hand side operand. */
val expr2: Expression<S>,
override val columnType: IColumnType
) : ExpressionWithColumnType<T>() {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = dbModOp(queryBuilder, expr1, expr2)
}
) : ExpressionWithColumnType<R>() {

class ModOpEntityID<T, S : Number?, K : EntityID<T>>(
/** Returns the left-hand side operand. */
val expr1: Expression<K>,
/** Returns the right-hand side operand. */
val expr2: Expression<S>,
override val columnType: IColumnType
) : ExpressionWithColumnType<K>() where T : Comparable<T>, T : Number? {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = dbModOp(queryBuilder, expr1, expr2)
}
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
when (currentDialectIfAvailable) {
is OracleDialect -> append("MOD(", expr1, ", ", expr2, ")")
else -> append('(', expr1, " % ", expr2, ')')
}
}
}

companion object {
@Suppress("UNCHECKED_CAST")
private fun <T : Number?, K : EntityID<T>?> originalColumn(expr1: ExpressionWithColumnType<K>): Column<T> {
return (expr1.columnType as EntityIDColumnType<*>).idColumn as Column<T>
}

internal operator fun <T, S : Number, K : EntityID<T>?> invoke(
expr1: ExpressionWithColumnType<K>,
expr2: Expression<S>
): ExpressionWithColumnType<T> where T : Number, T : Comparable<T> {
val column = originalColumn(expr1)
return ModOp(column, expr2, column.columnType)
}

internal operator fun <T, S : Number, K : EntityID<T>?> invoke(
expr1: Expression<S>,
expr2: ExpressionWithColumnType<K>
): ExpressionWithColumnType<T> where T : Number, T : Comparable<T> {
val column = originalColumn(expr2)
return ModOp(expr1, column, column.columnType)
}

private fun dbModOp(queryBuilder: QueryBuilder, expr1: Expression<*>, expr2: Expression<*>) {
queryBuilder {
when (currentDialectIfAvailable) {
is OracleDialect -> append("MOD(", expr1, ", ", expr2, ")")
else -> append('(', expr1, " % ", expr2, ')')
internal operator fun <T, S : Number, K : EntityID<T>?> invoke(
expr1: ExpressionWithColumnType<K>,
expr2: S
): ExpressionWithColumnType<T> where T : Number, T : Comparable<T> {
val column = originalColumn(expr1)
return ModOp(column, column.wrap(expr2), column.columnType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,31 +313,44 @@ interface ISqlExpressionBuilder {
infix operator fun <T, S : T> ExpressionWithColumnType<T>.div(other: Expression<S>): DivideOp<T, S> = DivideOp(this, other, columnType)

/** Calculates the remainder of dividing this expression by the [t] value. */
infix operator fun <T : Number?, S : T> ExpressionWithColumnType<T>.rem(t: S): ModOp<T, S> = ModOp(this, wrap(t), columnType)
infix operator fun <T : Number?, S : T> ExpressionWithColumnType<T>.rem(t: S) = ModOp<T, S, T>(this, wrap(t), columnType)

/** Calculates the remainder of dividing this expression by the [other] expression. */
infix operator fun <T : Number?, S : Number> ExpressionWithColumnType<T>.rem(other: Expression<S>): ModOp<T, S> = ModOp(this, other, columnType)
infix operator fun <T : Number?, S : Number> ExpressionWithColumnType<T>.rem(other: Expression<S>) = ModOp<T, S, T>(this, other, columnType)

/**
* Calculates the remainder of dividing the value of a numeric PK by the [other] number.
*/
infix operator fun <T, S : Number> ExpressionWithColumnType<EntityID<T>>.rem(other: S): ModOpEntityID<T, S, EntityID<T>>
where T : Number?, T : Comparable<T> =
ModOpEntityID(this, wrap(other), this.columnType)
/** Calculates the remainder of dividing the value of [this] numeric PK by the [other] number. */
@JvmName("remWithEntityId")
infix operator fun <T, S : Number, ID : EntityID<T>?> ExpressionWithColumnType<ID>.rem(other: S) where T : Number, T : Comparable<T> =
ModOp(this, other)

/** Calculates the remainder of dividing [this] number expression by [other] numeric PK */
@JvmName("remWithEntityId2")
infix operator fun <T, S : Number, ID : EntityID<T>?> Expression<S>.rem(other: ExpressionWithColumnType<ID>) where T : Number, T : Comparable<T> =
ModOp(this, other)

/** Calculates the remainder of dividing the value of [this] numeric PK by the [other] number expression. */
@JvmName("remWithEntityId3")
infix operator fun <T, S : Number, ID : EntityID<T>?> ExpressionWithColumnType<ID>.rem(other: Expression<S>) where T : Number, T : Comparable<T> =
ModOp(this, other)

/** Calculates the remainder of dividing this expression by the [t] value. */
infix fun <T : Number?, S : T> ExpressionWithColumnType<T>.mod(t: S): ModOp<T, S> = this % t
infix fun <T : Number?, S : T> ExpressionWithColumnType<T>.mod(t: S) = this % t

/** Calculates the remainder of dividing this expression by the [other] expression. */
infix fun <T : Number?, S : Number> ExpressionWithColumnType<T>.mod(other: Expression<S>): ModOp<T, S> =
this % other
infix fun <T : Number?, S : Number> ExpressionWithColumnType<T>.mod(other: Expression<S>) = this % other

/**
* Calculates the remainder of dividing the value of a numeric PK by the [other] number.
*/
infix fun <T, S : Number> ExpressionWithColumnType<EntityID<T>>.mod(other: S): ModOpEntityID<T, S, EntityID<T>>
where T : Number?, T : Comparable<T> =
ModOpEntityID(this, wrap(other), this.columnType)
/** Calculates the remainder of dividing the value of [this] numeric PK by the [other] number. */
@JvmName("modWithEntityId")
infix fun <T, S : Number, ID : EntityID<T>?> ExpressionWithColumnType<ID>.mod(other: S) where T : Number, T : Comparable<T> = this % other

/** Calculates the remainder of dividing [this] number expression by [other] numeric PK */
@JvmName("modWithEntityId2")
infix fun <T, S : Number, ID : EntityID<T>?> Expression<S>.mod(other: ExpressionWithColumnType<ID>) where T : Number, T : Comparable<T> = this % other

/** Calculates the remainder of dividing the value of [this] numeric PK by the [other] number expression. */
@JvmName("modWithEntityId3")
infix fun <T, S : Number, ID : EntityID<T>?> ExpressionWithColumnType<ID>.mod(other: Expression<S>) where T : Number, T : Comparable<T> =
ModOp(this, other)

/**
* Performs a bitwise `and` on this expression and [t].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.jetbrains.exposed.sql.tests.shared.functions

import org.jetbrains.exposed.crypt.Algorithms
import org.jetbrains.exposed.crypt.Encryptor
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.IntIdTable
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.Function
Expand Down Expand Up @@ -69,51 +68,57 @@ class FunctionsTests : DatabaseTestsBase() {
}
}


@Test
fun `rem on numeric PK should work`() {
// Create a new table here, since the other tables don't define PK
val table = object : IntIdTable("test_mod_on_pk") {

val otherColumn = short("other")
}
withTables(table) {
repeat(10) {
repeat(5) {
table.insert {

it[otherColumn] = 4
}
}

val modOnPK = Expression.build { table.id % 3 }.alias("shard")
val modOnPK1 = Expression.build { table.id % 3 }.alias("shard1")
val modOnPK2 = Expression.build { table.id % intLiteral(3) }.alias("shard2")
val modOnPK3 = Expression.build { table.id % table.otherColumn }.alias("shard3")
val modOnPK4 = Expression.build { table.otherColumn % table.id }.alias("shard4")

val r = (table).slice(table.id, modOnPK)
.selectAll().groupBy(table.id).orderBy(table.id).toList()
val r = table.slice(table.id, modOnPK1, modOnPK2, modOnPK3, modOnPK4).selectAll().last()

val shardedPK: EntityID<Int> = r[1][modOnPK]
assertEquals(10, r.size)
assertEquals(2, shardedPK.value)
assertEquals(2, r[modOnPK1])
assertEquals(2, r[modOnPK2])
assertEquals(1, r[modOnPK3])
assertEquals(4, r[modOnPK4])
}
}

@Test
fun `mod on numeric PK should work`() {
// Create a new table here, since the other tables don't define PK
val table = object : IntIdTable("test_mod_on_pk") {

val otherColumn = short("other")
}
withTables(table) {
repeat(10) {
repeat(5) {
table.insert {

it[otherColumn] = 4
}
}

val modOnPK = Expression.build { table.id mod 3 }.alias("shard")
val modOnPK1 = Expression.build { table.id mod 3 }.alias("shard1")
val modOnPK2 = Expression.build { table.id mod intLiteral(3) }.alias("shard2")
val modOnPK3 = Expression.build { table.id mod table.otherColumn }.alias("shard3")
val modOnPK4 = Expression.build { table.otherColumn mod table.id }.alias("shard4")

val r = (table).slice(table.id, modOnPK)
.selectAll().groupBy(table.id).orderBy(table.id).toList()
val r = table.slice(table.id, modOnPK1, modOnPK2, modOnPK3, modOnPK4).selectAll().last()

val shardedPK: EntityID<Int> = r[0][modOnPK]
assertEquals(10, r.size)
assertEquals(1, shardedPK.value)
assertEquals(2, r[modOnPK1])
assertEquals(2, r[modOnPK2])
assertEquals(1, r[modOnPK3])
assertEquals(4, r[modOnPK4])
}
}

Expand Down Expand Up @@ -310,7 +315,7 @@ class FunctionsTests : DatabaseTestsBase() {
}
}

@Test
@Test
fun testRegexp01() {
withCitiesAndUsers(listOf(TestDB.SQLITE, TestDB.SQLSERVER, TestDB.H2_SQLSERVER)) { _, users, _ ->
assertEquals(2L, users.select { users.id regexp "a.+" }.count())
Expand All @@ -320,7 +325,7 @@ class FunctionsTests : DatabaseTestsBase() {
}
}

@Test
@Test
fun testRegexp02() {
withCitiesAndUsers(listOf(TestDB.SQLITE, TestDB.SQLSERVER, TestDB.H2_SQLSERVER)) { _, users, _ ->
assertEquals(2L, users.select { users.id.regexp(stringLiteral("a.+")) }.count())
Expand Down

0 comments on commit d5d817f

Please sign in to comment.