Skip to content

Commit

Permalink
support for type overrides with JDBC, #32
Browse files Browse the repository at this point in the history
  • Loading branch information
Miha-x64 committed May 29, 2020
1 parent 1efd8c5 commit 80d3112
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 79 deletions.
Expand Up @@ -137,9 +137,6 @@ sealed class DataType<T> {
open fun storeAsString(value: T): CharSequence =
throw UnsupportedOperationException()

/*open val sqlType: CharSequence? get() = null
open fun storeSqlObject(value: T): Any? = throw UnsupportedOperationException()*/

override val type: Simple<T> get() = this
}

Expand Down
122 changes: 97 additions & 25 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/blocking/JdbcSession.kt
@@ -1,6 +1,8 @@
package net.aquadc.persistence.sql.blocking

import net.aquadc.persistence.array
import net.aquadc.persistence.fatAsList
import net.aquadc.persistence.fatMapTo
import net.aquadc.persistence.sql.Dao
import net.aquadc.persistence.sql.ExperimentalSql
import net.aquadc.persistence.sql.Fetch
Expand All @@ -17,14 +19,16 @@ import net.aquadc.persistence.sql.bindInsertionParams
import net.aquadc.persistence.sql.bindQueryParams
import net.aquadc.persistence.sql.bindValues
import net.aquadc.persistence.sql.dialect.Dialect
import net.aquadc.persistence.sql.flattened
import net.aquadc.persistence.sql.dialect.foldArrayType
import net.aquadc.persistence.sql.mapIndexedToArray
import net.aquadc.persistence.sql.noOrder
import net.aquadc.persistence.struct.Schema
import net.aquadc.persistence.struct.Struct
import net.aquadc.persistence.type.AnyCollection
import net.aquadc.persistence.type.DataType
import net.aquadc.persistence.type.Ilk
import net.aquadc.persistence.type.i64
import net.aquadc.persistence.type.serialized
import org.intellij.lang.annotations.Language
import java.sql.Connection
import java.sql.PreparedStatement
Expand Down Expand Up @@ -218,14 +222,22 @@ class JdbcSession(
private fun <T> Ilk<T, *>.bind(statement: PreparedStatement, index: Int, value: T) {
val i = 1 + index
val custom = this.custom
if (custom == null) {
(type as DataType<T>).flattened { isNullable, simple ->
if (custom != null) {
statement.setObject(i, custom.invoke(value))
} else {
val t = type as DataType<T>
val type = if (t is DataType.Nullable<*, *>) {
if (value == null) {
check(isNullable)
statement.setNull(i, Types.NULL)
} else {
val v = simple.store(value)
when (simple.kind) {
return
}
t.actualType as DataType.NotNull<T>
} else type as DataType.NotNull<T>

when (type) {
is DataType.NotNull.Simple -> {
val v = type.store(value)
when (type.kind) {
DataType.NotNull.Simple.Kind.Bool -> statement.setBoolean(i, v as Boolean)
DataType.NotNull.Simple.Kind.I32 -> statement.setInt(i, v as Int)
DataType.NotNull.Simple.Kind.I64 -> statement.setLong(i, v as Long)
Expand All @@ -236,38 +248,98 @@ class JdbcSession(
DataType.NotNull.Simple.Kind.Blob -> statement.setObject(i, v as ByteArray)
}//.also { }
}
is DataType.NotNull.Collect<T, *, *> -> {
foldArrayType(
dialect.hasArraySupport, type.elementType,
{ nullable, elT ->
statement.setArray(i,
connection.createArrayOf(
jdbcElType(type.elementType),
toArray(type.store(value), nullable, elT)
)
)
},
{
statement.setObject(i, serialized(type).store(value))
}
)
}
is DataType.NotNull.Partial<T, *> -> {
throw AssertionError() // 🤔 btw, Oracle supports Struct type
}
}
} else {
statement.setObject(i, custom.invoke(value), Types.OTHER)
}
}
private fun jdbcElType(t: DataType<*>): String = when (t) {
is DataType.Nullable<*, *> -> jdbcElType(t.actualType)
is DataType.NotNull.Simple -> dialect.nameOf(t.kind)
is DataType.NotNull.Collect<*, *, *> -> jdbcElType(t.elementType)
is DataType.NotNull.Partial<*, *> -> dialect.nameOf(DataType.NotNull.Simple.Kind.Blob)
}
private fun <T> toArray(value: AnyCollection, nullable: Boolean, elT: DataType.NotNull.Simple<T>): Array<out Any?> =
(value.fatAsList() as List<T?>).let { value ->
Array<Any?>(value.size) {
val el = value[it]
if (el == null) check(nullable).let { null }
else elT.store(el)
}
}

@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
private /*wannabe inline*/ fun <T> Ilk<T, *>.get(resultSet: ResultSet, index: Int): T {
return get1indexed(resultSet, 1 + index)
}

private fun <T> Ilk<T, *>.get1indexed(resultSet: ResultSet, i: Int): T = custom.let { custom ->
if (custom == null) {
(type as DataType<T>).flattened { isNullable, simple ->
val v = when (simple.kind) {
DataType.NotNull.Simple.Kind.Bool -> resultSet.getBoolean(i)
DataType.NotNull.Simple.Kind.I32 -> resultSet.getInt(i)
DataType.NotNull.Simple.Kind.I64 -> resultSet.getLong(i)
DataType.NotNull.Simple.Kind.F32 -> resultSet.getFloat(i)
DataType.NotNull.Simple.Kind.F64 -> resultSet.getDouble(i)
DataType.NotNull.Simple.Kind.Str -> resultSet.getString(i)
DataType.NotNull.Simple.Kind.Blob -> resultSet.getBytes(i)
else -> throw AssertionError()
if (custom != null) {
custom.back(resultSet.getObject(i))
} else {
val t = type as DataType<T>
val nullable: Boolean
val type =
if (t is DataType.Nullable<*, *>) { nullable = true; t.actualType as DataType.NotNull<T> }
else { nullable = false; type as DataType.NotNull<T> }
when (type) {
is DataType.NotNull.Simple -> {
val v = when (type.kind) {
DataType.NotNull.Simple.Kind.Bool -> resultSet.getBoolean(i)
DataType.NotNull.Simple.Kind.I32 -> resultSet.getInt(i)
DataType.NotNull.Simple.Kind.I64 -> resultSet.getLong(i)
DataType.NotNull.Simple.Kind.F32 -> resultSet.getFloat(i)
DataType.NotNull.Simple.Kind.F64 -> resultSet.getDouble(i)
DataType.NotNull.Simple.Kind.Str -> resultSet.getString(i)
DataType.NotNull.Simple.Kind.Blob -> resultSet.getBytes(i)
else -> throw AssertionError()
}
// must check, will get zeroes otherwise
if (resultSet.wasNull()) check(nullable).let { null as T }
else type.load(v)
}
is DataType.NotNull.Collect<T, *, *> -> {
foldArrayType(dialect.hasArraySupport, type.elementType,
{ nullable, elT ->
val arr = resultSet.getArray(i)
if (resultSet.wasNull()) { check(nullable); null as T }
else fromArray(type, arr.array as Array<out Any?>, nullable, elT)
},
{
val obj = resultSet.getObject(i)
if (resultSet.wasNull()) { check(nullable); null as T }
else serialized(type).load(obj)
}
)
}
is DataType.NotNull.Partial<T, *> -> {
throw AssertionError()
}
// must check, will get zeroes otherwise
if (resultSet.wasNull()) check(isNullable).let { null as T }
else simple.load(v)
}
} else {
custom.back(resultSet.getObject(i))
}
}
private fun <T> fromArray(type: DataType.NotNull.Collect<T, *, *>, value: AnyCollection, nullable: Boolean, elT: DataType.NotNull.Simple<*>): T =
type.load(value.fatMapTo(ArrayList<Any?>()) { it: Any? ->
if (it == null) { check(nullable); null } else elT.load(it)
})


override fun <T> cell(
query: String, argumentTypes: Array<out Ilk<*, DataType.NotNull<*>>>, arguments: Array<out Any>, type: Ilk<T, *>, orElse: () -> T
Expand Down
Expand Up @@ -16,9 +16,10 @@ import java.io.ByteArrayOutputStream
import java.io.DataOutputStream

@RestrictTo(RestrictTo.Scope.LIBRARY)
/*internal*/ open class BaseDialect(
/*wannabe internal*/ open class BaseDialect(
private val types: InlineEnumMap<DataType.NotNull.Simple.Kind, String>,
private val truncate: String
private val truncate: String,
private val arrayPostfix: String?
) : Dialect {

override fun <SCH : Schema<SCH>> insert(table: Table<SCH, *>): String = buildString {
Expand Down Expand Up @@ -128,7 +129,7 @@ import java.io.DataOutputStream
sb.appendName(colNames[i]).append(' ')
.let {
val t = colTypes[i]
if (t is DataType<*>) it.appendNameOf(t)
if (t is DataType<*>) it.appendTwN(t)
else it.append(t as CharSequence)
}

Expand All @@ -142,7 +143,7 @@ import java.io.DataOutputStream
return sb.append(");").toString()
}
protected open fun StringBuilder.appendPkType(type: DataType.NotNull.Simple<*>, managed: Boolean): StringBuilder =
appendNameOf(type) // used by SQLite, overridden for Postgres
appendTwN(type) // used by SQLite, overridden for Postgres

private fun <T> StringBuilder.appendDefault(type: DataType<T>, default: T) {
val type = if (type is DataType.Nullable<*, *>) {
Expand Down Expand Up @@ -172,25 +173,37 @@ import java.io.DataOutputStream
}
}

protected fun <T> StringBuilder.appendNameOf(dataType: DataType<T>) = apply {
val act = if (dataType is DataType.Nullable<*, *>) dataType.actualType else dataType
when (act) {
is DataType.Nullable<*, *> -> throw AssertionError()
is DataType.NotNull.Simple<*> -> append(types[act.kind]!!)
is DataType.NotNull.Collect<*, *, *> -> append(types[DataType.NotNull.Simple.Kind.Blob]!!)
is DataType.NotNull.Partial<*, *> -> throw UnsupportedOperationException() // column can't be of Partial type at this point
}
if (dataType === act) {
append(' ').append("NOT NULL")
}
/** Appends type along with its non-nullability */
protected fun StringBuilder.appendTwN(dataType: DataType<*>): StringBuilder {
val nn = dataType is DataType.NotNull<*>
return appendTnN(if (nn) dataType as DataType.NotNull else (dataType as DataType.Nullable<*, *>).actualType)
.appendIf(nn, ' ', "NOT NULL")
}

/** Appends type without its nullability info, i. e. like it is nullable. */
private fun StringBuilder.appendTnN(dataType: DataType.NotNull<*>): StringBuilder = when (dataType) {
is DataType.NotNull.Simple<*> -> append(nameOf(dataType.kind))
is DataType.NotNull.Collect<*, *, *> -> appendTArray(dataType.elementType)
is DataType.NotNull.Partial<*, *> -> throw UnsupportedOperationException() // column can't be of Partial type at this point
}
private fun StringBuilder.appendTArray(elementType: DataType<*>): StringBuilder =
foldArrayType(arrayPostfix != null, elementType,
{ _, elT -> appendTnN(elT).append(arrayPostfix) },
//^ all array elements are nullable in Postgres, there's nothing we can do about it.
// Are there any databases which work another way?
{ append(nameOf(DataType.NotNull.Simple.Kind.Blob)) }
)


/**
* {@implNote SQLite does not have TRUNCATE statement}
*/
override fun truncate(table: Table<*, *>): String =
buildString(13 + table.name.length) {
append(truncate).append(' ').appendName(table.name)
}

override val hasArraySupport: Boolean
get() = arrayPostfix != null

override fun nameOf(kind: DataType.NotNull.Simple.Kind): String =
types[kind]!!

}
11 changes: 11 additions & 0 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/dialect/Dialect.kt
Expand Up @@ -4,6 +4,7 @@ import net.aquadc.persistence.struct.Schema
import net.aquadc.persistence.sql.Order
import net.aquadc.persistence.sql.Table
import net.aquadc.persistence.sql.WhereCondition
import net.aquadc.persistence.type.DataType

/**
* Represents an SQL dialect. Provides functions for building queries.
Expand Down Expand Up @@ -65,4 +66,14 @@ interface Dialect {
*/
fun truncate(table: Table<*, *>): String

/**
* Whether database has support for arrays.
*/
val hasArraySupport: Boolean

/**
* Figures out simple name of a primitive type.
*/
fun nameOf(kind: DataType.NotNull.Simple.Kind): String

}
25 changes: 25 additions & 0 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/dialect/common.kt
@@ -1,5 +1,7 @@
package net.aquadc.persistence.sql.dialect

import net.aquadc.persistence.type.DataType


internal fun StringBuilder.appendPlaceholders(count: Int): StringBuilder {
if (count == 0) return this
Expand All @@ -25,3 +27,26 @@ internal inline fun StringBuilder.appendIf(cond: Boolean, what: String): StringB

internal inline fun StringBuilder.appendIf(cond: Boolean, what: Char): StringBuilder =
if (cond) append(what) else this

internal inline fun StringBuilder.appendIf(cond: Boolean, what1: Char, what2: String): StringBuilder =
if (cond) append(what1).append(what2) else this

internal inline fun <R> foldArrayType(
hasArraySupport: Boolean,
elementType: DataType<*>,
ifAppropriate: (nullable: Boolean, actualElementType: DataType.NotNull.Simple<*>) -> R,
ifNot: () -> R
): R {
val nullable: Boolean = elementType is DataType.Nullable<*, *>
val actualElementType: DataType.NotNull<*> = // damn. I really miss Java assignment as expression
if (nullable) (elementType as DataType.Nullable<*, *>).actualType
else elementType as DataType.NotNull<*>

// arrays of arrays or structs are still serialized.
// PostgreSQL multidimensional arrays are actually matrices
// which is kinda weird surprise and inappropriate constraint.
if (hasArraySupport && actualElementType is DataType.NotNull.Simple)
return ifAppropriate(nullable, actualElementType)
else
return ifNot()
}
Expand Up @@ -23,12 +23,13 @@ val PostgresDialect: Dialect = object : BaseDialect(
DataType.NotNull.Simple.Kind.Str, "text",
DataType.NotNull.Simple.Kind.Blob, "bytea"
),
truncate = "TRUNCATE TABLE"
truncate = "TRUNCATE TABLE",
arrayPostfix = "[]"
) {
private val serial = DataType.NotNull.Simple.Kind.I32 + DataType.NotNull.Simple.Kind.I64
override fun StringBuilder.appendPkType(type: DataType.NotNull.Simple<*>, managed: Boolean): StringBuilder =
// If PK column is 'managed', we just take `structToInsert[pkField]`.
if (managed || type.kind !in serial) appendNameOf(type)
// If PK column is 'managed', we just take `structToInsert[pkField]`. todo unique constraint
if (managed || type.kind !in serial) appendTwN(type)
// Otherwise its our responsibility to make PK auto-generated
else append("serial")
.appendIf(type.kind == DataType.NotNull.Simple.Kind.I64, '8')
Expand Down
Expand Up @@ -20,5 +20,6 @@ object SqliteDialect : BaseDialect(
DataType.NotNull.Simple.Kind.Str, "TEXT",
DataType.NotNull.Simple.Kind.Blob, "BLOB"
),
truncate = "DELETE FROM"
truncate = "DELETE FROM", // SQLite does not have TRUNCATE statement
arrayPostfix = null // no array support
)
2 changes: 1 addition & 1 deletion sql/src/test/kotlin/net/aquadc/persistence/sql/database.kt
Expand Up @@ -132,4 +132,4 @@ fun session(dialect: Dialect, url: String): JdbcSession =
val stmt = conn.createStatement()
TestTables.forEach { stmt.execute(dialect.createTable(it, temporary = true)) }
stmt.close()
}, SqliteDialect)
}, dialect)

0 comments on commit 80d3112

Please sign in to comment.