Skip to content

Commit

Permalink
Correct comparison of defaults for String type columns in PostgreSQL (#…
Browse files Browse the repository at this point in the history
…1589) / Oracle threats '' as NULL, PSQL fails on comparing non-string defaults
  • Loading branch information
Tapac committed Nov 14, 2022
1 parent dfb63d2 commit 8cad9f1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,33 +145,42 @@ object SchemaUtils {
@Suppress("NestedBlockDepth", "ComplexMethod")
private fun DataTypeProvider.dbDefaultToString(column: Column<*>, exp: Expression<*>): String {
return when (exp) {
is LiteralOp<*> -> when (exp.value) {
is Boolean -> when (currentDialect) {
is MysqlDialect -> if (exp.value) "1" else "0"
is PostgreSQLDialect -> exp.value.toString()
else -> booleanToStatementString(exp.value)
}
is String -> when (currentDialect) {
is PostgreSQLDialect ->
when(column.columnType) {
is VarCharColumnType -> "'${exp.value}'::character varying"
is TextColumnType -> "'${exp.value}'::text"
else -> processForDefaultValue(exp)
is LiteralOp<*> -> {
val dialect = currentDialect
when (val value = exp.value) {
is Boolean -> when (dialect) {
is MysqlDialect -> if (value) "1" else "0"
is PostgreSQLDialect -> value.toString()
else -> booleanToStatementString(value)
}
is String -> when {
dialect is PostgreSQLDialect ->
when(column.columnType) {
is VarCharColumnType -> "'${value}'::character varying"
is TextColumnType -> "'${value}'::text"
else -> processForDefaultValue(exp)
}
dialect is OracleDialect || dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle ->
when {
column.columnType is VarCharColumnType && value == "" -> "NULL"
column.columnType is TextColumnType && value == "" -> "NULL"
else -> value
}
else -> value
}
is Enum<*> -> when (exp.columnType) {
is EnumerationNameColumnType<*> -> when (dialect) {
is PostgreSQLDialect -> "'${value.name}'::character varying"
else -> value.name
}
else -> exp.value
}
is Enum<*> -> when (exp.columnType) {
is EnumerationNameColumnType<*> -> when (currentDialect) {
is PostgreSQLDialect -> "'${exp.value.name}'::character varying"
else -> exp.value.name
else -> processForDefaultValue(exp)
}
is BigDecimal -> when (dialect) {
is MysqlDialect -> value.setScale((exp.columnType as DecimalColumnType).scale).toString()
else -> processForDefaultValue(exp)
}
else -> processForDefaultValue(exp)
}
is BigDecimal -> when (currentDialect) {
is MysqlDialect -> exp.value.setScale((exp.columnType as DecimalColumnType).scale).toString()
else -> processForDefaultValue(exp)
}
else -> processForDefaultValue(exp)
}
else -> processForDefaultValue(exp)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)
dialect is OracleDialect || h2Mode == H2CompatibilityMode.Oracle -> defaultValue.trim().trim('\'')
dialect is MysqlDialect || h2Mode == H2CompatibilityMode.MySQL || h2Mode == H2CompatibilityMode.MariaDB ->
defaultValue.substringAfter("b'").trim('\'')
dialect is PostgreSQLDialect || h2Mode == H2CompatibilityMode.PostgreSQL -> defaultValue
dialect is PostgreSQLDialect || h2Mode == H2CompatibilityMode.PostgreSQL -> when {
defaultValue.startsWith('\'') && defaultValue.endsWith('\'') -> defaultValue.trim('\'')
else -> defaultValue
}
else -> defaultValue.trim('\'')
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.junit.AssumptionViolatedException
import org.testcontainers.containers.MySQLContainer
import org.testcontainers.containers.PostgreSQLContainer
import java.sql.Connection
import java.sql.SQLException
import java.time.Duration
import java.util.*
import kotlin.concurrent.thread
Expand Down Expand Up @@ -199,11 +200,16 @@ abstract class DatabaseTestsBase {
}

val database = dbSettings.db!!

transaction(database.transactionManager.defaultIsolationLevel, 1, db = database) {
registerInterceptor(CurrentTestDBInterceptor)
currentTestDB = dbSettings
statement(dbSettings)
try {
transaction(database.transactionManager.defaultIsolationLevel, 1, db = database) {
registerInterceptor(CurrentTestDBInterceptor)
currentTestDB = dbSettings
statement(dbSettings)
}
} catch (e: SQLException) {
throw e
} catch (e: Exception) {
throw Exception("Failed on ${dbSettings.name}", e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ class InsertTests : DatabaseTestsBase() {
TestTable.insert { it[foo] = 1 }
TestTable.insert { it[foo] = 0 }
}
fail("Should fail on constraint > 0")
fail("Should fail on constraint > 0 with $db")
} catch (_: SQLException) {
// expected
}
Expand Down

0 comments on commit 8cad9f1

Please sign in to comment.