Skip to content

Commit

Permalink
Propogate config changes from DataSourceService (cashapp#1134)
Browse files Browse the repository at this point in the history
  • Loading branch information
keeferrourke committed Aug 6, 2019
1 parent 0a4a4c3 commit 4ee2b86
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 61 deletions.
Expand Up @@ -4,9 +4,12 @@ import com.squareup.moshi.Moshi
import misk.ServiceModule
import misk.inject.KAbstractModule
import misk.inject.asSingleton
import misk.inject.keyOf
import misk.inject.toKey
import misk.jdbc.DataSourceConfig
import misk.jdbc.DataSourceConnector
import misk.jdbc.DataSourceDecorator
import misk.jdbc.DataSourceService
import misk.jdbc.DataSourceType
import misk.jdbc.TruncateTablesService
import misk.jdbc.VitessScaleSafetyChecks
Expand Down Expand Up @@ -35,22 +38,20 @@ class HibernateTestingModule(
override fun configure() {
val truncateTablesServiceKey = TruncateTablesService::class.toKey(qualifier)

val configKey = DataSourceConfig::class.toKey(qualifier)
val configProvider = getProvider(configKey)

val transacterKey = Transacter::class.toKey(qualifier)
val transacterProvider = getProvider(transacterKey)

if ((config == null || config.type == DataSourceType.VITESS)) {
bindVitessChecks(transacterProvider)
}

val dataSourceConnector = getProvider(keyOf<DataSourceConnector>(qualifier))
install(ServiceModule(truncateTablesServiceKey)
.dependsOn<SchemaMigratorService>(qualifier))
bind(truncateTablesServiceKey).toProvider(Provider {
TruncateTablesService(
qualifier = qualifier,
config = configProvider.get(),
connector = dataSourceConnector.get(),
transacterProvider = transacterProvider,
startUpStatements = startUpStatements,
shutDownStatements = shutDownStatements
Expand Down
Expand Up @@ -23,7 +23,7 @@ private val logger = getLogger<TruncateTablesService>()
*/
class TruncateTablesService(
private val qualifier: KClass<out Annotation>,
private val config: DataSourceConfig,
private val connector: DataSourceConnector,
private val transacterProvider: Provider<Transacter>,
private val startUpStatements: List<String> = listOf(),
private val shutDownStatements: List<String> = listOf()
Expand All @@ -45,6 +45,7 @@ class TruncateTablesService(
val truncatedTableNames = transacterProvider.get().shards().flatMap { shard ->
transacterProvider.get().transaction(shard) { session ->
session.withoutChecks {
val config = connector.config()
val tableNamesQuery = when (config.type) {
DataSourceType.MYSQL -> {
"SELECT table_name FROM information_schema.tables where table_schema='${config.database}'"
Expand Down
Expand Up @@ -24,9 +24,9 @@ internal class TruncateTablesServiceTest {
@MiskTestModule
val module = TestModule()

@Inject @TestDatasource lateinit var config: DataSourceConfig
@Inject @TestDatasource lateinit var sessionFactory: SessionFactory
@Inject @TestDatasource lateinit var transacter: Transacter
@Inject @TestDatasource lateinit var connector: DataSourceConnector

@BeforeEach
internal fun setUp() {
Expand Down Expand Up @@ -56,8 +56,7 @@ internal class TruncateTablesServiceTest {
assertThat(rowCount("movies")).isGreaterThan(0)

// Start up TruncateTablesService. The inserted data should be truncated.
val service = TruncateTablesService(TestDatasource::class, config,
Providers.of(transacter))
val service = TruncateTablesService(TestDatasource::class, connector, Providers.of(transacter))
service.startAsync()
service.awaitRunning()
assertThat(rowCount("schema_version")).isGreaterThan(0)
Expand All @@ -68,9 +67,10 @@ internal class TruncateTablesServiceTest {
fun startUpStatements() {
val service = TruncateTablesService(
TestDatasource::class,
config,
connector,
Providers.of(transacter),
startUpStatements = listOf("INSERT INTO movies (name) VALUES ('Star Wars')"))
startUpStatements = listOf("INSERT INTO movies (name) VALUES ('Star Wars')")
)

assertThat(rowCount("movies")).isEqualTo(0)
service.startAsync()
Expand All @@ -82,9 +82,10 @@ internal class TruncateTablesServiceTest {
fun shutDownStatements() {
val service = TruncateTablesService(
TestDatasource::class,
config,
connector,
Providers.of(transacter),
shutDownStatements = listOf("INSERT INTO movies (name) VALUES ('Star Wars')"))
shutDownStatements = listOf("INSERT INTO movies (name) VALUES ('Star Wars')")
)

service.startAsync()
service.awaitRunning()
Expand All @@ -109,10 +110,8 @@ internal class TruncateTablesServiceTest {
install(MiskTestingServiceModule())

val config = MiskConfig.load<TestConfig>("test_truncatetables_app", Environment.TESTING)
install(HibernateModule(
TestDatasource::class, config.data_source))
install(object : HibernateEntityModule(
TestDatasource::class) {
install(HibernateModule(TestDatasource::class, config.data_source))
install(object : HibernateEntityModule(TestDatasource::class) {
override fun configureHibernate() {
}
})
Expand All @@ -124,4 +123,4 @@ internal class TruncateTablesServiceTest {
@Qualifier
@Target(AnnotationTarget.FIELD, AnnotationTarget.FUNCTION)
annotation class TestDatasource
}
}
45 changes: 31 additions & 14 deletions misk-hibernate/src/main/kotlin/misk/hibernate/HibernateModule.kt
Expand Up @@ -11,6 +11,7 @@ import misk.inject.keyOf
import misk.inject.setOfType
import misk.inject.toKey
import misk.jdbc.DataSourceConfig
import misk.jdbc.DataSourceConnector
import misk.jdbc.DataSourceDecorator
import misk.jdbc.DataSourceService
import misk.jdbc.DatabasePool
Expand Down Expand Up @@ -86,8 +87,9 @@ class HibernateModule(
// Bind DataSourceService.
val dataSourceDecoratorsKey = setOfType(DataSourceDecorator::class).toKey(qualifier)
val dataSourceDecoratorsProvider = getProvider(dataSourceDecoratorsKey)
bind(keyOf<DataSource>(qualifier)).toProvider(
keyOf<DataSourceService>(qualifier)).asSingleton()
bind(keyOf<DataSource>(qualifier))
.toProvider(keyOf<DataSourceService>(qualifier))
.asSingleton()
bind(keyOf<DataSourceService>(qualifier)).toProvider(object : Provider<DataSourceService> {
@com.google.inject.Inject(optional = true) var metrics: Metrics? = null
override fun get() = DataSourceService(
Expand All @@ -99,6 +101,9 @@ class HibernateModule(
metrics = metrics
)
}).asSingleton()
val dataSourceServiceProvider = getProvider(keyOf<DataSourceService>(qualifier))
bind(keyOf<DataSourceConnector>(qualifier)).toProvider(dataSourceServiceProvider)
val connectorProvider = getProvider(keyOf<DataSourceConnector>(qualifier))
install(ServiceModule<DataSourceService>(qualifier)
.dependsOn<PingDatabaseService>(qualifier))

Expand All @@ -110,8 +115,12 @@ class HibernateModule(

bind(schemaMigratorKey).toProvider(object : Provider<SchemaMigrator> {
@Inject lateinit var resourceLoader: ResourceLoader
override fun get(): SchemaMigrator = SchemaMigrator(qualifier, resourceLoader,
transacterProvider, config)
override fun get(): SchemaMigrator = SchemaMigrator(
qualifier = qualifier,
resourceLoader = resourceLoader,
transacter = transacterProvider,
connector = connectorProvider.get()
)
}).asSingleton()
bind(transacterKey).toProvider(object : Provider<Transacter> {
@com.google.inject.Inject(optional = true) val tracer: Tracer? = null
Expand All @@ -132,15 +141,15 @@ class HibernateModule(
qualifier = qualifier,
environment = environment,
schemaMigratorProvider = schemaMigratorProvider,
config = config
connectorProvider = connectorProvider
)
}).asSingleton()
multibind<HealthCheck>().to(schemaMigratorServiceKey)
install(ServiceModule<SchemaMigratorService>(qualifier))
install(ServiceModule<SchemaMigratorService>(qualifier)
.dependsOn<DataSourceService>(qualifier))

// Bind SchemaValidatorService.
val sessionFactoryServiceProvider = getProvider(
keyOf<SessionFactoryService>(qualifier))
val sessionFactoryServiceProvider = getProvider(keyOf<SessionFactoryService>(qualifier))
val schemaValidatorServiceKey = keyOf<SchemaValidatorService>(qualifier)
bind(schemaValidatorServiceKey)
.toProvider(Provider {
Expand All @@ -166,9 +175,14 @@ class HibernateModule(
.asSingleton()
bind(keyOf<TransacterService>(qualifier)).to(keyOf<SessionFactoryService>(qualifier))
bind(keyOf<SessionFactoryService>(qualifier)).toProvider(Provider {
SessionFactoryService(qualifier, config, dataSourceProvider,
hibernateInjectorAccessProvider.get(),
entitiesProvider.get(), eventListenersProvider.get())
SessionFactoryService(
qualifier = qualifier,
connector = dataSourceServiceProvider.get(),
dataSource = dataSourceProvider,
hibernateInjectorAccess = hibernateInjectorAccessProvider.get(),
entityClasses = entitiesProvider.get(),
listenerRegistrations = eventListenersProvider.get()
)
}).asSingleton()
install(ServiceModule<TransacterService>(qualifier)
.enhancedBy<SchemaMigratorService>(qualifier)
Expand Down Expand Up @@ -200,8 +214,11 @@ class HibernateModule(
.asSingleton()
multibind<HealthCheck>().to(healthCheckKey)

install(ExceptionMapperModule.create<RetryTransactionException, RetryTransactionExceptionMapper>())
install(ExceptionMapperModule.create<ConstraintViolationException, ConstraintViolationExceptionMapper>())
install(ExceptionMapperModule.create<OptimisticLockException, OptimisticLockExceptionMapper>())
install(ExceptionMapperModule
.create<RetryTransactionException, RetryTransactionExceptionMapper>())
install(ExceptionMapperModule
.create<ConstraintViolationException, ConstraintViolationExceptionMapper>())
install(ExceptionMapperModule
.create<OptimisticLockException, OptimisticLockExceptionMapper>())
}
}
Expand Up @@ -3,7 +3,7 @@ package misk.hibernate
import com.google.common.annotations.VisibleForTesting
import com.google.common.base.Stopwatch
import com.google.common.collect.ImmutableList
import misk.jdbc.DataSourceConfig
import misk.jdbc.DataSourceConnector
import misk.logging.getLogger
import misk.resources.ResourceLoader
import org.hibernate.query.Query
Expand Down Expand Up @@ -96,10 +96,11 @@ internal class SchemaMigrator(
private val qualifier: KClass<out Annotation>,
private val resourceLoader: ResourceLoader,
private val transacter: Provider<Transacter>,
private val config: DataSourceConfig
private val connector: DataSourceConnector
) {

private fun getMigrationsResources(keyspace: Keyspace): List<String> {
val config = connector.config()
val migrationsResources = ImmutableList.builder<String>()
if (config.migrations_resource != null) {
migrationsResources.add(config.migrations_resource)
Expand Down
Expand Up @@ -5,21 +5,24 @@ import com.google.common.util.concurrent.Service
import misk.environment.Environment
import misk.healthchecks.HealthCheck
import misk.healthchecks.HealthStatus
import misk.jdbc.DataSourceConnector
import misk.jdbc.DataSourceType
import javax.inject.Provider
import kotlin.reflect.KClass

class SchemaMigratorService internal constructor(
private val qualifier: KClass<out Annotation>,
private val environment: Environment,
private val schemaMigratorProvider: javax.inject.Provider<SchemaMigrator>, // Lazy!
private val config: misk.jdbc.DataSourceConfig
private val schemaMigratorProvider: Provider<SchemaMigrator>, // Lazy!
private val connectorProvider: Provider<DataSourceConnector>
) : AbstractIdleService(), HealthCheck {
private lateinit var migrationState: MigrationState

override fun startUp() {
val schemaMigrator = schemaMigratorProvider.get()
val connector = connectorProvider.get()
if (environment == Environment.TESTING || environment == Environment.DEVELOPMENT) {
if (config.type != DataSourceType.VITESS) {
if (connector.config().type != DataSourceType.VITESS) {
val appliedMigrations = schemaMigrator.initialize()
migrationState = schemaMigrator.applyAll("SchemaMigratorService", appliedMigrations)
} else {
Expand Down
Expand Up @@ -2,7 +2,7 @@ package misk.hibernate

import com.google.common.base.Stopwatch
import com.google.common.util.concurrent.AbstractIdleService
import misk.jdbc.DataSourceConfig
import misk.jdbc.DataSourceConnector
import misk.logging.getLogger
import okio.ByteString
import org.hibernate.SessionFactory
Expand Down Expand Up @@ -32,7 +32,7 @@ private val logger = getLogger<SessionFactoryService>()
*/
internal class SessionFactoryService(
private val qualifier: KClass<out Annotation>,
private val config: DataSourceConfig,
private val connector: DataSourceConnector,
private val dataSource: Provider<DataSource>,
private val hibernateInjectorAccess: HibernateInjectorAccess,
private val entityClasses: Set<HibernateEntity> = setOf(),
Expand Down Expand Up @@ -75,6 +75,7 @@ internal class SessionFactoryService(

val registryBuilder = StandardServiceRegistryBuilder(bootstrapRegistryBuilder)
registryBuilder.addInitiator(hibernateInjectorAccess)
val config = connector.config()
registryBuilder.run {
applySetting(AvailableSettings.DATASOURCE, dataSource.get())
applySetting(AvailableSettings.DIALECT, config.type.hibernateDialect)
Expand Down Expand Up @@ -141,8 +142,7 @@ internal class SessionFactoryService(
persistentClass: Class<*>,
property: Property
) {
val value = property.value
if (value !is SimpleValue) return
val value = property.value as? SimpleValue ?: return

val field = field(persistentClass, property)
if (field.isAnnotationPresent(JsonColumn::class.java)) {
Expand Down
@@ -0,0 +1,5 @@
package misk.jdbc

interface DataSourceConnector {
fun config(): DataSourceConfig
}
10 changes: 6 additions & 4 deletions misk-hibernate/src/main/kotlin/misk/jdbc/DataSourceService.kt
Expand Up @@ -20,14 +20,14 @@ import kotlin.reflect.KClass
* the [databasePool] can pick an alternate database name for testing.
*/
@Singleton
internal class DataSourceService(
class DataSourceService(
private val qualifier: KClass<out Annotation>,
private val baseConfig: DataSourceConfig,
private val environment: Environment,
private val dataSourceDecorators: Set<DataSourceDecorator>,
private val databasePool: DatabasePool,
private val metrics: Metrics? = null
) : AbstractIdleService(), Provider<DataSource> {
) : AbstractIdleService(), DataSourceConnector, Provider<DataSource> {
private lateinit var config: DataSourceConfig
/** The backing connection pool */
private var hikariDataSource: HikariDataSource? = null
Expand All @@ -46,7 +46,7 @@ internal class DataSourceService(

private fun createDataSource() {
// Rewrite the caller's config to get a database name like "movies__20190730__5" in tests.
this.config = databasePool.takeDatabase(baseConfig)
config = databasePool.takeDatabase(baseConfig)

val hikariConfig = HikariConfig()
hikariConfig.driverClassName = config.type.driverClassName
Expand Down Expand Up @@ -93,6 +93,8 @@ internal class DataSourceService(
private fun decorate(dataSource: DataSource): DataSource =
dataSourceDecorators.fold(dataSource) { ds, decorator -> decorator.decorate(ds) }

override fun config(): DataSourceConfig = this.config

override fun shutDown() {
val stopwatch = Stopwatch.createStarted()
logger.info("Stopping @${qualifier.simpleName} connection pool")
Expand All @@ -112,4 +114,4 @@ internal class DataSourceService(
companion object {
val logger = getLogger<DataSourceService>()
}
}
}

0 comments on commit 4ee2b86

Please sign in to comment.