Skip to content

Commit

Permalink
PostgreSQL support, #32
Browse files Browse the repository at this point in the history
  • Loading branch information
Miha-x64 committed May 13, 2020
1 parent adbc2f2 commit 5afe7c0
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 77 deletions.
Expand Up @@ -7,11 +7,11 @@ import net.aquadc.persistence.sql.EmbedRelationsTest
import net.aquadc.persistence.sql.QueryBuilderTests
import net.aquadc.persistence.sql.Session
import net.aquadc.persistence.sql.SqlPropTest
import net.aquadc.persistence.sql.blocking.SqliteSession
import net.aquadc.persistence.sql.TemplatesTest
import net.aquadc.persistence.sql.TestTables
import net.aquadc.persistence.sql.dialect.sqlite.SqliteDialect
import net.aquadc.persistence.sql.blocking.Blocking
import net.aquadc.persistence.sql.blocking.SqliteSession
import net.aquadc.persistence.sql.dialect.sqlite.SqliteDialect
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Before
Expand Down
5 changes: 4 additions & 1 deletion sql/build.gradle
Expand Up @@ -14,13 +14,16 @@ artifacts { testOutput testJar }
dependencies {
compileOnly "com.google.android:android:$android_artifact_version"
implementation "org.jetbrains.kotlin:kotlin-stdlib"
implementation "net.aquadc.collections:Collection-utils-jvm:1.0-$collection_utils_version"
implementation project(':properties')
implementation project(':persistence')

compileOnly 'androidx.annotation:annotation:1.1.0'
compileOnly 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.3.3'
compileOnly 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.0.0'

testImplementation group: 'junit', name: 'junit', version: '4.12'
testImplementation 'org.xerial:sqlite-jdbc:3.25.2'
testImplementation 'org.postgresql:postgresql:42.2.12.jre7'
testImplementation project(':etc:testing')
testImplementation project(':extended-persistence')
}
Expand Down
19 changes: 12 additions & 7 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/Table.kt
Expand Up @@ -74,12 +74,12 @@ private constructor(
val recipe = ArrayList<Nesting>()
val ss = Nesting.StructStart(false, null, null, schema)
recipe.add(ss)
embed(rels, schema, null, null, columns, delegates, recipe)
embed(meta, schema, null, null, columns, delegates, recipe)
ss.colCount = columns.size
recipe.add(Nesting.StructEnd)
this._recipe = recipe.array()

if (rels.isNotEmpty()) throw RuntimeException("cannot consume relations: $rels")
if (meta.isNotEmpty()) throw RuntimeException("cannot consume meta: ${meta.values}")

this._delegates = delegates
val colsArray = columns.array()
Expand Down Expand Up @@ -124,22 +124,22 @@ private constructor(

if (relType != null) {
// got a struct type, a relation must be declared
val rel = rels.remove(path)
val meta = metas.remove(path)
?: throw NoSuchElementException("${this@Table} requires a Relation to be declared for path $path storing values of type $relType")

when (meta) {
is ColMeta.Embed<*> -> {
val start = outColumns.size
val fieldSetCol = rel.fieldSetColName?.let { fieldSetColName ->
(rel.naming.concatErased(this.schema, schema, path, FieldSetLens<Schema<*>>(fieldSetColName)) as StoredNamedLens<SCH, out Long?, *>)
val fieldSetCol = meta.fieldSetColName?.let { fieldSetColName ->
(meta.naming.concatErased(this.schema, schema, path, FieldSetLens<Schema<*>>(fieldSetColName)) as StoredNamedLens<SCH, out Long?, *>)
.also { outColumns.add(it, it.name(this.schema)) }
}

val relSchema = relType.schema
val recipeStart = outRecipe.size
val ss = Nesting.StructStart(fieldSetCol != null, field, type, relType)
outRecipe.add(ss)
/*val nestedLenses =*/ embed(rels, relSchema, rel.naming, path, outColumns, null, outRecipe)
/*val nestedLenses =*/ embed(metas, relSchema, meta.naming, path, outColumns, null, outRecipe)
ss.colCount = outColumns.size - start
outRecipe.add(Nesting.StructEnd)

Expand Down Expand Up @@ -179,7 +179,12 @@ private constructor(
get() = _columns.value

val pkColumn: NamedLens<SCH, Record<SCH, ID>, Record<SCH, ID>, ID, out DataType.Simple<ID>>
get() = columns[0] as NamedLens<SCH, Record<SCH, ID>, Record<SCH, ID>, ID, out DataType.Simple<ID>>
get() = _columns.let {
if (it.isInitialized())
it.value[0] as NamedLens<SCH, Record<SCH, ID>, Record<SCH, ID>, ID, out DataType.Simple<ID>>
else
TODO("allow getting PK within meta() without reentrancy")
}

internal val recipe: Array<out Nesting>
get() = _recipe ?: _columns.value.let { _ /* unwrap lazy */ -> _recipe!! }
Expand Down
Expand Up @@ -150,7 +150,10 @@ class SqliteSession(
table.name,
if (columnNames == null) arrayOf("COUNT(*)")
else columnNames.mapIndexedToArray { _, name -> name.toString() },
StringBuilder().appendWhereClause(table, condition).toString(),
StringBuilder().let {
condition.appendSqlTo(table, SqliteDialect, it)
if (it.isEmpty()) null/*todo deallocate SB*/ else it.toString()
},
/*groupBy=*/null,
/*having=*/null,
if (order.isEmpty()) null else StringBuilder().appendOrderClause(table.schema, order).toString(),
Expand Down Expand Up @@ -380,6 +383,6 @@ class SqliteSession(
/**
* Calls [SQLiteDatabase.execSQL] for the given [table] in [this] database.
*/
fun <SCH : Schema<SCH>> SQLiteDatabase.createTable(table: Table<SCH, *>) {
fun SQLiteDatabase.createTable(table: Table<*, *>) {
execSQL(SqliteDialect.createTable(table))
}
@@ -1,11 +1,12 @@
package net.aquadc.persistence.sql.dialect.sqlite
package net.aquadc.persistence.sql.dialect

import androidx.annotation.RestrictTo
import net.aquadc.collections.InlineEnumMap
import net.aquadc.collections.get
import net.aquadc.persistence.sql.Order
import net.aquadc.persistence.sql.Table
import net.aquadc.persistence.sql.WhereCondition
import net.aquadc.persistence.sql.dialect.Dialect
import net.aquadc.persistence.sql.dialect.appendPlaceholders
import net.aquadc.persistence.sql.dialect.appendReplacing
import net.aquadc.persistence.sql.dialect.sqlite.SqliteDialect
import net.aquadc.persistence.sql.noOrder
import net.aquadc.persistence.stream.DataStreams
import net.aquadc.persistence.stream.write
Expand All @@ -14,10 +15,11 @@ import net.aquadc.persistence.type.DataType
import java.io.ByteArrayOutputStream
import java.io.DataOutputStream

/**
* Implements SQLite [Dialect].
*/
object SqliteDialect : Dialect {
@RestrictTo(RestrictTo.Scope.LIBRARY)
/*internal*/ open class BaseDialect(
private val types: InlineEnumMap<DataType.Simple.Kind, String>,
private val truncate: String
) : Dialect {

override fun <SCH : Schema<SCH>> insert(table: Table<SCH, *>): String = buildString {
val cols = table.managedColNames
Expand Down Expand Up @@ -45,7 +47,10 @@ object SqliteDialect : Dialect {
.let { if (columns == null) it.append("COUNT(*)") else it.appendNames(columns) }
.append(" FROM ").appendName(table.name)
.append(" WHERE ")
sb.appendWhereClause(table, condition)

val afterWhere = sb.length
condition.appendSqlTo(table, this@BaseDialect, sb)
sb.length.let { if (it == afterWhere) sb.setLength(it - 7) } // erase " WHERE "

if (order.isNotEmpty())
sb.append(" ORDER BY ").appendOrderClause(table.schema, order)
Expand All @@ -58,9 +63,10 @@ object SqliteDialect : Dialect {
condition: WhereCondition<SCH>
): StringBuilder = apply {
val afterWhere = length
condition.appendSqlTo(context, this@SqliteDialect, this)
condition.appendSqlTo(context, this@BaseDialect, this)

if (length == afterWhere) append('1') // no condition: SELECT "whatever" FROM "somewhere" WHERE 1
// no condition: SELECT "whatever" FROM "somewhere" WHERE true
if (length == afterWhere) append(if (this@BaseDialect === SqliteDialect) "1" else "true")
}

override fun <SCH : Schema<SCH>> StringBuilder.appendOrderClause(
Expand All @@ -73,12 +79,11 @@ object SqliteDialect : Dialect {

override fun <SCH : Schema<SCH>> updateQuery(table: Table<SCH, *>, cols: Array<out CharSequence>): String =
buildString {
append("UPDATE ").appendName(table.name).append(" SET ")
check(cols.isNotEmpty())

cols.forEach { col ->
appendName(col).append(" = ?, ")
}
setLength(length - 2) // assume not empty
append("UPDATE ").appendName(table.name).append(" SET ")
cols.forEach { col -> appendName(col).append(" = ?, ") }
setLength(length - 2)

append(" WHERE ").appendName(table.idColName).append(" = ?;")
}
Expand All @@ -100,14 +105,16 @@ object SqliteDialect : Dialect {
}
}

override fun createTable(table: Table<*, *>): String {
val sb = StringBuilder("CREATE TABLE ").appendName(table.name).append(" (")
.appendName(table.idColName).append(' ').appendNameOf(table.idColType).append(" PRIMARY KEY")
override fun createTable(table: Table<*, *>, temporary: Boolean): String {
val managedPk = table.pkField != null
val sb = StringBuilder("CREATE").append(' ')
.appendIf(temporary, "TEMP ").append("TABLE").append(' ').appendName(table.name).append(" (")
.appendName(table.idColName).append(' ').appendPkType(table.idColType, managedPk).append(" PRIMARY KEY")

val colNames = table.managedColNames
val colTypes = table.managedColTypes

val startIndex = if (table.pkField == null) 0 else 1
val startIndex = if (managedPk) 1 else 0
val endExclusive = colNames.size
if (endExclusive != startIndex) {
sb.append(", ")
Expand All @@ -125,6 +132,10 @@ object SqliteDialect : Dialect {
sb.setLength(sb.length - 2) // trim last comma; schema.fields must not be empty
return sb.append(");").toString()
}
private inline fun StringBuilder.appendIf(cond: Boolean, what: String): StringBuilder =
if (cond) append(what) else this
protected open fun StringBuilder.appendPkType(type: DataType.Simple<*>, managed: Boolean): StringBuilder =
appendNameOf(type) // used by SQLite, overridden for Postgres

private fun <T> StringBuilder.appendDefault(type: DataType<T>, default: T) {
val type = if (type is DataType.Nullable<*, *>) {
Expand Down Expand Up @@ -154,20 +165,11 @@ object SqliteDialect : Dialect {
}
}

private fun <T> StringBuilder.appendNameOf(dataType: DataType<T>) = apply {
protected fun <T> StringBuilder.appendNameOf(dataType: DataType<T>) = apply {
val act = if (dataType is DataType.Nullable<*, *>) dataType.actualType else dataType
when (act) {
is DataType.Nullable<*, *> -> throw AssertionError()
is DataType.Simple<*> -> append(when (act.kind) {
DataType.Simple.Kind.Bool,
DataType.Simple.Kind.I32,
DataType.Simple.Kind.I64 -> "INTEGER"
DataType.Simple.Kind.F32,
DataType.Simple.Kind.F64 -> "REAL"
DataType.Simple.Kind.Str -> "TEXT"
DataType.Simple.Kind.Blob -> "BLOB"
else -> throw AssertionError()
})
is DataType.Simple<*> -> append(types[act.kind]!!)
is DataType.Collect<*, *, *> -> append("BLOB")
is DataType.Partial<*, *> -> throw UnsupportedOperationException() // column can't be of Partial type at this point
}
Expand All @@ -181,7 +183,7 @@ object SqliteDialect : Dialect {
*/
override fun truncate(table: Table<*, *>): String =
buildString(13 + table.name.length) {
append("DELETE FROM").appendName(table.name)
append(truncate).append(' ').appendName(table.name)
}

}
Expand Up @@ -31,6 +31,7 @@ interface Dialect {
/**
* Appends WHERE clause (without WHERE itself) to [this] builder.
*/
@Deprecated("unused by Session")
fun <SCH : Schema<SCH>> StringBuilder.appendWhereClause(context: Table<SCH, *>, condition: WhereCondition<SCH>): StringBuilder

/**
Expand All @@ -57,7 +58,7 @@ interface Dialect {
/**
* Returns an SQL query to create the given [table].
*/
fun createTable(table: Table<*, *>): String
fun createTable(table: Table<*, *>, temporary: Boolean = false): String

/**
* Returns `TRUNCATE` query to clear the whole table.
Expand Down
33 changes: 33 additions & 0 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/dialect/postgres.kt
@@ -0,0 +1,33 @@
@file:JvmName("PostgresDialect")
package net.aquadc.persistence.sql.dialect.postgres

import net.aquadc.collections.enumMapOf
import net.aquadc.persistence.sql.dialect.BaseDialect
import net.aquadc.persistence.sql.dialect.Dialect
import net.aquadc.persistence.type.DataType

/**
* Implements PostgreSQL [Dialect].
*/
@JvmField
val PostgresDialect: Dialect = object : BaseDialect(
enumMapOf(
DataType.Simple.Kind.Bool, "bool",
DataType.Simple.Kind.I32, "int",
DataType.Simple.Kind.I64, "int8",
DataType.Simple.Kind.F32, "real",
DataType.Simple.Kind.F64, "float8",
DataType.Simple.Kind.Str, "text",
DataType.Simple.Kind.Blob, "bytea"
),
truncate = "TRUNCATE TABLE"
) {
override fun StringBuilder.appendPkType(type: DataType.Simple<*>, managed: Boolean): StringBuilder =
if (managed) appendNameOf(type)
else { // If PK column is 'managed', we just take `structToInsert[pkField]`.
// Otherwise, its our responsibility to make PK auto-generated
if (type.kind == DataType.Simple.Kind.I32) append("serial")
else if (type.kind == DataType.Simple.Kind.I64) append("serial8")
else throw UnsupportedOperationException() // wat? Boolean, float, double, string, byte[] primary key? O_o
}
}
24 changes: 24 additions & 0 deletions sql/src/main/kotlin/net/aquadc/persistence/sql/dialect/sqlite.kt
@@ -0,0 +1,24 @@
@file:JvmName("SqliteDialect")
package net.aquadc.persistence.sql.dialect.sqlite

import net.aquadc.collections.enumMapOf
import net.aquadc.persistence.sql.dialect.BaseDialect
import net.aquadc.persistence.sql.dialect.Dialect
import net.aquadc.persistence.type.DataType

/**
* Implements SQLite [Dialect].
*/
// wannabe `@JvmField val` but this breaks compilation. Dafuq?
object SqliteDialect : BaseDialect(
enumMapOf(
DataType.Simple.Kind.Bool, "INTEGER",
DataType.Simple.Kind.I32, "INTEGER",
DataType.Simple.Kind.I64, "INTEGER",
DataType.Simple.Kind.F32, "REAL",
DataType.Simple.Kind.F64, "REAL",
DataType.Simple.Kind.Str, "TEXT",
DataType.Simple.Kind.Blob, "BLOB"
),
truncate = "DELETE FROM"
)
Expand Up @@ -20,9 +20,9 @@ import org.junit.Assert.assertSame
import org.junit.Test


open class EmbedRelationsTest {
abstract class EmbedRelationsTest {

open val session: Session<*> get() = jdbcSession
protected abstract val session: Session<*>

@Test fun embed() {
val rec = session.withTransaction {
Expand Down
Expand Up @@ -8,9 +8,9 @@ import org.junit.Assert.assertEquals
import org.junit.Test


open class QueryBuilderTests {
abstract class QueryBuilderTests {

open val session: Session<*> get() = jdbcSession
protected abstract val session: Session<*>

@Test fun between() {
session.withTransaction {
Expand Down
18 changes: 14 additions & 4 deletions sql/src/test/kotlin/net/aquadc/persistence/sql/SqlPropTest.kt
Expand Up @@ -9,13 +9,13 @@ import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotSame
import org.junit.Assert.assertSame
import org.junit.Assert.fail
import org.junit.AssumptionViolatedException
import org.junit.Test
import java.sql.SQLException


open class SqlPropTest {

open val session: Session<*> get() = jdbcSession
abstract class SqlPropTest {
protected abstract val session: Session<*>

private val someDao get() = session[SomeTable]

Expand Down Expand Up @@ -163,8 +163,9 @@ open class SqlPropTest {
})
}
} catch (e: Exception) {
if (e is AssumptionViolatedException) throw e
if (!duplicatePkExceptionClass.isInstance(e)) {
fail()
fail("expected:<" + duplicatePkExceptionClass.name + "> but was:<" + e.javaClass + ">")
}
}
}
Expand Down Expand Up @@ -202,6 +203,15 @@ open class SqlPropTest {
}
}
private fun Session<*>.createTestRecord() =
withTransaction {
insert(SomeTable, SomeSchema {
it[A] = "first"
it[B] = 2
it[C] = 3
})
}
}
// TODO: .shapshots() should change only one time per transaction

0 comments on commit 5afe7c0

Please sign in to comment.